complete HyperNEAT!

This commit is contained in:
wls2002
2023-07-21 15:03:12 +08:00
parent 80ee5ea2ea
commit 48f90c7eef
32 changed files with 432 additions and 136 deletions

View File

@@ -1,3 +1,4 @@
from .base import Algorithm
from .state import State from .state import State
from .neat import NEAT from .neat import NEAT
from .config import Configer from .hyperneat import HyperNEAT

17
algorithm/base.py Normal file
View File

@@ -0,0 +1,17 @@
from typing import Callable
from .state import State
EMPTY = lambda *args: args
class Algorithm:
def __init__(self):
self.tell: Callable = EMPTY
self.ask: Callable = EMPTY
self.forward: Callable = EMPTY
self.forward_transform: Callable = EMPTY
def setup(self, randkey, state=State()):
pass

View File

@@ -0,0 +1,2 @@
from .hyperneat import HyperNEAT
from .substrate import BaseSubstrate

View File

@@ -0,0 +1,70 @@
from typing import Type
import jax
import numpy as np
from .substrate import BaseSubstrate, analysis_substrate
from .hyperneat_gene import HyperNEATGene
from algorithm import State, Algorithm, neat
class HyperNEAT(Algorithm):
def __init__(self, config, gene_type: Type[neat.BaseGene], substrate: Type[BaseSubstrate]):
super().__init__()
self.config = config
self.gene_type = gene_type
self.substrate = substrate
self.neat = neat.NEAT(config, gene_type)
self.tell = create_tell(self.neat)
self.forward_transform = create_forward_transform(config, self.neat)
self.forward = HyperNEATGene.create_forward(config)
def setup(self, randkey, state=State()):
state = state.update(
below_threshold=self.config['below_threshold'],
max_weight=self.config['max_weight']
)
state = self.substrate.setup(state, self.config)
h_input_idx, h_output_idx, h_hidden_idx, query_coors, correspond_keys = analysis_substrate(state)
h_nodes = np.concatenate((h_input_idx, h_output_idx, h_hidden_idx))[..., np.newaxis]
h_conns = np.zeros((correspond_keys.shape[0], 3), dtype=np.float32)
h_conns[:, 0:2] = correspond_keys
state = state.update(
# h is short for hyperneat
h_input_idx=h_input_idx,
h_output_idx=h_output_idx,
h_hidden_idx=h_hidden_idx,
query_coors=query_coors,
correspond_keys=correspond_keys,
h_nodes=h_nodes,
h_conns=h_conns
)
state = self.neat.setup(randkey, state=state)
self.config['h_input_idx'] = h_input_idx
self.config['h_output_idx'] = h_output_idx
return state
def create_tell(neat_instance):
def tell(state, fitness):
return neat_instance.tell(state, fitness)
return tell
def create_forward_transform(config, neat_instance):
def forward_transform(state, nodes, conns):
t = neat_instance.forward_transform(state, nodes, conns)
batch_forward_func = jax.vmap(neat_instance.forward, in_axes=(0, None))
query_res = batch_forward_func(state.query_coors, t) # hyperneat connections
h_nodes = state.h_nodes
h_conns = state.h_conns.at[:, 2:].set(query_res)
return HyperNEATGene.forward_transform(state, h_nodes, h_conns)
return forward_transform

View File

@@ -0,0 +1,54 @@
import jax
from jax import numpy as jnp, vmap
from algorithm.neat import BaseGene
from algorithm.neat.gene import Activation
from algorithm.neat.gene import Aggregation
class HyperNEATGene(BaseGene):
node_attrs = [] # no node attributes
conn_attrs = ['weight']
@staticmethod
def forward_transform(state, nodes, conns):
N = nodes.shape[0]
u_conns = jnp.zeros((N, N), dtype=jnp.float32)
in_keys = jnp.asarray(conns[:, 0], jnp.int32)
out_keys = jnp.asarray(conns[:, 1], jnp.int32)
weights = conns[:, 2]
u_conns = u_conns.at[in_keys, out_keys].set(weights)
return nodes, u_conns
@staticmethod
def create_forward(config):
act = Activation.name2func[config['h_activation']]
agg = Aggregation.name2func[config['h_aggregation']]
batch_act, batch_agg = vmap(act), vmap(agg)
def forward(inputs, transform):
inputs_with_bias = jnp.concatenate((inputs, jnp.ones((1,))), axis=0)
nodes, weights = transform
input_idx = config['h_input_idx']
output_idx = config['h_output_idx']
N = nodes.shape[0]
vals = jnp.full((N,), 0.)
def body_func(i, values):
values = values.at[input_idx].set(inputs_with_bias)
nodes_ins = values * weights.T
values = batch_agg(nodes_ins) # z = agg(ins)
values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias
values = batch_act(values) # z = act(z)
return values
vals = jax.lax.fori_loop(0, config['h_activate_times'], body_func, vals)
return vals[output_idx]
return forward

View File

@@ -0,0 +1,2 @@
from .base import BaseSubstrate
from .tools import analysis_substrate

View File

@@ -0,0 +1,12 @@
import numpy as np
class BaseSubstrate:
@staticmethod
def setup(state, config):
return state.update(
input_coors=np.asarray(config['input_coors'], dtype=np.float32),
output_coors=np.asarray(config['output_coors'], dtype=np.float32),
hidden_coors=np.asarray(config['hidden_coors'], dtype=np.float32),
)

View File

@@ -0,0 +1,53 @@
from typing import Type
import numpy as np
from .base import BaseSubstrate
def analysis_substrate(state):
cd = state.input_coors.shape[1] # coordinate dimensions
si = state.input_coors.shape[0] # input coordinate size
so = state.output_coors.shape[0] # output coordinate size
sh = state.hidden_coors.shape[0] # hidden coordinate size
input_idx = np.arange(si)
output_idx = np.arange(si, si + so)
hidden_idx = np.arange(si + so, si + so + sh)
total_conns = si * sh + sh * sh + sh * so
query_coors = np.zeros((total_conns, cd * 2))
correspond_keys = np.zeros((total_conns, 2))
# connect input to hidden
aux_coors, aux_keys = cartesian_product(input_idx, hidden_idx, state.input_coors, state.hidden_coors)
query_coors[0: si * sh, :] = aux_coors
correspond_keys[0: si * sh, :] = aux_keys
# connect hidden to hidden
aux_coors, aux_keys = cartesian_product(hidden_idx, hidden_idx, state.hidden_coors, state.hidden_coors)
query_coors[si * sh: si * sh + sh * sh, :] = aux_coors
correspond_keys[si * sh: si * sh + sh * sh, :] = aux_keys
# connect hidden to output
aux_coors, aux_keys = cartesian_product(hidden_idx, output_idx, state.hidden_coors, state.output_coors)
query_coors[si * sh + sh * sh:, :] = aux_coors
correspond_keys[si * sh + sh * sh:, :] = aux_keys
return input_idx, output_idx, hidden_idx, query_coors, correspond_keys
def cartesian_product(keys1, keys2, coors1, coors2):
len1 = keys1.shape[0]
len2 = keys2.shape[0]
repeated_coors1 = np.repeat(coors1, len2, axis=0)
repeated_keys1 = np.repeat(keys1, len2)
tiled_coors2 = np.tile(coors2, (len1, 1))
tiled_keys2 = np.tile(keys2, len1)
new_coors = np.concatenate((repeated_coors1, tiled_coors2), axis=1)
correspond_keys = np.column_stack((repeated_keys1, tiled_keys2))
return new_coors, correspond_keys

View File

@@ -1,3 +1,2 @@
from .neat import NEAT from .neat import NEAT
from .gene import NormalGene, RecurrentGene from .gene import BaseGene, NormalGene, RecurrentGene
from .pipeline import Pipeline

View File

@@ -33,12 +33,10 @@ class BaseGene:
def distance_conn(state, conn1: Array, conn2: Array): def distance_conn(state, conn1: Array, conn2: Array):
return conn1 return conn1
@staticmethod @staticmethod
def forward_transform(nodes, conns): def forward_transform(state, nodes, conns):
return nodes, conns return nodes, conns
@staticmethod @staticmethod
def create_forward(config): def create_forward(config):
return None return None

View File

@@ -4,7 +4,7 @@ from jax import Array, numpy as jnp
from .base import BaseGene from .base import BaseGene
from .activation import Activation from .activation import Activation
from .aggregation import Aggregation from .aggregation import Aggregation
from ..utils import unflatten_connections, I_INT from algorithm.utils import unflatten_connections, I_INT
from ..genome import topological_sort from ..genome import topological_sort
@@ -84,7 +84,7 @@ class NormalGene(BaseGene):
return (con1[2] != con2[2]) + jnp.abs(con1[3] - con2[3]) # enable + weight return (con1[2] != con2[2]) + jnp.abs(con1[3] - con2[3]) # enable + weight
@staticmethod @staticmethod
def forward_transform(nodes, conns): def forward_transform(state, nodes, conns):
u_conns = unflatten_connections(nodes, conns) u_conns = unflatten_connections(nodes, conns)
conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False) conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False)

View File

@@ -4,13 +4,13 @@ from jax import Array, numpy as jnp, vmap
from .normal import NormalGene from .normal import NormalGene
from .activation import Activation from .activation import Activation
from .aggregation import Aggregation from .aggregation import Aggregation
from ..utils import unflatten_connections, I_INT from algorithm.utils import unflatten_connections
class RecurrentGene(NormalGene): class RecurrentGene(NormalGene):
@staticmethod @staticmethod
def forward_transform(nodes, conns): def forward_transform(state, nodes, conns):
u_conns = unflatten_connections(nodes, conns) u_conns = unflatten_connections(nodes, conns)
# remove un-enable connections and remove enable attr # remove un-enable connections and remove enable attr

View File

@@ -6,7 +6,7 @@ from jax import Array, numpy as jnp
from algorithm import State from algorithm import State
from ..gene import BaseGene from ..gene import BaseGene
from ..utils import fetch_first from algorithm.utils import fetch_first
def initialize_genomes(state: State, gene_type: Type[BaseGene]): def initialize_genomes(state: State, gene_type: Type[BaseGene]):
@@ -48,6 +48,7 @@ def count(nodes: Array, cons: Array):
cons_cnt = jnp.sum(~jnp.isnan(cons[:, 0])) cons_cnt = jnp.sum(~jnp.isnan(cons[:, 0]))
return node_cnt, cons_cnt return node_cnt, cons_cnt
def add_node(nodes: Array, cons: Array, new_key: int, attrs: Array) -> Tuple[Array, Array]: def add_node(nodes: Array, cons: Array, new_key: int, attrs: Array) -> Tuple[Array, Array]:
""" """
Add a new node to the genome. Add a new node to the genome.

View File

@@ -6,7 +6,7 @@ Only used in feed-forward networks.
import jax import jax
from jax import jit, Array, numpy as jnp from jax import jit, Array, numpy as jnp
from ..utils import fetch_first, I_INT from algorithm.utils import fetch_first, I_INT
@jit @jit

View File

@@ -4,9 +4,9 @@ import jax
from jax import Array, numpy as jnp, vmap from jax import Array, numpy as jnp, vmap
from algorithm import State from algorithm import State
from .basic import add_node, add_connection, delete_node_by_idx, delete_connection_by_idx, count from .basic import add_node, add_connection, delete_node_by_idx, delete_connection_by_idx
from .graph import check_cycles from .graph import check_cycles
from ..utils import fetch_random, fetch_first, I_INT, unflatten_connections from algorithm.utils import fetch_random, fetch_first, I_INT, unflatten_connections
from ..gene import BaseGene from ..gene import BaseGene

View File

@@ -3,22 +3,25 @@ from typing import Type
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from algorithm.state import State from algorithm import Algorithm, State
from .gene import BaseGene from .gene import BaseGene
from .genome import initialize_genomes from .genome import initialize_genomes
from .population import create_tell from .population import create_tell
class NEAT: class NEAT(Algorithm):
def __init__(self, config, gene_type: Type[BaseGene]): def __init__(self, config, gene_type: Type[BaseGene]):
super().__init__()
self.config = config self.config = config
self.gene_type = gene_type self.gene_type = gene_type
self.tell_func = jax.jit(create_tell(config, self.gene_type)) self.tell = create_tell(config, self.gene_type)
self.ask = None
self.forward = self.gene_type.create_forward(config)
self.forward_transform = self.gene_type.forward_transform
def setup(self, randkey): def setup(self, randkey, state=State()):
state = state.update(
state = State(
P=self.config['pop_size'], P=self.config['pop_size'],
N=self.config['maximum_nodes'], N=self.config['maximum_nodes'],
C=self.config['maximum_conns'], C=self.config['maximum_conns'],
@@ -70,6 +73,3 @@ class NEAT:
state = jax.device_put(state) state = jax.device_put(state)
return state return state
def step(self, state, fitness):
return self.tell_func(state, fitness)

View File

@@ -3,7 +3,7 @@ from typing import Type
import jax import jax
from jax import numpy as jnp, vmap from jax import numpy as jnp, vmap
from .utils import rank_elements, fetch_first from algorithm.utils import rank_elements, fetch_first
from .genome import create_mutate, create_distance, crossover from .genome import create_mutate, create_distance, crossover
from .gene import BaseGene from .gene import BaseGene

1
config/__init__.py Normal file
View File

@@ -0,0 +1 @@
from .config import Configer

View File

@@ -4,6 +4,7 @@ import configparser
import numpy as np import numpy as np
class Configer: class Configer:
@classmethod @classmethod
@@ -28,7 +29,7 @@ class Configer:
def __check_redundant_config(cls, default_config, config): def __check_redundant_config(cls, default_config, config):
for key in config: for key in config:
if key not in default_config: if key not in default_config:
warnings.warn(f"Redundant config: {key} in {config.name}") warnings.warn(f"Redundant config: {key} in config!")
@classmethod @classmethod
def __complete_config(cls, default_config, config): def __complete_config(cls, default_config, config):

View File

@@ -1,26 +1,38 @@
[basic] [basic]
random_seed = 0
generation_limit = 1000
[problem]
fitness_threshold = 3.9999
num_inputs = 2 num_inputs = 2
num_outputs = 1 num_outputs = 1
maximum_nodes = 50
maximum_conns = 100 [neat]
maximum_species = 10
forward_way = "pop"
batch_size = 4
random_seed = 0
network_type = "feedforward" network_type = "feedforward"
activate_times = 10 activate_times = 5
maximum_nodes = 50
maximum_conns = 50
maximum_species = 10
[hyperneat]
below_threshold = 0.2
max_weight = 3
h_activation = "sigmoid"
h_aggregation = "sum"
h_activate_times = 5
[substrate]
input_coors = [[-1, 1], [0, 1], [1, 1]]
hidden_coors = [[-1, 0], [0, 0], [1, 0]]
output_coors = [[0, -1]]
[population] [population]
fitness_threshold = 3.9999 pop_size = 10
generation_limit = 1000
fitness_criterion = "max"
pop_size = 50000
[genome] [genome]
compatibility_disjoint = 1.0 compatibility_disjoint = 1.0
compatibility_weight = 0.5 compatibility_weight = 0.5
conn_add_prob = 0.4 conn_add_prob = 0.4
conn_add_trials = 1
conn_delete_prob = 0 conn_delete_prob = 0
node_add_prob = 0.2 node_add_prob = 0.2
node_delete_prob = 0 node_delete_prob = 0
@@ -34,39 +46,37 @@ survival_threshold = 0.2
min_species_size = 1 min_species_size = 1
spawn_number_change_rate = 0.5 spawn_number_change_rate = 0.5
[gene-bias] [gene]
# bias
bias_init_mean = 0.0 bias_init_mean = 0.0
bias_init_std = 1.0 bias_init_std = 1.0
bias_mutate_power = 0.5 bias_mutate_power = 0.5
bias_mutate_rate = 0.7 bias_mutate_rate = 0.7
bias_replace_rate = 0.1 bias_replace_rate = 0.1
[gene-response] # response
response_init_mean = 1.0 response_init_mean = 1.0
response_init_std = 0.0 response_init_std = 0.0
response_mutate_power = 0.0 response_mutate_power = 0.0
response_mutate_rate = 0.0 response_mutate_rate = 0.0
response_replace_rate = 0.0 response_replace_rate = 0.0
[gene-activation] # activation
activation_default = "sigmoid" activation_default = "sigmoid"
activation_option_names = ["sigmoid"] activation_option_names = ["tanh"]
activation_replace_rate = 0.0 activation_replace_rate = 0.0
[gene-aggregation] # aggregation
aggregation_default = "sum" aggregation_default = "sum"
aggregation_option_names = ["sum"] aggregation_option_names = ["sum"]
aggregation_replace_rate = 0.0 aggregation_replace_rate = 0.0
[gene-weight] # weight
weight_init_mean = 0.0 weight_init_mean = 0.0
weight_init_std = 1.0 weight_init_std = 1.0
weight_mutate_power = 0.5 weight_mutate_power = 0.5
weight_mutate_rate = 0.8 weight_mutate_rate = 0.8
weight_replace_rate = 0.1 weight_replace_rate = 0.1
[gene-enable]
enable_mutate_rate = 0.01
[visualize] [visualize]
renumber_nodes = True renumber_nodes = True

11
examples/a.py Normal file
View File

@@ -0,0 +1,11 @@
import numpy as np
import jax.numpy as jnp
a = jnp.zeros((5, 5))
k1 = jnp.array([1, 2, 3])
k2 = jnp.array([2, 3, 4])
v = jnp.array([1, 1, 1])
a = a.at[k1, k2].set(v)
print(a)

View File

@@ -1,13 +1,44 @@
import numpy as np import jax
import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class
vals = np.array([1, 2]) @register_pytree_node_class
weights = np.array([[0, 4], [5, 0]]) class Genome:
def __init__(self, nodes, conns):
self.nodes = nodes
self.conns = conns
ins1 = vals * weights[:, 0] def update_nodes(self, nodes):
ins2 = vals * weights[:, 1] return Genome(nodes, self.conns)
ins_all = vals * weights.T
print(ins1) def update_conns(self, conns):
print(ins2) return Genome(self.nodes, conns)
print(ins_all)
def tree_flatten(self):
children = self.nodes, self.conns
aux_data = None
return children, aux_data
@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(*children)
def __repr__(self):
return f"Genome ({self.nodes}, \n\t{self.conns})"
@jax.jit
def add_node(self, a: int):
nodes = self.nodes.at[0, :].set(a)
return self.update_nodes(nodes)
nodes, conns = jnp.array([[1, 2, 3, 4, 5]]), jnp.array([[1, 2, 3, 4]])
g = Genome(nodes, conns)
print(g)
g = g.add_node(1)
print(g)
g = jax.jit(g.add_node)(2)
print(g)

View File

@@ -1,15 +0,0 @@
import jax
from jax import numpy as jnp
from algorithm.state import State
@jax.jit
def func(state: State, a):
return state.update(a=a)
state = State(c=1, b=2)
print(state)
vmap_func = jax.vmap(func, in_axes=(None, 0))
print(vmap_func(state, jnp.array([1, 2, 3])))

View File

@@ -1,7 +1,12 @@
[basic] [basic]
forward_way = "common"
network_type = "recurrent"
activate_times = 5 activate_times = 5
fitness_threshold = 4
[population] [population]
fitness_threshold = 4 pop_size = 1000
[neat]
network_type = "recurrent"
num_inputs = 4
num_outputs = 1

View File

@@ -1,8 +1,10 @@
import jax import jax
import numpy as np import numpy as np
from algorithm import Configer, NEAT from pipeline import Pipeline
from algorithm.neat import NormalGene, RecurrentGene, Pipeline from config import Configer
from algorithm import NEAT
from algorithm.neat import RecurrentGene
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32) xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32) xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
@@ -21,7 +23,6 @@ def evaluate(forward_func):
def main(): def main():
config = Configer.load_config("xor.ini") config = Configer.load_config("xor.ini")
# algorithm = NEAT(config, NormalGene)
algorithm = NEAT(config, RecurrentGene) algorithm = NEAT(config, RecurrentGene)
pipeline = Pipeline(config, algorithm) pipeline = Pipeline(config, algorithm)
best = pipeline.auto_run(evaluate) best = pipeline.auto_run(evaluate)

33
examples/xor_hyperneat.py Normal file
View File

@@ -0,0 +1,33 @@
import jax
import numpy as np
from pipeline import Pipeline
from config import Configer
from algorithm import NEAT, HyperNEAT
from algorithm.neat import RecurrentGene
from algorithm.hyperneat import BaseSubstrate
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
def evaluate(forward_func):
"""
:param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size)
:return:
"""
outs = forward_func(xor_inputs)
outs = jax.device_get(outs)
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
return fitnesses
def main():
config = Configer.load_config("xor.ini")
algorithm = HyperNEAT(config, RecurrentGene, BaseSubstrate)
pipeline = Pipeline(config, algorithm)
pipeline.auto_run(evaluate)
if __name__ == '__main__':
main()

View File

@@ -1,50 +0,0 @@
import jax
import numpy as np
from algorithm.config import Configer
from algorithm.neat import NEAT, NormalGene, RecurrentGene, Pipeline
from algorithm.neat.genome import create_mutate
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
def single_genome(func, nodes, conns):
t = RecurrentGene.forward_transform(nodes, conns)
out1 = func(xor_inputs[0], t)
out2 = func(xor_inputs[1], t)
out3 = func(xor_inputs[2], t)
out4 = func(xor_inputs[3], t)
print(out1, out2, out3, out4)
def batch_genome(func, nodes, conns):
t = NormalGene.forward_transform(nodes, conns)
out = jax.vmap(func, in_axes=(0, None))(xor_inputs, t)
print(out)
def pop_batch_genome(func, pop_nodes, pop_conns):
t = jax.vmap(NormalGene.forward_transform)(pop_nodes, pop_conns)
func = jax.vmap(jax.vmap(func, in_axes=(0, None)), in_axes=(None, 0))
out = func(xor_inputs, t)
print(out)
if __name__ == '__main__':
config = Configer.load_config("xor.ini")
# neat = NEAT(config, NormalGene)
neat = NEAT(config, RecurrentGene)
randkey = jax.random.PRNGKey(42)
state = neat.setup(randkey)
forward_func = RecurrentGene.create_forward(config)
mutate_func = create_mutate(config, RecurrentGene)
nodes, conns = state.pop_nodes[0], state.pop_conns[0]
single_genome(forward_func, nodes, conns)
# batch_genome(forward_func, nodes, conns)
nodes, conns = mutate_func(state, randkey, nodes, conns, 10000)
single_genome(forward_func, nodes, conns)
# batch_genome(forward_func, nodes, conns)
#

View File

@@ -5,15 +5,18 @@ import jax
from jax import vmap, jit from jax import vmap, jit
import numpy as np import numpy as np
from algorithm import Algorithm
class Pipeline: class Pipeline:
""" """
Neat algorithm pipeline. Neat algorithm pipeline.
""" """
def __init__(self, config, algorithm): def __init__(self, config, algorithm: Algorithm):
self.config = config self.config = config
self.algorithm = algorithm self.algorithm = algorithm
randkey = jax.random.PRNGKey(config['random_seed']) randkey = jax.random.PRNGKey(config['random_seed'])
self.state = algorithm.setup(randkey) self.state = algorithm.setup(randkey)
@@ -23,18 +26,18 @@ class Pipeline:
self.evaluate_time = 0 self.evaluate_time = 0
self.forward_func = algorithm.gene_type.create_forward(config) self.forward_func = jit(self.algorithm.forward)
self.batch_forward_func = jit(vmap(self.forward_func, in_axes=(0, None))) self.batch_forward_func = jit(vmap(self.forward_func, in_axes=(0, None)))
self.pop_batch_forward_func = jit(vmap(self.batch_forward_func, in_axes=(None, 0))) self.pop_batch_forward_func = jit(vmap(self.batch_forward_func, in_axes=(None, 0)))
self.forward_transform_func = jit(vmap(self.algorithm.forward_transform, in_axes=(None, 0, 0)))
self.pop_transform_func = jit(vmap(algorithm.gene_type.forward_transform)) self.tell_func = jit(self.algorithm.tell)
def ask(self): def ask(self):
pop_transforms = self.pop_transform_func(self.state.pop_nodes, self.state.pop_conns) pop_transforms = self.forward_transform_func(self.state, self.state.pop_nodes, self.state.pop_conns)
return lambda inputs: self.pop_batch_forward_func(inputs, pop_transforms) return lambda inputs: self.pop_batch_forward_func(inputs, pop_transforms)
def tell(self, fitness): def tell(self, fitness):
self.state = self.algorithm.step(self.state, fitness) self.state = self.tell_func(self.state, fitness)
def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"): def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"):
for _ in range(self.config['generation_limit']): for _ in range(self.config['generation_limit']):

View File

@@ -0,0 +1,56 @@
import numpy as np
from algorithm.hyperneat.substrate.tools import cartesian_product
def test01():
keys1 = np.array([1, 2, 3])
keys2 = np.array([4, 5, 6, 7])
coors1 = np.array([
[1, 1, 1],
[2, 2, 2],
[3, 3, 3]
])
coors2 = np.array([
[4, 4, 4],
[5, 5, 5],
[6, 6, 6],
[7, 7, 7]
])
target_coors = np.array([
[1, 1, 1, 4, 4, 4],
[1, 1, 1, 5, 5, 5],
[1, 1, 1, 6, 6, 6],
[1, 1, 1, 7, 7, 7],
[2, 2, 2, 4, 4, 4],
[2, 2, 2, 5, 5, 5],
[2, 2, 2, 6, 6, 6],
[2, 2, 2, 7, 7, 7],
[3, 3, 3, 4, 4, 4],
[3, 3, 3, 5, 5, 5],
[3, 3, 3, 6, 6, 6],
[3, 3, 3, 7, 7, 7]
])
target_keys = np.array([
[1, 4],
[1, 5],
[1, 6],
[1, 7],
[2, 4],
[2, 5],
[2, 6],
[2, 7],
[3, 4],
[3, 5],
[3, 6],
[3, 7]
])
new_coors, correspond_keys = cartesian_product(keys1, keys2, coors1, coors2)
assert np.array_equal(new_coors, target_coors)
assert np.array_equal(correspond_keys, target_keys)

View File

@@ -1,7 +1,7 @@
import jax.numpy as jnp import jax.numpy as jnp
from algorithm.neat.genome.graph import topological_sort, check_cycles from algorithm.neat.genome.graph import topological_sort, check_cycles
from algorithm.neat.utils import I_INT from algorithm.utils import I_INT
nodes = jnp.array([ nodes = jnp.array([
[0], [0],

View File

@@ -1,5 +1,5 @@
import jax.numpy as jnp import jax.numpy as jnp
from algorithm.neat.utils import unflatten_connections from algorithm.utils import unflatten_connections
def test_unflatten(): def test_unflatten():