update recurrent genome
This commit is contained in:
@@ -34,9 +34,6 @@ class BaseGene(StatefulBaseClass):
|
||||
def forward(self, state, attrs, inputs):
|
||||
raise NotImplementedError
|
||||
|
||||
def update_by_batch(self, state, attrs, batch_inputs):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def length(self):
|
||||
return len(self.fixed_attrs) + len(self.custom_attrs)
|
||||
|
||||
@@ -31,7 +31,7 @@ class BaseGenome(StatefulBaseClass):
|
||||
input_transform: Callable = None,
|
||||
init_hidden_layers: Sequence[int] = (),
|
||||
):
|
||||
|
||||
|
||||
# check transform functions
|
||||
if input_transform is not None:
|
||||
try:
|
||||
@@ -64,7 +64,7 @@ class BaseGenome(StatefulBaseClass):
|
||||
all_init_conns_in_idx.append(in_idx)
|
||||
all_init_conns_out_idx.append(out_idx)
|
||||
all_init_nodes.extend(in_layer)
|
||||
all_init_nodes.extend(layer_indices[-1])
|
||||
all_init_nodes.extend(layer_indices[-1]) # output layer
|
||||
|
||||
if max_nodes < len(all_init_nodes):
|
||||
raise ValueError(
|
||||
@@ -75,7 +75,7 @@ class BaseGenome(StatefulBaseClass):
|
||||
raise ValueError(
|
||||
f"max_conns={max_conns} must be greater than or equal to the number of initial connections={len(all_init_conns_in_idx)}"
|
||||
)
|
||||
|
||||
|
||||
self.num_inputs = num_inputs
|
||||
self.num_outputs = num_outputs
|
||||
self.max_nodes = max_nodes
|
||||
|
||||
@@ -78,31 +78,34 @@ class DefaultGenome(BaseGenome):
|
||||
|
||||
def cond_fun(carry):
|
||||
values, idx = carry
|
||||
return (idx < self.max_nodes) & (cal_seqs[idx] != I_INF)
|
||||
return (idx < self.max_nodes) & (
|
||||
cal_seqs[idx] != I_INF
|
||||
) # not out of bounds and next node exists
|
||||
|
||||
def body_func(carry):
|
||||
values, idx = carry
|
||||
i = cal_seqs[idx]
|
||||
|
||||
def input_node():
|
||||
z = self.node_gene.input_transform(state, nodes_attrs[i], values[i])
|
||||
new_values = values.at[i].set(z)
|
||||
return new_values
|
||||
return values
|
||||
|
||||
def otherwise():
|
||||
# calculate connections
|
||||
conn_indices = u_conns[:, i]
|
||||
hit_attrs = attach_with_inf(conns_attrs, conn_indices)
|
||||
hit_attrs = attach_with_inf(conns_attrs, conn_indices) # fetch conn attrs
|
||||
ins = vmap(self.conn_gene.forward, in_axes=(None, 0, 0))(
|
||||
state, hit_attrs, values
|
||||
)
|
||||
|
||||
# calculate nodes
|
||||
z = self.node_gene.forward(
|
||||
state,
|
||||
nodes_attrs[i],
|
||||
ins,
|
||||
is_output_node=jnp.isin(i, self.output_idx),
|
||||
is_output_node=jnp.isin(nodes[0], self.output_idx), # nodes[0] -> the key of nodes
|
||||
)
|
||||
|
||||
# set new value
|
||||
new_values = values.at[i].set(z)
|
||||
return new_values
|
||||
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
from typing import Callable
|
||||
|
||||
import jax, jax.numpy as jnp
|
||||
import jax
|
||||
from jax import vmap, numpy as jnp
|
||||
from .utils import unflatten_conns
|
||||
|
||||
from . import BaseGenome
|
||||
from .base import BaseGenome
|
||||
from .operations import DefaultMutation, DefaultCrossover, DefaultDistance
|
||||
from .utils import unflatten_conns, extract_node_attrs, extract_conn_attrs
|
||||
from ..gene import DefaultNodeGene, DefaultConnGene
|
||||
from .operations import DefaultMutation, DefaultCrossover
|
||||
|
||||
from tensorneat.common import attach_with_inf
|
||||
|
||||
class RecurrentGenome(BaseGenome):
|
||||
"""Default genome class, with the same behavior as the NEAT-Python"""
|
||||
@@ -17,14 +18,17 @@ class RecurrentGenome(BaseGenome):
|
||||
self,
|
||||
num_inputs: int,
|
||||
num_outputs: int,
|
||||
max_nodes = 50,
|
||||
max_conns = 100,
|
||||
max_nodes=50,
|
||||
max_conns=100,
|
||||
node_gene=DefaultNodeGene(),
|
||||
conn_gene=DefaultConnGene(),
|
||||
mutation=DefaultMutation(),
|
||||
crossover=DefaultCrossover(),
|
||||
distance=DefaultDistance(),
|
||||
output_transform=None,
|
||||
input_transform=None,
|
||||
init_hidden_layers=(),
|
||||
activate_time=10,
|
||||
output_transform: Callable = None,
|
||||
):
|
||||
super().__init__(
|
||||
num_inputs,
|
||||
@@ -35,29 +39,25 @@ class RecurrentGenome(BaseGenome):
|
||||
conn_gene,
|
||||
mutation,
|
||||
crossover,
|
||||
distance,
|
||||
output_transform,
|
||||
input_transform,
|
||||
init_hidden_layers,
|
||||
)
|
||||
self.activate_time = activate_time
|
||||
|
||||
if output_transform is not None:
|
||||
try:
|
||||
_ = output_transform(jnp.zeros(num_outputs))
|
||||
except Exception as e:
|
||||
raise ValueError(f"Output transform function failed: {e}")
|
||||
self.output_transform = output_transform
|
||||
|
||||
def transform(self, state, nodes, conns):
|
||||
u_conns = unflatten_conns(nodes, conns)
|
||||
return nodes, conns, u_conns
|
||||
|
||||
def restore(self, state, transformed):
|
||||
def forward(self, state, transformed, inputs):
|
||||
nodes, conns, u_conns = transformed
|
||||
return nodes, conns
|
||||
|
||||
def forward(self, state, inputs, transformed):
|
||||
nodes, conns = transformed
|
||||
|
||||
vals = jnp.full((self.max_nodes,), jnp.nan)
|
||||
nodes_attrs = nodes[:, 1:] # remove index
|
||||
|
||||
nodes_attrs = vmap(extract_node_attrs)(nodes)
|
||||
conns_attrs = vmap(extract_conn_attrs)(conns)
|
||||
expand_conns_attrs = attach_with_inf(conns_attrs, u_conns)
|
||||
|
||||
def body_func(_, values):
|
||||
|
||||
@@ -65,14 +65,14 @@ class RecurrentGenome(BaseGenome):
|
||||
values = values.at[self.input_idx].set(inputs)
|
||||
|
||||
# calculate connections
|
||||
node_ins = jax.vmap(
|
||||
jax.vmap(self.conn_gene.forward, in_axes=(None, 1, None)),
|
||||
in_axes=(None, 1, 0),
|
||||
)(state, conns, values)
|
||||
node_ins = vmap(
|
||||
vmap(self.conn_gene.forward, in_axes=(None, 0, None)),
|
||||
in_axes=(None, 0, 0),
|
||||
)(state, expand_conns_attrs, values)
|
||||
|
||||
# calculate nodes
|
||||
is_output_nodes = jnp.isin(jnp.arange(self.max_nodes), self.output_idx)
|
||||
values = jax.vmap(self.node_gene.forward, in_axes=(None, 0, 0, 0))(
|
||||
is_output_nodes = jnp.isin(nodes[:, 0], self.output_idx)
|
||||
values = vmap(self.node_gene.forward, in_axes=(None, 0, 0, 0))(
|
||||
state, nodes_attrs, node_ins.T, is_output_nodes
|
||||
)
|
||||
|
||||
@@ -87,3 +87,6 @@ class RecurrentGenome(BaseGenome):
|
||||
|
||||
def sympy_func(self, state, network, precision=3):
|
||||
raise ValueError("Sympy function is not supported for Recurrent Network!")
|
||||
|
||||
def visualize(self, network):
|
||||
raise ValueError("Visualize function is not supported for Recurrent Network!")
|
||||
|
||||
Reference in New Issue
Block a user