All function with state will update the state and return it.
Remove randkey args in functions with state, since it can attach the randkey by states.
This commit is contained in:
@@ -3,8 +3,8 @@ from utils import State
|
||||
|
||||
class BaseCrossover:
|
||||
|
||||
def setup(self, key, state=State()):
|
||||
def setup(self, state=State()):
|
||||
return state
|
||||
|
||||
def __call__(self, state, key, genome, nodes1, nodes2, conns1, conns2):
|
||||
def __call__(self, state, genome, nodes1, nodes2, conns1, conns2):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -5,12 +5,12 @@ from .base import BaseCrossover
|
||||
|
||||
class DefaultCrossover(BaseCrossover):
|
||||
|
||||
def __call__(self, state, key, genome, nodes1, conns1, nodes2, conns2):
|
||||
def __call__(self, state, genome, nodes1, conns1, nodes2, conns2):
|
||||
"""
|
||||
use genome1 and genome2 to generate a new genome
|
||||
notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
|
||||
"""
|
||||
randkey_1, randkey_2, key = jax.random.split(key, 3)
|
||||
randkey1, randkey2, randkey = jax.random.split(state.randkey, 3)
|
||||
|
||||
# crossover nodes
|
||||
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
|
||||
@@ -20,16 +20,16 @@ class DefaultCrossover(BaseCrossover):
|
||||
# For not homologous genes, use the value of nodes1(winner)
|
||||
# For homologous genes, use the crossover result between nodes1 and nodes2
|
||||
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1,
|
||||
self.crossover_gene(randkey_1, nodes1, nodes2, is_conn=False))
|
||||
self.crossover_gene(randkey1, nodes1, nodes2, is_conn=False))
|
||||
|
||||
# crossover connections
|
||||
con_keys1, con_keys2 = conns1[:, :2], conns2[:, :2]
|
||||
conns2 = self.align_array(con_keys1, con_keys2, conns2, is_conn=True)
|
||||
|
||||
new_conns = jnp.where(jnp.isnan(conns1) | jnp.isnan(conns2), conns1,
|
||||
self.crossover_gene(randkey_2, conns1, conns2, is_conn=True))
|
||||
self.crossover_gene(randkey2, conns1, conns2, is_conn=True))
|
||||
|
||||
return new_nodes, new_conns
|
||||
return state.update(randkey=randkey), new_nodes, new_conns
|
||||
|
||||
def align_array(self, seq1, seq2, ar2, is_conn: bool):
|
||||
"""
|
||||
|
||||
@@ -3,8 +3,8 @@ from utils import State
|
||||
|
||||
class BaseMutation:
|
||||
|
||||
def setup(self, key, state=State()):
|
||||
def setup(self, state=State()):
|
||||
return state
|
||||
|
||||
def __call__(self, state, key, genome, nodes, conns, new_node_key):
|
||||
def __call__(self, state, genome, nodes, conns, new_node_key):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -17,13 +17,13 @@ class DefaultMutation(BaseMutation):
|
||||
self.node_add = node_add
|
||||
self.node_delete = node_delete
|
||||
|
||||
def __call__(self, state, key, genome, nodes, conns, new_node_key):
|
||||
k1, k2 = jax.random.split(key)
|
||||
def __call__(self, state, genome, nodes, conns, new_node_key):
|
||||
k1, k2, randkey = jax.random.split(state.randkey)
|
||||
|
||||
nodes, conns = self.mutate_structure(k1, genome, nodes, conns, new_node_key)
|
||||
nodes, conns = self.mutate_values(k2, genome, nodes, conns)
|
||||
|
||||
return nodes, conns
|
||||
return state.update(randkey=randkey), nodes, conns
|
||||
|
||||
def mutate_structure(self, key, genome, nodes, conns, new_node_key):
|
||||
def mutate_add_node(key_, nodes_, conns_):
|
||||
|
||||
@@ -9,13 +9,13 @@ class BaseGene:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def setup(self, key, state=State()):
|
||||
def setup(self, state=State()):
|
||||
return state
|
||||
|
||||
def new_attrs(self, state, key):
|
||||
def new_attrs(self, state):
|
||||
raise NotImplementedError
|
||||
|
||||
def mutate(self, state, key, gene):
|
||||
def mutate(self, state, gene):
|
||||
raise NotImplementedError
|
||||
|
||||
def distance(self, state, gene1, gene2):
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import jax.numpy as jnp
|
||||
import jax.random
|
||||
|
||||
from utils import mutate_float
|
||||
from . import BaseConnGene
|
||||
@@ -24,14 +25,15 @@ class DefaultConnGene(BaseConnGene):
|
||||
self.weight_mutate_rate = weight_mutate_rate
|
||||
self.weight_replace_rate = weight_replace_rate
|
||||
|
||||
def new_attrs(self, state, key):
|
||||
return jnp.array([self.weight_init_mean])
|
||||
def new_attrs(self, state):
|
||||
return state, jnp.array([self.weight_init_mean])
|
||||
|
||||
def mutate(self, state, key, conn):
|
||||
def mutate(self, state, conn):
|
||||
randkey_, randkey = jax.random.split(state.randkey, 2)
|
||||
input_index = conn[0]
|
||||
output_index = conn[1]
|
||||
enabled = conn[2]
|
||||
weight = mutate_float(key,
|
||||
weight = mutate_float(randkey_,
|
||||
conn[3],
|
||||
self.weight_init_mean,
|
||||
self.weight_init_std,
|
||||
@@ -40,11 +42,11 @@ class DefaultConnGene(BaseConnGene):
|
||||
self.weight_replace_rate
|
||||
)
|
||||
|
||||
return jnp.array([input_index, output_index, enabled, weight])
|
||||
return state.update(randkey=randkey), jnp.array([input_index, output_index, enabled, weight])
|
||||
|
||||
def distance(self, state, attrs1, attrs2):
|
||||
return (attrs1[2] != attrs2[2]) + jnp.abs(attrs1[3] - attrs2[3]) # enable + weight
|
||||
return state, (attrs1[2] != attrs2[2]) + jnp.abs(attrs1[3] - attrs2[3]) # enable + weight
|
||||
|
||||
def forward(self, state, attrs, inputs):
|
||||
weight = attrs[0]
|
||||
return inputs * weight
|
||||
return state, inputs * weight
|
||||
|
||||
@@ -56,13 +56,13 @@ class DefaultNodeGene(BaseNodeGene):
|
||||
self.aggregation_indices = jnp.arange(len(aggregation_options))
|
||||
self.aggregation_replace_rate = aggregation_replace_rate
|
||||
|
||||
def new_attrs(self, state, key):
|
||||
return jnp.array(
|
||||
def new_attrs(self, state):
|
||||
return state, jnp.array(
|
||||
[self.bias_init_mean, self.response_init_mean, self.activation_default, self.aggregation_default]
|
||||
)
|
||||
|
||||
def mutate(self, state, key, node):
|
||||
k1, k2, k3, k4 = jax.random.split(key, num=4)
|
||||
def mutate(self, state, node):
|
||||
k1, k2, k3, k4, randkey = jax.random.split(state.randkey, num=5)
|
||||
index = node[0]
|
||||
|
||||
bias = mutate_float(k1, node[1], self.bias_init_mean, self.bias_init_std,
|
||||
@@ -75,10 +75,10 @@ class DefaultNodeGene(BaseNodeGene):
|
||||
|
||||
agg = mutate_int(k4, node[4], self.aggregation_indices, self.aggregation_replace_rate)
|
||||
|
||||
return jnp.array([index, bias, res, act, agg])
|
||||
return state.update(randkey=randkey), jnp.array([index, bias, res, act, agg])
|
||||
|
||||
def distance(self, state, node1, node2):
|
||||
return (
|
||||
return state, (
|
||||
jnp.abs(node1[1] - node2[1]) +
|
||||
jnp.abs(node1[2] - node2[2]) +
|
||||
(node1[3] != node2[3]) +
|
||||
@@ -98,4 +98,4 @@ class DefaultNodeGene(BaseNodeGene):
|
||||
lambda: act(act_idx, z, self.activation_options)
|
||||
)
|
||||
|
||||
return z
|
||||
return state, z
|
||||
|
||||
@@ -24,7 +24,7 @@ class BaseGenome:
|
||||
self.node_gene = node_gene
|
||||
self.conn_gene = conn_gene
|
||||
|
||||
def setup(self, key, state=State()):
|
||||
def setup(self, state=State()):
|
||||
return state
|
||||
|
||||
def transform(self, state, nodes, conns):
|
||||
|
||||
@@ -38,7 +38,7 @@ class DefaultGenome(BaseGenome):
|
||||
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
|
||||
seqs = topological_sort(nodes, conn_enable)
|
||||
|
||||
return seqs, nodes, u_conns
|
||||
return state, seqs, nodes, u_conns
|
||||
|
||||
def forward(self, state, inputs, transformed):
|
||||
cal_seqs, nodes, conns = transformed
|
||||
@@ -49,34 +49,32 @@ class DefaultGenome(BaseGenome):
|
||||
nodes_attrs = nodes[:, 1:]
|
||||
|
||||
def cond_fun(carry):
|
||||
values, idx = carry
|
||||
state_, values, idx = carry
|
||||
return (idx < N) & (cal_seqs[idx] != I_INT)
|
||||
|
||||
def body_func(carry):
|
||||
values, idx = carry
|
||||
state_, values, idx = carry
|
||||
i = cal_seqs[idx]
|
||||
|
||||
def hit():
|
||||
ins = jax.vmap(self.conn_gene.forward, in_axes=(None, 1, 0))(conns[:, :, i], values)
|
||||
z = self.node_gene.forward(state, nodes_attrs[i], ins, is_output_node=jnp.isin(i, self.output_idx))
|
||||
s, ins = jax.vmap(self.conn_gene.forward,
|
||||
in_axes=(None, 1, 0), out_axes=(None, 0))(state_, conns[:, :, i], values)
|
||||
s, z = self.node_gene.forward(s, nodes_attrs[i], ins, is_output_node=jnp.isin(i, self.output_idx))
|
||||
new_values = values.at[i].set(z)
|
||||
return new_values
|
||||
|
||||
def miss():
|
||||
return values
|
||||
return s, new_values
|
||||
|
||||
# the val of input nodes is obtained by the task, not by calculation
|
||||
values = jax.lax.cond(
|
||||
state_, values = jax.lax.cond(
|
||||
jnp.isin(i, self.input_idx),
|
||||
miss,
|
||||
lambda: (state_, values),
|
||||
hit
|
||||
)
|
||||
|
||||
return values, idx + 1
|
||||
return state_, values, idx + 1
|
||||
|
||||
vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0))
|
||||
state, vals, _ = jax.lax.while_loop(cond_fun, body_func, (state, ini_vals, 0))
|
||||
|
||||
if self.output_transform is None:
|
||||
return vals[self.output_idx]
|
||||
return state, vals[self.output_idx]
|
||||
else:
|
||||
return self.output_transform(vals[self.output_idx])
|
||||
return state, self.output_transform(vals[self.output_idx])
|
||||
|
||||
@@ -39,7 +39,7 @@ class RecurrentGenome(BaseGenome):
|
||||
conn_enable = u_conns[0] == 1
|
||||
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
|
||||
|
||||
return nodes, u_conns
|
||||
return state, nodes, u_conns
|
||||
|
||||
def forward(self, state, inputs, transformed):
|
||||
nodes, conns = transformed
|
||||
@@ -48,27 +48,36 @@ class RecurrentGenome(BaseGenome):
|
||||
vals = jnp.full((N,), jnp.nan)
|
||||
nodes_attrs = nodes[:, 1:]
|
||||
|
||||
def body_func(_, values):
|
||||
def body_func(_, carry):
|
||||
state_, values = carry
|
||||
|
||||
# set input values
|
||||
values = values.at[self.input_idx].set(inputs)
|
||||
|
||||
# calculate connections
|
||||
node_ins = jax.vmap(
|
||||
state_, node_ins = jax.vmap(
|
||||
jax.vmap(
|
||||
self.conn_gene.forward,
|
||||
in_axes=(None, 1, None)
|
||||
in_axes=(None, 1, None),
|
||||
out_axes=(None, 0)
|
||||
),
|
||||
in_axes=(None, 1, 0)
|
||||
)(state, conns, values)
|
||||
in_axes=(None, 1, 0),
|
||||
out_axes=(None, 0)
|
||||
)(state_, conns, values)
|
||||
|
||||
# calculate nodes
|
||||
is_output_nodes = jnp.isin(
|
||||
jnp.arange(N),
|
||||
self.output_idx
|
||||
)
|
||||
values = jax.vmap(self.node_gene.forward, in_axes=(None, 0, 0))(nodes_attrs, node_ins.T, is_output_nodes)
|
||||
return values
|
||||
state_, values = jax.vmap(
|
||||
self.node_gene.forward,
|
||||
in_axes=(None, 0, 0, 0),
|
||||
out_axes=(None, 0)
|
||||
)(state_, nodes_attrs, node_ins.T, is_output_nodes)
|
||||
|
||||
vals = jax.lax.fori_loop(0, self.activate_time, body_func, vals)
|
||||
return state_, values
|
||||
|
||||
return vals[self.output_idx]
|
||||
state, vals = jax.lax.fori_loop(0, self.activate_time, body_func, (state, vals))
|
||||
|
||||
return state, vals[self.output_idx]
|
||||
|
||||
@@ -40,8 +40,8 @@ class DefaultSpecies(BaseSpecies):
|
||||
|
||||
self.species_arange = jnp.arange(self.species_size)
|
||||
|
||||
def setup(self, randkey):
|
||||
k1, k2 = jax.random.split(randkey, 2)
|
||||
def setup(self, key, state=State()):
|
||||
k1, k2 = jax.random.split(key, 2)
|
||||
pop_nodes, pop_conns = initialize_population(self.pop_size, self.genome, k1, self.initialize_method)
|
||||
|
||||
species_keys = jnp.full((self.species_size,), jnp.nan) # the unique index (primary key) for each species
|
||||
@@ -65,8 +65,8 @@ class DefaultSpecies(BaseSpecies):
|
||||
|
||||
pop_nodes, pop_conns = jax.device_put((pop_nodes, pop_conns))
|
||||
|
||||
return State(
|
||||
randkey=k2,
|
||||
return state.register(
|
||||
species_randkey=k2,
|
||||
pop_nodes=pop_nodes,
|
||||
pop_conns=pop_conns,
|
||||
species_keys=species_keys,
|
||||
@@ -134,8 +134,7 @@ class DefaultSpecies(BaseSpecies):
|
||||
def check_stagnation(idx):
|
||||
# determine whether the species stagnation
|
||||
st = (
|
||||
(species_fitness[idx] <= state.best_fitness[
|
||||
idx]) & # not better than the best fitness of the species
|
||||
(species_fitness[idx] <= state.best_fitness[idx]) & # not better than the best fitness of the species
|
||||
(generation - state.last_improved[idx] > self.max_stagnation) # for a long time
|
||||
)
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from algorithm.neat import *
|
||||
from utils import Act, Agg
|
||||
from utils import Act, Agg, State
|
||||
|
||||
import jax, jax.numpy as jnp
|
||||
|
||||
@@ -36,11 +36,14 @@ def test_default():
|
||||
),
|
||||
)
|
||||
|
||||
transformed = genome.transform(nodes, conns)
|
||||
state = genome.setup(State(randkey=jax.random.key(0)))
|
||||
|
||||
state, *transformed = genome.transform(state, nodes, conns)
|
||||
print(*transformed, sep='\n')
|
||||
|
||||
inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]])
|
||||
outputs = jax.jit(jax.vmap(genome.forward, in_axes=(0, None)))(inputs, transformed)
|
||||
state, outputs = jax.jit(jax.vmap(genome.forward,
|
||||
in_axes=(None, 0, None), out_axes=(None, 0)))(state, inputs, transformed)
|
||||
print(outputs)
|
||||
assert jnp.allclose(outputs, jnp.array([[0.5], [0.75], [0.75], [1]]))
|
||||
# expected: [[0.5], [0.75], [0.75], [1]]
|
||||
@@ -50,11 +53,11 @@ def test_default():
|
||||
conns = conns.at[0, 2].set(False) # disable in[0] -> hidden[0]
|
||||
print(conns)
|
||||
|
||||
transformed = genome.transform(nodes, conns)
|
||||
state, *transformed = genome.transform(state, nodes, conns)
|
||||
print(*transformed, sep='\n')
|
||||
|
||||
inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]])
|
||||
outputs = jax.vmap(genome.forward, in_axes=(0, None))(inputs, transformed)
|
||||
state, outputs = jax.vmap(genome.forward, in_axes=(None, 0, None), out_axes=(None, 0))(state, inputs, transformed)
|
||||
print(outputs)
|
||||
assert jnp.allclose(outputs, jnp.array([[0], [0.25], [0], [0.25]]))
|
||||
# expected: [[0.5], [0.75], [0.5], [0.75]]
|
||||
@@ -93,11 +96,14 @@ def test_recurrent():
|
||||
activate_time=3,
|
||||
)
|
||||
|
||||
transformed = genome.transform(nodes, conns)
|
||||
state = genome.setup(State(randkey=jax.random.key(0)))
|
||||
|
||||
state, *transformed = genome.transform(state, nodes, conns)
|
||||
print(*transformed, sep='\n')
|
||||
|
||||
inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]])
|
||||
outputs = jax.jit(jax.vmap(genome.forward, in_axes=(0, None)))(inputs, transformed)
|
||||
state, outputs = jax.jit(jax.vmap(genome.forward,
|
||||
in_axes=(None, 0, None), out_axes=(None, 0)))(state, inputs, transformed)
|
||||
print(outputs)
|
||||
assert jnp.allclose(outputs, jnp.array([[0.5], [0.75], [0.75], [1]]))
|
||||
# expected: [[0.5], [0.75], [0.75], [1]]
|
||||
@@ -107,11 +113,11 @@ def test_recurrent():
|
||||
conns = conns.at[0, 2].set(False) # disable in[0] -> hidden[0]
|
||||
print(conns)
|
||||
|
||||
transformed = genome.transform(nodes, conns)
|
||||
state, *transformed = genome.transform(state, nodes, conns)
|
||||
print(*transformed, sep='\n')
|
||||
|
||||
inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]])
|
||||
outputs = jax.vmap(genome.forward, in_axes=(0, None))(inputs, transformed)
|
||||
state, outputs = jax.vmap(genome.forward, in_axes=(None, 0, None), out_axes=(None, 0))(state, inputs, transformed)
|
||||
print(outputs)
|
||||
assert jnp.allclose(outputs, jnp.array([[0], [0.25], [0], [0.25]]))
|
||||
# expected: [[0.5], [0.75], [0.5], [0.75]]
|
||||
Reference in New Issue
Block a user