complete HyperNEAT!

This commit is contained in:
wls2002
2023-07-21 15:03:12 +08:00
parent 80ee5ea2ea
commit 48f90c7eef
32 changed files with 432 additions and 136 deletions

View File

@@ -1,3 +1,4 @@
from .base import Algorithm
from .state import State
from .neat import NEAT
from .config import Configer
from .hyperneat import HyperNEAT

17
algorithm/base.py Normal file
View File

@@ -0,0 +1,17 @@
from typing import Callable
from .state import State
EMPTY = lambda *args: args
class Algorithm:
def __init__(self):
self.tell: Callable = EMPTY
self.ask: Callable = EMPTY
self.forward: Callable = EMPTY
self.forward_transform: Callable = EMPTY
def setup(self, randkey, state=State()):
pass

View File

@@ -1,69 +0,0 @@
import os
import warnings
import configparser
import numpy as np
class Configer:
@classmethod
def __load_default_config(cls):
par_dir = os.path.dirname(os.path.abspath(__file__))
default_config_path = os.path.join(par_dir, "default_config.ini")
return cls.__load_config(default_config_path)
@classmethod
def __load_config(cls, config_path):
c = configparser.ConfigParser()
c.read(config_path)
config = {}
for section in c.sections():
for key, value in c.items(section):
config[key] = eval(value)
return config
@classmethod
def __check_redundant_config(cls, default_config, config):
for key in config:
if key not in default_config:
warnings.warn(f"Redundant config: {key} in {config.name}")
@classmethod
def __complete_config(cls, default_config, config):
for key in default_config:
if key not in config:
config[key] = default_config[key]
@classmethod
def load_config(cls, config_path=None):
default_config = cls.__load_default_config()
if config_path is None:
config = {}
elif not os.path.exists(config_path):
warnings.warn(f"config file {config_path} not exist!")
config = {}
else:
config = cls.__load_config(config_path)
cls.__check_redundant_config(default_config, config)
cls.__complete_config(default_config, config)
cls.refactor_activation(config)
cls.refactor_aggregation(config)
config['input_idx'] = np.arange(config['num_inputs'])
config['output_idx'] = np.arange(config['num_inputs'], config['num_inputs'] + config['num_outputs'])
return config
@classmethod
def refactor_activation(cls, config):
config['activation_default'] = 0
config['activation_options'] = np.arange(len(config['activation_option_names']))
@classmethod
def refactor_aggregation(cls, config):
config['aggregation_default'] = 0
config['aggregation_options'] = np.arange(len(config['aggregation_option_names']))

View File

@@ -1,72 +0,0 @@
[basic]
num_inputs = 2
num_outputs = 1
maximum_nodes = 50
maximum_conns = 100
maximum_species = 10
forward_way = "pop"
batch_size = 4
random_seed = 0
network_type = "feedforward"
activate_times = 10
[population]
fitness_threshold = 3.9999
generation_limit = 1000
fitness_criterion = "max"
pop_size = 50000
[genome]
compatibility_disjoint = 1.0
compatibility_weight = 0.5
conn_add_prob = 0.4
conn_add_trials = 1
conn_delete_prob = 0
node_add_prob = 0.2
node_delete_prob = 0
[species]
compatibility_threshold = 3.0
species_elitism = 2
max_stagnation = 15
genome_elitism = 2
survival_threshold = 0.2
min_species_size = 1
spawn_number_change_rate = 0.5
[gene-bias]
bias_init_mean = 0.0
bias_init_std = 1.0
bias_mutate_power = 0.5
bias_mutate_rate = 0.7
bias_replace_rate = 0.1
[gene-response]
response_init_mean = 1.0
response_init_std = 0.0
response_mutate_power = 0.0
response_mutate_rate = 0.0
response_replace_rate = 0.0
[gene-activation]
activation_default = "sigmoid"
activation_option_names = ["sigmoid"]
activation_replace_rate = 0.0
[gene-aggregation]
aggregation_default = "sum"
aggregation_option_names = ["sum"]
aggregation_replace_rate = 0.0
[gene-weight]
weight_init_mean = 0.0
weight_init_std = 1.0
weight_mutate_power = 0.5
weight_mutate_rate = 0.8
weight_replace_rate = 0.1
[gene-enable]
enable_mutate_rate = 0.01
[visualize]
renumber_nodes = True

View File

@@ -0,0 +1,2 @@
from .hyperneat import HyperNEAT
from .substrate import BaseSubstrate

View File

@@ -0,0 +1,70 @@
from typing import Type
import jax
import numpy as np
from .substrate import BaseSubstrate, analysis_substrate
from .hyperneat_gene import HyperNEATGene
from algorithm import State, Algorithm, neat
class HyperNEAT(Algorithm):
def __init__(self, config, gene_type: Type[neat.BaseGene], substrate: Type[BaseSubstrate]):
super().__init__()
self.config = config
self.gene_type = gene_type
self.substrate = substrate
self.neat = neat.NEAT(config, gene_type)
self.tell = create_tell(self.neat)
self.forward_transform = create_forward_transform(config, self.neat)
self.forward = HyperNEATGene.create_forward(config)
def setup(self, randkey, state=State()):
state = state.update(
below_threshold=self.config['below_threshold'],
max_weight=self.config['max_weight']
)
state = self.substrate.setup(state, self.config)
h_input_idx, h_output_idx, h_hidden_idx, query_coors, correspond_keys = analysis_substrate(state)
h_nodes = np.concatenate((h_input_idx, h_output_idx, h_hidden_idx))[..., np.newaxis]
h_conns = np.zeros((correspond_keys.shape[0], 3), dtype=np.float32)
h_conns[:, 0:2] = correspond_keys
state = state.update(
# h is short for hyperneat
h_input_idx=h_input_idx,
h_output_idx=h_output_idx,
h_hidden_idx=h_hidden_idx,
query_coors=query_coors,
correspond_keys=correspond_keys,
h_nodes=h_nodes,
h_conns=h_conns
)
state = self.neat.setup(randkey, state=state)
self.config['h_input_idx'] = h_input_idx
self.config['h_output_idx'] = h_output_idx
return state
def create_tell(neat_instance):
def tell(state, fitness):
return neat_instance.tell(state, fitness)
return tell
def create_forward_transform(config, neat_instance):
def forward_transform(state, nodes, conns):
t = neat_instance.forward_transform(state, nodes, conns)
batch_forward_func = jax.vmap(neat_instance.forward, in_axes=(0, None))
query_res = batch_forward_func(state.query_coors, t) # hyperneat connections
h_nodes = state.h_nodes
h_conns = state.h_conns.at[:, 2:].set(query_res)
return HyperNEATGene.forward_transform(state, h_nodes, h_conns)
return forward_transform

View File

@@ -0,0 +1,54 @@
import jax
from jax import numpy as jnp, vmap
from algorithm.neat import BaseGene
from algorithm.neat.gene import Activation
from algorithm.neat.gene import Aggregation
class HyperNEATGene(BaseGene):
node_attrs = [] # no node attributes
conn_attrs = ['weight']
@staticmethod
def forward_transform(state, nodes, conns):
N = nodes.shape[0]
u_conns = jnp.zeros((N, N), dtype=jnp.float32)
in_keys = jnp.asarray(conns[:, 0], jnp.int32)
out_keys = jnp.asarray(conns[:, 1], jnp.int32)
weights = conns[:, 2]
u_conns = u_conns.at[in_keys, out_keys].set(weights)
return nodes, u_conns
@staticmethod
def create_forward(config):
act = Activation.name2func[config['h_activation']]
agg = Aggregation.name2func[config['h_aggregation']]
batch_act, batch_agg = vmap(act), vmap(agg)
def forward(inputs, transform):
inputs_with_bias = jnp.concatenate((inputs, jnp.ones((1,))), axis=0)
nodes, weights = transform
input_idx = config['h_input_idx']
output_idx = config['h_output_idx']
N = nodes.shape[0]
vals = jnp.full((N,), 0.)
def body_func(i, values):
values = values.at[input_idx].set(inputs_with_bias)
nodes_ins = values * weights.T
values = batch_agg(nodes_ins) # z = agg(ins)
values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias
values = batch_act(values) # z = act(z)
return values
vals = jax.lax.fori_loop(0, config['h_activate_times'], body_func, vals)
return vals[output_idx]
return forward

View File

@@ -0,0 +1,2 @@
from .base import BaseSubstrate
from .tools import analysis_substrate

View File

@@ -0,0 +1,12 @@
import numpy as np
class BaseSubstrate:
@staticmethod
def setup(state, config):
return state.update(
input_coors=np.asarray(config['input_coors'], dtype=np.float32),
output_coors=np.asarray(config['output_coors'], dtype=np.float32),
hidden_coors=np.asarray(config['hidden_coors'], dtype=np.float32),
)

View File

@@ -0,0 +1,53 @@
from typing import Type
import numpy as np
from .base import BaseSubstrate
def analysis_substrate(state):
cd = state.input_coors.shape[1] # coordinate dimensions
si = state.input_coors.shape[0] # input coordinate size
so = state.output_coors.shape[0] # output coordinate size
sh = state.hidden_coors.shape[0] # hidden coordinate size
input_idx = np.arange(si)
output_idx = np.arange(si, si + so)
hidden_idx = np.arange(si + so, si + so + sh)
total_conns = si * sh + sh * sh + sh * so
query_coors = np.zeros((total_conns, cd * 2))
correspond_keys = np.zeros((total_conns, 2))
# connect input to hidden
aux_coors, aux_keys = cartesian_product(input_idx, hidden_idx, state.input_coors, state.hidden_coors)
query_coors[0: si * sh, :] = aux_coors
correspond_keys[0: si * sh, :] = aux_keys
# connect hidden to hidden
aux_coors, aux_keys = cartesian_product(hidden_idx, hidden_idx, state.hidden_coors, state.hidden_coors)
query_coors[si * sh: si * sh + sh * sh, :] = aux_coors
correspond_keys[si * sh: si * sh + sh * sh, :] = aux_keys
# connect hidden to output
aux_coors, aux_keys = cartesian_product(hidden_idx, output_idx, state.hidden_coors, state.output_coors)
query_coors[si * sh + sh * sh:, :] = aux_coors
correspond_keys[si * sh + sh * sh:, :] = aux_keys
return input_idx, output_idx, hidden_idx, query_coors, correspond_keys
def cartesian_product(keys1, keys2, coors1, coors2):
len1 = keys1.shape[0]
len2 = keys2.shape[0]
repeated_coors1 = np.repeat(coors1, len2, axis=0)
repeated_keys1 = np.repeat(keys1, len2)
tiled_coors2 = np.tile(coors2, (len1, 1))
tiled_keys2 = np.tile(keys2, len1)
new_coors = np.concatenate((repeated_coors1, tiled_coors2), axis=1)
correspond_keys = np.column_stack((repeated_keys1, tiled_keys2))
return new_coors, correspond_keys

View File

@@ -1,3 +1,2 @@
from .neat import NEAT
from .gene import NormalGene, RecurrentGene
from .pipeline import Pipeline
from .gene import BaseGene, NormalGene, RecurrentGene

View File

@@ -33,12 +33,10 @@ class BaseGene:
def distance_conn(state, conn1: Array, conn2: Array):
return conn1
@staticmethod
def forward_transform(nodes, conns):
def forward_transform(state, nodes, conns):
return nodes, conns
@staticmethod
def create_forward(config):
return None
return None

View File

@@ -4,7 +4,7 @@ from jax import Array, numpy as jnp
from .base import BaseGene
from .activation import Activation
from .aggregation import Aggregation
from ..utils import unflatten_connections, I_INT
from algorithm.utils import unflatten_connections, I_INT
from ..genome import topological_sort
@@ -84,7 +84,7 @@ class NormalGene(BaseGene):
return (con1[2] != con2[2]) + jnp.abs(con1[3] - con2[3]) # enable + weight
@staticmethod
def forward_transform(nodes, conns):
def forward_transform(state, nodes, conns):
u_conns = unflatten_connections(nodes, conns)
conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False)

View File

@@ -4,13 +4,13 @@ from jax import Array, numpy as jnp, vmap
from .normal import NormalGene
from .activation import Activation
from .aggregation import Aggregation
from ..utils import unflatten_connections, I_INT
from algorithm.utils import unflatten_connections
class RecurrentGene(NormalGene):
@staticmethod
def forward_transform(nodes, conns):
def forward_transform(state, nodes, conns):
u_conns = unflatten_connections(nodes, conns)
# remove un-enable connections and remove enable attr

View File

@@ -6,7 +6,7 @@ from jax import Array, numpy as jnp
from algorithm import State
from ..gene import BaseGene
from ..utils import fetch_first
from algorithm.utils import fetch_first
def initialize_genomes(state: State, gene_type: Type[BaseGene]):
@@ -48,6 +48,7 @@ def count(nodes: Array, cons: Array):
cons_cnt = jnp.sum(~jnp.isnan(cons[:, 0]))
return node_cnt, cons_cnt
def add_node(nodes: Array, cons: Array, new_key: int, attrs: Array) -> Tuple[Array, Array]:
"""
Add a new node to the genome.

View File

@@ -6,7 +6,7 @@ Only used in feed-forward networks.
import jax
from jax import jit, Array, numpy as jnp
from ..utils import fetch_first, I_INT
from algorithm.utils import fetch_first, I_INT
@jit

View File

@@ -4,9 +4,9 @@ import jax
from jax import Array, numpy as jnp, vmap
from algorithm import State
from .basic import add_node, add_connection, delete_node_by_idx, delete_connection_by_idx, count
from .basic import add_node, add_connection, delete_node_by_idx, delete_connection_by_idx
from .graph import check_cycles
from ..utils import fetch_random, fetch_first, I_INT, unflatten_connections
from algorithm.utils import fetch_random, fetch_first, I_INT, unflatten_connections
from ..gene import BaseGene

View File

@@ -3,22 +3,25 @@ from typing import Type
import jax
import jax.numpy as jnp
from algorithm.state import State
from algorithm import Algorithm, State
from .gene import BaseGene
from .genome import initialize_genomes
from .population import create_tell
class NEAT:
class NEAT(Algorithm):
def __init__(self, config, gene_type: Type[BaseGene]):
super().__init__()
self.config = config
self.gene_type = gene_type
self.tell_func = jax.jit(create_tell(config, self.gene_type))
self.tell = create_tell(config, self.gene_type)
self.ask = None
self.forward = self.gene_type.create_forward(config)
self.forward_transform = self.gene_type.forward_transform
def setup(self, randkey):
state = State(
def setup(self, randkey, state=State()):
state = state.update(
P=self.config['pop_size'],
N=self.config['maximum_nodes'],
C=self.config['maximum_conns'],
@@ -69,7 +72,4 @@ class NEAT:
# move to device
state = jax.device_put(state)
return state
def step(self, state, fitness):
return self.tell_func(state, fitness)
return state

View File

@@ -1,77 +0,0 @@
import time
from typing import Union, Callable
import jax
from jax import vmap, jit
import numpy as np
class Pipeline:
"""
Neat algorithm pipeline.
"""
def __init__(self, config, algorithm):
self.config = config
self.algorithm = algorithm
randkey = jax.random.PRNGKey(config['random_seed'])
self.state = algorithm.setup(randkey)
self.best_genome = None
self.best_fitness = float('-inf')
self.generation_timestamp = time.time()
self.evaluate_time = 0
self.forward_func = algorithm.gene_type.create_forward(config)
self.batch_forward_func = jit(vmap(self.forward_func, in_axes=(0, None)))
self.pop_batch_forward_func = jit(vmap(self.batch_forward_func, in_axes=(None, 0)))
self.pop_transform_func = jit(vmap(algorithm.gene_type.forward_transform))
def ask(self):
pop_transforms = self.pop_transform_func(self.state.pop_nodes, self.state.pop_conns)
return lambda inputs: self.pop_batch_forward_func(inputs, pop_transforms)
def tell(self, fitness):
self.state = self.algorithm.step(self.state, fitness)
def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"):
for _ in range(self.config['generation_limit']):
forward_func = self.ask()
fitnesses = fitness_func(forward_func)
if analysis is not None:
if analysis == "default":
self.default_analysis(fitnesses)
else:
assert callable(analysis), f"What the fuck you passed in? A {analysis}?"
analysis(fitnesses)
if max(fitnesses) >= self.config['fitness_threshold']:
print("Fitness limit reached!")
return self.best_genome
self.tell(fitnesses)
print("Generation limit reached!")
return self.best_genome
def default_analysis(self, fitnesses):
max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses)
new_timestamp = time.time()
cost_time = new_timestamp - self.generation_timestamp
self.generation_timestamp = new_timestamp
max_idx = np.argmax(fitnesses)
if fitnesses[max_idx] > self.best_fitness:
self.best_fitness = fitnesses[max_idx]
self.best_genome = (self.state.pop_nodes[max_idx], self.state.pop_conns[max_idx])
member_count = jax.device_get(self.state.species_info[:, 3])
species_sizes = [int(i) for i in member_count if i > 0]
print(f"Generation: {self.state.generation}",
f"species: {len(species_sizes)}, {species_sizes}",
f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Cost time: {cost_time}")

View File

@@ -3,7 +3,7 @@ from typing import Type
import jax
from jax import numpy as jnp, vmap
from .utils import rank_elements, fetch_first
from algorithm.utils import rank_elements, fetch_first
from .genome import create_mutate, create_distance, crossover
from .gene import BaseGene