use black format all files;

remove "return state" for functions which will be executed in vmap;
recover randkey as args in mutation methods
This commit is contained in:
wls2002
2024-05-26 15:46:04 +08:00
parent 79d53ea7af
commit cf69b916af
38 changed files with 932 additions and 582 deletions

View File

@@ -1,2 +1,2 @@
from .hyperneat import HyperNEAT
from .substrate import BaseSubstrate, DefaultSubstrate, FullSubstrate
from .substrate import BaseSubstrate, DefaultSubstrate, FullSubstrate

View File

@@ -10,20 +10,20 @@ from .substrate import *
class HyperNEAT(BaseAlgorithm):
def __init__(
self,
substrate: BaseSubstrate,
neat: NEAT,
below_threshold: float = 0.3,
max_weight: float = 5.,
activation=Act.sigmoid,
aggregation=Agg.sum,
activate_time: int = 10,
output_transform: Callable = Act.sigmoid,
self,
substrate: BaseSubstrate,
neat: NEAT,
below_threshold: float = 0.3,
max_weight: float = 5.0,
activation=Act.sigmoid,
aggregation=Agg.sum,
activate_time: int = 10,
output_transform: Callable = Act.sigmoid,
):
assert substrate.query_coors.shape[1] == neat.num_inputs, \
"Substrate input size should be equal to NEAT input size"
assert (
substrate.query_coors.shape[1] == neat.num_inputs
), "Substrate input size should be equal to NEAT input size"
self.substrate = substrate
self.neat = neat
@@ -37,39 +37,43 @@ class HyperNEAT(BaseAlgorithm):
node_gene=HyperNodeGene(activation, aggregation),
conn_gene=HyperNEATConnGene(),
activate_time=activate_time,
output_transform=output_transform
output_transform=output_transform,
)
def setup(self, randkey):
return State(
neat_state=self.neat.setup(randkey)
)
return State(neat_state=self.neat.setup(randkey))
def ask(self, state: State):
return self.neat.ask(state.neat_state)
def tell(self, state: State, fitness):
return state.update(
neat_state=self.neat.tell(state.neat_state, fitness)
)
return state.update(neat_state=self.neat.tell(state.neat_state, fitness))
def transform(self, individual):
transformed = self.neat.transform(individual)
query_res = jax.vmap(self.neat.forward, in_axes=(0, None))(self.substrate.query_coors, transformed)
query_res = jax.vmap(self.neat.forward, in_axes=(0, None))(
self.substrate.query_coors, transformed
)
# mute the connection with weight below threshold
query_res = jnp.where(
(-self.below_threshold < query_res) & (query_res < self.below_threshold),
0.,
query_res
0.0,
query_res,
)
# make query res in range [-max_weight, max_weight]
query_res = jnp.where(query_res > 0, query_res - self.below_threshold, query_res)
query_res = jnp.where(query_res < 0, query_res + self.below_threshold, query_res)
query_res = jnp.where(
query_res > 0, query_res - self.below_threshold, query_res
)
query_res = jnp.where(
query_res < 0, query_res + self.below_threshold, query_res
)
query_res = query_res / (1 - self.below_threshold) * self.max_weight
h_nodes, h_conns = self.substrate.make_nodes(query_res), self.substrate.make_conn(query_res)
h_nodes, h_conns = self.substrate.make_nodes(
query_res
), self.substrate.make_conn(query_res)
return self.hyper_genome.transform(h_nodes, h_conns)
def forward(self, inputs, transformed):
@@ -97,11 +101,11 @@ class HyperNEAT(BaseAlgorithm):
class HyperNodeGene(BaseNodeGene):
def __init__(self,
activation=Act.sigmoid,
aggregation=Agg.sum,
):
def __init__(
self,
activation=Act.sigmoid,
aggregation=Agg.sum,
):
super().__init__()
self.activation = activation
self.aggregation = aggregation
@@ -110,12 +114,12 @@ class HyperNodeGene(BaseNodeGene):
return jax.lax.cond(
is_output_node,
lambda: self.aggregation(inputs), # output node does not need activation
lambda: self.activation(self.aggregation(inputs))
lambda: self.activation(self.aggregation(inputs)),
)
class HyperNEATConnGene(BaseConnGene):
custom_attrs = ['weight']
custom_attrs = ["weight"]
def forward(self, attrs, inputs):
weight = attrs[0]

View File

@@ -1,5 +1,4 @@
class BaseSubstrate:
def make_nodes(self, query_res):
raise NotImplementedError

View File

@@ -3,7 +3,6 @@ from . import BaseSubstrate
class DefaultSubstrate(BaseSubstrate):
def __init__(self, num_inputs, num_outputs, coors, nodes, conns):
self.inputs = num_inputs
self.outputs = num_outputs

View File

@@ -3,20 +3,16 @@ from .default import DefaultSubstrate
class FullSubstrate(DefaultSubstrate):
def __init__(self,
input_coors=((-1, -1), (0, -1), (1, -1)),
hidden_coors=((-1, 0), (0, 0), (1, 0)),
output_coors=((0, 1),),
):
query_coors, nodes, conns = analysis_substrate(input_coors, output_coors, hidden_coors)
super().__init__(
len(input_coors),
len(output_coors),
query_coors,
nodes,
conns
def __init__(
self,
input_coors=((-1, -1), (0, -1), (1, -1)),
hidden_coors=((-1, 0), (0, 0), (1, 0)),
output_coors=((0, 1),),
):
query_coors, nodes, conns = analysis_substrate(
input_coors, output_coors, hidden_coors
)
super().__init__(len(input_coors), len(output_coors), query_coors, nodes, conns)
def analysis_substrate(input_coors, output_coors, hidden_coors):
@@ -38,22 +34,30 @@ def analysis_substrate(input_coors, output_coors, hidden_coors):
correspond_keys = np.zeros((total_conns, 2))
# connect input to hidden
aux_coors, aux_keys = cartesian_product(input_idx, hidden_idx, input_coors, hidden_coors)
query_coors[0: si * sh, :] = aux_coors
correspond_keys[0: si * sh, :] = aux_keys
aux_coors, aux_keys = cartesian_product(
input_idx, hidden_idx, input_coors, hidden_coors
)
query_coors[0 : si * sh, :] = aux_coors
correspond_keys[0 : si * sh, :] = aux_keys
# connect hidden to hidden
aux_coors, aux_keys = cartesian_product(hidden_idx, hidden_idx, hidden_coors, hidden_coors)
query_coors[si * sh: si * sh + sh * sh, :] = aux_coors
correspond_keys[si * sh: si * sh + sh * sh, :] = aux_keys
aux_coors, aux_keys = cartesian_product(
hidden_idx, hidden_idx, hidden_coors, hidden_coors
)
query_coors[si * sh : si * sh + sh * sh, :] = aux_coors
correspond_keys[si * sh : si * sh + sh * sh, :] = aux_keys
# connect hidden to output
aux_coors, aux_keys = cartesian_product(hidden_idx, output_idx, hidden_coors, output_coors)
query_coors[si * sh + sh * sh:, :] = aux_coors
correspond_keys[si * sh + sh * sh:, :] = aux_keys
aux_coors, aux_keys = cartesian_product(
hidden_idx, output_idx, hidden_coors, output_coors
)
query_coors[si * sh + sh * sh :, :] = aux_coors
correspond_keys[si * sh + sh * sh :, :] = aux_keys
nodes = np.concatenate((input_idx, output_idx, hidden_idx))[..., np.newaxis]
conns = np.zeros((correspond_keys.shape[0], 4), dtype=np.float32) # input_idx, output_idx, enabled, weight
conns = np.zeros(
(correspond_keys.shape[0], 4), dtype=np.float32
) # input_idx, output_idx, enabled, weight
conns[:, 0:2] = correspond_keys
conns[:, 2] = 1 # enabled is True