complete HyperNEAT!
This commit is contained in:
@@ -1,3 +1,4 @@
|
|||||||
|
from .base import Algorithm
|
||||||
from .state import State
|
from .state import State
|
||||||
from .neat import NEAT
|
from .neat import NEAT
|
||||||
from .config import Configer
|
from .hyperneat import HyperNEAT
|
||||||
|
|||||||
17
algorithm/base.py
Normal file
17
algorithm/base.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
from .state import State
|
||||||
|
|
||||||
|
EMPTY = lambda *args: args
|
||||||
|
|
||||||
|
|
||||||
|
class Algorithm:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.tell: Callable = EMPTY
|
||||||
|
self.ask: Callable = EMPTY
|
||||||
|
self.forward: Callable = EMPTY
|
||||||
|
self.forward_transform: Callable = EMPTY
|
||||||
|
|
||||||
|
def setup(self, randkey, state=State()):
|
||||||
|
pass
|
||||||
@@ -0,0 +1,2 @@
|
|||||||
|
from .hyperneat import HyperNEAT
|
||||||
|
from .substrate import BaseSubstrate
|
||||||
|
|||||||
70
algorithm/hyperneat/hyperneat.py
Normal file
70
algorithm/hyperneat/hyperneat.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
from typing import Type
|
||||||
|
|
||||||
|
import jax
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .substrate import BaseSubstrate, analysis_substrate
|
||||||
|
from .hyperneat_gene import HyperNEATGene
|
||||||
|
from algorithm import State, Algorithm, neat
|
||||||
|
|
||||||
|
|
||||||
|
class HyperNEAT(Algorithm):
|
||||||
|
|
||||||
|
def __init__(self, config, gene_type: Type[neat.BaseGene], substrate: Type[BaseSubstrate]):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.gene_type = gene_type
|
||||||
|
self.substrate = substrate
|
||||||
|
self.neat = neat.NEAT(config, gene_type)
|
||||||
|
|
||||||
|
self.tell = create_tell(self.neat)
|
||||||
|
self.forward_transform = create_forward_transform(config, self.neat)
|
||||||
|
self.forward = HyperNEATGene.create_forward(config)
|
||||||
|
|
||||||
|
def setup(self, randkey, state=State()):
|
||||||
|
state = state.update(
|
||||||
|
below_threshold=self.config['below_threshold'],
|
||||||
|
max_weight=self.config['max_weight']
|
||||||
|
)
|
||||||
|
|
||||||
|
state = self.substrate.setup(state, self.config)
|
||||||
|
h_input_idx, h_output_idx, h_hidden_idx, query_coors, correspond_keys = analysis_substrate(state)
|
||||||
|
h_nodes = np.concatenate((h_input_idx, h_output_idx, h_hidden_idx))[..., np.newaxis]
|
||||||
|
h_conns = np.zeros((correspond_keys.shape[0], 3), dtype=np.float32)
|
||||||
|
h_conns[:, 0:2] = correspond_keys
|
||||||
|
|
||||||
|
state = state.update(
|
||||||
|
# h is short for hyperneat
|
||||||
|
h_input_idx=h_input_idx,
|
||||||
|
h_output_idx=h_output_idx,
|
||||||
|
h_hidden_idx=h_hidden_idx,
|
||||||
|
query_coors=query_coors,
|
||||||
|
correspond_keys=correspond_keys,
|
||||||
|
h_nodes=h_nodes,
|
||||||
|
h_conns=h_conns
|
||||||
|
)
|
||||||
|
state = self.neat.setup(randkey, state=state)
|
||||||
|
|
||||||
|
self.config['h_input_idx'] = h_input_idx
|
||||||
|
self.config['h_output_idx'] = h_output_idx
|
||||||
|
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
def create_tell(neat_instance):
|
||||||
|
def tell(state, fitness):
|
||||||
|
return neat_instance.tell(state, fitness)
|
||||||
|
|
||||||
|
return tell
|
||||||
|
|
||||||
|
|
||||||
|
def create_forward_transform(config, neat_instance):
|
||||||
|
def forward_transform(state, nodes, conns):
|
||||||
|
t = neat_instance.forward_transform(state, nodes, conns)
|
||||||
|
batch_forward_func = jax.vmap(neat_instance.forward, in_axes=(0, None))
|
||||||
|
query_res = batch_forward_func(state.query_coors, t) # hyperneat connections
|
||||||
|
h_nodes = state.h_nodes
|
||||||
|
h_conns = state.h_conns.at[:, 2:].set(query_res)
|
||||||
|
return HyperNEATGene.forward_transform(state, h_nodes, h_conns)
|
||||||
|
|
||||||
|
return forward_transform
|
||||||
54
algorithm/hyperneat/hyperneat_gene.py
Normal file
54
algorithm/hyperneat/hyperneat_gene.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
import jax
|
||||||
|
from jax import numpy as jnp, vmap
|
||||||
|
|
||||||
|
from algorithm.neat import BaseGene
|
||||||
|
from algorithm.neat.gene import Activation
|
||||||
|
from algorithm.neat.gene import Aggregation
|
||||||
|
|
||||||
|
|
||||||
|
class HyperNEATGene(BaseGene):
|
||||||
|
node_attrs = [] # no node attributes
|
||||||
|
conn_attrs = ['weight']
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward_transform(state, nodes, conns):
|
||||||
|
N = nodes.shape[0]
|
||||||
|
u_conns = jnp.zeros((N, N), dtype=jnp.float32)
|
||||||
|
|
||||||
|
in_keys = jnp.asarray(conns[:, 0], jnp.int32)
|
||||||
|
out_keys = jnp.asarray(conns[:, 1], jnp.int32)
|
||||||
|
weights = conns[:, 2]
|
||||||
|
|
||||||
|
u_conns = u_conns.at[in_keys, out_keys].set(weights)
|
||||||
|
return nodes, u_conns
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_forward(config):
|
||||||
|
act = Activation.name2func[config['h_activation']]
|
||||||
|
agg = Aggregation.name2func[config['h_aggregation']]
|
||||||
|
|
||||||
|
batch_act, batch_agg = vmap(act), vmap(agg)
|
||||||
|
|
||||||
|
def forward(inputs, transform):
|
||||||
|
|
||||||
|
inputs_with_bias = jnp.concatenate((inputs, jnp.ones((1,))), axis=0)
|
||||||
|
nodes, weights = transform
|
||||||
|
|
||||||
|
input_idx = config['h_input_idx']
|
||||||
|
output_idx = config['h_output_idx']
|
||||||
|
|
||||||
|
N = nodes.shape[0]
|
||||||
|
vals = jnp.full((N,), 0.)
|
||||||
|
|
||||||
|
def body_func(i, values):
|
||||||
|
values = values.at[input_idx].set(inputs_with_bias)
|
||||||
|
nodes_ins = values * weights.T
|
||||||
|
values = batch_agg(nodes_ins) # z = agg(ins)
|
||||||
|
values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias
|
||||||
|
values = batch_act(values) # z = act(z)
|
||||||
|
return values
|
||||||
|
|
||||||
|
vals = jax.lax.fori_loop(0, config['h_activate_times'], body_func, vals)
|
||||||
|
return vals[output_idx]
|
||||||
|
|
||||||
|
return forward
|
||||||
2
algorithm/hyperneat/substrate/__init__.py
Normal file
2
algorithm/hyperneat/substrate/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
from .base import BaseSubstrate
|
||||||
|
from .tools import analysis_substrate
|
||||||
12
algorithm/hyperneat/substrate/base.py
Normal file
12
algorithm/hyperneat/substrate/base.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class BaseSubstrate:
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def setup(state, config):
|
||||||
|
return state.update(
|
||||||
|
input_coors=np.asarray(config['input_coors'], dtype=np.float32),
|
||||||
|
output_coors=np.asarray(config['output_coors'], dtype=np.float32),
|
||||||
|
hidden_coors=np.asarray(config['hidden_coors'], dtype=np.float32),
|
||||||
|
)
|
||||||
53
algorithm/hyperneat/substrate/tools.py
Normal file
53
algorithm/hyperneat/substrate/tools.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
from typing import Type
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .base import BaseSubstrate
|
||||||
|
|
||||||
|
|
||||||
|
def analysis_substrate(state):
|
||||||
|
cd = state.input_coors.shape[1] # coordinate dimensions
|
||||||
|
si = state.input_coors.shape[0] # input coordinate size
|
||||||
|
so = state.output_coors.shape[0] # output coordinate size
|
||||||
|
sh = state.hidden_coors.shape[0] # hidden coordinate size
|
||||||
|
|
||||||
|
input_idx = np.arange(si)
|
||||||
|
output_idx = np.arange(si, si + so)
|
||||||
|
hidden_idx = np.arange(si + so, si + so + sh)
|
||||||
|
|
||||||
|
total_conns = si * sh + sh * sh + sh * so
|
||||||
|
query_coors = np.zeros((total_conns, cd * 2))
|
||||||
|
correspond_keys = np.zeros((total_conns, 2))
|
||||||
|
|
||||||
|
# connect input to hidden
|
||||||
|
aux_coors, aux_keys = cartesian_product(input_idx, hidden_idx, state.input_coors, state.hidden_coors)
|
||||||
|
query_coors[0: si * sh, :] = aux_coors
|
||||||
|
correspond_keys[0: si * sh, :] = aux_keys
|
||||||
|
|
||||||
|
# connect hidden to hidden
|
||||||
|
aux_coors, aux_keys = cartesian_product(hidden_idx, hidden_idx, state.hidden_coors, state.hidden_coors)
|
||||||
|
query_coors[si * sh: si * sh + sh * sh, :] = aux_coors
|
||||||
|
correspond_keys[si * sh: si * sh + sh * sh, :] = aux_keys
|
||||||
|
|
||||||
|
# connect hidden to output
|
||||||
|
aux_coors, aux_keys = cartesian_product(hidden_idx, output_idx, state.hidden_coors, state.output_coors)
|
||||||
|
query_coors[si * sh + sh * sh:, :] = aux_coors
|
||||||
|
correspond_keys[si * sh + sh * sh:, :] = aux_keys
|
||||||
|
|
||||||
|
return input_idx, output_idx, hidden_idx, query_coors, correspond_keys
|
||||||
|
|
||||||
|
|
||||||
|
def cartesian_product(keys1, keys2, coors1, coors2):
|
||||||
|
len1 = keys1.shape[0]
|
||||||
|
len2 = keys2.shape[0]
|
||||||
|
|
||||||
|
repeated_coors1 = np.repeat(coors1, len2, axis=0)
|
||||||
|
repeated_keys1 = np.repeat(keys1, len2)
|
||||||
|
|
||||||
|
tiled_coors2 = np.tile(coors2, (len1, 1))
|
||||||
|
tiled_keys2 = np.tile(keys2, len1)
|
||||||
|
|
||||||
|
new_coors = np.concatenate((repeated_coors1, tiled_coors2), axis=1)
|
||||||
|
correspond_keys = np.column_stack((repeated_keys1, tiled_keys2))
|
||||||
|
|
||||||
|
return new_coors, correspond_keys
|
||||||
@@ -1,3 +1,2 @@
|
|||||||
from .neat import NEAT
|
from .neat import NEAT
|
||||||
from .gene import NormalGene, RecurrentGene
|
from .gene import BaseGene, NormalGene, RecurrentGene
|
||||||
from .pipeline import Pipeline
|
|
||||||
|
|||||||
@@ -33,12 +33,10 @@ class BaseGene:
|
|||||||
def distance_conn(state, conn1: Array, conn2: Array):
|
def distance_conn(state, conn1: Array, conn2: Array):
|
||||||
return conn1
|
return conn1
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward_transform(nodes, conns):
|
def forward_transform(state, nodes, conns):
|
||||||
return nodes, conns
|
return nodes, conns
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_forward(config):
|
def create_forward(config):
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from jax import Array, numpy as jnp
|
|||||||
from .base import BaseGene
|
from .base import BaseGene
|
||||||
from .activation import Activation
|
from .activation import Activation
|
||||||
from .aggregation import Aggregation
|
from .aggregation import Aggregation
|
||||||
from ..utils import unflatten_connections, I_INT
|
from algorithm.utils import unflatten_connections, I_INT
|
||||||
from ..genome import topological_sort
|
from ..genome import topological_sort
|
||||||
|
|
||||||
|
|
||||||
@@ -84,7 +84,7 @@ class NormalGene(BaseGene):
|
|||||||
return (con1[2] != con2[2]) + jnp.abs(con1[3] - con2[3]) # enable + weight
|
return (con1[2] != con2[2]) + jnp.abs(con1[3] - con2[3]) # enable + weight
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward_transform(nodes, conns):
|
def forward_transform(state, nodes, conns):
|
||||||
u_conns = unflatten_connections(nodes, conns)
|
u_conns = unflatten_connections(nodes, conns)
|
||||||
conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False)
|
conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False)
|
||||||
|
|
||||||
|
|||||||
@@ -4,13 +4,13 @@ from jax import Array, numpy as jnp, vmap
|
|||||||
from .normal import NormalGene
|
from .normal import NormalGene
|
||||||
from .activation import Activation
|
from .activation import Activation
|
||||||
from .aggregation import Aggregation
|
from .aggregation import Aggregation
|
||||||
from ..utils import unflatten_connections, I_INT
|
from algorithm.utils import unflatten_connections
|
||||||
|
|
||||||
|
|
||||||
class RecurrentGene(NormalGene):
|
class RecurrentGene(NormalGene):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward_transform(nodes, conns):
|
def forward_transform(state, nodes, conns):
|
||||||
u_conns = unflatten_connections(nodes, conns)
|
u_conns = unflatten_connections(nodes, conns)
|
||||||
|
|
||||||
# remove un-enable connections and remove enable attr
|
# remove un-enable connections and remove enable attr
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from jax import Array, numpy as jnp
|
|||||||
|
|
||||||
from algorithm import State
|
from algorithm import State
|
||||||
from ..gene import BaseGene
|
from ..gene import BaseGene
|
||||||
from ..utils import fetch_first
|
from algorithm.utils import fetch_first
|
||||||
|
|
||||||
|
|
||||||
def initialize_genomes(state: State, gene_type: Type[BaseGene]):
|
def initialize_genomes(state: State, gene_type: Type[BaseGene]):
|
||||||
@@ -48,6 +48,7 @@ def count(nodes: Array, cons: Array):
|
|||||||
cons_cnt = jnp.sum(~jnp.isnan(cons[:, 0]))
|
cons_cnt = jnp.sum(~jnp.isnan(cons[:, 0]))
|
||||||
return node_cnt, cons_cnt
|
return node_cnt, cons_cnt
|
||||||
|
|
||||||
|
|
||||||
def add_node(nodes: Array, cons: Array, new_key: int, attrs: Array) -> Tuple[Array, Array]:
|
def add_node(nodes: Array, cons: Array, new_key: int, attrs: Array) -> Tuple[Array, Array]:
|
||||||
"""
|
"""
|
||||||
Add a new node to the genome.
|
Add a new node to the genome.
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ Only used in feed-forward networks.
|
|||||||
import jax
|
import jax
|
||||||
from jax import jit, Array, numpy as jnp
|
from jax import jit, Array, numpy as jnp
|
||||||
|
|
||||||
from ..utils import fetch_first, I_INT
|
from algorithm.utils import fetch_first, I_INT
|
||||||
|
|
||||||
|
|
||||||
@jit
|
@jit
|
||||||
|
|||||||
@@ -4,9 +4,9 @@ import jax
|
|||||||
from jax import Array, numpy as jnp, vmap
|
from jax import Array, numpy as jnp, vmap
|
||||||
|
|
||||||
from algorithm import State
|
from algorithm import State
|
||||||
from .basic import add_node, add_connection, delete_node_by_idx, delete_connection_by_idx, count
|
from .basic import add_node, add_connection, delete_node_by_idx, delete_connection_by_idx
|
||||||
from .graph import check_cycles
|
from .graph import check_cycles
|
||||||
from ..utils import fetch_random, fetch_first, I_INT, unflatten_connections
|
from algorithm.utils import fetch_random, fetch_first, I_INT, unflatten_connections
|
||||||
from ..gene import BaseGene
|
from ..gene import BaseGene
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -3,22 +3,25 @@ from typing import Type
|
|||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
|
||||||
from algorithm.state import State
|
from algorithm import Algorithm, State
|
||||||
from .gene import BaseGene
|
from .gene import BaseGene
|
||||||
from .genome import initialize_genomes
|
from .genome import initialize_genomes
|
||||||
from .population import create_tell
|
from .population import create_tell
|
||||||
|
|
||||||
|
|
||||||
class NEAT:
|
class NEAT(Algorithm):
|
||||||
def __init__(self, config, gene_type: Type[BaseGene]):
|
def __init__(self, config, gene_type: Type[BaseGene]):
|
||||||
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.gene_type = gene_type
|
self.gene_type = gene_type
|
||||||
|
|
||||||
self.tell_func = jax.jit(create_tell(config, self.gene_type))
|
self.tell = create_tell(config, self.gene_type)
|
||||||
|
self.ask = None
|
||||||
|
self.forward = self.gene_type.create_forward(config)
|
||||||
|
self.forward_transform = self.gene_type.forward_transform
|
||||||
|
|
||||||
def setup(self, randkey):
|
def setup(self, randkey, state=State()):
|
||||||
|
state = state.update(
|
||||||
state = State(
|
|
||||||
P=self.config['pop_size'],
|
P=self.config['pop_size'],
|
||||||
N=self.config['maximum_nodes'],
|
N=self.config['maximum_nodes'],
|
||||||
C=self.config['maximum_conns'],
|
C=self.config['maximum_conns'],
|
||||||
@@ -69,7 +72,4 @@ class NEAT:
|
|||||||
# move to device
|
# move to device
|
||||||
state = jax.device_put(state)
|
state = jax.device_put(state)
|
||||||
|
|
||||||
return state
|
return state
|
||||||
|
|
||||||
def step(self, state, fitness):
|
|
||||||
return self.tell_func(state, fitness)
|
|
||||||
@@ -3,7 +3,7 @@ from typing import Type
|
|||||||
import jax
|
import jax
|
||||||
from jax import numpy as jnp, vmap
|
from jax import numpy as jnp, vmap
|
||||||
|
|
||||||
from .utils import rank_elements, fetch_first
|
from algorithm.utils import rank_elements, fetch_first
|
||||||
from .genome import create_mutate, create_distance, crossover
|
from .genome import create_mutate, create_distance, crossover
|
||||||
from .gene import BaseGene
|
from .gene import BaseGene
|
||||||
|
|
||||||
|
|||||||
1
config/__init__.py
Normal file
1
config/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from .config import Configer
|
||||||
@@ -4,6 +4,7 @@ import configparser
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
class Configer:
|
class Configer:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -28,7 +29,7 @@ class Configer:
|
|||||||
def __check_redundant_config(cls, default_config, config):
|
def __check_redundant_config(cls, default_config, config):
|
||||||
for key in config:
|
for key in config:
|
||||||
if key not in default_config:
|
if key not in default_config:
|
||||||
warnings.warn(f"Redundant config: {key} in {config.name}")
|
warnings.warn(f"Redundant config: {key} in config!")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __complete_config(cls, default_config, config):
|
def __complete_config(cls, default_config, config):
|
||||||
@@ -1,26 +1,38 @@
|
|||||||
[basic]
|
[basic]
|
||||||
|
random_seed = 0
|
||||||
|
generation_limit = 1000
|
||||||
|
|
||||||
|
[problem]
|
||||||
|
fitness_threshold = 3.9999
|
||||||
num_inputs = 2
|
num_inputs = 2
|
||||||
num_outputs = 1
|
num_outputs = 1
|
||||||
maximum_nodes = 50
|
|
||||||
maximum_conns = 100
|
[neat]
|
||||||
maximum_species = 10
|
|
||||||
forward_way = "pop"
|
|
||||||
batch_size = 4
|
|
||||||
random_seed = 0
|
|
||||||
network_type = "feedforward"
|
network_type = "feedforward"
|
||||||
activate_times = 10
|
activate_times = 5
|
||||||
|
maximum_nodes = 50
|
||||||
|
maximum_conns = 50
|
||||||
|
maximum_species = 10
|
||||||
|
|
||||||
|
[hyperneat]
|
||||||
|
below_threshold = 0.2
|
||||||
|
max_weight = 3
|
||||||
|
h_activation = "sigmoid"
|
||||||
|
h_aggregation = "sum"
|
||||||
|
h_activate_times = 5
|
||||||
|
|
||||||
|
[substrate]
|
||||||
|
input_coors = [[-1, 1], [0, 1], [1, 1]]
|
||||||
|
hidden_coors = [[-1, 0], [0, 0], [1, 0]]
|
||||||
|
output_coors = [[0, -1]]
|
||||||
|
|
||||||
[population]
|
[population]
|
||||||
fitness_threshold = 3.9999
|
pop_size = 10
|
||||||
generation_limit = 1000
|
|
||||||
fitness_criterion = "max"
|
|
||||||
pop_size = 50000
|
|
||||||
|
|
||||||
[genome]
|
[genome]
|
||||||
compatibility_disjoint = 1.0
|
compatibility_disjoint = 1.0
|
||||||
compatibility_weight = 0.5
|
compatibility_weight = 0.5
|
||||||
conn_add_prob = 0.4
|
conn_add_prob = 0.4
|
||||||
conn_add_trials = 1
|
|
||||||
conn_delete_prob = 0
|
conn_delete_prob = 0
|
||||||
node_add_prob = 0.2
|
node_add_prob = 0.2
|
||||||
node_delete_prob = 0
|
node_delete_prob = 0
|
||||||
@@ -34,39 +46,37 @@ survival_threshold = 0.2
|
|||||||
min_species_size = 1
|
min_species_size = 1
|
||||||
spawn_number_change_rate = 0.5
|
spawn_number_change_rate = 0.5
|
||||||
|
|
||||||
[gene-bias]
|
[gene]
|
||||||
|
# bias
|
||||||
bias_init_mean = 0.0
|
bias_init_mean = 0.0
|
||||||
bias_init_std = 1.0
|
bias_init_std = 1.0
|
||||||
bias_mutate_power = 0.5
|
bias_mutate_power = 0.5
|
||||||
bias_mutate_rate = 0.7
|
bias_mutate_rate = 0.7
|
||||||
bias_replace_rate = 0.1
|
bias_replace_rate = 0.1
|
||||||
|
|
||||||
[gene-response]
|
# response
|
||||||
response_init_mean = 1.0
|
response_init_mean = 1.0
|
||||||
response_init_std = 0.0
|
response_init_std = 0.0
|
||||||
response_mutate_power = 0.0
|
response_mutate_power = 0.0
|
||||||
response_mutate_rate = 0.0
|
response_mutate_rate = 0.0
|
||||||
response_replace_rate = 0.0
|
response_replace_rate = 0.0
|
||||||
|
|
||||||
[gene-activation]
|
# activation
|
||||||
activation_default = "sigmoid"
|
activation_default = "sigmoid"
|
||||||
activation_option_names = ["sigmoid"]
|
activation_option_names = ["tanh"]
|
||||||
activation_replace_rate = 0.0
|
activation_replace_rate = 0.0
|
||||||
|
|
||||||
[gene-aggregation]
|
# aggregation
|
||||||
aggregation_default = "sum"
|
aggregation_default = "sum"
|
||||||
aggregation_option_names = ["sum"]
|
aggregation_option_names = ["sum"]
|
||||||
aggregation_replace_rate = 0.0
|
aggregation_replace_rate = 0.0
|
||||||
|
|
||||||
[gene-weight]
|
# weight
|
||||||
weight_init_mean = 0.0
|
weight_init_mean = 0.0
|
||||||
weight_init_std = 1.0
|
weight_init_std = 1.0
|
||||||
weight_mutate_power = 0.5
|
weight_mutate_power = 0.5
|
||||||
weight_mutate_rate = 0.8
|
weight_mutate_rate = 0.8
|
||||||
weight_replace_rate = 0.1
|
weight_replace_rate = 0.1
|
||||||
|
|
||||||
[gene-enable]
|
|
||||||
enable_mutate_rate = 0.01
|
|
||||||
|
|
||||||
[visualize]
|
[visualize]
|
||||||
renumber_nodes = True
|
renumber_nodes = True
|
||||||
11
examples/a.py
Normal file
11
examples/a.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
import numpy as np
|
||||||
|
import jax.numpy as jnp
|
||||||
|
|
||||||
|
a = jnp.zeros((5, 5))
|
||||||
|
k1 = jnp.array([1, 2, 3])
|
||||||
|
k2 = jnp.array([2, 3, 4])
|
||||||
|
v = jnp.array([1, 1, 1])
|
||||||
|
|
||||||
|
a = a.at[k1, k2].set(v)
|
||||||
|
|
||||||
|
print(a)
|
||||||
@@ -1,13 +1,44 @@
|
|||||||
import numpy as np
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from jax.tree_util import register_pytree_node_class
|
||||||
|
|
||||||
|
|
||||||
vals = np.array([1, 2])
|
@register_pytree_node_class
|
||||||
weights = np.array([[0, 4], [5, 0]])
|
class Genome:
|
||||||
|
def __init__(self, nodes, conns):
|
||||||
|
self.nodes = nodes
|
||||||
|
self.conns = conns
|
||||||
|
|
||||||
ins1 = vals * weights[:, 0]
|
def update_nodes(self, nodes):
|
||||||
ins2 = vals * weights[:, 1]
|
return Genome(nodes, self.conns)
|
||||||
ins_all = vals * weights.T
|
|
||||||
|
|
||||||
print(ins1)
|
def update_conns(self, conns):
|
||||||
print(ins2)
|
return Genome(self.nodes, conns)
|
||||||
print(ins_all)
|
|
||||||
|
def tree_flatten(self):
|
||||||
|
children = self.nodes, self.conns
|
||||||
|
aux_data = None
|
||||||
|
return children, aux_data
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tree_unflatten(cls, aux_data, children):
|
||||||
|
return cls(*children)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"Genome ({self.nodes}, \n\t{self.conns})"
|
||||||
|
|
||||||
|
@jax.jit
|
||||||
|
def add_node(self, a: int):
|
||||||
|
nodes = self.nodes.at[0, :].set(a)
|
||||||
|
return self.update_nodes(nodes)
|
||||||
|
|
||||||
|
|
||||||
|
nodes, conns = jnp.array([[1, 2, 3, 4, 5]]), jnp.array([[1, 2, 3, 4]])
|
||||||
|
g = Genome(nodes, conns)
|
||||||
|
print(g)
|
||||||
|
|
||||||
|
g = g.add_node(1)
|
||||||
|
print(g)
|
||||||
|
|
||||||
|
g = jax.jit(g.add_node)(2)
|
||||||
|
print(g)
|
||||||
|
|||||||
@@ -1,15 +0,0 @@
|
|||||||
import jax
|
|
||||||
from jax import numpy as jnp
|
|
||||||
from algorithm.state import State
|
|
||||||
|
|
||||||
|
|
||||||
@jax.jit
|
|
||||||
def func(state: State, a):
|
|
||||||
return state.update(a=a)
|
|
||||||
|
|
||||||
|
|
||||||
state = State(c=1, b=2)
|
|
||||||
print(state)
|
|
||||||
|
|
||||||
vmap_func = jax.vmap(func, in_axes=(None, 0))
|
|
||||||
print(vmap_func(state, jnp.array([1, 2, 3])))
|
|
||||||
@@ -1,7 +1,12 @@
|
|||||||
[basic]
|
[basic]
|
||||||
forward_way = "common"
|
|
||||||
network_type = "recurrent"
|
|
||||||
activate_times = 5
|
activate_times = 5
|
||||||
|
fitness_threshold = 4
|
||||||
|
|
||||||
[population]
|
[population]
|
||||||
fitness_threshold = 4
|
pop_size = 1000
|
||||||
|
|
||||||
|
[neat]
|
||||||
|
network_type = "recurrent"
|
||||||
|
num_inputs = 4
|
||||||
|
num_outputs = 1
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
import jax
|
import jax
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from algorithm import Configer, NEAT
|
from pipeline import Pipeline
|
||||||
from algorithm.neat import NormalGene, RecurrentGene, Pipeline
|
from config import Configer
|
||||||
|
from algorithm import NEAT
|
||||||
|
from algorithm.neat import RecurrentGene
|
||||||
|
|
||||||
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
|
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
|
||||||
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
|
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
|
||||||
@@ -21,7 +23,6 @@ def evaluate(forward_func):
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
config = Configer.load_config("xor.ini")
|
config = Configer.load_config("xor.ini")
|
||||||
# algorithm = NEAT(config, NormalGene)
|
|
||||||
algorithm = NEAT(config, RecurrentGene)
|
algorithm = NEAT(config, RecurrentGene)
|
||||||
pipeline = Pipeline(config, algorithm)
|
pipeline = Pipeline(config, algorithm)
|
||||||
best = pipeline.auto_run(evaluate)
|
best = pipeline.auto_run(evaluate)
|
||||||
|
|||||||
33
examples/xor_hyperneat.py
Normal file
33
examples/xor_hyperneat.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
import jax
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from pipeline import Pipeline
|
||||||
|
from config import Configer
|
||||||
|
from algorithm import NEAT, HyperNEAT
|
||||||
|
from algorithm.neat import RecurrentGene
|
||||||
|
from algorithm.hyperneat import BaseSubstrate
|
||||||
|
|
||||||
|
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
|
||||||
|
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(forward_func):
|
||||||
|
"""
|
||||||
|
:param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size)
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
outs = forward_func(xor_inputs)
|
||||||
|
outs = jax.device_get(outs)
|
||||||
|
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
|
||||||
|
return fitnesses
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
config = Configer.load_config("xor.ini")
|
||||||
|
algorithm = HyperNEAT(config, RecurrentGene, BaseSubstrate)
|
||||||
|
pipeline = Pipeline(config, algorithm)
|
||||||
|
pipeline.auto_run(evaluate)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
@@ -1,50 +0,0 @@
|
|||||||
import jax
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from algorithm.config import Configer
|
|
||||||
from algorithm.neat import NEAT, NormalGene, RecurrentGene, Pipeline
|
|
||||||
from algorithm.neat.genome import create_mutate
|
|
||||||
|
|
||||||
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
|
|
||||||
|
|
||||||
|
|
||||||
def single_genome(func, nodes, conns):
|
|
||||||
t = RecurrentGene.forward_transform(nodes, conns)
|
|
||||||
out1 = func(xor_inputs[0], t)
|
|
||||||
out2 = func(xor_inputs[1], t)
|
|
||||||
out3 = func(xor_inputs[2], t)
|
|
||||||
out4 = func(xor_inputs[3], t)
|
|
||||||
print(out1, out2, out3, out4)
|
|
||||||
|
|
||||||
|
|
||||||
def batch_genome(func, nodes, conns):
|
|
||||||
t = NormalGene.forward_transform(nodes, conns)
|
|
||||||
out = jax.vmap(func, in_axes=(0, None))(xor_inputs, t)
|
|
||||||
print(out)
|
|
||||||
|
|
||||||
|
|
||||||
def pop_batch_genome(func, pop_nodes, pop_conns):
|
|
||||||
t = jax.vmap(NormalGene.forward_transform)(pop_nodes, pop_conns)
|
|
||||||
func = jax.vmap(jax.vmap(func, in_axes=(0, None)), in_axes=(None, 0))
|
|
||||||
out = func(xor_inputs, t)
|
|
||||||
print(out)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
config = Configer.load_config("xor.ini")
|
|
||||||
# neat = NEAT(config, NormalGene)
|
|
||||||
neat = NEAT(config, RecurrentGene)
|
|
||||||
randkey = jax.random.PRNGKey(42)
|
|
||||||
state = neat.setup(randkey)
|
|
||||||
forward_func = RecurrentGene.create_forward(config)
|
|
||||||
mutate_func = create_mutate(config, RecurrentGene)
|
|
||||||
|
|
||||||
nodes, conns = state.pop_nodes[0], state.pop_conns[0]
|
|
||||||
single_genome(forward_func, nodes, conns)
|
|
||||||
# batch_genome(forward_func, nodes, conns)
|
|
||||||
|
|
||||||
nodes, conns = mutate_func(state, randkey, nodes, conns, 10000)
|
|
||||||
single_genome(forward_func, nodes, conns)
|
|
||||||
|
|
||||||
# batch_genome(forward_func, nodes, conns)
|
|
||||||
#
|
|
||||||
@@ -5,15 +5,18 @@ import jax
|
|||||||
from jax import vmap, jit
|
from jax import vmap, jit
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from algorithm import Algorithm
|
||||||
|
|
||||||
|
|
||||||
class Pipeline:
|
class Pipeline:
|
||||||
"""
|
"""
|
||||||
Neat algorithm pipeline.
|
Neat algorithm pipeline.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config, algorithm):
|
def __init__(self, config, algorithm: Algorithm):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.algorithm = algorithm
|
self.algorithm = algorithm
|
||||||
|
|
||||||
randkey = jax.random.PRNGKey(config['random_seed'])
|
randkey = jax.random.PRNGKey(config['random_seed'])
|
||||||
self.state = algorithm.setup(randkey)
|
self.state = algorithm.setup(randkey)
|
||||||
|
|
||||||
@@ -23,18 +26,18 @@ class Pipeline:
|
|||||||
|
|
||||||
self.evaluate_time = 0
|
self.evaluate_time = 0
|
||||||
|
|
||||||
self.forward_func = algorithm.gene_type.create_forward(config)
|
self.forward_func = jit(self.algorithm.forward)
|
||||||
self.batch_forward_func = jit(vmap(self.forward_func, in_axes=(0, None)))
|
self.batch_forward_func = jit(vmap(self.forward_func, in_axes=(0, None)))
|
||||||
self.pop_batch_forward_func = jit(vmap(self.batch_forward_func, in_axes=(None, 0)))
|
self.pop_batch_forward_func = jit(vmap(self.batch_forward_func, in_axes=(None, 0)))
|
||||||
|
self.forward_transform_func = jit(vmap(self.algorithm.forward_transform, in_axes=(None, 0, 0)))
|
||||||
self.pop_transform_func = jit(vmap(algorithm.gene_type.forward_transform))
|
self.tell_func = jit(self.algorithm.tell)
|
||||||
|
|
||||||
def ask(self):
|
def ask(self):
|
||||||
pop_transforms = self.pop_transform_func(self.state.pop_nodes, self.state.pop_conns)
|
pop_transforms = self.forward_transform_func(self.state, self.state.pop_nodes, self.state.pop_conns)
|
||||||
return lambda inputs: self.pop_batch_forward_func(inputs, pop_transforms)
|
return lambda inputs: self.pop_batch_forward_func(inputs, pop_transforms)
|
||||||
|
|
||||||
def tell(self, fitness):
|
def tell(self, fitness):
|
||||||
self.state = self.algorithm.step(self.state, fitness)
|
self.state = self.tell_func(self.state, fitness)
|
||||||
|
|
||||||
def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"):
|
def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"):
|
||||||
for _ in range(self.config['generation_limit']):
|
for _ in range(self.config['generation_limit']):
|
||||||
56
test/unit/test_cartesian_product.py
Normal file
56
test/unit/test_cartesian_product.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from algorithm.hyperneat.substrate.tools import cartesian_product
|
||||||
|
|
||||||
|
|
||||||
|
def test01():
|
||||||
|
keys1 = np.array([1, 2, 3])
|
||||||
|
keys2 = np.array([4, 5, 6, 7])
|
||||||
|
|
||||||
|
coors1 = np.array([
|
||||||
|
[1, 1, 1],
|
||||||
|
[2, 2, 2],
|
||||||
|
[3, 3, 3]
|
||||||
|
])
|
||||||
|
|
||||||
|
coors2 = np.array([
|
||||||
|
[4, 4, 4],
|
||||||
|
[5, 5, 5],
|
||||||
|
[6, 6, 6],
|
||||||
|
[7, 7, 7]
|
||||||
|
])
|
||||||
|
|
||||||
|
target_coors = np.array([
|
||||||
|
[1, 1, 1, 4, 4, 4],
|
||||||
|
[1, 1, 1, 5, 5, 5],
|
||||||
|
[1, 1, 1, 6, 6, 6],
|
||||||
|
[1, 1, 1, 7, 7, 7],
|
||||||
|
[2, 2, 2, 4, 4, 4],
|
||||||
|
[2, 2, 2, 5, 5, 5],
|
||||||
|
[2, 2, 2, 6, 6, 6],
|
||||||
|
[2, 2, 2, 7, 7, 7],
|
||||||
|
[3, 3, 3, 4, 4, 4],
|
||||||
|
[3, 3, 3, 5, 5, 5],
|
||||||
|
[3, 3, 3, 6, 6, 6],
|
||||||
|
[3, 3, 3, 7, 7, 7]
|
||||||
|
])
|
||||||
|
|
||||||
|
target_keys = np.array([
|
||||||
|
[1, 4],
|
||||||
|
[1, 5],
|
||||||
|
[1, 6],
|
||||||
|
[1, 7],
|
||||||
|
[2, 4],
|
||||||
|
[2, 5],
|
||||||
|
[2, 6],
|
||||||
|
[2, 7],
|
||||||
|
[3, 4],
|
||||||
|
[3, 5],
|
||||||
|
[3, 6],
|
||||||
|
[3, 7]
|
||||||
|
])
|
||||||
|
|
||||||
|
new_coors, correspond_keys = cartesian_product(keys1, keys2, coors1, coors2)
|
||||||
|
|
||||||
|
assert np.array_equal(new_coors, target_coors)
|
||||||
|
assert np.array_equal(correspond_keys, target_keys)
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
|
||||||
from algorithm.neat.genome.graph import topological_sort, check_cycles
|
from algorithm.neat.genome.graph import topological_sort, check_cycles
|
||||||
from algorithm.neat.utils import I_INT
|
from algorithm.utils import I_INT
|
||||||
|
|
||||||
nodes = jnp.array([
|
nodes = jnp.array([
|
||||||
[0],
|
[0],
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from algorithm.neat.utils import unflatten_connections
|
from algorithm.utils import unflatten_connections
|
||||||
|
|
||||||
|
|
||||||
def test_unflatten():
|
def test_unflatten():
|
||||||
|
|||||||
Reference in New Issue
Block a user