All function with state will update the state and return it.

Remove randkey args in functions with state, since it can attach the randkey by states.
This commit is contained in:
wls2002
2024-05-25 20:45:57 +08:00
parent 5626fddf41
commit 79d53ea7af
12 changed files with 84 additions and 70 deletions

View File

@@ -3,8 +3,8 @@ from utils import State
class BaseCrossover:
def setup(self, key, state=State()):
def setup(self, state=State()):
return state
def __call__(self, state, key, genome, nodes1, nodes2, conns1, conns2):
def __call__(self, state, genome, nodes1, nodes2, conns1, conns2):
raise NotImplementedError

View File

@@ -5,12 +5,12 @@ from .base import BaseCrossover
class DefaultCrossover(BaseCrossover):
def __call__(self, state, key, genome, nodes1, conns1, nodes2, conns2):
def __call__(self, state, genome, nodes1, conns1, nodes2, conns2):
"""
use genome1 and genome2 to generate a new genome
notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
"""
randkey_1, randkey_2, key = jax.random.split(key, 3)
randkey1, randkey2, randkey = jax.random.split(state.randkey, 3)
# crossover nodes
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
@@ -20,16 +20,16 @@ class DefaultCrossover(BaseCrossover):
# For not homologous genes, use the value of nodes1(winner)
# For homologous genes, use the crossover result between nodes1 and nodes2
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1,
self.crossover_gene(randkey_1, nodes1, nodes2, is_conn=False))
self.crossover_gene(randkey1, nodes1, nodes2, is_conn=False))
# crossover connections
con_keys1, con_keys2 = conns1[:, :2], conns2[:, :2]
conns2 = self.align_array(con_keys1, con_keys2, conns2, is_conn=True)
new_conns = jnp.where(jnp.isnan(conns1) | jnp.isnan(conns2), conns1,
self.crossover_gene(randkey_2, conns1, conns2, is_conn=True))
self.crossover_gene(randkey2, conns1, conns2, is_conn=True))
return new_nodes, new_conns
return state.update(randkey=randkey), new_nodes, new_conns
def align_array(self, seq1, seq2, ar2, is_conn: bool):
"""

View File

@@ -3,8 +3,8 @@ from utils import State
class BaseMutation:
def setup(self, key, state=State()):
def setup(self, state=State()):
return state
def __call__(self, state, key, genome, nodes, conns, new_node_key):
def __call__(self, state, genome, nodes, conns, new_node_key):
raise NotImplementedError

View File

@@ -17,13 +17,13 @@ class DefaultMutation(BaseMutation):
self.node_add = node_add
self.node_delete = node_delete
def __call__(self, state, key, genome, nodes, conns, new_node_key):
k1, k2 = jax.random.split(key)
def __call__(self, state, genome, nodes, conns, new_node_key):
k1, k2, randkey = jax.random.split(state.randkey)
nodes, conns = self.mutate_structure(k1, genome, nodes, conns, new_node_key)
nodes, conns = self.mutate_values(k2, genome, nodes, conns)
return nodes, conns
return state.update(randkey=randkey), nodes, conns
def mutate_structure(self, key, genome, nodes, conns, new_node_key):
def mutate_add_node(key_, nodes_, conns_):

View File

@@ -9,13 +9,13 @@ class BaseGene:
def __init__(self):
pass
def setup(self, key, state=State()):
def setup(self, state=State()):
return state
def new_attrs(self, state, key):
def new_attrs(self, state):
raise NotImplementedError
def mutate(self, state, key, gene):
def mutate(self, state, gene):
raise NotImplementedError
def distance(self, state, gene1, gene2):

View File

@@ -1,4 +1,5 @@
import jax.numpy as jnp
import jax.random
from utils import mutate_float
from . import BaseConnGene
@@ -24,14 +25,15 @@ class DefaultConnGene(BaseConnGene):
self.weight_mutate_rate = weight_mutate_rate
self.weight_replace_rate = weight_replace_rate
def new_attrs(self, state, key):
return jnp.array([self.weight_init_mean])
def new_attrs(self, state):
return state, jnp.array([self.weight_init_mean])
def mutate(self, state, key, conn):
def mutate(self, state, conn):
randkey_, randkey = jax.random.split(state.randkey, 2)
input_index = conn[0]
output_index = conn[1]
enabled = conn[2]
weight = mutate_float(key,
weight = mutate_float(randkey_,
conn[3],
self.weight_init_mean,
self.weight_init_std,
@@ -40,11 +42,11 @@ class DefaultConnGene(BaseConnGene):
self.weight_replace_rate
)
return jnp.array([input_index, output_index, enabled, weight])
return state.update(randkey=randkey), jnp.array([input_index, output_index, enabled, weight])
def distance(self, state, attrs1, attrs2):
return (attrs1[2] != attrs2[2]) + jnp.abs(attrs1[3] - attrs2[3]) # enable + weight
return state, (attrs1[2] != attrs2[2]) + jnp.abs(attrs1[3] - attrs2[3]) # enable + weight
def forward(self, state, attrs, inputs):
weight = attrs[0]
return inputs * weight
return state, inputs * weight

View File

@@ -56,13 +56,13 @@ class DefaultNodeGene(BaseNodeGene):
self.aggregation_indices = jnp.arange(len(aggregation_options))
self.aggregation_replace_rate = aggregation_replace_rate
def new_attrs(self, state, key):
return jnp.array(
def new_attrs(self, state):
return state, jnp.array(
[self.bias_init_mean, self.response_init_mean, self.activation_default, self.aggregation_default]
)
def mutate(self, state, key, node):
k1, k2, k3, k4 = jax.random.split(key, num=4)
def mutate(self, state, node):
k1, k2, k3, k4, randkey = jax.random.split(state.randkey, num=5)
index = node[0]
bias = mutate_float(k1, node[1], self.bias_init_mean, self.bias_init_std,
@@ -75,10 +75,10 @@ class DefaultNodeGene(BaseNodeGene):
agg = mutate_int(k4, node[4], self.aggregation_indices, self.aggregation_replace_rate)
return jnp.array([index, bias, res, act, agg])
return state.update(randkey=randkey), jnp.array([index, bias, res, act, agg])
def distance(self, state, node1, node2):
return (
return state, (
jnp.abs(node1[1] - node2[1]) +
jnp.abs(node1[2] - node2[2]) +
(node1[3] != node2[3]) +
@@ -98,4 +98,4 @@ class DefaultNodeGene(BaseNodeGene):
lambda: act(act_idx, z, self.activation_options)
)
return z
return state, z

View File

@@ -24,7 +24,7 @@ class BaseGenome:
self.node_gene = node_gene
self.conn_gene = conn_gene
def setup(self, key, state=State()):
def setup(self, state=State()):
return state
def transform(self, state, nodes, conns):

View File

@@ -38,7 +38,7 @@ class DefaultGenome(BaseGenome):
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
seqs = topological_sort(nodes, conn_enable)
return seqs, nodes, u_conns
return state, seqs, nodes, u_conns
def forward(self, state, inputs, transformed):
cal_seqs, nodes, conns = transformed
@@ -49,34 +49,32 @@ class DefaultGenome(BaseGenome):
nodes_attrs = nodes[:, 1:]
def cond_fun(carry):
values, idx = carry
state_, values, idx = carry
return (idx < N) & (cal_seqs[idx] != I_INT)
def body_func(carry):
values, idx = carry
state_, values, idx = carry
i = cal_seqs[idx]
def hit():
ins = jax.vmap(self.conn_gene.forward, in_axes=(None, 1, 0))(conns[:, :, i], values)
z = self.node_gene.forward(state, nodes_attrs[i], ins, is_output_node=jnp.isin(i, self.output_idx))
s, ins = jax.vmap(self.conn_gene.forward,
in_axes=(None, 1, 0), out_axes=(None, 0))(state_, conns[:, :, i], values)
s, z = self.node_gene.forward(s, nodes_attrs[i], ins, is_output_node=jnp.isin(i, self.output_idx))
new_values = values.at[i].set(z)
return new_values
def miss():
return values
return s, new_values
# the val of input nodes is obtained by the task, not by calculation
values = jax.lax.cond(
state_, values = jax.lax.cond(
jnp.isin(i, self.input_idx),
miss,
lambda: (state_, values),
hit
)
return values, idx + 1
return state_, values, idx + 1
vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0))
state, vals, _ = jax.lax.while_loop(cond_fun, body_func, (state, ini_vals, 0))
if self.output_transform is None:
return vals[self.output_idx]
return state, vals[self.output_idx]
else:
return self.output_transform(vals[self.output_idx])
return state, self.output_transform(vals[self.output_idx])

View File

@@ -39,7 +39,7 @@ class RecurrentGenome(BaseGenome):
conn_enable = u_conns[0] == 1
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
return nodes, u_conns
return state, nodes, u_conns
def forward(self, state, inputs, transformed):
nodes, conns = transformed
@@ -48,27 +48,36 @@ class RecurrentGenome(BaseGenome):
vals = jnp.full((N,), jnp.nan)
nodes_attrs = nodes[:, 1:]
def body_func(_, values):
def body_func(_, carry):
state_, values = carry
# set input values
values = values.at[self.input_idx].set(inputs)
# calculate connections
node_ins = jax.vmap(
state_, node_ins = jax.vmap(
jax.vmap(
self.conn_gene.forward,
in_axes=(None, 1, None)
in_axes=(None, 1, None),
out_axes=(None, 0)
),
in_axes=(None, 1, 0)
)(state, conns, values)
in_axes=(None, 1, 0),
out_axes=(None, 0)
)(state_, conns, values)
# calculate nodes
is_output_nodes = jnp.isin(
jnp.arange(N),
self.output_idx
)
values = jax.vmap(self.node_gene.forward, in_axes=(None, 0, 0))(nodes_attrs, node_ins.T, is_output_nodes)
return values
state_, values = jax.vmap(
self.node_gene.forward,
in_axes=(None, 0, 0, 0),
out_axes=(None, 0)
)(state_, nodes_attrs, node_ins.T, is_output_nodes)
vals = jax.lax.fori_loop(0, self.activate_time, body_func, vals)
return state_, values
return vals[self.output_idx]
state, vals = jax.lax.fori_loop(0, self.activate_time, body_func, (state, vals))
return state, vals[self.output_idx]

View File

@@ -40,8 +40,8 @@ class DefaultSpecies(BaseSpecies):
self.species_arange = jnp.arange(self.species_size)
def setup(self, randkey):
k1, k2 = jax.random.split(randkey, 2)
def setup(self, key, state=State()):
k1, k2 = jax.random.split(key, 2)
pop_nodes, pop_conns = initialize_population(self.pop_size, self.genome, k1, self.initialize_method)
species_keys = jnp.full((self.species_size,), jnp.nan) # the unique index (primary key) for each species
@@ -65,8 +65,8 @@ class DefaultSpecies(BaseSpecies):
pop_nodes, pop_conns = jax.device_put((pop_nodes, pop_conns))
return State(
randkey=k2,
return state.register(
species_randkey=k2,
pop_nodes=pop_nodes,
pop_conns=pop_conns,
species_keys=species_keys,
@@ -134,8 +134,7 @@ class DefaultSpecies(BaseSpecies):
def check_stagnation(idx):
# determine whether the species stagnation
st = (
(species_fitness[idx] <= state.best_fitness[
idx]) & # not better than the best fitness of the species
(species_fitness[idx] <= state.best_fitness[idx]) & # not better than the best fitness of the species
(generation - state.last_improved[idx] > self.max_stagnation) # for a long time
)

View File

@@ -1,5 +1,5 @@
from algorithm.neat import *
from utils import Act, Agg
from utils import Act, Agg, State
import jax, jax.numpy as jnp
@@ -36,11 +36,14 @@ def test_default():
),
)
transformed = genome.transform(nodes, conns)
state = genome.setup(State(randkey=jax.random.key(0)))
state, *transformed = genome.transform(state, nodes, conns)
print(*transformed, sep='\n')
inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]])
outputs = jax.jit(jax.vmap(genome.forward, in_axes=(0, None)))(inputs, transformed)
state, outputs = jax.jit(jax.vmap(genome.forward,
in_axes=(None, 0, None), out_axes=(None, 0)))(state, inputs, transformed)
print(outputs)
assert jnp.allclose(outputs, jnp.array([[0.5], [0.75], [0.75], [1]]))
# expected: [[0.5], [0.75], [0.75], [1]]
@@ -50,11 +53,11 @@ def test_default():
conns = conns.at[0, 2].set(False) # disable in[0] -> hidden[0]
print(conns)
transformed = genome.transform(nodes, conns)
state, *transformed = genome.transform(state, nodes, conns)
print(*transformed, sep='\n')
inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]])
outputs = jax.vmap(genome.forward, in_axes=(0, None))(inputs, transformed)
state, outputs = jax.vmap(genome.forward, in_axes=(None, 0, None), out_axes=(None, 0))(state, inputs, transformed)
print(outputs)
assert jnp.allclose(outputs, jnp.array([[0], [0.25], [0], [0.25]]))
# expected: [[0.5], [0.75], [0.5], [0.75]]
@@ -93,11 +96,14 @@ def test_recurrent():
activate_time=3,
)
transformed = genome.transform(nodes, conns)
state = genome.setup(State(randkey=jax.random.key(0)))
state, *transformed = genome.transform(state, nodes, conns)
print(*transformed, sep='\n')
inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]])
outputs = jax.jit(jax.vmap(genome.forward, in_axes=(0, None)))(inputs, transformed)
state, outputs = jax.jit(jax.vmap(genome.forward,
in_axes=(None, 0, None), out_axes=(None, 0)))(state, inputs, transformed)
print(outputs)
assert jnp.allclose(outputs, jnp.array([[0.5], [0.75], [0.75], [1]]))
# expected: [[0.5], [0.75], [0.75], [1]]
@@ -107,11 +113,11 @@ def test_recurrent():
conns = conns.at[0, 2].set(False) # disable in[0] -> hidden[0]
print(conns)
transformed = genome.transform(nodes, conns)
state, *transformed = genome.transform(state, nodes, conns)
print(*transformed, sep='\n')
inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]])
outputs = jax.vmap(genome.forward, in_axes=(0, None))(inputs, transformed)
state, outputs = jax.vmap(genome.forward, in_axes=(None, 0, None), out_axes=(None, 0))(state, inputs, transformed)
print(outputs)
assert jnp.allclose(outputs, jnp.array([[0], [0.25], [0], [0.25]]))
# expected: [[0.5], [0.75], [0.5], [0.75]]