All function with state will update the state and return it.

Remove randkey args in functions with state, since it can attach the randkey by states.
This commit is contained in:
wls2002
2024-05-25 20:45:57 +08:00
parent 5626fddf41
commit 79d53ea7af
12 changed files with 84 additions and 70 deletions

View File

@@ -3,8 +3,8 @@ from utils import State
class BaseCrossover: class BaseCrossover:
def setup(self, key, state=State()): def setup(self, state=State()):
return state return state
def __call__(self, state, key, genome, nodes1, nodes2, conns1, conns2): def __call__(self, state, genome, nodes1, nodes2, conns1, conns2):
raise NotImplementedError raise NotImplementedError

View File

@@ -5,12 +5,12 @@ from .base import BaseCrossover
class DefaultCrossover(BaseCrossover): class DefaultCrossover(BaseCrossover):
def __call__(self, state, key, genome, nodes1, conns1, nodes2, conns2): def __call__(self, state, genome, nodes1, conns1, nodes2, conns2):
""" """
use genome1 and genome2 to generate a new genome use genome1 and genome2 to generate a new genome
notice that genome1 should have higher fitness than genome2 (genome1 is winner!) notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
""" """
randkey_1, randkey_2, key = jax.random.split(key, 3) randkey1, randkey2, randkey = jax.random.split(state.randkey, 3)
# crossover nodes # crossover nodes
keys1, keys2 = nodes1[:, 0], nodes2[:, 0] keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
@@ -20,16 +20,16 @@ class DefaultCrossover(BaseCrossover):
# For not homologous genes, use the value of nodes1(winner) # For not homologous genes, use the value of nodes1(winner)
# For homologous genes, use the crossover result between nodes1 and nodes2 # For homologous genes, use the crossover result between nodes1 and nodes2
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1,
self.crossover_gene(randkey_1, nodes1, nodes2, is_conn=False)) self.crossover_gene(randkey1, nodes1, nodes2, is_conn=False))
# crossover connections # crossover connections
con_keys1, con_keys2 = conns1[:, :2], conns2[:, :2] con_keys1, con_keys2 = conns1[:, :2], conns2[:, :2]
conns2 = self.align_array(con_keys1, con_keys2, conns2, is_conn=True) conns2 = self.align_array(con_keys1, con_keys2, conns2, is_conn=True)
new_conns = jnp.where(jnp.isnan(conns1) | jnp.isnan(conns2), conns1, new_conns = jnp.where(jnp.isnan(conns1) | jnp.isnan(conns2), conns1,
self.crossover_gene(randkey_2, conns1, conns2, is_conn=True)) self.crossover_gene(randkey2, conns1, conns2, is_conn=True))
return new_nodes, new_conns return state.update(randkey=randkey), new_nodes, new_conns
def align_array(self, seq1, seq2, ar2, is_conn: bool): def align_array(self, seq1, seq2, ar2, is_conn: bool):
""" """

View File

@@ -3,8 +3,8 @@ from utils import State
class BaseMutation: class BaseMutation:
def setup(self, key, state=State()): def setup(self, state=State()):
return state return state
def __call__(self, state, key, genome, nodes, conns, new_node_key): def __call__(self, state, genome, nodes, conns, new_node_key):
raise NotImplementedError raise NotImplementedError

View File

@@ -17,13 +17,13 @@ class DefaultMutation(BaseMutation):
self.node_add = node_add self.node_add = node_add
self.node_delete = node_delete self.node_delete = node_delete
def __call__(self, state, key, genome, nodes, conns, new_node_key): def __call__(self, state, genome, nodes, conns, new_node_key):
k1, k2 = jax.random.split(key) k1, k2, randkey = jax.random.split(state.randkey)
nodes, conns = self.mutate_structure(k1, genome, nodes, conns, new_node_key) nodes, conns = self.mutate_structure(k1, genome, nodes, conns, new_node_key)
nodes, conns = self.mutate_values(k2, genome, nodes, conns) nodes, conns = self.mutate_values(k2, genome, nodes, conns)
return nodes, conns return state.update(randkey=randkey), nodes, conns
def mutate_structure(self, key, genome, nodes, conns, new_node_key): def mutate_structure(self, key, genome, nodes, conns, new_node_key):
def mutate_add_node(key_, nodes_, conns_): def mutate_add_node(key_, nodes_, conns_):

View File

@@ -9,13 +9,13 @@ class BaseGene:
def __init__(self): def __init__(self):
pass pass
def setup(self, key, state=State()): def setup(self, state=State()):
return state return state
def new_attrs(self, state, key): def new_attrs(self, state):
raise NotImplementedError raise NotImplementedError
def mutate(self, state, key, gene): def mutate(self, state, gene):
raise NotImplementedError raise NotImplementedError
def distance(self, state, gene1, gene2): def distance(self, state, gene1, gene2):

View File

@@ -1,4 +1,5 @@
import jax.numpy as jnp import jax.numpy as jnp
import jax.random
from utils import mutate_float from utils import mutate_float
from . import BaseConnGene from . import BaseConnGene
@@ -24,14 +25,15 @@ class DefaultConnGene(BaseConnGene):
self.weight_mutate_rate = weight_mutate_rate self.weight_mutate_rate = weight_mutate_rate
self.weight_replace_rate = weight_replace_rate self.weight_replace_rate = weight_replace_rate
def new_attrs(self, state, key): def new_attrs(self, state):
return jnp.array([self.weight_init_mean]) return state, jnp.array([self.weight_init_mean])
def mutate(self, state, key, conn): def mutate(self, state, conn):
randkey_, randkey = jax.random.split(state.randkey, 2)
input_index = conn[0] input_index = conn[0]
output_index = conn[1] output_index = conn[1]
enabled = conn[2] enabled = conn[2]
weight = mutate_float(key, weight = mutate_float(randkey_,
conn[3], conn[3],
self.weight_init_mean, self.weight_init_mean,
self.weight_init_std, self.weight_init_std,
@@ -40,11 +42,11 @@ class DefaultConnGene(BaseConnGene):
self.weight_replace_rate self.weight_replace_rate
) )
return jnp.array([input_index, output_index, enabled, weight]) return state.update(randkey=randkey), jnp.array([input_index, output_index, enabled, weight])
def distance(self, state, attrs1, attrs2): def distance(self, state, attrs1, attrs2):
return (attrs1[2] != attrs2[2]) + jnp.abs(attrs1[3] - attrs2[3]) # enable + weight return state, (attrs1[2] != attrs2[2]) + jnp.abs(attrs1[3] - attrs2[3]) # enable + weight
def forward(self, state, attrs, inputs): def forward(self, state, attrs, inputs):
weight = attrs[0] weight = attrs[0]
return inputs * weight return state, inputs * weight

View File

@@ -56,13 +56,13 @@ class DefaultNodeGene(BaseNodeGene):
self.aggregation_indices = jnp.arange(len(aggregation_options)) self.aggregation_indices = jnp.arange(len(aggregation_options))
self.aggregation_replace_rate = aggregation_replace_rate self.aggregation_replace_rate = aggregation_replace_rate
def new_attrs(self, state, key): def new_attrs(self, state):
return jnp.array( return state, jnp.array(
[self.bias_init_mean, self.response_init_mean, self.activation_default, self.aggregation_default] [self.bias_init_mean, self.response_init_mean, self.activation_default, self.aggregation_default]
) )
def mutate(self, state, key, node): def mutate(self, state, node):
k1, k2, k3, k4 = jax.random.split(key, num=4) k1, k2, k3, k4, randkey = jax.random.split(state.randkey, num=5)
index = node[0] index = node[0]
bias = mutate_float(k1, node[1], self.bias_init_mean, self.bias_init_std, bias = mutate_float(k1, node[1], self.bias_init_mean, self.bias_init_std,
@@ -75,10 +75,10 @@ class DefaultNodeGene(BaseNodeGene):
agg = mutate_int(k4, node[4], self.aggregation_indices, self.aggregation_replace_rate) agg = mutate_int(k4, node[4], self.aggregation_indices, self.aggregation_replace_rate)
return jnp.array([index, bias, res, act, agg]) return state.update(randkey=randkey), jnp.array([index, bias, res, act, agg])
def distance(self, state, node1, node2): def distance(self, state, node1, node2):
return ( return state, (
jnp.abs(node1[1] - node2[1]) + jnp.abs(node1[1] - node2[1]) +
jnp.abs(node1[2] - node2[2]) + jnp.abs(node1[2] - node2[2]) +
(node1[3] != node2[3]) + (node1[3] != node2[3]) +
@@ -98,4 +98,4 @@ class DefaultNodeGene(BaseNodeGene):
lambda: act(act_idx, z, self.activation_options) lambda: act(act_idx, z, self.activation_options)
) )
return z return state, z

View File

@@ -24,7 +24,7 @@ class BaseGenome:
self.node_gene = node_gene self.node_gene = node_gene
self.conn_gene = conn_gene self.conn_gene = conn_gene
def setup(self, key, state=State()): def setup(self, state=State()):
return state return state
def transform(self, state, nodes, conns): def transform(self, state, nodes, conns):

View File

@@ -38,7 +38,7 @@ class DefaultGenome(BaseGenome):
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan) u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
seqs = topological_sort(nodes, conn_enable) seqs = topological_sort(nodes, conn_enable)
return seqs, nodes, u_conns return state, seqs, nodes, u_conns
def forward(self, state, inputs, transformed): def forward(self, state, inputs, transformed):
cal_seqs, nodes, conns = transformed cal_seqs, nodes, conns = transformed
@@ -49,34 +49,32 @@ class DefaultGenome(BaseGenome):
nodes_attrs = nodes[:, 1:] nodes_attrs = nodes[:, 1:]
def cond_fun(carry): def cond_fun(carry):
values, idx = carry state_, values, idx = carry
return (idx < N) & (cal_seqs[idx] != I_INT) return (idx < N) & (cal_seqs[idx] != I_INT)
def body_func(carry): def body_func(carry):
values, idx = carry state_, values, idx = carry
i = cal_seqs[idx] i = cal_seqs[idx]
def hit(): def hit():
ins = jax.vmap(self.conn_gene.forward, in_axes=(None, 1, 0))(conns[:, :, i], values) s, ins = jax.vmap(self.conn_gene.forward,
z = self.node_gene.forward(state, nodes_attrs[i], ins, is_output_node=jnp.isin(i, self.output_idx)) in_axes=(None, 1, 0), out_axes=(None, 0))(state_, conns[:, :, i], values)
s, z = self.node_gene.forward(s, nodes_attrs[i], ins, is_output_node=jnp.isin(i, self.output_idx))
new_values = values.at[i].set(z) new_values = values.at[i].set(z)
return new_values return s, new_values
def miss():
return values
# the val of input nodes is obtained by the task, not by calculation # the val of input nodes is obtained by the task, not by calculation
values = jax.lax.cond( state_, values = jax.lax.cond(
jnp.isin(i, self.input_idx), jnp.isin(i, self.input_idx),
miss, lambda: (state_, values),
hit hit
) )
return values, idx + 1 return state_, values, idx + 1
vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0)) state, vals, _ = jax.lax.while_loop(cond_fun, body_func, (state, ini_vals, 0))
if self.output_transform is None: if self.output_transform is None:
return vals[self.output_idx] return state, vals[self.output_idx]
else: else:
return self.output_transform(vals[self.output_idx]) return state, self.output_transform(vals[self.output_idx])

View File

@@ -39,7 +39,7 @@ class RecurrentGenome(BaseGenome):
conn_enable = u_conns[0] == 1 conn_enable = u_conns[0] == 1
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan) u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
return nodes, u_conns return state, nodes, u_conns
def forward(self, state, inputs, transformed): def forward(self, state, inputs, transformed):
nodes, conns = transformed nodes, conns = transformed
@@ -48,27 +48,36 @@ class RecurrentGenome(BaseGenome):
vals = jnp.full((N,), jnp.nan) vals = jnp.full((N,), jnp.nan)
nodes_attrs = nodes[:, 1:] nodes_attrs = nodes[:, 1:]
def body_func(_, values): def body_func(_, carry):
state_, values = carry
# set input values # set input values
values = values.at[self.input_idx].set(inputs) values = values.at[self.input_idx].set(inputs)
# calculate connections # calculate connections
node_ins = jax.vmap( state_, node_ins = jax.vmap(
jax.vmap( jax.vmap(
self.conn_gene.forward, self.conn_gene.forward,
in_axes=(None, 1, None) in_axes=(None, 1, None),
out_axes=(None, 0)
), ),
in_axes=(None, 1, 0) in_axes=(None, 1, 0),
)(state, conns, values) out_axes=(None, 0)
)(state_, conns, values)
# calculate nodes # calculate nodes
is_output_nodes = jnp.isin( is_output_nodes = jnp.isin(
jnp.arange(N), jnp.arange(N),
self.output_idx self.output_idx
) )
values = jax.vmap(self.node_gene.forward, in_axes=(None, 0, 0))(nodes_attrs, node_ins.T, is_output_nodes) state_, values = jax.vmap(
return values self.node_gene.forward,
in_axes=(None, 0, 0, 0),
out_axes=(None, 0)
)(state_, nodes_attrs, node_ins.T, is_output_nodes)
vals = jax.lax.fori_loop(0, self.activate_time, body_func, vals) return state_, values
return vals[self.output_idx] state, vals = jax.lax.fori_loop(0, self.activate_time, body_func, (state, vals))
return state, vals[self.output_idx]

View File

@@ -40,8 +40,8 @@ class DefaultSpecies(BaseSpecies):
self.species_arange = jnp.arange(self.species_size) self.species_arange = jnp.arange(self.species_size)
def setup(self, randkey): def setup(self, key, state=State()):
k1, k2 = jax.random.split(randkey, 2) k1, k2 = jax.random.split(key, 2)
pop_nodes, pop_conns = initialize_population(self.pop_size, self.genome, k1, self.initialize_method) pop_nodes, pop_conns = initialize_population(self.pop_size, self.genome, k1, self.initialize_method)
species_keys = jnp.full((self.species_size,), jnp.nan) # the unique index (primary key) for each species species_keys = jnp.full((self.species_size,), jnp.nan) # the unique index (primary key) for each species
@@ -65,8 +65,8 @@ class DefaultSpecies(BaseSpecies):
pop_nodes, pop_conns = jax.device_put((pop_nodes, pop_conns)) pop_nodes, pop_conns = jax.device_put((pop_nodes, pop_conns))
return State( return state.register(
randkey=k2, species_randkey=k2,
pop_nodes=pop_nodes, pop_nodes=pop_nodes,
pop_conns=pop_conns, pop_conns=pop_conns,
species_keys=species_keys, species_keys=species_keys,
@@ -134,8 +134,7 @@ class DefaultSpecies(BaseSpecies):
def check_stagnation(idx): def check_stagnation(idx):
# determine whether the species stagnation # determine whether the species stagnation
st = ( st = (
(species_fitness[idx] <= state.best_fitness[ (species_fitness[idx] <= state.best_fitness[idx]) & # not better than the best fitness of the species
idx]) & # not better than the best fitness of the species
(generation - state.last_improved[idx] > self.max_stagnation) # for a long time (generation - state.last_improved[idx] > self.max_stagnation) # for a long time
) )

View File

@@ -1,5 +1,5 @@
from algorithm.neat import * from algorithm.neat import *
from utils import Act, Agg from utils import Act, Agg, State
import jax, jax.numpy as jnp import jax, jax.numpy as jnp
@@ -36,11 +36,14 @@ def test_default():
), ),
) )
transformed = genome.transform(nodes, conns) state = genome.setup(State(randkey=jax.random.key(0)))
state, *transformed = genome.transform(state, nodes, conns)
print(*transformed, sep='\n') print(*transformed, sep='\n')
inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]]) inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]])
outputs = jax.jit(jax.vmap(genome.forward, in_axes=(0, None)))(inputs, transformed) state, outputs = jax.jit(jax.vmap(genome.forward,
in_axes=(None, 0, None), out_axes=(None, 0)))(state, inputs, transformed)
print(outputs) print(outputs)
assert jnp.allclose(outputs, jnp.array([[0.5], [0.75], [0.75], [1]])) assert jnp.allclose(outputs, jnp.array([[0.5], [0.75], [0.75], [1]]))
# expected: [[0.5], [0.75], [0.75], [1]] # expected: [[0.5], [0.75], [0.75], [1]]
@@ -50,11 +53,11 @@ def test_default():
conns = conns.at[0, 2].set(False) # disable in[0] -> hidden[0] conns = conns.at[0, 2].set(False) # disable in[0] -> hidden[0]
print(conns) print(conns)
transformed = genome.transform(nodes, conns) state, *transformed = genome.transform(state, nodes, conns)
print(*transformed, sep='\n') print(*transformed, sep='\n')
inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]]) inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]])
outputs = jax.vmap(genome.forward, in_axes=(0, None))(inputs, transformed) state, outputs = jax.vmap(genome.forward, in_axes=(None, 0, None), out_axes=(None, 0))(state, inputs, transformed)
print(outputs) print(outputs)
assert jnp.allclose(outputs, jnp.array([[0], [0.25], [0], [0.25]])) assert jnp.allclose(outputs, jnp.array([[0], [0.25], [0], [0.25]]))
# expected: [[0.5], [0.75], [0.5], [0.75]] # expected: [[0.5], [0.75], [0.5], [0.75]]
@@ -93,11 +96,14 @@ def test_recurrent():
activate_time=3, activate_time=3,
) )
transformed = genome.transform(nodes, conns) state = genome.setup(State(randkey=jax.random.key(0)))
state, *transformed = genome.transform(state, nodes, conns)
print(*transformed, sep='\n') print(*transformed, sep='\n')
inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]]) inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]])
outputs = jax.jit(jax.vmap(genome.forward, in_axes=(0, None)))(inputs, transformed) state, outputs = jax.jit(jax.vmap(genome.forward,
in_axes=(None, 0, None), out_axes=(None, 0)))(state, inputs, transformed)
print(outputs) print(outputs)
assert jnp.allclose(outputs, jnp.array([[0.5], [0.75], [0.75], [1]])) assert jnp.allclose(outputs, jnp.array([[0.5], [0.75], [0.75], [1]]))
# expected: [[0.5], [0.75], [0.75], [1]] # expected: [[0.5], [0.75], [0.75], [1]]
@@ -107,11 +113,11 @@ def test_recurrent():
conns = conns.at[0, 2].set(False) # disable in[0] -> hidden[0] conns = conns.at[0, 2].set(False) # disable in[0] -> hidden[0]
print(conns) print(conns)
transformed = genome.transform(nodes, conns) state, *transformed = genome.transform(state, nodes, conns)
print(*transformed, sep='\n') print(*transformed, sep='\n')
inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]]) inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]])
outputs = jax.vmap(genome.forward, in_axes=(0, None))(inputs, transformed) state, outputs = jax.vmap(genome.forward, in_axes=(None, 0, None), out_axes=(None, 0))(state, inputs, transformed)
print(outputs) print(outputs)
assert jnp.allclose(outputs, jnp.array([[0], [0.25], [0], [0.25]])) assert jnp.allclose(outputs, jnp.array([[0], [0.25], [0], [0.25]]))
# expected: [[0.5], [0.75], [0.5], [0.75]] # expected: [[0.5], [0.75], [0.5], [0.75]]