diff --git a/tensorneat/algorithm/neat/ga/crossover/base.py b/tensorneat/algorithm/neat/ga/crossover/base.py index 1849deb..206a6f8 100644 --- a/tensorneat/algorithm/neat/ga/crossover/base.py +++ b/tensorneat/algorithm/neat/ga/crossover/base.py @@ -3,8 +3,8 @@ from utils import State class BaseCrossover: - def setup(self, key, state=State()): + def setup(self, state=State()): return state - def __call__(self, state, key, genome, nodes1, nodes2, conns1, conns2): + def __call__(self, state, genome, nodes1, nodes2, conns1, conns2): raise NotImplementedError diff --git a/tensorneat/algorithm/neat/ga/crossover/default.py b/tensorneat/algorithm/neat/ga/crossover/default.py index c7eb531..c6e3e37 100644 --- a/tensorneat/algorithm/neat/ga/crossover/default.py +++ b/tensorneat/algorithm/neat/ga/crossover/default.py @@ -5,12 +5,12 @@ from .base import BaseCrossover class DefaultCrossover(BaseCrossover): - def __call__(self, state, key, genome, nodes1, conns1, nodes2, conns2): + def __call__(self, state, genome, nodes1, conns1, nodes2, conns2): """ use genome1 and genome2 to generate a new genome notice that genome1 should have higher fitness than genome2 (genome1 is winner!) """ - randkey_1, randkey_2, key = jax.random.split(key, 3) + randkey1, randkey2, randkey = jax.random.split(state.randkey, 3) # crossover nodes keys1, keys2 = nodes1[:, 0], nodes2[:, 0] @@ -20,16 +20,16 @@ class DefaultCrossover(BaseCrossover): # For not homologous genes, use the value of nodes1(winner) # For homologous genes, use the crossover result between nodes1 and nodes2 new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, - self.crossover_gene(randkey_1, nodes1, nodes2, is_conn=False)) + self.crossover_gene(randkey1, nodes1, nodes2, is_conn=False)) # crossover connections con_keys1, con_keys2 = conns1[:, :2], conns2[:, :2] conns2 = self.align_array(con_keys1, con_keys2, conns2, is_conn=True) new_conns = jnp.where(jnp.isnan(conns1) | jnp.isnan(conns2), conns1, - self.crossover_gene(randkey_2, conns1, conns2, is_conn=True)) + self.crossover_gene(randkey2, conns1, conns2, is_conn=True)) - return new_nodes, new_conns + return state.update(randkey=randkey), new_nodes, new_conns def align_array(self, seq1, seq2, ar2, is_conn: bool): """ diff --git a/tensorneat/algorithm/neat/ga/mutation/base.py b/tensorneat/algorithm/neat/ga/mutation/base.py index 2322f85..4e4a0b3 100644 --- a/tensorneat/algorithm/neat/ga/mutation/base.py +++ b/tensorneat/algorithm/neat/ga/mutation/base.py @@ -3,8 +3,8 @@ from utils import State class BaseMutation: - def setup(self, key, state=State()): + def setup(self, state=State()): return state - def __call__(self, state, key, genome, nodes, conns, new_node_key): + def __call__(self, state, genome, nodes, conns, new_node_key): raise NotImplementedError diff --git a/tensorneat/algorithm/neat/ga/mutation/default.py b/tensorneat/algorithm/neat/ga/mutation/default.py index aea3304..b0fc047 100644 --- a/tensorneat/algorithm/neat/ga/mutation/default.py +++ b/tensorneat/algorithm/neat/ga/mutation/default.py @@ -17,13 +17,13 @@ class DefaultMutation(BaseMutation): self.node_add = node_add self.node_delete = node_delete - def __call__(self, state, key, genome, nodes, conns, new_node_key): - k1, k2 = jax.random.split(key) + def __call__(self, state, genome, nodes, conns, new_node_key): + k1, k2, randkey = jax.random.split(state.randkey) nodes, conns = self.mutate_structure(k1, genome, nodes, conns, new_node_key) nodes, conns = self.mutate_values(k2, genome, nodes, conns) - return nodes, conns + return state.update(randkey=randkey), nodes, conns def mutate_structure(self, key, genome, nodes, conns, new_node_key): def mutate_add_node(key_, nodes_, conns_): diff --git a/tensorneat/algorithm/neat/gene/base.py b/tensorneat/algorithm/neat/gene/base.py index abb2d52..4e3a49a 100644 --- a/tensorneat/algorithm/neat/gene/base.py +++ b/tensorneat/algorithm/neat/gene/base.py @@ -9,13 +9,13 @@ class BaseGene: def __init__(self): pass - def setup(self, key, state=State()): + def setup(self, state=State()): return state - def new_attrs(self, state, key): + def new_attrs(self, state): raise NotImplementedError - def mutate(self, state, key, gene): + def mutate(self, state, gene): raise NotImplementedError def distance(self, state, gene1, gene2): diff --git a/tensorneat/algorithm/neat/gene/conn/default.py b/tensorneat/algorithm/neat/gene/conn/default.py index d8834ed..8f4d0e8 100644 --- a/tensorneat/algorithm/neat/gene/conn/default.py +++ b/tensorneat/algorithm/neat/gene/conn/default.py @@ -1,4 +1,5 @@ import jax.numpy as jnp +import jax.random from utils import mutate_float from . import BaseConnGene @@ -24,14 +25,15 @@ class DefaultConnGene(BaseConnGene): self.weight_mutate_rate = weight_mutate_rate self.weight_replace_rate = weight_replace_rate - def new_attrs(self, state, key): - return jnp.array([self.weight_init_mean]) + def new_attrs(self, state): + return state, jnp.array([self.weight_init_mean]) - def mutate(self, state, key, conn): + def mutate(self, state, conn): + randkey_, randkey = jax.random.split(state.randkey, 2) input_index = conn[0] output_index = conn[1] enabled = conn[2] - weight = mutate_float(key, + weight = mutate_float(randkey_, conn[3], self.weight_init_mean, self.weight_init_std, @@ -40,11 +42,11 @@ class DefaultConnGene(BaseConnGene): self.weight_replace_rate ) - return jnp.array([input_index, output_index, enabled, weight]) + return state.update(randkey=randkey), jnp.array([input_index, output_index, enabled, weight]) def distance(self, state, attrs1, attrs2): - return (attrs1[2] != attrs2[2]) + jnp.abs(attrs1[3] - attrs2[3]) # enable + weight + return state, (attrs1[2] != attrs2[2]) + jnp.abs(attrs1[3] - attrs2[3]) # enable + weight def forward(self, state, attrs, inputs): weight = attrs[0] - return inputs * weight + return state, inputs * weight diff --git a/tensorneat/algorithm/neat/gene/node/default.py b/tensorneat/algorithm/neat/gene/node/default.py index 1c46e17..6e118ce 100644 --- a/tensorneat/algorithm/neat/gene/node/default.py +++ b/tensorneat/algorithm/neat/gene/node/default.py @@ -56,13 +56,13 @@ class DefaultNodeGene(BaseNodeGene): self.aggregation_indices = jnp.arange(len(aggregation_options)) self.aggregation_replace_rate = aggregation_replace_rate - def new_attrs(self, state, key): - return jnp.array( + def new_attrs(self, state): + return state, jnp.array( [self.bias_init_mean, self.response_init_mean, self.activation_default, self.aggregation_default] ) - def mutate(self, state, key, node): - k1, k2, k3, k4 = jax.random.split(key, num=4) + def mutate(self, state, node): + k1, k2, k3, k4, randkey = jax.random.split(state.randkey, num=5) index = node[0] bias = mutate_float(k1, node[1], self.bias_init_mean, self.bias_init_std, @@ -75,10 +75,10 @@ class DefaultNodeGene(BaseNodeGene): agg = mutate_int(k4, node[4], self.aggregation_indices, self.aggregation_replace_rate) - return jnp.array([index, bias, res, act, agg]) + return state.update(randkey=randkey), jnp.array([index, bias, res, act, agg]) def distance(self, state, node1, node2): - return ( + return state, ( jnp.abs(node1[1] - node2[1]) + jnp.abs(node1[2] - node2[2]) + (node1[3] != node2[3]) + @@ -98,4 +98,4 @@ class DefaultNodeGene(BaseNodeGene): lambda: act(act_idx, z, self.activation_options) ) - return z + return state, z diff --git a/tensorneat/algorithm/neat/genome/base.py b/tensorneat/algorithm/neat/genome/base.py index 9a3ac81..fd711ab 100644 --- a/tensorneat/algorithm/neat/genome/base.py +++ b/tensorneat/algorithm/neat/genome/base.py @@ -24,7 +24,7 @@ class BaseGenome: self.node_gene = node_gene self.conn_gene = conn_gene - def setup(self, key, state=State()): + def setup(self, state=State()): return state def transform(self, state, nodes, conns): diff --git a/tensorneat/algorithm/neat/genome/default.py b/tensorneat/algorithm/neat/genome/default.py index 0425938..98445bf 100644 --- a/tensorneat/algorithm/neat/genome/default.py +++ b/tensorneat/algorithm/neat/genome/default.py @@ -38,7 +38,7 @@ class DefaultGenome(BaseGenome): u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan) seqs = topological_sort(nodes, conn_enable) - return seqs, nodes, u_conns + return state, seqs, nodes, u_conns def forward(self, state, inputs, transformed): cal_seqs, nodes, conns = transformed @@ -49,34 +49,32 @@ class DefaultGenome(BaseGenome): nodes_attrs = nodes[:, 1:] def cond_fun(carry): - values, idx = carry + state_, values, idx = carry return (idx < N) & (cal_seqs[idx] != I_INT) def body_func(carry): - values, idx = carry + state_, values, idx = carry i = cal_seqs[idx] def hit(): - ins = jax.vmap(self.conn_gene.forward, in_axes=(None, 1, 0))(conns[:, :, i], values) - z = self.node_gene.forward(state, nodes_attrs[i], ins, is_output_node=jnp.isin(i, self.output_idx)) + s, ins = jax.vmap(self.conn_gene.forward, + in_axes=(None, 1, 0), out_axes=(None, 0))(state_, conns[:, :, i], values) + s, z = self.node_gene.forward(s, nodes_attrs[i], ins, is_output_node=jnp.isin(i, self.output_idx)) new_values = values.at[i].set(z) - return new_values - - def miss(): - return values + return s, new_values # the val of input nodes is obtained by the task, not by calculation - values = jax.lax.cond( + state_, values = jax.lax.cond( jnp.isin(i, self.input_idx), - miss, + lambda: (state_, values), hit ) - return values, idx + 1 + return state_, values, idx + 1 - vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0)) + state, vals, _ = jax.lax.while_loop(cond_fun, body_func, (state, ini_vals, 0)) if self.output_transform is None: - return vals[self.output_idx] + return state, vals[self.output_idx] else: - return self.output_transform(vals[self.output_idx]) + return state, self.output_transform(vals[self.output_idx]) diff --git a/tensorneat/algorithm/neat/genome/recurrent.py b/tensorneat/algorithm/neat/genome/recurrent.py index 8469884..93b3614 100644 --- a/tensorneat/algorithm/neat/genome/recurrent.py +++ b/tensorneat/algorithm/neat/genome/recurrent.py @@ -39,7 +39,7 @@ class RecurrentGenome(BaseGenome): conn_enable = u_conns[0] == 1 u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan) - return nodes, u_conns + return state, nodes, u_conns def forward(self, state, inputs, transformed): nodes, conns = transformed @@ -48,27 +48,36 @@ class RecurrentGenome(BaseGenome): vals = jnp.full((N,), jnp.nan) nodes_attrs = nodes[:, 1:] - def body_func(_, values): + def body_func(_, carry): + state_, values = carry + # set input values values = values.at[self.input_idx].set(inputs) # calculate connections - node_ins = jax.vmap( + state_, node_ins = jax.vmap( jax.vmap( self.conn_gene.forward, - in_axes=(None, 1, None) + in_axes=(None, 1, None), + out_axes=(None, 0) ), - in_axes=(None, 1, 0) - )(state, conns, values) + in_axes=(None, 1, 0), + out_axes=(None, 0) + )(state_, conns, values) # calculate nodes is_output_nodes = jnp.isin( jnp.arange(N), self.output_idx ) - values = jax.vmap(self.node_gene.forward, in_axes=(None, 0, 0))(nodes_attrs, node_ins.T, is_output_nodes) - return values + state_, values = jax.vmap( + self.node_gene.forward, + in_axes=(None, 0, 0, 0), + out_axes=(None, 0) + )(state_, nodes_attrs, node_ins.T, is_output_nodes) - vals = jax.lax.fori_loop(0, self.activate_time, body_func, vals) + return state_, values - return vals[self.output_idx] + state, vals = jax.lax.fori_loop(0, self.activate_time, body_func, (state, vals)) + + return state, vals[self.output_idx] diff --git a/tensorneat/algorithm/neat/species/default.py b/tensorneat/algorithm/neat/species/default.py index d5b297d..0bf2eec 100644 --- a/tensorneat/algorithm/neat/species/default.py +++ b/tensorneat/algorithm/neat/species/default.py @@ -40,8 +40,8 @@ class DefaultSpecies(BaseSpecies): self.species_arange = jnp.arange(self.species_size) - def setup(self, randkey): - k1, k2 = jax.random.split(randkey, 2) + def setup(self, key, state=State()): + k1, k2 = jax.random.split(key, 2) pop_nodes, pop_conns = initialize_population(self.pop_size, self.genome, k1, self.initialize_method) species_keys = jnp.full((self.species_size,), jnp.nan) # the unique index (primary key) for each species @@ -65,8 +65,8 @@ class DefaultSpecies(BaseSpecies): pop_nodes, pop_conns = jax.device_put((pop_nodes, pop_conns)) - return State( - randkey=k2, + return state.register( + species_randkey=k2, pop_nodes=pop_nodes, pop_conns=pop_conns, species_keys=species_keys, @@ -134,8 +134,7 @@ class DefaultSpecies(BaseSpecies): def check_stagnation(idx): # determine whether the species stagnation st = ( - (species_fitness[idx] <= state.best_fitness[ - idx]) & # not better than the best fitness of the species + (species_fitness[idx] <= state.best_fitness[idx]) & # not better than the best fitness of the species (generation - state.last_improved[idx] > self.max_stagnation) # for a long time ) diff --git a/tensorneat/test/test_genome.py b/tensorneat/test/test_genome.py index 712d272..d889097 100644 --- a/tensorneat/test/test_genome.py +++ b/tensorneat/test/test_genome.py @@ -1,5 +1,5 @@ from algorithm.neat import * -from utils import Act, Agg +from utils import Act, Agg, State import jax, jax.numpy as jnp @@ -36,11 +36,14 @@ def test_default(): ), ) - transformed = genome.transform(nodes, conns) + state = genome.setup(State(randkey=jax.random.key(0))) + + state, *transformed = genome.transform(state, nodes, conns) print(*transformed, sep='\n') inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]]) - outputs = jax.jit(jax.vmap(genome.forward, in_axes=(0, None)))(inputs, transformed) + state, outputs = jax.jit(jax.vmap(genome.forward, + in_axes=(None, 0, None), out_axes=(None, 0)))(state, inputs, transformed) print(outputs) assert jnp.allclose(outputs, jnp.array([[0.5], [0.75], [0.75], [1]])) # expected: [[0.5], [0.75], [0.75], [1]] @@ -50,11 +53,11 @@ def test_default(): conns = conns.at[0, 2].set(False) # disable in[0] -> hidden[0] print(conns) - transformed = genome.transform(nodes, conns) + state, *transformed = genome.transform(state, nodes, conns) print(*transformed, sep='\n') inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]]) - outputs = jax.vmap(genome.forward, in_axes=(0, None))(inputs, transformed) + state, outputs = jax.vmap(genome.forward, in_axes=(None, 0, None), out_axes=(None, 0))(state, inputs, transformed) print(outputs) assert jnp.allclose(outputs, jnp.array([[0], [0.25], [0], [0.25]])) # expected: [[0.5], [0.75], [0.5], [0.75]] @@ -93,11 +96,14 @@ def test_recurrent(): activate_time=3, ) - transformed = genome.transform(nodes, conns) + state = genome.setup(State(randkey=jax.random.key(0))) + + state, *transformed = genome.transform(state, nodes, conns) print(*transformed, sep='\n') inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]]) - outputs = jax.jit(jax.vmap(genome.forward, in_axes=(0, None)))(inputs, transformed) + state, outputs = jax.jit(jax.vmap(genome.forward, + in_axes=(None, 0, None), out_axes=(None, 0)))(state, inputs, transformed) print(outputs) assert jnp.allclose(outputs, jnp.array([[0.5], [0.75], [0.75], [1]])) # expected: [[0.5], [0.75], [0.75], [1]] @@ -107,11 +113,11 @@ def test_recurrent(): conns = conns.at[0, 2].set(False) # disable in[0] -> hidden[0] print(conns) - transformed = genome.transform(nodes, conns) + state, *transformed = genome.transform(state, nodes, conns) print(*transformed, sep='\n') inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]]) - outputs = jax.vmap(genome.forward, in_axes=(0, None))(inputs, transformed) + state, outputs = jax.vmap(genome.forward, in_axes=(None, 0, None), out_axes=(None, 0))(state, inputs, transformed) print(outputs) assert jnp.allclose(outputs, jnp.array([[0], [0.25], [0], [0.25]])) # expected: [[0.5], [0.75], [0.5], [0.75]] \ No newline at end of file