This commit is contained in:
wls2002
2024-06-20 16:32:52 +08:00
parent 9f72813c35
commit 075460f896
17 changed files with 224 additions and 140 deletions

View File

@@ -1,5 +1,5 @@
import jax, jax.numpy as jnp
from utils import State, StatefulBaseClass
from utils import State, StatefulBaseClass, hash_array
class BaseGene(StatefulBaseClass):
@@ -43,3 +43,6 @@ class BaseGene(StatefulBaseClass):
def repr(self, state, gene, precision=2):
raise NotImplementedError
def hash(self, gene):
return hash_array(gene)

View File

@@ -2,7 +2,7 @@ import numpy as np
import jax, jax.numpy as jnp
from ..gene import BaseNodeGene, BaseConnGene
from ..ga import BaseMutation, BaseCrossover
from utils import State, StatefulBaseClass, topological_sort_python
from utils import State, StatefulBaseClass, topological_sort_python, hash_array
class BaseGenome(StatefulBaseClass):
@@ -255,10 +255,14 @@ class BaseGenome(StatefulBaseClass):
nx.draw(
G,
with_labels=True,
pos=rotated_pos,
node_size=node_sizes,
node_color=node_colors,
**kwargs,
)
plt.savefig(save_path, dpi=save_dpi)
def hash(self, nodes, conns):
nodes_hashs = jax.vmap(hash_array)(nodes)
conns_hashs = jax.vmap(hash_array)(conns)
return hash_array(jnp.concatenate([nodes_hashs, conns_hashs]))

View File

@@ -210,7 +210,14 @@ class DefaultGenome(BaseGenome):
new_transformed,
)
def sympy_func(self, state, network, sympy_input_transform=None, sympy_output_transform=None, backend="jax"):
def sympy_func(
self,
state,
network,
sympy_input_transform=None,
sympy_output_transform=None,
backend="jax",
):
assert backend in ["jax", "numpy"], "backend should be 'jax' or 'numpy'"
module = SYMPY_FUNCS_MODULE_JNP if backend == "jax" else SYMPY_FUNCS_MODULE_NP
@@ -219,6 +226,10 @@ class DefaultGenome(BaseGenome):
warnings.warn(
"genome.input_transform is not None but sympy_input_transform is None!"
)
if sympy_input_transform is None:
sympy_input_transform = lambda x: x
if sympy_input_transform is not None:
if not isinstance(sympy_input_transform, list):
sympy_input_transform = [sympy_input_transform] * self.num_inputs
@@ -231,11 +242,14 @@ class DefaultGenome(BaseGenome):
input_idx = self.get_input_idx()
output_idx = self.get_output_idx()
order, _ = topological_sort_python(set(network["nodes"]), set(network["conns"]))
hidden_idx = [i for i in network["nodes"] if i not in input_idx and i not in output_idx]
hidden_idx = [
i for i in network["nodes"] if i not in input_idx and i not in output_idx
]
symbols = {}
for i in network["nodes"]:
if i in input_idx:
symbols[i] = sp.Symbol(f"i{i - min(input_idx)}")
symbols[-i - 1] = sp.Symbol(f"i{i - min(input_idx)}") # origin_i
symbols[i] = sp.Symbol(f"norm{i - min(input_idx)}")
elif i in output_idx:
symbols[i] = sp.Symbol(f"o{i - min(output_idx)}")
else: # hidden
@@ -246,10 +260,9 @@ class DefaultGenome(BaseGenome):
for i in order:
if i in input_idx:
if sympy_input_transform is not None:
nodes_exprs[symbols[i]] = sympy_input_transform[i - min(input_idx)](symbols[i])
else:
nodes_exprs[symbols[i]] = symbols[i]
nodes_exprs[symbols[-i - 1]] = symbols[-i - 1] # origin equal to its symbol
nodes_exprs[symbols[i]] = sympy_input_transform[i - min(input_idx)](symbols[-i - 1]) # normed i
else:
in_conns = [c for c in network["conns"] if c[1] == i]
node_inputs = []
@@ -270,12 +283,13 @@ class DefaultGenome(BaseGenome):
is_output_node=(i in output_idx),
)
args_symbols.update(a_s)
if i in output_idx and sympy_output_transform is not None:
nodes_exprs[symbols[i]] = sympy_output_transform(
nodes_exprs[symbols[i]]
)
input_symbols = [v for k, v in symbols.items() if k in input_idx]
input_symbols = [symbols[-i - 1] for i in input_idx]
reduced_exprs = nodes_exprs.copy()
for i in order:
reduced_exprs[symbols[i]] = reduced_exprs[symbols[i]].subs(reduced_exprs)
@@ -299,7 +313,9 @@ class DefaultGenome(BaseGenome):
fixed_args_output_funcs.append(f)
forward_func = lambda inputs: jnp.array([f(inputs) for f in fixed_args_output_funcs])
forward_func = lambda inputs: jnp.array(
[f(inputs) for f in fixed_args_output_funcs]
)
return (
symbols,

View File

@@ -2,8 +2,6 @@ import jax, jax.numpy as jnp
from utils import State
from .. import BaseAlgorithm
from .species import *
from .ga import *
from .genome import *
class NEAT(BaseAlgorithm):
@@ -16,28 +14,13 @@ class NEAT(BaseAlgorithm):
def setup(self, state=State()):
state = self.species.setup(state)
state = state.register(
generation=jnp.array(0.0),
next_node_key=jnp.array(
max(*self.genome.input_idx, *self.genome.output_idx) + 2,
dtype=jnp.float32,
),
)
return state
def ask(self, state: State):
return self.species.ask(state)
def tell(self, state: State, fitness):
k1, k2, randkey = jax.random.split(state.randkey, 3)
state = state.update(generation=state.generation + 1, randkey=randkey)
state, winner, loser, elite_mask = self.species.update_species(state, fitness)
state = self.create_next_generation(state, winner, loser, elite_mask)
state = self.species.speciate(state)
return state
return self.species.tell(state, fitness)
def transform(self, state, individual):
"""transform the genome into a neural network"""
@@ -65,50 +48,6 @@ class NEAT(BaseAlgorithm):
def pop_size(self):
return self.species.pop_size
def create_next_generation(self, state, winner, loser, elite_mask):
# prepare random keys
pop_size = self.species.pop_size
new_node_keys = jnp.arange(pop_size) + state.next_node_key
k1, k2, randkey = jax.random.split(state.randkey, 3)
crossover_randkeys = jax.random.split(k1, pop_size)
mutate_randkeys = jax.random.split(k2, pop_size)
wpn, wpc = state.pop_nodes[winner], state.pop_conns[winner]
lpn, lpc = state.pop_nodes[loser], state.pop_conns[loser]
# batch crossover
n_nodes, n_conns = jax.vmap(
self.genome.execute_crossover, in_axes=(None, 0, 0, 0, 0, 0)
)(
state, crossover_randkeys, wpn, wpc, lpn, lpc
) # new_nodes, new_conns
# batch mutation
m_n_nodes, m_n_conns = jax.vmap(
self.genome.execute_mutation, in_axes=(None, 0, 0, 0, 0)
)(
state, mutate_randkeys, n_nodes, n_conns, new_node_keys
) # mutated_new_nodes, mutated_new_conns
# elitism don't mutate
pop_nodes = jnp.where(elite_mask[:, None, None], n_nodes, m_n_nodes)
pop_conns = jnp.where(elite_mask[:, None, None], n_conns, m_n_conns)
# update next node key
all_nodes_keys = pop_nodes[:, :, 0]
max_node_key = jnp.max(
jnp.where(jnp.isnan(all_nodes_keys), -jnp.inf, all_nodes_keys)
)
next_node_key = max_node_key + 1
return state.update(
randkey=randkey,
pop_nodes=pop_nodes,
pop_conns=pop_conns,
next_node_key=next_node_key,
)
def member_count(self, state: State):
return state.member_count

View File

@@ -10,6 +10,9 @@ class BaseSpecies(StatefulBaseClass):
def ask(self, state: State):
raise NotImplementedError
def tell(self, state: State, fitness):
raise NotImplementedError
def update_species(self, state, fitness):
raise NotImplementedError

View File

@@ -113,12 +113,23 @@ class DefaultSpecies(BaseSpecies):
idx2species=idx2species,
center_nodes=center_nodes,
center_conns=center_conns,
next_species_key=jnp.array(1), # 0 is reserved for the first species
next_species_key=jnp.float32(1), # 0 is reserved for the first species
generation=jnp.float32(0),
)
def ask(self, state):
return state.pop_nodes, state.pop_conns
def tell(self, state, fitness):
k1, k2, randkey = jax.random.split(state.randkey, 3)
state = state.update(generation=state.generation + 1, randkey=randkey)
state, winner, loser, elite_mask = self.update_species(state, fitness)
state = self.create_next_generation(state, winner, loser, elite_mask)
state = self.speciate(state)
return state
def update_species(self, state, fitness):
# update the fitness of each species
state, species_fitness = self.update_species_fitness(state, fitness)
@@ -619,3 +630,43 @@ class DefaultSpecies(BaseSpecies):
val = jnp.where(max_cnt == 0, 0, val / max_cnt) # normalize
return val
def create_next_generation(self, state, winner, loser, elite_mask):
# find next node key
all_nodes_keys = state.pop_nodes[:, :, 0]
max_node_key = jnp.max(all_nodes_keys, where=~jnp.isnan(all_nodes_keys), initial=0)
next_node_key = max_node_key + 1
new_node_keys = jnp.arange(self.pop_size) + next_node_key
# prepare random keys
k1, k2, randkey = jax.random.split(state.randkey, 3)
crossover_randkeys = jax.random.split(k1, self.pop_size)
mutate_randkeys = jax.random.split(k2, self.pop_size)
wpn, wpc = state.pop_nodes[winner], state.pop_conns[winner]
lpn, lpc = state.pop_nodes[loser], state.pop_conns[loser]
# batch crossover
n_nodes, n_conns = jax.vmap(
self.genome.execute_crossover, in_axes=(None, 0, 0, 0, 0, 0)
)(
state, crossover_randkeys, wpn, wpc, lpn, lpc
) # new_nodes, new_conns
# batch mutation
m_n_nodes, m_n_conns = jax.vmap(
self.genome.execute_mutation, in_axes=(None, 0, 0, 0, 0)
)(
state, mutate_randkeys, n_nodes, n_conns, new_node_keys
) # mutated_new_nodes, mutated_new_conns
# elitism don't mutate
pop_nodes = jnp.where(elite_mask[:, None, None], n_nodes, m_n_nodes)
pop_conns = jnp.where(elite_mask[:, None, None], n_conns, m_n_conns)
return state.update(
randkey=randkey,
pop_nodes=pop_nodes,
pop_conns=pop_conns,
)