remove create_func....
This commit is contained in:
@@ -1 +1,2 @@
|
|||||||
from .neat import NEAT
|
from .neat import NEAT
|
||||||
|
from .hyperneat import HyperNEAT
|
||||||
|
|||||||
2
algorithm/hyperneat/__init__.py
Normal file
2
algorithm/hyperneat/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
from .hyperneat import HyperNEAT
|
||||||
|
from .substrate import NormalSubstrate, NormalSubstrateConfig
|
||||||
116
algorithm/hyperneat/hyperneat.py
Normal file
116
algorithm/hyperneat/hyperneat.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
from typing import Type
|
||||||
|
|
||||||
|
import jax
|
||||||
|
from jax import numpy as jnp, Array, vmap
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from config import Config, HyperNeatConfig
|
||||||
|
from core import Algorithm, Substrate, State, Genome
|
||||||
|
from utils import Activation, Aggregation
|
||||||
|
from algorithm.neat import NEAT
|
||||||
|
from .substrate import analysis_substrate
|
||||||
|
|
||||||
|
|
||||||
|
class HyperNEAT(Algorithm):
|
||||||
|
|
||||||
|
def __init__(self, config: Config, neat: NEAT, substrate: Type[Substrate]):
|
||||||
|
self.config = config
|
||||||
|
self.neat = neat
|
||||||
|
self.substrate = substrate
|
||||||
|
|
||||||
|
def setup(self, randkey, state=State()):
|
||||||
|
neat_key, randkey = jax.random.split(randkey)
|
||||||
|
state = state.update(
|
||||||
|
below_threshold=self.config.hyper_neat.below_threshold,
|
||||||
|
max_weight=self.config.hyper_neat.max_weight,
|
||||||
|
)
|
||||||
|
state = self.neat.setup(neat_key, state)
|
||||||
|
state = self.substrate.setup(self.config.substrate, state)
|
||||||
|
|
||||||
|
assert self.config.hyper_neat.inputs + 1 == state.input_coors.shape[0] # +1 for bias
|
||||||
|
assert self.config.hyper_neat.outputs == state.output_coors.shape[0]
|
||||||
|
|
||||||
|
h_input_idx, h_output_idx, h_hidden_idx, query_coors, correspond_keys = analysis_substrate(state)
|
||||||
|
h_nodes = np.concatenate((h_input_idx, h_output_idx, h_hidden_idx))[..., np.newaxis]
|
||||||
|
h_conns = np.zeros((correspond_keys.shape[0], 3), dtype=np.float32)
|
||||||
|
h_conns[:, 0:2] = correspond_keys
|
||||||
|
|
||||||
|
state = state.update(
|
||||||
|
h_input_idx=h_input_idx,
|
||||||
|
h_output_idx=h_output_idx,
|
||||||
|
h_hidden_idx=h_hidden_idx,
|
||||||
|
h_nodes=h_nodes,
|
||||||
|
h_conns=h_conns,
|
||||||
|
query_coors=query_coors,
|
||||||
|
)
|
||||||
|
|
||||||
|
return state
|
||||||
|
|
||||||
|
def ask_algorithm(self, state: State):
|
||||||
|
return state.pop_genomes
|
||||||
|
|
||||||
|
def tell_algorithm(self, state: State, fitness):
|
||||||
|
return self.neat.tell(state, fitness)
|
||||||
|
|
||||||
|
def forward(self, state, inputs: Array, transformed: Array):
|
||||||
|
return HyperNEATGene.forward(self.config.hyper_neat, state, inputs, transformed)
|
||||||
|
|
||||||
|
def forward_transform(self, state: State, genome: Genome):
|
||||||
|
t = self.neat.forward_transform(state, genome)
|
||||||
|
query_res = vmap(self.neat.forward, in_axes=(None, 0, None))(state, state.query_coors, t)
|
||||||
|
|
||||||
|
# mute the connection with weight below threshold
|
||||||
|
query_res = jnp.where((-state.below_threshold < query_res) & (query_res < state.below_threshold), 0., query_res)
|
||||||
|
|
||||||
|
# make query res in range [-max_weight, max_weight]
|
||||||
|
query_res = jnp.where(query_res > 0, query_res - state.below_threshold, query_res)
|
||||||
|
query_res = jnp.where(query_res < 0, query_res + state.below_threshold, query_res)
|
||||||
|
query_res = query_res / (1 - state.below_threshold) * state.max_weight
|
||||||
|
|
||||||
|
h_conns = state.h_conns.at[:, 2:].set(query_res)
|
||||||
|
return HyperNEATGene.forward_transform(Genome(state.h_nodes, h_conns))
|
||||||
|
|
||||||
|
|
||||||
|
class HyperNEATGene:
|
||||||
|
node_attrs = [] # no node attributes
|
||||||
|
conn_attrs = ['weight']
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward_transform(genome: Genome):
|
||||||
|
N = genome.nodes.shape[0]
|
||||||
|
u_conns = jnp.zeros((N, N), dtype=jnp.float32)
|
||||||
|
|
||||||
|
in_keys = jnp.asarray(genome.conns[:, 0], jnp.int32)
|
||||||
|
out_keys = jnp.asarray(genome.conns[:, 1], jnp.int32)
|
||||||
|
weights = genome.conns[:, 2]
|
||||||
|
|
||||||
|
u_conns = u_conns.at[in_keys, out_keys].set(weights)
|
||||||
|
return genome.nodes, u_conns
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(config: HyperNeatConfig, state: State, inputs, transformed):
|
||||||
|
act = Activation.name2func[config.activation]
|
||||||
|
agg = Aggregation.name2func[config.aggregation]
|
||||||
|
|
||||||
|
batch_act, batch_agg = jax.vmap(act), jax.vmap(agg)
|
||||||
|
|
||||||
|
nodes, weights = transformed
|
||||||
|
|
||||||
|
inputs_with_bias = jnp.concatenate((inputs, jnp.ones((1,))), axis=0)
|
||||||
|
|
||||||
|
input_idx = state.h_input_idx
|
||||||
|
output_idx = state.h_output_idx
|
||||||
|
|
||||||
|
N = nodes.shape[0]
|
||||||
|
vals = jnp.full((N,), 0.)
|
||||||
|
|
||||||
|
def body_func(i, values):
|
||||||
|
values = values.at[input_idx].set(inputs_with_bias)
|
||||||
|
nodes_ins = values * weights.T
|
||||||
|
values = batch_agg(nodes_ins) # z = agg(ins)
|
||||||
|
values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias
|
||||||
|
values = batch_act(values) # z = act(z)
|
||||||
|
return values
|
||||||
|
|
||||||
|
vals = jax.lax.fori_loop(0, config.activate_times, body_func, vals)
|
||||||
|
return vals[output_idx]
|
||||||
2
algorithm/hyperneat/substrate/__init__.py
Normal file
2
algorithm/hyperneat/substrate/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
from .normal import NormalSubstrate, NormalSubstrateConfig
|
||||||
|
from .tools import analysis_substrate
|
||||||
25
algorithm/hyperneat/substrate/normal.py
Normal file
25
algorithm/hyperneat/substrate/normal.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from core import Substrate, State
|
||||||
|
from config import SubstrateConfig
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class NormalSubstrateConfig(SubstrateConfig):
|
||||||
|
input_coors: Tuple[Tuple[float]] = ((-1, -1), (0, -1), (1, -1))
|
||||||
|
hidden_coors: Tuple[Tuple[float]] = ((-1, 0), (0, 0), (1, 0))
|
||||||
|
output_coors: Tuple[Tuple[float]] = ((0, 1),)
|
||||||
|
|
||||||
|
|
||||||
|
class NormalSubstrate(Substrate):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def setup(config: NormalSubstrateConfig, state: State = State()):
|
||||||
|
return state.update(
|
||||||
|
input_coors=np.asarray(config.input_coors, dtype=np.float32),
|
||||||
|
output_coors=np.asarray(config.output_coors, dtype=np.float32),
|
||||||
|
hidden_coors=np.asarray(config.hidden_coors, dtype=np.float32),
|
||||||
|
)
|
||||||
49
algorithm/hyperneat/substrate/tools.py
Normal file
49
algorithm/hyperneat/substrate/tools.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def analysis_substrate(state):
|
||||||
|
cd = state.input_coors.shape[1] # coordinate dimensions
|
||||||
|
si = state.input_coors.shape[0] # input coordinate size
|
||||||
|
so = state.output_coors.shape[0] # output coordinate size
|
||||||
|
sh = state.hidden_coors.shape[0] # hidden coordinate size
|
||||||
|
|
||||||
|
input_idx = np.arange(si)
|
||||||
|
output_idx = np.arange(si, si + so)
|
||||||
|
hidden_idx = np.arange(si + so, si + so + sh)
|
||||||
|
|
||||||
|
total_conns = si * sh + sh * sh + sh * so
|
||||||
|
query_coors = np.zeros((total_conns, cd * 2))
|
||||||
|
correspond_keys = np.zeros((total_conns, 2))
|
||||||
|
|
||||||
|
# connect input to hidden
|
||||||
|
aux_coors, aux_keys = cartesian_product(input_idx, hidden_idx, state.input_coors, state.hidden_coors)
|
||||||
|
query_coors[0: si * sh, :] = aux_coors
|
||||||
|
correspond_keys[0: si * sh, :] = aux_keys
|
||||||
|
|
||||||
|
# connect hidden to hidden
|
||||||
|
aux_coors, aux_keys = cartesian_product(hidden_idx, hidden_idx, state.hidden_coors, state.hidden_coors)
|
||||||
|
query_coors[si * sh: si * sh + sh * sh, :] = aux_coors
|
||||||
|
correspond_keys[si * sh: si * sh + sh * sh, :] = aux_keys
|
||||||
|
|
||||||
|
# connect hidden to output
|
||||||
|
aux_coors, aux_keys = cartesian_product(hidden_idx, output_idx, state.hidden_coors, state.output_coors)
|
||||||
|
query_coors[si * sh + sh * sh:, :] = aux_coors
|
||||||
|
correspond_keys[si * sh + sh * sh:, :] = aux_keys
|
||||||
|
|
||||||
|
return input_idx, output_idx, hidden_idx, query_coors, correspond_keys
|
||||||
|
|
||||||
|
|
||||||
|
def cartesian_product(keys1, keys2, coors1, coors2):
|
||||||
|
len1 = keys1.shape[0]
|
||||||
|
len2 = keys2.shape[0]
|
||||||
|
|
||||||
|
repeated_coors1 = np.repeat(coors1, len2, axis=0)
|
||||||
|
repeated_keys1 = np.repeat(keys1, len2)
|
||||||
|
|
||||||
|
tiled_coors2 = np.tile(coors2, (len1, 1))
|
||||||
|
tiled_keys2 = np.tile(keys2, len1)
|
||||||
|
|
||||||
|
new_coors = np.concatenate((repeated_coors1, tiled_coors2), axis=1)
|
||||||
|
correspond_keys = np.column_stack((repeated_keys1, tiled_keys2))
|
||||||
|
|
||||||
|
return new_coors, correspond_keys
|
||||||
@@ -1 +1,3 @@
|
|||||||
from .normal import NormalGene, NormalGeneConfig
|
from .normal import NormalGene, NormalGeneConfig
|
||||||
|
from .recurrent import RecurrentGene, RecurrentGeneConfig
|
||||||
|
|
||||||
|
|||||||
@@ -24,11 +24,11 @@ class NormalGeneConfig(GeneConfig):
|
|||||||
response_replace_rate: float = 0.1
|
response_replace_rate: float = 0.1
|
||||||
|
|
||||||
activation_default: str = 'sigmoid'
|
activation_default: str = 'sigmoid'
|
||||||
activation_options: Tuple[str] = ('sigmoid',)
|
activation_options: Tuple = ('sigmoid',)
|
||||||
activation_replace_rate: float = 0.1
|
activation_replace_rate: float = 0.1
|
||||||
|
|
||||||
aggregation_default: str = 'sum'
|
aggregation_default: str = 'sum'
|
||||||
aggregation_options: Tuple[str] = ('sum',)
|
aggregation_options: Tuple = ('sum',)
|
||||||
aggregation_replace_rate: float = 0.1
|
aggregation_replace_rate: float = 0.1
|
||||||
|
|
||||||
weight_init_mean: float = 0.0
|
weight_init_mean: float = 0.0
|
||||||
|
|||||||
57
algorithm/neat/gene/recurrent.py
Normal file
57
algorithm/neat/gene/recurrent.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import jax
|
||||||
|
from jax import numpy as jnp, vmap
|
||||||
|
|
||||||
|
from .normal import NormalGene, NormalGeneConfig
|
||||||
|
from core import State, Genome
|
||||||
|
from utils import unflatten_conns, act, agg
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class RecurrentGeneConfig(NormalGeneConfig):
|
||||||
|
activate_times: int = 10
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
super().__post_init__()
|
||||||
|
assert self.activate_times > 0
|
||||||
|
|
||||||
|
|
||||||
|
class RecurrentGene(NormalGene):
|
||||||
|
|
||||||
|
def __init__(self, config: RecurrentGeneConfig):
|
||||||
|
self.config = config
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
def forward_transform(self, state: State, genome: Genome):
|
||||||
|
u_conns = unflatten_conns(genome.nodes, genome.conns)
|
||||||
|
|
||||||
|
# remove un-enable connections and remove enable attr
|
||||||
|
conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False)
|
||||||
|
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
|
||||||
|
|
||||||
|
return genome.nodes, u_conns
|
||||||
|
|
||||||
|
def forward(self, state: State, inputs, transformed):
|
||||||
|
nodes, conns = transformed
|
||||||
|
|
||||||
|
batch_act, batch_agg = vmap(act, in_axes=(0, 0, None)), vmap(agg, in_axes=(0, 0, None))
|
||||||
|
|
||||||
|
input_idx = state.input_idx
|
||||||
|
output_idx = state.output_idx
|
||||||
|
|
||||||
|
N = nodes.shape[0]
|
||||||
|
vals = jnp.full((N,), 0.)
|
||||||
|
|
||||||
|
weights = conns[0, :]
|
||||||
|
|
||||||
|
def body_func(i, values):
|
||||||
|
values = values.at[input_idx].set(inputs)
|
||||||
|
nodes_ins = values * weights.T
|
||||||
|
values = batch_agg(nodes[:, 4], nodes_ins, self.agg_funcs) # z = agg(ins)
|
||||||
|
values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias
|
||||||
|
values = batch_act(nodes[:, 3], values, self.act_funcs) # z = act(z)
|
||||||
|
return values
|
||||||
|
|
||||||
|
vals = jax.lax.fori_loop(0, self.config.activate_times, body_func, vals)
|
||||||
|
return vals[output_idx]
|
||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from typing import Type
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
from jax import numpy as jnp
|
from jax import numpy as jnp
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -10,9 +12,9 @@ from .species import SpeciesInfo, update_species, speciate
|
|||||||
|
|
||||||
class NEAT(Algorithm):
|
class NEAT(Algorithm):
|
||||||
|
|
||||||
def __init__(self, config: Config, gene: Gene):
|
def __init__(self, config: Config, gene_type: Type[Gene]):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.gene = gene
|
self.gene = gene_type(config.gene)
|
||||||
|
|
||||||
self.forward_func = None
|
self.forward_func = None
|
||||||
self.tell_func = None
|
self.tell_func = None
|
||||||
|
|||||||
@@ -92,6 +92,11 @@ class SubstrateConfig:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ProblemConfig:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class Config:
|
class Config:
|
||||||
basic: BasicConfig = BasicConfig()
|
basic: BasicConfig = BasicConfig()
|
||||||
@@ -99,3 +104,4 @@ class Config:
|
|||||||
hyper_neat: HyperNeatConfig = HyperNeatConfig()
|
hyper_neat: HyperNeatConfig = HyperNeatConfig()
|
||||||
gene: GeneConfig = GeneConfig()
|
gene: GeneConfig = GeneConfig()
|
||||||
substrate: SubstrateConfig = SubstrateConfig()
|
substrate: SubstrateConfig = SubstrateConfig()
|
||||||
|
problem: ProblemConfig = ProblemConfig()
|
||||||
|
|||||||
@@ -2,4 +2,5 @@ from .algorithm import Algorithm
|
|||||||
from .state import State
|
from .state import State
|
||||||
from .genome import Genome
|
from .genome import Genome
|
||||||
from .gene import Gene
|
from .gene import Gene
|
||||||
from .substrate import Substrate
|
from .substrate import Substrate
|
||||||
|
from .problem import Problem
|
||||||
|
|||||||
@@ -6,6 +6,9 @@ class Gene:
|
|||||||
node_attrs = []
|
node_attrs = []
|
||||||
conn_attrs = []
|
conn_attrs = []
|
||||||
|
|
||||||
|
def __init__(self, config: GeneConfig):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def setup(self, state=State()):
|
def setup(self, state=State()):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|||||||
15
core/problem.py
Normal file
15
core/problem.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
from typing import Callable
|
||||||
|
from config import ProblemConfig
|
||||||
|
from state import State
|
||||||
|
|
||||||
|
|
||||||
|
class Problem:
|
||||||
|
|
||||||
|
def __init__(self, config: ProblemConfig):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def setup(self, state=State()):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def evaluate(self, state: State, act_func: Callable, params):
|
||||||
|
raise NotImplementedError
|
||||||
@@ -6,3 +6,5 @@ class Substrate:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def setup(state, config: SubstrateConfig):
|
def setup(state, config: SubstrateConfig):
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,24 +1,31 @@
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
import jax
|
import jax
|
||||||
|
|
||||||
|
from utils import unflatten_conns, act, agg, Activation, Aggregation
|
||||||
|
from algorithm.neat.gene import RecurrentGeneConfig
|
||||||
|
|
||||||
|
config = RecurrentGeneConfig(
|
||||||
|
activation_options=("tanh", "sigmoid"),
|
||||||
|
activation_default="tanh",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class A:
|
class A:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.a = 1
|
self.act_funcs = [Activation.name2func[name] for name in config.activation_options]
|
||||||
self.b = 2
|
self.agg_funcs = [Aggregation.name2func[name] for name in config.aggregation_options]
|
||||||
self.isTrue = False
|
self.isTrue = False
|
||||||
|
|
||||||
@partial(jax.jit, static_argnums=(0,))
|
@partial(jax.jit, static_argnums=(0,))
|
||||||
def step(self):
|
def step(self):
|
||||||
if self.isTrue:
|
i = jax.numpy.array([0, 1])
|
||||||
return self.a + 1
|
z = jax.numpy.array([
|
||||||
else:
|
[1, 1],
|
||||||
return self.b + 1
|
[2, 2]
|
||||||
|
])
|
||||||
|
print(self.act_funcs)
|
||||||
|
return jax.vmap(act, in_axes=(0, 0, None))(i, z, self.act_funcs)
|
||||||
|
|
||||||
|
|
||||||
AA = A()
|
AA = A()
|
||||||
print(AA.step(), hash(AA))
|
print(AA.step())
|
||||||
print(AA.step(), hash(AA))
|
|
||||||
print(AA.step(), hash(AA))
|
|
||||||
AA.a = (2, 3, 4)
|
|
||||||
print(AA.step(), hash(AA))
|
|
||||||
|
|||||||
@@ -28,11 +28,13 @@ if __name__ == '__main__':
|
|||||||
pop_size=10000
|
pop_size=10000
|
||||||
),
|
),
|
||||||
neat=NeatConfig(
|
neat=NeatConfig(
|
||||||
maximum_nodes=20,
|
maximum_nodes=50,
|
||||||
maximum_conns=50,
|
maximum_conns=100,
|
||||||
)
|
compatibility_threshold=4
|
||||||
|
),
|
||||||
|
gene=NormalGeneConfig()
|
||||||
)
|
)
|
||||||
normal_gene = NormalGene(NormalGeneConfig())
|
|
||||||
algorithm = NEAT(config, normal_gene)
|
algorithm = NEAT(config, NormalGene)
|
||||||
pipeline = Pipeline(config, algorithm)
|
pipeline = Pipeline(config, algorithm)
|
||||||
pipeline.auto_run(evaluate)
|
pipeline.auto_run(evaluate)
|
||||||
|
|||||||
49
examples/xor_hyperneat.py
Normal file
49
examples/xor_hyperneat.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
import jax
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from config import Config, BasicConfig, NeatConfig
|
||||||
|
from pipeline import Pipeline
|
||||||
|
from algorithm import NEAT, HyperNEAT
|
||||||
|
from algorithm.neat.gene import RecurrentGene, RecurrentGeneConfig
|
||||||
|
from algorithm.hyperneat.substrate import NormalSubstrate, NormalSubstrateConfig
|
||||||
|
|
||||||
|
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
|
||||||
|
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(forward_func):
|
||||||
|
"""
|
||||||
|
:param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size)
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
outs = forward_func(xor_inputs)
|
||||||
|
outs = jax.device_get(outs)
|
||||||
|
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
|
||||||
|
return fitnesses
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
config = Config(
|
||||||
|
basic=BasicConfig(
|
||||||
|
fitness_target=3.99999,
|
||||||
|
pop_size=10000
|
||||||
|
),
|
||||||
|
neat=NeatConfig(
|
||||||
|
network_type="recurrent",
|
||||||
|
maximum_nodes=50,
|
||||||
|
maximum_conns=100,
|
||||||
|
inputs=4,
|
||||||
|
outputs=1
|
||||||
|
|
||||||
|
),
|
||||||
|
gene=RecurrentGeneConfig(
|
||||||
|
activation_default="tanh",
|
||||||
|
activation_options=("tanh",),
|
||||||
|
),
|
||||||
|
substrate=NormalSubstrateConfig(),
|
||||||
|
)
|
||||||
|
neat = NEAT(config, RecurrentGene)
|
||||||
|
hyperNEAT = HyperNEAT(config, neat, NormalSubstrate)
|
||||||
|
|
||||||
|
pipeline = Pipeline(config, hyperNEAT)
|
||||||
|
pipeline.auto_run(evaluate)
|
||||||
42
examples/xor_recurrent.py
Normal file
42
examples/xor_recurrent.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
import jax
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from config import Config, BasicConfig, NeatConfig
|
||||||
|
from pipeline import Pipeline
|
||||||
|
from algorithm import NEAT
|
||||||
|
from algorithm.neat.gene import RecurrentGene, RecurrentGeneConfig
|
||||||
|
|
||||||
|
|
||||||
|
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
|
||||||
|
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(forward_func):
|
||||||
|
"""
|
||||||
|
:param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size)
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
outs = forward_func(xor_inputs)
|
||||||
|
outs = jax.device_get(outs)
|
||||||
|
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
|
||||||
|
return fitnesses
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
config = Config(
|
||||||
|
basic=BasicConfig(
|
||||||
|
fitness_target=3.99999,
|
||||||
|
pop_size=10000
|
||||||
|
),
|
||||||
|
neat=NeatConfig(
|
||||||
|
network_type="recurrent",
|
||||||
|
maximum_nodes=50,
|
||||||
|
maximum_conns=100
|
||||||
|
),
|
||||||
|
gene=RecurrentGeneConfig(
|
||||||
|
activate_times=3
|
||||||
|
)
|
||||||
|
)
|
||||||
|
algorithm = NEAT(config, RecurrentGene)
|
||||||
|
pipeline = Pipeline(config, algorithm)
|
||||||
|
pipeline.auto_run(evaluate)
|
||||||
0
problem/__init__.py
Normal file
0
problem/__init__.py
Normal file
0
problem/func_fit/__init__.py
Normal file
0
problem/func_fit/__init__.py
Normal file
21
problem/func_fit/func_fitting.py
Normal file
21
problem/func_fit/func_fitting.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
from config import ProblemConfig
|
||||||
|
from core import Problem, State
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class FuncFitConfig:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class FuncFit(Problem):
|
||||||
|
def __init__(self, config: ProblemConfig):
|
||||||
|
self.config = ProblemConfig
|
||||||
|
|
||||||
|
def setup(self, state=State()):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def evaluate(self, state: State, act_func: Callable, params):
|
||||||
|
pass
|
||||||
Reference in New Issue
Block a user