update recurrent genome
This commit is contained in:
@@ -1,10 +1,21 @@
|
||||
import jax, jax.numpy as jnp
|
||||
|
||||
from tensorneat.algorithm import NEAT
|
||||
from tensorneat.algorithm.neat import DefaultGenome
|
||||
from tensorneat.algorithm.neat import DefaultGenome, RecurrentGenome
|
||||
|
||||
key = jax.random.key(0)
|
||||
genome = DefaultGenome(num_inputs=5, num_outputs=3, max_nodes=100, max_conns=500, init_hidden_layers=())
|
||||
genome = DefaultGenome(num_inputs=5, num_outputs=3, max_nodes=100, max_conns=500, init_hidden_layers=(1, 2 ,3))
|
||||
state = genome.setup()
|
||||
nodes, conns = genome.initialize(state, key)
|
||||
print(genome.repr(state, nodes, conns))
|
||||
|
||||
inputs = jnp.array([1, 2, 3, 4, 5])
|
||||
transformed = genome.transform(state, nodes, conns)
|
||||
outputs = genome.forward(state, transformed, inputs)
|
||||
|
||||
print(outputs)
|
||||
|
||||
network = genome.network_dict(state, nodes, conns)
|
||||
print(network)
|
||||
|
||||
genome.visualize(network)
|
||||
|
||||
16
examples/tmp2.py
Normal file
16
examples/tmp2.py
Normal file
@@ -0,0 +1,16 @@
|
||||
import jax, jax.numpy as jnp
|
||||
|
||||
arr = jnp.ones((10, 10))
|
||||
a = jnp.array([
|
||||
[1, 2, 3],
|
||||
[4, 5, 6]
|
||||
])
|
||||
|
||||
def attach_with_inf(arr, idx):
|
||||
target_dim = arr.ndim + idx.ndim - 1
|
||||
expand_idx = jnp.expand_dims(idx, axis=tuple(range(idx.ndim, target_dim)))
|
||||
|
||||
return jnp.where(expand_idx == 1, jnp.nan, arr[idx])
|
||||
|
||||
b = attach_with_inf(arr, a)
|
||||
print(b)
|
||||
Reference in New Issue
Block a user