complete HyperNEAT!
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
from .base import Algorithm
|
||||
from .state import State
|
||||
from .neat import NEAT
|
||||
from .config import Configer
|
||||
from .hyperneat import HyperNEAT
|
||||
|
||||
17
algorithm/base.py
Normal file
17
algorithm/base.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from typing import Callable
|
||||
|
||||
from .state import State
|
||||
|
||||
EMPTY = lambda *args: args
|
||||
|
||||
|
||||
class Algorithm:
|
||||
|
||||
def __init__(self):
|
||||
self.tell: Callable = EMPTY
|
||||
self.ask: Callable = EMPTY
|
||||
self.forward: Callable = EMPTY
|
||||
self.forward_transform: Callable = EMPTY
|
||||
|
||||
def setup(self, randkey, state=State()):
|
||||
pass
|
||||
@@ -0,0 +1,2 @@
|
||||
from .hyperneat import HyperNEAT
|
||||
from .substrate import BaseSubstrate
|
||||
|
||||
70
algorithm/hyperneat/hyperneat.py
Normal file
70
algorithm/hyperneat/hyperneat.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from typing import Type
|
||||
|
||||
import jax
|
||||
import numpy as np
|
||||
|
||||
from .substrate import BaseSubstrate, analysis_substrate
|
||||
from .hyperneat_gene import HyperNEATGene
|
||||
from algorithm import State, Algorithm, neat
|
||||
|
||||
|
||||
class HyperNEAT(Algorithm):
|
||||
|
||||
def __init__(self, config, gene_type: Type[neat.BaseGene], substrate: Type[BaseSubstrate]):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.gene_type = gene_type
|
||||
self.substrate = substrate
|
||||
self.neat = neat.NEAT(config, gene_type)
|
||||
|
||||
self.tell = create_tell(self.neat)
|
||||
self.forward_transform = create_forward_transform(config, self.neat)
|
||||
self.forward = HyperNEATGene.create_forward(config)
|
||||
|
||||
def setup(self, randkey, state=State()):
|
||||
state = state.update(
|
||||
below_threshold=self.config['below_threshold'],
|
||||
max_weight=self.config['max_weight']
|
||||
)
|
||||
|
||||
state = self.substrate.setup(state, self.config)
|
||||
h_input_idx, h_output_idx, h_hidden_idx, query_coors, correspond_keys = analysis_substrate(state)
|
||||
h_nodes = np.concatenate((h_input_idx, h_output_idx, h_hidden_idx))[..., np.newaxis]
|
||||
h_conns = np.zeros((correspond_keys.shape[0], 3), dtype=np.float32)
|
||||
h_conns[:, 0:2] = correspond_keys
|
||||
|
||||
state = state.update(
|
||||
# h is short for hyperneat
|
||||
h_input_idx=h_input_idx,
|
||||
h_output_idx=h_output_idx,
|
||||
h_hidden_idx=h_hidden_idx,
|
||||
query_coors=query_coors,
|
||||
correspond_keys=correspond_keys,
|
||||
h_nodes=h_nodes,
|
||||
h_conns=h_conns
|
||||
)
|
||||
state = self.neat.setup(randkey, state=state)
|
||||
|
||||
self.config['h_input_idx'] = h_input_idx
|
||||
self.config['h_output_idx'] = h_output_idx
|
||||
|
||||
return state
|
||||
|
||||
|
||||
def create_tell(neat_instance):
|
||||
def tell(state, fitness):
|
||||
return neat_instance.tell(state, fitness)
|
||||
|
||||
return tell
|
||||
|
||||
|
||||
def create_forward_transform(config, neat_instance):
|
||||
def forward_transform(state, nodes, conns):
|
||||
t = neat_instance.forward_transform(state, nodes, conns)
|
||||
batch_forward_func = jax.vmap(neat_instance.forward, in_axes=(0, None))
|
||||
query_res = batch_forward_func(state.query_coors, t) # hyperneat connections
|
||||
h_nodes = state.h_nodes
|
||||
h_conns = state.h_conns.at[:, 2:].set(query_res)
|
||||
return HyperNEATGene.forward_transform(state, h_nodes, h_conns)
|
||||
|
||||
return forward_transform
|
||||
54
algorithm/hyperneat/hyperneat_gene.py
Normal file
54
algorithm/hyperneat/hyperneat_gene.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import jax
|
||||
from jax import numpy as jnp, vmap
|
||||
|
||||
from algorithm.neat import BaseGene
|
||||
from algorithm.neat.gene import Activation
|
||||
from algorithm.neat.gene import Aggregation
|
||||
|
||||
|
||||
class HyperNEATGene(BaseGene):
|
||||
node_attrs = [] # no node attributes
|
||||
conn_attrs = ['weight']
|
||||
|
||||
@staticmethod
|
||||
def forward_transform(state, nodes, conns):
|
||||
N = nodes.shape[0]
|
||||
u_conns = jnp.zeros((N, N), dtype=jnp.float32)
|
||||
|
||||
in_keys = jnp.asarray(conns[:, 0], jnp.int32)
|
||||
out_keys = jnp.asarray(conns[:, 1], jnp.int32)
|
||||
weights = conns[:, 2]
|
||||
|
||||
u_conns = u_conns.at[in_keys, out_keys].set(weights)
|
||||
return nodes, u_conns
|
||||
|
||||
@staticmethod
|
||||
def create_forward(config):
|
||||
act = Activation.name2func[config['h_activation']]
|
||||
agg = Aggregation.name2func[config['h_aggregation']]
|
||||
|
||||
batch_act, batch_agg = vmap(act), vmap(agg)
|
||||
|
||||
def forward(inputs, transform):
|
||||
|
||||
inputs_with_bias = jnp.concatenate((inputs, jnp.ones((1,))), axis=0)
|
||||
nodes, weights = transform
|
||||
|
||||
input_idx = config['h_input_idx']
|
||||
output_idx = config['h_output_idx']
|
||||
|
||||
N = nodes.shape[0]
|
||||
vals = jnp.full((N,), 0.)
|
||||
|
||||
def body_func(i, values):
|
||||
values = values.at[input_idx].set(inputs_with_bias)
|
||||
nodes_ins = values * weights.T
|
||||
values = batch_agg(nodes_ins) # z = agg(ins)
|
||||
values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias
|
||||
values = batch_act(values) # z = act(z)
|
||||
return values
|
||||
|
||||
vals = jax.lax.fori_loop(0, config['h_activate_times'], body_func, vals)
|
||||
return vals[output_idx]
|
||||
|
||||
return forward
|
||||
2
algorithm/hyperneat/substrate/__init__.py
Normal file
2
algorithm/hyperneat/substrate/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .base import BaseSubstrate
|
||||
from .tools import analysis_substrate
|
||||
12
algorithm/hyperneat/substrate/base.py
Normal file
12
algorithm/hyperneat/substrate/base.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
class BaseSubstrate:
|
||||
|
||||
@staticmethod
|
||||
def setup(state, config):
|
||||
return state.update(
|
||||
input_coors=np.asarray(config['input_coors'], dtype=np.float32),
|
||||
output_coors=np.asarray(config['output_coors'], dtype=np.float32),
|
||||
hidden_coors=np.asarray(config['hidden_coors'], dtype=np.float32),
|
||||
)
|
||||
53
algorithm/hyperneat/substrate/tools.py
Normal file
53
algorithm/hyperneat/substrate/tools.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from typing import Type
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .base import BaseSubstrate
|
||||
|
||||
|
||||
def analysis_substrate(state):
|
||||
cd = state.input_coors.shape[1] # coordinate dimensions
|
||||
si = state.input_coors.shape[0] # input coordinate size
|
||||
so = state.output_coors.shape[0] # output coordinate size
|
||||
sh = state.hidden_coors.shape[0] # hidden coordinate size
|
||||
|
||||
input_idx = np.arange(si)
|
||||
output_idx = np.arange(si, si + so)
|
||||
hidden_idx = np.arange(si + so, si + so + sh)
|
||||
|
||||
total_conns = si * sh + sh * sh + sh * so
|
||||
query_coors = np.zeros((total_conns, cd * 2))
|
||||
correspond_keys = np.zeros((total_conns, 2))
|
||||
|
||||
# connect input to hidden
|
||||
aux_coors, aux_keys = cartesian_product(input_idx, hidden_idx, state.input_coors, state.hidden_coors)
|
||||
query_coors[0: si * sh, :] = aux_coors
|
||||
correspond_keys[0: si * sh, :] = aux_keys
|
||||
|
||||
# connect hidden to hidden
|
||||
aux_coors, aux_keys = cartesian_product(hidden_idx, hidden_idx, state.hidden_coors, state.hidden_coors)
|
||||
query_coors[si * sh: si * sh + sh * sh, :] = aux_coors
|
||||
correspond_keys[si * sh: si * sh + sh * sh, :] = aux_keys
|
||||
|
||||
# connect hidden to output
|
||||
aux_coors, aux_keys = cartesian_product(hidden_idx, output_idx, state.hidden_coors, state.output_coors)
|
||||
query_coors[si * sh + sh * sh:, :] = aux_coors
|
||||
correspond_keys[si * sh + sh * sh:, :] = aux_keys
|
||||
|
||||
return input_idx, output_idx, hidden_idx, query_coors, correspond_keys
|
||||
|
||||
|
||||
def cartesian_product(keys1, keys2, coors1, coors2):
|
||||
len1 = keys1.shape[0]
|
||||
len2 = keys2.shape[0]
|
||||
|
||||
repeated_coors1 = np.repeat(coors1, len2, axis=0)
|
||||
repeated_keys1 = np.repeat(keys1, len2)
|
||||
|
||||
tiled_coors2 = np.tile(coors2, (len1, 1))
|
||||
tiled_keys2 = np.tile(keys2, len1)
|
||||
|
||||
new_coors = np.concatenate((repeated_coors1, tiled_coors2), axis=1)
|
||||
correspond_keys = np.column_stack((repeated_keys1, tiled_keys2))
|
||||
|
||||
return new_coors, correspond_keys
|
||||
@@ -1,3 +1,2 @@
|
||||
from .neat import NEAT
|
||||
from .gene import NormalGene, RecurrentGene
|
||||
from .pipeline import Pipeline
|
||||
from .gene import BaseGene, NormalGene, RecurrentGene
|
||||
|
||||
@@ -33,12 +33,10 @@ class BaseGene:
|
||||
def distance_conn(state, conn1: Array, conn2: Array):
|
||||
return conn1
|
||||
|
||||
|
||||
@staticmethod
|
||||
def forward_transform(nodes, conns):
|
||||
def forward_transform(state, nodes, conns):
|
||||
return nodes, conns
|
||||
|
||||
|
||||
@staticmethod
|
||||
def create_forward(config):
|
||||
return None
|
||||
@@ -4,7 +4,7 @@ from jax import Array, numpy as jnp
|
||||
from .base import BaseGene
|
||||
from .activation import Activation
|
||||
from .aggregation import Aggregation
|
||||
from ..utils import unflatten_connections, I_INT
|
||||
from algorithm.utils import unflatten_connections, I_INT
|
||||
from ..genome import topological_sort
|
||||
|
||||
|
||||
@@ -84,7 +84,7 @@ class NormalGene(BaseGene):
|
||||
return (con1[2] != con2[2]) + jnp.abs(con1[3] - con2[3]) # enable + weight
|
||||
|
||||
@staticmethod
|
||||
def forward_transform(nodes, conns):
|
||||
def forward_transform(state, nodes, conns):
|
||||
u_conns = unflatten_connections(nodes, conns)
|
||||
conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False)
|
||||
|
||||
|
||||
@@ -4,13 +4,13 @@ from jax import Array, numpy as jnp, vmap
|
||||
from .normal import NormalGene
|
||||
from .activation import Activation
|
||||
from .aggregation import Aggregation
|
||||
from ..utils import unflatten_connections, I_INT
|
||||
from algorithm.utils import unflatten_connections
|
||||
|
||||
|
||||
class RecurrentGene(NormalGene):
|
||||
|
||||
@staticmethod
|
||||
def forward_transform(nodes, conns):
|
||||
def forward_transform(state, nodes, conns):
|
||||
u_conns = unflatten_connections(nodes, conns)
|
||||
|
||||
# remove un-enable connections and remove enable attr
|
||||
|
||||
@@ -6,7 +6,7 @@ from jax import Array, numpy as jnp
|
||||
|
||||
from algorithm import State
|
||||
from ..gene import BaseGene
|
||||
from ..utils import fetch_first
|
||||
from algorithm.utils import fetch_first
|
||||
|
||||
|
||||
def initialize_genomes(state: State, gene_type: Type[BaseGene]):
|
||||
@@ -48,6 +48,7 @@ def count(nodes: Array, cons: Array):
|
||||
cons_cnt = jnp.sum(~jnp.isnan(cons[:, 0]))
|
||||
return node_cnt, cons_cnt
|
||||
|
||||
|
||||
def add_node(nodes: Array, cons: Array, new_key: int, attrs: Array) -> Tuple[Array, Array]:
|
||||
"""
|
||||
Add a new node to the genome.
|
||||
|
||||
@@ -6,7 +6,7 @@ Only used in feed-forward networks.
|
||||
import jax
|
||||
from jax import jit, Array, numpy as jnp
|
||||
|
||||
from ..utils import fetch_first, I_INT
|
||||
from algorithm.utils import fetch_first, I_INT
|
||||
|
||||
|
||||
@jit
|
||||
|
||||
@@ -4,9 +4,9 @@ import jax
|
||||
from jax import Array, numpy as jnp, vmap
|
||||
|
||||
from algorithm import State
|
||||
from .basic import add_node, add_connection, delete_node_by_idx, delete_connection_by_idx, count
|
||||
from .basic import add_node, add_connection, delete_node_by_idx, delete_connection_by_idx
|
||||
from .graph import check_cycles
|
||||
from ..utils import fetch_random, fetch_first, I_INT, unflatten_connections
|
||||
from algorithm.utils import fetch_random, fetch_first, I_INT, unflatten_connections
|
||||
from ..gene import BaseGene
|
||||
|
||||
|
||||
|
||||
@@ -3,22 +3,25 @@ from typing import Type
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from algorithm.state import State
|
||||
from algorithm import Algorithm, State
|
||||
from .gene import BaseGene
|
||||
from .genome import initialize_genomes
|
||||
from .population import create_tell
|
||||
|
||||
|
||||
class NEAT:
|
||||
class NEAT(Algorithm):
|
||||
def __init__(self, config, gene_type: Type[BaseGene]):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.gene_type = gene_type
|
||||
|
||||
self.tell_func = jax.jit(create_tell(config, self.gene_type))
|
||||
self.tell = create_tell(config, self.gene_type)
|
||||
self.ask = None
|
||||
self.forward = self.gene_type.create_forward(config)
|
||||
self.forward_transform = self.gene_type.forward_transform
|
||||
|
||||
def setup(self, randkey):
|
||||
|
||||
state = State(
|
||||
def setup(self, randkey, state=State()):
|
||||
state = state.update(
|
||||
P=self.config['pop_size'],
|
||||
N=self.config['maximum_nodes'],
|
||||
C=self.config['maximum_conns'],
|
||||
@@ -70,6 +73,3 @@ class NEAT:
|
||||
state = jax.device_put(state)
|
||||
|
||||
return state
|
||||
|
||||
def step(self, state, fitness):
|
||||
return self.tell_func(state, fitness)
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Type
|
||||
import jax
|
||||
from jax import numpy as jnp, vmap
|
||||
|
||||
from .utils import rank_elements, fetch_first
|
||||
from algorithm.utils import rank_elements, fetch_first
|
||||
from .genome import create_mutate, create_distance, crossover
|
||||
from .gene import BaseGene
|
||||
|
||||
|
||||
1
config/__init__.py
Normal file
1
config/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .config import Configer
|
||||
@@ -4,6 +4,7 @@ import configparser
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Configer:
|
||||
|
||||
@classmethod
|
||||
@@ -28,7 +29,7 @@ class Configer:
|
||||
def __check_redundant_config(cls, default_config, config):
|
||||
for key in config:
|
||||
if key not in default_config:
|
||||
warnings.warn(f"Redundant config: {key} in {config.name}")
|
||||
warnings.warn(f"Redundant config: {key} in config!")
|
||||
|
||||
@classmethod
|
||||
def __complete_config(cls, default_config, config):
|
||||
@@ -1,26 +1,38 @@
|
||||
[basic]
|
||||
random_seed = 0
|
||||
generation_limit = 1000
|
||||
|
||||
[problem]
|
||||
fitness_threshold = 3.9999
|
||||
num_inputs = 2
|
||||
num_outputs = 1
|
||||
maximum_nodes = 50
|
||||
maximum_conns = 100
|
||||
maximum_species = 10
|
||||
forward_way = "pop"
|
||||
batch_size = 4
|
||||
random_seed = 0
|
||||
|
||||
[neat]
|
||||
network_type = "feedforward"
|
||||
activate_times = 10
|
||||
activate_times = 5
|
||||
maximum_nodes = 50
|
||||
maximum_conns = 50
|
||||
maximum_species = 10
|
||||
|
||||
[hyperneat]
|
||||
below_threshold = 0.2
|
||||
max_weight = 3
|
||||
h_activation = "sigmoid"
|
||||
h_aggregation = "sum"
|
||||
h_activate_times = 5
|
||||
|
||||
[substrate]
|
||||
input_coors = [[-1, 1], [0, 1], [1, 1]]
|
||||
hidden_coors = [[-1, 0], [0, 0], [1, 0]]
|
||||
output_coors = [[0, -1]]
|
||||
|
||||
[population]
|
||||
fitness_threshold = 3.9999
|
||||
generation_limit = 1000
|
||||
fitness_criterion = "max"
|
||||
pop_size = 50000
|
||||
pop_size = 10
|
||||
|
||||
[genome]
|
||||
compatibility_disjoint = 1.0
|
||||
compatibility_weight = 0.5
|
||||
conn_add_prob = 0.4
|
||||
conn_add_trials = 1
|
||||
conn_delete_prob = 0
|
||||
node_add_prob = 0.2
|
||||
node_delete_prob = 0
|
||||
@@ -34,39 +46,37 @@ survival_threshold = 0.2
|
||||
min_species_size = 1
|
||||
spawn_number_change_rate = 0.5
|
||||
|
||||
[gene-bias]
|
||||
[gene]
|
||||
# bias
|
||||
bias_init_mean = 0.0
|
||||
bias_init_std = 1.0
|
||||
bias_mutate_power = 0.5
|
||||
bias_mutate_rate = 0.7
|
||||
bias_replace_rate = 0.1
|
||||
|
||||
[gene-response]
|
||||
# response
|
||||
response_init_mean = 1.0
|
||||
response_init_std = 0.0
|
||||
response_mutate_power = 0.0
|
||||
response_mutate_rate = 0.0
|
||||
response_replace_rate = 0.0
|
||||
|
||||
[gene-activation]
|
||||
# activation
|
||||
activation_default = "sigmoid"
|
||||
activation_option_names = ["sigmoid"]
|
||||
activation_option_names = ["tanh"]
|
||||
activation_replace_rate = 0.0
|
||||
|
||||
[gene-aggregation]
|
||||
# aggregation
|
||||
aggregation_default = "sum"
|
||||
aggregation_option_names = ["sum"]
|
||||
aggregation_replace_rate = 0.0
|
||||
|
||||
[gene-weight]
|
||||
# weight
|
||||
weight_init_mean = 0.0
|
||||
weight_init_std = 1.0
|
||||
weight_mutate_power = 0.5
|
||||
weight_mutate_rate = 0.8
|
||||
weight_replace_rate = 0.1
|
||||
|
||||
[gene-enable]
|
||||
enable_mutate_rate = 0.01
|
||||
|
||||
[visualize]
|
||||
renumber_nodes = True
|
||||
11
examples/a.py
Normal file
11
examples/a.py
Normal file
@@ -0,0 +1,11 @@
|
||||
import numpy as np
|
||||
import jax.numpy as jnp
|
||||
|
||||
a = jnp.zeros((5, 5))
|
||||
k1 = jnp.array([1, 2, 3])
|
||||
k2 = jnp.array([2, 3, 4])
|
||||
v = jnp.array([1, 1, 1])
|
||||
|
||||
a = a.at[k1, k2].set(v)
|
||||
|
||||
print(a)
|
||||
@@ -1,13 +1,44 @@
|
||||
import numpy as np
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax.tree_util import register_pytree_node_class
|
||||
|
||||
|
||||
vals = np.array([1, 2])
|
||||
weights = np.array([[0, 4], [5, 0]])
|
||||
@register_pytree_node_class
|
||||
class Genome:
|
||||
def __init__(self, nodes, conns):
|
||||
self.nodes = nodes
|
||||
self.conns = conns
|
||||
|
||||
ins1 = vals * weights[:, 0]
|
||||
ins2 = vals * weights[:, 1]
|
||||
ins_all = vals * weights.T
|
||||
def update_nodes(self, nodes):
|
||||
return Genome(nodes, self.conns)
|
||||
|
||||
print(ins1)
|
||||
print(ins2)
|
||||
print(ins_all)
|
||||
def update_conns(self, conns):
|
||||
return Genome(self.nodes, conns)
|
||||
|
||||
def tree_flatten(self):
|
||||
children = self.nodes, self.conns
|
||||
aux_data = None
|
||||
return children, aux_data
|
||||
|
||||
@classmethod
|
||||
def tree_unflatten(cls, aux_data, children):
|
||||
return cls(*children)
|
||||
|
||||
def __repr__(self):
|
||||
return f"Genome ({self.nodes}, \n\t{self.conns})"
|
||||
|
||||
@jax.jit
|
||||
def add_node(self, a: int):
|
||||
nodes = self.nodes.at[0, :].set(a)
|
||||
return self.update_nodes(nodes)
|
||||
|
||||
|
||||
nodes, conns = jnp.array([[1, 2, 3, 4, 5]]), jnp.array([[1, 2, 3, 4]])
|
||||
g = Genome(nodes, conns)
|
||||
print(g)
|
||||
|
||||
g = g.add_node(1)
|
||||
print(g)
|
||||
|
||||
g = jax.jit(g.add_node)(2)
|
||||
print(g)
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
import jax
|
||||
from jax import numpy as jnp
|
||||
from algorithm.state import State
|
||||
|
||||
|
||||
@jax.jit
|
||||
def func(state: State, a):
|
||||
return state.update(a=a)
|
||||
|
||||
|
||||
state = State(c=1, b=2)
|
||||
print(state)
|
||||
|
||||
vmap_func = jax.vmap(func, in_axes=(None, 0))
|
||||
print(vmap_func(state, jnp.array([1, 2, 3])))
|
||||
@@ -1,7 +1,12 @@
|
||||
[basic]
|
||||
forward_way = "common"
|
||||
network_type = "recurrent"
|
||||
activate_times = 5
|
||||
fitness_threshold = 4
|
||||
|
||||
[population]
|
||||
fitness_threshold = 4
|
||||
pop_size = 1000
|
||||
|
||||
[neat]
|
||||
network_type = "recurrent"
|
||||
num_inputs = 4
|
||||
num_outputs = 1
|
||||
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import jax
|
||||
import numpy as np
|
||||
|
||||
from algorithm import Configer, NEAT
|
||||
from algorithm.neat import NormalGene, RecurrentGene, Pipeline
|
||||
from pipeline import Pipeline
|
||||
from config import Configer
|
||||
from algorithm import NEAT
|
||||
from algorithm.neat import RecurrentGene
|
||||
|
||||
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
|
||||
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
|
||||
@@ -21,7 +23,6 @@ def evaluate(forward_func):
|
||||
|
||||
def main():
|
||||
config = Configer.load_config("xor.ini")
|
||||
# algorithm = NEAT(config, NormalGene)
|
||||
algorithm = NEAT(config, RecurrentGene)
|
||||
pipeline = Pipeline(config, algorithm)
|
||||
best = pipeline.auto_run(evaluate)
|
||||
|
||||
33
examples/xor_hyperneat.py
Normal file
33
examples/xor_hyperneat.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import jax
|
||||
import numpy as np
|
||||
|
||||
from pipeline import Pipeline
|
||||
from config import Configer
|
||||
from algorithm import NEAT, HyperNEAT
|
||||
from algorithm.neat import RecurrentGene
|
||||
from algorithm.hyperneat import BaseSubstrate
|
||||
|
||||
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
|
||||
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
|
||||
|
||||
|
||||
def evaluate(forward_func):
|
||||
"""
|
||||
:param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size)
|
||||
:return:
|
||||
"""
|
||||
outs = forward_func(xor_inputs)
|
||||
outs = jax.device_get(outs)
|
||||
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
|
||||
return fitnesses
|
||||
|
||||
|
||||
def main():
|
||||
config = Configer.load_config("xor.ini")
|
||||
algorithm = HyperNEAT(config, RecurrentGene, BaseSubstrate)
|
||||
pipeline = Pipeline(config, algorithm)
|
||||
pipeline.auto_run(evaluate)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -1,50 +0,0 @@
|
||||
import jax
|
||||
import numpy as np
|
||||
|
||||
from algorithm.config import Configer
|
||||
from algorithm.neat import NEAT, NormalGene, RecurrentGene, Pipeline
|
||||
from algorithm.neat.genome import create_mutate
|
||||
|
||||
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
|
||||
|
||||
|
||||
def single_genome(func, nodes, conns):
|
||||
t = RecurrentGene.forward_transform(nodes, conns)
|
||||
out1 = func(xor_inputs[0], t)
|
||||
out2 = func(xor_inputs[1], t)
|
||||
out3 = func(xor_inputs[2], t)
|
||||
out4 = func(xor_inputs[3], t)
|
||||
print(out1, out2, out3, out4)
|
||||
|
||||
|
||||
def batch_genome(func, nodes, conns):
|
||||
t = NormalGene.forward_transform(nodes, conns)
|
||||
out = jax.vmap(func, in_axes=(0, None))(xor_inputs, t)
|
||||
print(out)
|
||||
|
||||
|
||||
def pop_batch_genome(func, pop_nodes, pop_conns):
|
||||
t = jax.vmap(NormalGene.forward_transform)(pop_nodes, pop_conns)
|
||||
func = jax.vmap(jax.vmap(func, in_axes=(0, None)), in_axes=(None, 0))
|
||||
out = func(xor_inputs, t)
|
||||
print(out)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
config = Configer.load_config("xor.ini")
|
||||
# neat = NEAT(config, NormalGene)
|
||||
neat = NEAT(config, RecurrentGene)
|
||||
randkey = jax.random.PRNGKey(42)
|
||||
state = neat.setup(randkey)
|
||||
forward_func = RecurrentGene.create_forward(config)
|
||||
mutate_func = create_mutate(config, RecurrentGene)
|
||||
|
||||
nodes, conns = state.pop_nodes[0], state.pop_conns[0]
|
||||
single_genome(forward_func, nodes, conns)
|
||||
# batch_genome(forward_func, nodes, conns)
|
||||
|
||||
nodes, conns = mutate_func(state, randkey, nodes, conns, 10000)
|
||||
single_genome(forward_func, nodes, conns)
|
||||
|
||||
# batch_genome(forward_func, nodes, conns)
|
||||
#
|
||||
@@ -5,15 +5,18 @@ import jax
|
||||
from jax import vmap, jit
|
||||
import numpy as np
|
||||
|
||||
from algorithm import Algorithm
|
||||
|
||||
|
||||
class Pipeline:
|
||||
"""
|
||||
Neat algorithm pipeline.
|
||||
"""
|
||||
|
||||
def __init__(self, config, algorithm):
|
||||
def __init__(self, config, algorithm: Algorithm):
|
||||
self.config = config
|
||||
self.algorithm = algorithm
|
||||
|
||||
randkey = jax.random.PRNGKey(config['random_seed'])
|
||||
self.state = algorithm.setup(randkey)
|
||||
|
||||
@@ -23,18 +26,18 @@ class Pipeline:
|
||||
|
||||
self.evaluate_time = 0
|
||||
|
||||
self.forward_func = algorithm.gene_type.create_forward(config)
|
||||
self.forward_func = jit(self.algorithm.forward)
|
||||
self.batch_forward_func = jit(vmap(self.forward_func, in_axes=(0, None)))
|
||||
self.pop_batch_forward_func = jit(vmap(self.batch_forward_func, in_axes=(None, 0)))
|
||||
|
||||
self.pop_transform_func = jit(vmap(algorithm.gene_type.forward_transform))
|
||||
self.forward_transform_func = jit(vmap(self.algorithm.forward_transform, in_axes=(None, 0, 0)))
|
||||
self.tell_func = jit(self.algorithm.tell)
|
||||
|
||||
def ask(self):
|
||||
pop_transforms = self.pop_transform_func(self.state.pop_nodes, self.state.pop_conns)
|
||||
pop_transforms = self.forward_transform_func(self.state, self.state.pop_nodes, self.state.pop_conns)
|
||||
return lambda inputs: self.pop_batch_forward_func(inputs, pop_transforms)
|
||||
|
||||
def tell(self, fitness):
|
||||
self.state = self.algorithm.step(self.state, fitness)
|
||||
self.state = self.tell_func(self.state, fitness)
|
||||
|
||||
def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"):
|
||||
for _ in range(self.config['generation_limit']):
|
||||
56
test/unit/test_cartesian_product.py
Normal file
56
test/unit/test_cartesian_product.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import numpy as np
|
||||
|
||||
from algorithm.hyperneat.substrate.tools import cartesian_product
|
||||
|
||||
|
||||
def test01():
|
||||
keys1 = np.array([1, 2, 3])
|
||||
keys2 = np.array([4, 5, 6, 7])
|
||||
|
||||
coors1 = np.array([
|
||||
[1, 1, 1],
|
||||
[2, 2, 2],
|
||||
[3, 3, 3]
|
||||
])
|
||||
|
||||
coors2 = np.array([
|
||||
[4, 4, 4],
|
||||
[5, 5, 5],
|
||||
[6, 6, 6],
|
||||
[7, 7, 7]
|
||||
])
|
||||
|
||||
target_coors = np.array([
|
||||
[1, 1, 1, 4, 4, 4],
|
||||
[1, 1, 1, 5, 5, 5],
|
||||
[1, 1, 1, 6, 6, 6],
|
||||
[1, 1, 1, 7, 7, 7],
|
||||
[2, 2, 2, 4, 4, 4],
|
||||
[2, 2, 2, 5, 5, 5],
|
||||
[2, 2, 2, 6, 6, 6],
|
||||
[2, 2, 2, 7, 7, 7],
|
||||
[3, 3, 3, 4, 4, 4],
|
||||
[3, 3, 3, 5, 5, 5],
|
||||
[3, 3, 3, 6, 6, 6],
|
||||
[3, 3, 3, 7, 7, 7]
|
||||
])
|
||||
|
||||
target_keys = np.array([
|
||||
[1, 4],
|
||||
[1, 5],
|
||||
[1, 6],
|
||||
[1, 7],
|
||||
[2, 4],
|
||||
[2, 5],
|
||||
[2, 6],
|
||||
[2, 7],
|
||||
[3, 4],
|
||||
[3, 5],
|
||||
[3, 6],
|
||||
[3, 7]
|
||||
])
|
||||
|
||||
new_coors, correspond_keys = cartesian_product(keys1, keys2, coors1, coors2)
|
||||
|
||||
assert np.array_equal(new_coors, target_coors)
|
||||
assert np.array_equal(correspond_keys, target_keys)
|
||||
@@ -1,7 +1,7 @@
|
||||
import jax.numpy as jnp
|
||||
|
||||
from algorithm.neat.genome.graph import topological_sort, check_cycles
|
||||
from algorithm.neat.utils import I_INT
|
||||
from algorithm.utils import I_INT
|
||||
|
||||
nodes = jnp.array([
|
||||
[0],
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import jax.numpy as jnp
|
||||
from algorithm.neat.utils import unflatten_connections
|
||||
from algorithm.utils import unflatten_connections
|
||||
|
||||
|
||||
def test_unflatten():
|
||||
|
||||
Reference in New Issue
Block a user