finish all refactoring
This commit is contained in:
@@ -16,9 +16,30 @@ class BaseAlgorithm:
|
||||
"""update the state of the algorithm"""
|
||||
raise NotImplementedError
|
||||
|
||||
def transform(self, state: State):
|
||||
def transform(self, individual):
|
||||
"""transform the genome into a neural network"""
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, inputs, transformed):
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def num_inputs(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def num_outputs(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def pop_size(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def member_count(self, state: State):
|
||||
# to analysis the species
|
||||
raise NotImplementedError
|
||||
|
||||
def generation(self, state: State):
|
||||
# to analysis the algorithm
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
2
algorithm/hyperneat/__init__.py
Normal file
2
algorithm/hyperneat/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .hyperneat import HyperNEAT
|
||||
from .substrate import BaseSubstrate, DefaultSubstrate, FullSubstrate
|
||||
116
algorithm/hyperneat/hyperneat.py
Normal file
116
algorithm/hyperneat/hyperneat.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import jax, jax.numpy as jnp
|
||||
|
||||
from utils import State, Act, Agg
|
||||
from .. import BaseAlgorithm, NEAT
|
||||
from ..neat.gene import BaseNodeGene, BaseConnGene
|
||||
from ..neat.genome import RecurrentGenome
|
||||
from .substrate import *
|
||||
|
||||
|
||||
class HyperNEAT(BaseAlgorithm):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
substrate: BaseSubstrate,
|
||||
neat: NEAT,
|
||||
below_threshold: float = 0.3,
|
||||
max_weight: float = 5.,
|
||||
activation=Act.sigmoid,
|
||||
aggregation=Agg.sum,
|
||||
activate_time: int = 10,
|
||||
):
|
||||
assert substrate.query_coors.shape[1] == neat.num_inputs, \
|
||||
"Substrate input size should be equal to NEAT input size"
|
||||
|
||||
self.substrate = substrate
|
||||
self.neat = neat
|
||||
self.below_threshold = below_threshold
|
||||
self.max_weight = max_weight
|
||||
self.hyper_genome = RecurrentGenome(
|
||||
num_inputs=substrate.num_inputs,
|
||||
num_outputs=substrate.num_outputs,
|
||||
max_nodes=substrate.nodes_cnt,
|
||||
max_conns=substrate.conns_cnt,
|
||||
node_gene=HyperNodeGene(activation, aggregation),
|
||||
conn_gene=HyperNEATConnGene(),
|
||||
activate_time=activate_time,
|
||||
)
|
||||
|
||||
def setup(self, randkey):
|
||||
return State(
|
||||
neat_state=self.neat.setup(randkey)
|
||||
)
|
||||
|
||||
def ask(self, state: State):
|
||||
return self.neat.ask(state.neat_state)
|
||||
|
||||
def tell(self, state: State, fitness):
|
||||
return state.update(
|
||||
neat_state=self.neat.tell(state.neat_state, fitness)
|
||||
)
|
||||
|
||||
def transform(self, individual):
|
||||
transformed = self.neat.transform(individual)
|
||||
query_res = jax.vmap(self.neat.forward, in_axes=(0, None))(self.substrate.query_coors, transformed)
|
||||
|
||||
# mute the connection with weight below threshold
|
||||
query_res = jnp.where(
|
||||
(-self.below_threshold < query_res) & (query_res < self.below_threshold),
|
||||
0.,
|
||||
query_res
|
||||
)
|
||||
|
||||
# make query res in range [-max_weight, max_weight]
|
||||
query_res = jnp.where(query_res > 0, query_res - self.below_threshold, query_res)
|
||||
query_res = jnp.where(query_res < 0, query_res + self.below_threshold, query_res)
|
||||
query_res = query_res / (1 - self.below_threshold) * self.max_weight
|
||||
|
||||
h_nodes, h_conns = self.substrate.make_nodes(query_res), self.substrate.make_conn(query_res)
|
||||
return self.hyper_genome.transform(h_nodes, h_conns)
|
||||
|
||||
def forward(self, inputs, transformed):
|
||||
# add bias
|
||||
inputs_with_bias = jnp.concatenate([inputs, jnp.array([1])])
|
||||
return self.hyper_genome.forward(inputs_with_bias, transformed)
|
||||
|
||||
@property
|
||||
def num_inputs(self):
|
||||
return self.substrate.num_inputs - 1 # remove bias
|
||||
|
||||
@property
|
||||
def num_outputs(self):
|
||||
return self.substrate.num_outputs
|
||||
|
||||
@property
|
||||
def pop_size(self):
|
||||
return self.neat.pop_size
|
||||
|
||||
def member_count(self, state: State):
|
||||
return self.neat.member_count(state.neat_state)
|
||||
|
||||
def generation(self, state: State):
|
||||
return self.neat.generation(state.neat_state)
|
||||
|
||||
|
||||
class HyperNodeGene(BaseNodeGene):
|
||||
|
||||
def __init__(self,
|
||||
activation=Act.sigmoid,
|
||||
aggregation=Agg.sum,
|
||||
):
|
||||
super().__init__()
|
||||
self.activation = activation
|
||||
self.aggregation = aggregation
|
||||
|
||||
def forward(self, attrs, inputs):
|
||||
return self.activation(
|
||||
self.aggregation(inputs)
|
||||
)
|
||||
|
||||
|
||||
class HyperNEATConnGene(BaseConnGene):
|
||||
custom_attrs = ['weight']
|
||||
|
||||
def forward(self, attrs, inputs):
|
||||
weight = attrs[0]
|
||||
return inputs * weight
|
||||
3
algorithm/hyperneat/substrate/__init__.py
Normal file
3
algorithm/hyperneat/substrate/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .base import BaseSubstrate
|
||||
from .default import DefaultSubstrate
|
||||
from .full import FullSubstrate
|
||||
27
algorithm/hyperneat/substrate/base.py
Normal file
27
algorithm/hyperneat/substrate/base.py
Normal file
@@ -0,0 +1,27 @@
|
||||
class BaseSubstrate:
|
||||
|
||||
def make_nodes(self, query_res):
|
||||
raise NotImplementedError
|
||||
|
||||
def make_conn(self, query_res):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def query_coors(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def num_inputs(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def num_outputs(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def nodes_cnt(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def conns_cnt(self):
|
||||
raise NotImplementedError
|
||||
38
algorithm/hyperneat/substrate/default.py
Normal file
38
algorithm/hyperneat/substrate/default.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import jax.numpy as jnp
|
||||
from . import BaseSubstrate
|
||||
|
||||
|
||||
class DefaultSubstrate(BaseSubstrate):
|
||||
|
||||
def __init__(self, num_inputs, num_outputs, coors, nodes, conns):
|
||||
self.inputs = num_inputs
|
||||
self.outputs = num_outputs
|
||||
self.coors = jnp.array(coors)
|
||||
self.nodes = jnp.array(nodes)
|
||||
self.conns = jnp.array(conns)
|
||||
|
||||
def make_nodes(self, query_res):
|
||||
return self.nodes
|
||||
|
||||
def make_conn(self, query_res):
|
||||
return self.conns.at[:, 3:].set(query_res) # change weight
|
||||
|
||||
@property
|
||||
def query_coors(self):
|
||||
return self.coors
|
||||
|
||||
@property
|
||||
def num_inputs(self):
|
||||
return self.inputs
|
||||
|
||||
@property
|
||||
def num_outputs(self):
|
||||
return self.outputs
|
||||
|
||||
@property
|
||||
def nodes_cnt(self):
|
||||
return self.nodes.shape[0]
|
||||
|
||||
@property
|
||||
def conns_cnt(self):
|
||||
return self.conns.shape[0]
|
||||
76
algorithm/hyperneat/substrate/full.py
Normal file
76
algorithm/hyperneat/substrate/full.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import numpy as np
|
||||
from .default import DefaultSubstrate
|
||||
|
||||
|
||||
class FullSubstrate(DefaultSubstrate):
|
||||
|
||||
def __init__(self,
|
||||
input_coors=((-1, -1), (0, -1), (1, -1)),
|
||||
hidden_coors=((-1, 0), (0, 0), (1, 0)),
|
||||
output_coors=((0, 1),),
|
||||
):
|
||||
query_coors, nodes, conns = analysis_substrate(input_coors, output_coors, hidden_coors)
|
||||
super().__init__(
|
||||
len(input_coors),
|
||||
len(output_coors),
|
||||
query_coors,
|
||||
nodes,
|
||||
conns
|
||||
)
|
||||
|
||||
|
||||
def analysis_substrate(input_coors, output_coors, hidden_coors):
|
||||
input_coors = np.array(input_coors)
|
||||
output_coors = np.array(output_coors)
|
||||
hidden_coors = np.array(hidden_coors)
|
||||
|
||||
cd = input_coors.shape[1] # coordinate dimensions
|
||||
si = input_coors.shape[0] # input coordinate size
|
||||
so = output_coors.shape[0] # output coordinate size
|
||||
sh = hidden_coors.shape[0] # hidden coordinate size
|
||||
|
||||
input_idx = np.arange(si)
|
||||
output_idx = np.arange(si, si + so)
|
||||
hidden_idx = np.arange(si + so, si + so + sh)
|
||||
|
||||
total_conns = si * sh + sh * sh + sh * so
|
||||
query_coors = np.zeros((total_conns, cd * 2))
|
||||
correspond_keys = np.zeros((total_conns, 2))
|
||||
|
||||
# connect input to hidden
|
||||
aux_coors, aux_keys = cartesian_product(input_idx, hidden_idx, input_coors, hidden_coors)
|
||||
query_coors[0: si * sh, :] = aux_coors
|
||||
correspond_keys[0: si * sh, :] = aux_keys
|
||||
|
||||
# connect hidden to hidden
|
||||
aux_coors, aux_keys = cartesian_product(hidden_idx, hidden_idx, hidden_coors, hidden_coors)
|
||||
query_coors[si * sh: si * sh + sh * sh, :] = aux_coors
|
||||
correspond_keys[si * sh: si * sh + sh * sh, :] = aux_keys
|
||||
|
||||
# connect hidden to output
|
||||
aux_coors, aux_keys = cartesian_product(hidden_idx, output_idx, hidden_coors, output_coors)
|
||||
query_coors[si * sh + sh * sh:, :] = aux_coors
|
||||
correspond_keys[si * sh + sh * sh:, :] = aux_keys
|
||||
|
||||
nodes = np.concatenate((input_idx, output_idx, hidden_idx))[..., np.newaxis]
|
||||
conns = np.zeros((correspond_keys.shape[0], 4), dtype=np.float32) # input_idx, output_idx, enabled, weight
|
||||
conns[:, 0:2] = correspond_keys
|
||||
conns[:, 2] = 1 # enabled is True
|
||||
|
||||
return query_coors, nodes, conns
|
||||
|
||||
|
||||
def cartesian_product(keys1, keys2, coors1, coors2):
|
||||
len1 = keys1.shape[0]
|
||||
len2 = keys2.shape[0]
|
||||
|
||||
repeated_coors1 = np.repeat(coors1, len2, axis=0)
|
||||
repeated_keys1 = np.repeat(keys1, len2)
|
||||
|
||||
tiled_coors2 = np.tile(coors2, (len1, 1))
|
||||
tiled_keys2 = np.tile(keys2, len1)
|
||||
|
||||
new_coors = np.concatenate((repeated_coors1, tiled_coors2), axis=1)
|
||||
correspond_keys = np.column_stack((repeated_keys1, tiled_keys2))
|
||||
|
||||
return new_coors, correspond_keys
|
||||
@@ -1,3 +1,5 @@
|
||||
from .gene import *
|
||||
from .genome import *
|
||||
from .species import *
|
||||
from .neat import NEAT
|
||||
|
||||
|
||||
@@ -3,7 +3,8 @@ import jax, jax.numpy as jnp
|
||||
from .base import BaseCrossover
|
||||
|
||||
class DefaultCrossover(BaseCrossover):
|
||||
def __call__(self, randkey, genome, nodes1, nodes2, conns1, conns2):
|
||||
|
||||
def __call__(self, randkey, genome, nodes1, conns1, nodes2, conns2):
|
||||
"""
|
||||
use genome1 and genome2 to generate a new genome
|
||||
notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
|
||||
|
||||
@@ -92,7 +92,7 @@ class DefaultMutation(BaseMutation):
|
||||
return nodes_, conns_
|
||||
|
||||
def successful():
|
||||
return nodes_, genome.add_conn(conns_, i_key, o_key, True, genome.conns.new_custom_attrs())
|
||||
return nodes_, genome.add_conn(conns_, i_key, o_key, True, genome.conn_gene.new_custom_attrs())
|
||||
|
||||
def already_exist():
|
||||
return nodes_, conns_.at[conn_pos, 2].set(True)
|
||||
@@ -105,11 +105,12 @@ class DefaultMutation(BaseMutation):
|
||||
return jax.lax.cond(
|
||||
is_already_exist,
|
||||
already_exist,
|
||||
jax.lax.cond(
|
||||
is_cycle,
|
||||
nothing,
|
||||
successful
|
||||
)
|
||||
lambda:
|
||||
jax.lax.cond(
|
||||
is_cycle,
|
||||
nothing,
|
||||
successful
|
||||
)
|
||||
)
|
||||
|
||||
elif genome.network_type == 'recurrent':
|
||||
@@ -138,23 +139,23 @@ class DefaultMutation(BaseMutation):
|
||||
k1, k2, k3, k4 = jax.random.split(randkey, num=4)
|
||||
r1, r2, r3, r4 = jax.random.uniform(k1, shape=(4,))
|
||||
|
||||
def no(k, g):
|
||||
return g
|
||||
def no(key_, nodes_, conns_):
|
||||
return nodes_, conns_
|
||||
|
||||
genome = jax.lax.cond(r1 < self.node_add, mutate_add_node, no, k1, nodes, conns)
|
||||
genome = jax.lax.cond(r2 < self.node_delete, mutate_delete_node, no, k2, nodes, conns)
|
||||
genome = jax.lax.cond(r3 < self.conn_add, mutate_add_conn, no, k3, nodes, conns)
|
||||
genome = jax.lax.cond(r4 < self.conn_delete, mutate_delete_conn, no, k4, nodes, conns)
|
||||
nodes, conns = jax.lax.cond(r1 < self.node_add, mutate_add_node, no, k1, nodes, conns)
|
||||
nodes, conns = jax.lax.cond(r2 < self.node_delete, mutate_delete_node, no, k2, nodes, conns)
|
||||
nodes, conns = jax.lax.cond(r3 < self.conn_add, mutate_add_conn, no, k3, nodes, conns)
|
||||
nodes, conns = jax.lax.cond(r4 < self.conn_delete, mutate_delete_conn, no, k4, nodes, conns)
|
||||
|
||||
return genome
|
||||
return nodes, conns
|
||||
|
||||
def mutate_values(self, randkey, genome, nodes, conns):
|
||||
k1, k2 = jax.random.split(randkey, num=2)
|
||||
nodes_keys = jax.random.split(k1, num=genome.nodes.shape[0])
|
||||
conns_keys = jax.random.split(k2, num=genome.conns.shape[0])
|
||||
nodes_keys = jax.random.split(k1, num=nodes.shape[0])
|
||||
conns_keys = jax.random.split(k2, num=conns.shape[0])
|
||||
|
||||
new_nodes = jax.vmap(genome.nodes.mutate, in_axes=(0, 0))(nodes_keys, nodes)
|
||||
new_conns = jax.vmap(genome.conns.mutate, in_axes=(0, 0))(conns_keys, conns)
|
||||
new_nodes = jax.vmap(genome.node_gene.mutate, in_axes=(0, 0))(nodes_keys, nodes)
|
||||
new_conns = jax.vmap(genome.conn_gene.mutate, in_axes=(0, 0))(conns_keys, conns)
|
||||
|
||||
# nan nodes not changed
|
||||
new_nodes = jnp.where(jnp.isnan(nodes), jnp.nan, new_nodes)
|
||||
|
||||
@@ -7,8 +7,7 @@ from . import BaseConnGene
|
||||
class DefaultConnGene(BaseConnGene):
|
||||
"Default connection gene, with the same behavior as in NEAT-python."
|
||||
|
||||
fixed_attrs = ['input_index', 'output_index', 'enabled']
|
||||
attrs = ['weight']
|
||||
custom_attrs = ['weight']
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -9,7 +9,6 @@ from . import BaseNodeGene
|
||||
class DefaultNodeGene(BaseNodeGene):
|
||||
"Default node gene, with the same behavior as in NEAT-python."
|
||||
|
||||
fixed_attrs = ['index']
|
||||
custom_attrs = ['bias', 'response', 'aggregation', 'activation']
|
||||
|
||||
def __init__(
|
||||
@@ -82,8 +81,8 @@ class DefaultNodeGene(BaseNodeGene):
|
||||
return (
|
||||
jnp.abs(node1[1] - node2[1]) +
|
||||
jnp.abs(node1[2] - node2[2]) +
|
||||
node1[3] != node2[3] +
|
||||
node1[4] != node2[4]
|
||||
(node1[3] != node2[3]) +
|
||||
(node1[4] != node2[4])
|
||||
)
|
||||
|
||||
def forward(self, attrs, inputs):
|
||||
|
||||
@@ -4,7 +4,6 @@ from utils import fetch_first
|
||||
|
||||
|
||||
class BaseGenome:
|
||||
|
||||
network_type = None
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Callable
|
||||
|
||||
import jax, jax.numpy as jnp
|
||||
from utils import unflatten_conns, topological_sort, I_INT
|
||||
|
||||
@@ -13,10 +15,20 @@ class DefaultGenome(BaseGenome):
|
||||
def __init__(self,
|
||||
num_inputs: int,
|
||||
num_outputs: int,
|
||||
max_nodes=5,
|
||||
max_conns=4,
|
||||
node_gene: BaseNodeGene = DefaultNodeGene(),
|
||||
conn_gene: BaseConnGene = DefaultConnGene(),
|
||||
output_transform: Callable = None
|
||||
):
|
||||
super().__init__(num_inputs, num_outputs, node_gene, conn_gene)
|
||||
super().__init__(num_inputs, num_outputs, max_nodes, max_conns, node_gene, conn_gene)
|
||||
|
||||
if output_transform is not None:
|
||||
try:
|
||||
aux = output_transform(jnp.zeros(num_outputs))
|
||||
except Exception as e:
|
||||
raise ValueError(f"Output transform function failed: {e}")
|
||||
self.output_transform = output_transform
|
||||
|
||||
def transform(self, nodes, conns):
|
||||
u_conns = unflatten_conns(nodes, conns)
|
||||
@@ -72,4 +84,7 @@ class DefaultGenome(BaseGenome):
|
||||
|
||||
vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0))
|
||||
|
||||
return vals[self.output_idx]
|
||||
if self.output_transform is None:
|
||||
return vals[self.output_idx]
|
||||
else:
|
||||
return self.output_transform(vals[self.output_idx])
|
||||
|
||||
@@ -13,11 +13,13 @@ class RecurrentGenome(BaseGenome):
|
||||
def __init__(self,
|
||||
num_inputs: int,
|
||||
num_outputs: int,
|
||||
max_nodes: int,
|
||||
max_conns: int,
|
||||
node_gene: BaseNodeGene = DefaultNodeGene(),
|
||||
conn_gene: BaseConnGene = DefaultConnGene(),
|
||||
activate_time: int = 10,
|
||||
):
|
||||
super().__init__(num_inputs, num_outputs, node_gene, conn_gene)
|
||||
super().__init__(num_inputs, num_outputs, max_nodes, max_conns, node_gene, conn_gene)
|
||||
self.activate_time = activate_time
|
||||
|
||||
def transform(self, nodes, conns):
|
||||
|
||||
@@ -1,20 +1,19 @@
|
||||
import jax, jax.numpy as jnp
|
||||
from utils import State
|
||||
from .. import BaseAlgorithm
|
||||
from .genome import *
|
||||
from .species import *
|
||||
from .ga import *
|
||||
|
||||
|
||||
class NEAT(BaseAlgorithm):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
genome: BaseGenome,
|
||||
species: BaseSpecies,
|
||||
mutation: BaseMutation = DefaultMutation(),
|
||||
crossover: BaseCrossover = DefaultCrossover(),
|
||||
):
|
||||
self.genome = genome
|
||||
self.genome = species.genome
|
||||
self.species = species
|
||||
self.mutation = mutation
|
||||
self.crossover = crossover
|
||||
@@ -23,14 +22,14 @@ class NEAT(BaseAlgorithm):
|
||||
k1, k2 = jax.random.split(randkey, 2)
|
||||
return State(
|
||||
randkey=k1,
|
||||
generation=0,
|
||||
next_node_key=max(*self.genome.input_idx, *self.genome.output_idx) + 2,
|
||||
generation=jnp.array(0.),
|
||||
next_node_key=jnp.array(max(*self.genome.input_idx, *self.genome.output_idx) + 2, dtype=jnp.float32),
|
||||
# inputs nodes, output nodes, 1 hidden node
|
||||
species=self.species.setup(k2),
|
||||
)
|
||||
|
||||
def ask(self, state: State):
|
||||
return self.species.ask(state)
|
||||
return self.species.ask(state.species)
|
||||
|
||||
def tell(self, state: State, fitness):
|
||||
k1, k2, randkey = jax.random.split(state.randkey, 3)
|
||||
@@ -40,25 +39,39 @@ class NEAT(BaseAlgorithm):
|
||||
randkey=randkey
|
||||
)
|
||||
|
||||
state, winner, loser, elite_mask = self.species.update_species(state, fitness, state.generation)
|
||||
species_state, winner, loser, elite_mask = self.species.update_species(state.species, fitness, state.generation)
|
||||
state = state.update(species=species_state)
|
||||
|
||||
state = self.create_next_generation(k2, state, winner, loser, elite_mask)
|
||||
|
||||
state = self.species.speciate(state, state.generation)
|
||||
|
||||
species_state = self.species.speciate(state.species, state.generation)
|
||||
state = state.update(species=species_state)
|
||||
return state
|
||||
|
||||
def transform(self, state: State):
|
||||
def transform(self, individual):
|
||||
"""transform the genome into a neural network"""
|
||||
raise NotImplementedError
|
||||
nodes, conns = individual
|
||||
return self.genome.transform(nodes, conns)
|
||||
|
||||
def forward(self, inputs, transformed):
|
||||
raise NotImplementedError
|
||||
return self.genome.forward(inputs, transformed)
|
||||
|
||||
@property
|
||||
def num_inputs(self):
|
||||
return self.genome.num_inputs
|
||||
|
||||
@property
|
||||
def num_outputs(self):
|
||||
return self.genome.num_outputs
|
||||
|
||||
@property
|
||||
def pop_size(self):
|
||||
return self.species.pop_size
|
||||
|
||||
def create_next_generation(self, randkey, state, winner, loser, elite_mask):
|
||||
# prepare random keys
|
||||
pop_size = self.species.pop_size
|
||||
new_node_keys = jnp.arange(pop_size) + state.species.next_node_key
|
||||
new_node_keys = jnp.arange(pop_size) + state.next_node_key
|
||||
|
||||
k1, k2 = jax.random.split(randkey, 2)
|
||||
crossover_rand_keys = jax.random.split(k1, pop_size)
|
||||
@@ -69,11 +82,11 @@ class NEAT(BaseAlgorithm):
|
||||
|
||||
# batch crossover
|
||||
n_nodes, n_conns = (jax.vmap(self.crossover, in_axes=(0, None, 0, 0, 0, 0))
|
||||
(crossover_rand_keys, self.genome, wpn, wpc, lpn, lpc))
|
||||
(crossover_rand_keys, self.genome, wpn, wpc, lpn, lpc))
|
||||
|
||||
# batch mutation
|
||||
m_n_nodes, m_n_conns = (jax.vmap(self.mutation, in_axes=(0, None, 0, 0, 0))
|
||||
(mutate_rand_keys, self.genome, n_nodes, n_conns, new_node_keys))
|
||||
(mutate_rand_keys, self.genome, n_nodes, n_conns, new_node_keys))
|
||||
|
||||
# elitism don't mutate
|
||||
pop_nodes = jnp.where(elite_mask[:, None, None], n_nodes, m_n_nodes)
|
||||
@@ -92,3 +105,9 @@ class NEAT(BaseAlgorithm):
|
||||
next_node_key=next_node_key,
|
||||
)
|
||||
|
||||
def member_count(self, state: State):
|
||||
return state.species.member_count
|
||||
|
||||
def generation(self, state: State):
|
||||
# to analysis the algorithm
|
||||
return state.generation
|
||||
|
||||
@@ -2,9 +2,10 @@ import numpy as np
|
||||
import jax, jax.numpy as jnp
|
||||
from utils import State, rank_elements, argmin_with_mask, fetch_first
|
||||
from ..genome import BaseGenome
|
||||
from .base import BaseSpecies
|
||||
|
||||
|
||||
class DefaultSpecies:
|
||||
class DefaultSpecies(BaseSpecies):
|
||||
|
||||
def __init__(self,
|
||||
genome: BaseGenome,
|
||||
@@ -18,9 +19,8 @@ class DefaultSpecies:
|
||||
genome_elitism: int = 2,
|
||||
survival_threshold: float = 0.2,
|
||||
min_species_size: int = 1,
|
||||
compatibility_threshold: float = 3.5
|
||||
compatibility_threshold: float = 3.
|
||||
):
|
||||
|
||||
self.genome = genome
|
||||
self.pop_size = pop_size
|
||||
self.species_size = species_size
|
||||
@@ -59,8 +59,12 @@ class DefaultSpecies:
|
||||
center_nodes = center_nodes.at[0].set(pop_nodes[0])
|
||||
center_conns = center_conns.at[0].set(pop_conns[0])
|
||||
|
||||
pop_nodes, pop_conns = jax.device_put((pop_nodes, pop_conns))
|
||||
|
||||
return State(
|
||||
randkey=randkey,
|
||||
pop_nodes=pop_nodes,
|
||||
pop_conns=pop_conns,
|
||||
species_keys=species_keys,
|
||||
best_fitness=best_fitness,
|
||||
last_improved=last_improved,
|
||||
@@ -68,7 +72,7 @@ class DefaultSpecies:
|
||||
idx2species=idx2species,
|
||||
center_nodes=center_nodes,
|
||||
center_conns=center_conns,
|
||||
next_species_key=1, # 0 is reserved for the first species
|
||||
next_species_key=jnp.array(1), # 0 is reserved for the first species
|
||||
)
|
||||
|
||||
def ask(self, state):
|
||||
@@ -99,7 +103,7 @@ class DefaultSpecies:
|
||||
# crossover info
|
||||
winner, loser, elite_mask = self.create_crossover_pair(state, k1, spawn_number, fitness)
|
||||
|
||||
return state(randkey=k2), winner, loser, elite_mask
|
||||
return state.update(randkey=k2), winner, loser, elite_mask
|
||||
|
||||
def update_species_fitness(self, state, fitness):
|
||||
"""
|
||||
@@ -156,17 +160,17 @@ class DefaultSpecies:
|
||||
jnp.nan, # last_improved
|
||||
jnp.nan, # member_count
|
||||
-jnp.inf, # species_fitness
|
||||
jnp.full_like(center_nodes[idx], jnp.nan), # center_nodes
|
||||
jnp.full_like(center_conns[idx], jnp.nan), # center_conns
|
||||
jnp.full_like(state.center_nodes[idx], jnp.nan), # center_nodes
|
||||
jnp.full_like(state.center_conns[idx], jnp.nan), # center_conns
|
||||
), # stagnation species
|
||||
lambda: (
|
||||
species_keys[idx],
|
||||
state.species_keys[idx],
|
||||
best_fitness[idx],
|
||||
last_improved[idx],
|
||||
state.member_count[idx],
|
||||
species_fitness[idx],
|
||||
center_nodes[idx],
|
||||
center_conns[idx]
|
||||
state.center_nodes[idx],
|
||||
state.center_conns[idx]
|
||||
) # not stagnation species
|
||||
)
|
||||
|
||||
@@ -216,7 +220,7 @@ class DefaultSpecies:
|
||||
spawn_number = spawn_number.astype(jnp.int32)
|
||||
|
||||
# must control the sum of spawn_number to be equal to pop_size
|
||||
error = state.P - jnp.sum(spawn_number)
|
||||
error = self.pop_size - jnp.sum(spawn_number)
|
||||
|
||||
# add error to the first species to control the sum of spawn_number
|
||||
spawn_number = spawn_number.at[0].add(error)
|
||||
@@ -287,14 +291,14 @@ class DefaultSpecies:
|
||||
def body_func(carry):
|
||||
i, i2s, cns, ccs, o2c = carry
|
||||
|
||||
distances = o2p_distance_func(cns, ccs, state.pop_nodes, state.pop_conns)
|
||||
distances = o2p_distance_func(cns[i], ccs[i], state.pop_nodes, state.pop_conns)
|
||||
|
||||
# find the closest one
|
||||
closest_idx = argmin_with_mask(distances, mask=jnp.isnan(i2s))
|
||||
|
||||
i2s = i2s.at[closest_idx].set(state.species_info.species_keys[i])
|
||||
cns = cns.set(i, state.pop_nodes[closest_idx])
|
||||
ccs = ccs.set(i, state.pop_conns[closest_idx])
|
||||
i2s = i2s.at[closest_idx].set(state.species_keys[i])
|
||||
cns = cns.at[i].set(state.pop_nodes[closest_idx])
|
||||
ccs = ccs.at[i].set(state.pop_conns[closest_idx])
|
||||
|
||||
# the genome with closest_idx will become the new center, thus its distance to center is 0.
|
||||
o2c = o2c.at[closest_idx].set(0)
|
||||
@@ -346,8 +350,8 @@ class DefaultSpecies:
|
||||
o2c = o2c.at[idx].set(0)
|
||||
|
||||
# update center genomes
|
||||
cns = cns.set(i, state.pop_nodes[idx])
|
||||
ccs = ccs.set(i, state.pop_conns[idx])
|
||||
cns = cns.at[i].set(state.pop_nodes[idx])
|
||||
ccs = ccs.at[i].set(state.pop_conns[idx])
|
||||
|
||||
# find the members for the new species
|
||||
i2s, o2c = speciate_by_threshold(i, i2s, cns, ccs, sk, o2c)
|
||||
@@ -384,7 +388,7 @@ class DefaultSpecies:
|
||||
_, idx2species, center_nodes, center_conns, species_keys, _, next_species_key = jax.lax.while_loop(
|
||||
cond_func,
|
||||
body_func,
|
||||
(0, state.idx2species, state.center_nodes, center_conns, state.species_info.species_keys, o2c_distances,
|
||||
(0, state.idx2species, center_nodes, center_conns, state.species_keys, o2c_distances,
|
||||
state.next_species_key)
|
||||
)
|
||||
|
||||
@@ -401,8 +405,8 @@ class DefaultSpecies:
|
||||
def count_members(idx):
|
||||
return jax.lax.cond(
|
||||
jnp.isnan(species_keys[idx]), # if the species is not existing
|
||||
lambda _: jnp.nan, # nan
|
||||
lambda _: jnp.sum(idx2species == species_keys[idx], dtype=jnp.float32) # count members
|
||||
lambda: jnp.nan, # nan
|
||||
lambda: jnp.sum(idx2species == species_keys[idx], dtype=jnp.float32) # count members
|
||||
)
|
||||
|
||||
member_count = jax.vmap(count_members)(self.species_arange)
|
||||
@@ -422,7 +426,8 @@ class DefaultSpecies:
|
||||
"""
|
||||
The distance between two genomes
|
||||
"""
|
||||
return self.node_distance(nodes1, nodes2) + self.conn_distance(conns1, conns2)
|
||||
d = self.node_distance(nodes1, nodes2) + self.conn_distance(conns1, conns2)
|
||||
return d
|
||||
|
||||
def node_distance(self, nodes1, nodes2):
|
||||
"""
|
||||
@@ -494,18 +499,18 @@ def initialize_population(pop_size, genome):
|
||||
o_nodes[input_idx, 0] = genome.input_idx
|
||||
o_nodes[output_idx, 0] = genome.output_idx
|
||||
o_nodes[new_node_key, 0] = new_node_key # one hidden node
|
||||
o_nodes[np.concatenate([input_idx, output_idx]), 1:] = genome.node_gene.new_attrs()
|
||||
o_nodes[new_node_key, 1:] = genome.node_gene.new_attrs() # one hidden node
|
||||
o_nodes[np.concatenate([input_idx, output_idx]), 1:] = genome.node_gene.new_custom_attrs()
|
||||
o_nodes[new_node_key, 1:] = genome.node_gene.new_custom_attrs() # one hidden node
|
||||
|
||||
input_conns = np.c_[input_idx, np.full_like(input_idx, new_node_key)] # input nodes to hidden
|
||||
o_conns[input_idx, 0:2] = input_conns # in key, out key
|
||||
o_conns[input_idx, 2] = True # enabled
|
||||
o_conns[input_idx, 3:] = genome.conn_gene.new_conn_attrs()
|
||||
o_conns[input_idx, 3:] = genome.conn_gene.new_custom_attrs()
|
||||
|
||||
output_conns = np.c_[np.full_like(output_idx, new_node_key), output_idx] # hidden to output nodes
|
||||
o_conns[output_idx, 0:2] = output_conns # in key, out key
|
||||
o_conns[output_idx, 2] = True # enabled
|
||||
o_conns[output_idx, 3:] = genome.conn_gene.new_conn_attrs()
|
||||
o_conns[output_idx, 3:] = genome.conn_gene.new_custom_attrs()
|
||||
|
||||
# repeat origin genome for P times to create population
|
||||
pop_nodes = np.tile(o_nodes, (pop_size, 1, 1))
|
||||
|
||||
Reference in New Issue
Block a user