remove create_func....

This commit is contained in:
wls2002
2023-08-02 15:02:08 +08:00
parent 1499e062fe
commit c7fb1ddabe
22 changed files with 425 additions and 21 deletions

View File

@@ -1 +1,2 @@
from .neat import NEAT
from .hyperneat import HyperNEAT

View File

@@ -0,0 +1,2 @@
from .hyperneat import HyperNEAT
from .substrate import NormalSubstrate, NormalSubstrateConfig

View File

@@ -0,0 +1,116 @@
from typing import Type
import jax
from jax import numpy as jnp, Array, vmap
import numpy as np
from config import Config, HyperNeatConfig
from core import Algorithm, Substrate, State, Genome
from utils import Activation, Aggregation
from algorithm.neat import NEAT
from .substrate import analysis_substrate
class HyperNEAT(Algorithm):
def __init__(self, config: Config, neat: NEAT, substrate: Type[Substrate]):
self.config = config
self.neat = neat
self.substrate = substrate
def setup(self, randkey, state=State()):
neat_key, randkey = jax.random.split(randkey)
state = state.update(
below_threshold=self.config.hyper_neat.below_threshold,
max_weight=self.config.hyper_neat.max_weight,
)
state = self.neat.setup(neat_key, state)
state = self.substrate.setup(self.config.substrate, state)
assert self.config.hyper_neat.inputs + 1 == state.input_coors.shape[0] # +1 for bias
assert self.config.hyper_neat.outputs == state.output_coors.shape[0]
h_input_idx, h_output_idx, h_hidden_idx, query_coors, correspond_keys = analysis_substrate(state)
h_nodes = np.concatenate((h_input_idx, h_output_idx, h_hidden_idx))[..., np.newaxis]
h_conns = np.zeros((correspond_keys.shape[0], 3), dtype=np.float32)
h_conns[:, 0:2] = correspond_keys
state = state.update(
h_input_idx=h_input_idx,
h_output_idx=h_output_idx,
h_hidden_idx=h_hidden_idx,
h_nodes=h_nodes,
h_conns=h_conns,
query_coors=query_coors,
)
return state
def ask_algorithm(self, state: State):
return state.pop_genomes
def tell_algorithm(self, state: State, fitness):
return self.neat.tell(state, fitness)
def forward(self, state, inputs: Array, transformed: Array):
return HyperNEATGene.forward(self.config.hyper_neat, state, inputs, transformed)
def forward_transform(self, state: State, genome: Genome):
t = self.neat.forward_transform(state, genome)
query_res = vmap(self.neat.forward, in_axes=(None, 0, None))(state, state.query_coors, t)
# mute the connection with weight below threshold
query_res = jnp.where((-state.below_threshold < query_res) & (query_res < state.below_threshold), 0., query_res)
# make query res in range [-max_weight, max_weight]
query_res = jnp.where(query_res > 0, query_res - state.below_threshold, query_res)
query_res = jnp.where(query_res < 0, query_res + state.below_threshold, query_res)
query_res = query_res / (1 - state.below_threshold) * state.max_weight
h_conns = state.h_conns.at[:, 2:].set(query_res)
return HyperNEATGene.forward_transform(Genome(state.h_nodes, h_conns))
class HyperNEATGene:
node_attrs = [] # no node attributes
conn_attrs = ['weight']
@staticmethod
def forward_transform(genome: Genome):
N = genome.nodes.shape[0]
u_conns = jnp.zeros((N, N), dtype=jnp.float32)
in_keys = jnp.asarray(genome.conns[:, 0], jnp.int32)
out_keys = jnp.asarray(genome.conns[:, 1], jnp.int32)
weights = genome.conns[:, 2]
u_conns = u_conns.at[in_keys, out_keys].set(weights)
return genome.nodes, u_conns
@staticmethod
def forward(config: HyperNeatConfig, state: State, inputs, transformed):
act = Activation.name2func[config.activation]
agg = Aggregation.name2func[config.aggregation]
batch_act, batch_agg = jax.vmap(act), jax.vmap(agg)
nodes, weights = transformed
inputs_with_bias = jnp.concatenate((inputs, jnp.ones((1,))), axis=0)
input_idx = state.h_input_idx
output_idx = state.h_output_idx
N = nodes.shape[0]
vals = jnp.full((N,), 0.)
def body_func(i, values):
values = values.at[input_idx].set(inputs_with_bias)
nodes_ins = values * weights.T
values = batch_agg(nodes_ins) # z = agg(ins)
values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias
values = batch_act(values) # z = act(z)
return values
vals = jax.lax.fori_loop(0, config.activate_times, body_func, vals)
return vals[output_idx]

View File

@@ -0,0 +1,2 @@
from .normal import NormalSubstrate, NormalSubstrateConfig
from .tools import analysis_substrate

View File

@@ -0,0 +1,25 @@
from dataclasses import dataclass
from typing import Tuple
import numpy as np
from core import Substrate, State
from config import SubstrateConfig
@dataclass(frozen=True)
class NormalSubstrateConfig(SubstrateConfig):
input_coors: Tuple[Tuple[float]] = ((-1, -1), (0, -1), (1, -1))
hidden_coors: Tuple[Tuple[float]] = ((-1, 0), (0, 0), (1, 0))
output_coors: Tuple[Tuple[float]] = ((0, 1),)
class NormalSubstrate(Substrate):
@staticmethod
def setup(config: NormalSubstrateConfig, state: State = State()):
return state.update(
input_coors=np.asarray(config.input_coors, dtype=np.float32),
output_coors=np.asarray(config.output_coors, dtype=np.float32),
hidden_coors=np.asarray(config.hidden_coors, dtype=np.float32),
)

View File

@@ -0,0 +1,49 @@
import numpy as np
def analysis_substrate(state):
cd = state.input_coors.shape[1] # coordinate dimensions
si = state.input_coors.shape[0] # input coordinate size
so = state.output_coors.shape[0] # output coordinate size
sh = state.hidden_coors.shape[0] # hidden coordinate size
input_idx = np.arange(si)
output_idx = np.arange(si, si + so)
hidden_idx = np.arange(si + so, si + so + sh)
total_conns = si * sh + sh * sh + sh * so
query_coors = np.zeros((total_conns, cd * 2))
correspond_keys = np.zeros((total_conns, 2))
# connect input to hidden
aux_coors, aux_keys = cartesian_product(input_idx, hidden_idx, state.input_coors, state.hidden_coors)
query_coors[0: si * sh, :] = aux_coors
correspond_keys[0: si * sh, :] = aux_keys
# connect hidden to hidden
aux_coors, aux_keys = cartesian_product(hidden_idx, hidden_idx, state.hidden_coors, state.hidden_coors)
query_coors[si * sh: si * sh + sh * sh, :] = aux_coors
correspond_keys[si * sh: si * sh + sh * sh, :] = aux_keys
# connect hidden to output
aux_coors, aux_keys = cartesian_product(hidden_idx, output_idx, state.hidden_coors, state.output_coors)
query_coors[si * sh + sh * sh:, :] = aux_coors
correspond_keys[si * sh + sh * sh:, :] = aux_keys
return input_idx, output_idx, hidden_idx, query_coors, correspond_keys
def cartesian_product(keys1, keys2, coors1, coors2):
len1 = keys1.shape[0]
len2 = keys2.shape[0]
repeated_coors1 = np.repeat(coors1, len2, axis=0)
repeated_keys1 = np.repeat(keys1, len2)
tiled_coors2 = np.tile(coors2, (len1, 1))
tiled_keys2 = np.tile(keys2, len1)
new_coors = np.concatenate((repeated_coors1, tiled_coors2), axis=1)
correspond_keys = np.column_stack((repeated_keys1, tiled_keys2))
return new_coors, correspond_keys

View File

@@ -1 +1,3 @@
from .normal import NormalGene, NormalGeneConfig
from .recurrent import RecurrentGene, RecurrentGeneConfig

View File

@@ -24,11 +24,11 @@ class NormalGeneConfig(GeneConfig):
response_replace_rate: float = 0.1
activation_default: str = 'sigmoid'
activation_options: Tuple[str] = ('sigmoid',)
activation_options: Tuple = ('sigmoid',)
activation_replace_rate: float = 0.1
aggregation_default: str = 'sum'
aggregation_options: Tuple[str] = ('sum',)
aggregation_options: Tuple = ('sum',)
aggregation_replace_rate: float = 0.1
weight_init_mean: float = 0.0

View File

@@ -0,0 +1,57 @@
from dataclasses import dataclass
import jax
from jax import numpy as jnp, vmap
from .normal import NormalGene, NormalGeneConfig
from core import State, Genome
from utils import unflatten_conns, act, agg
@dataclass(frozen=True)
class RecurrentGeneConfig(NormalGeneConfig):
activate_times: int = 10
def __post_init__(self):
super().__post_init__()
assert self.activate_times > 0
class RecurrentGene(NormalGene):
def __init__(self, config: RecurrentGeneConfig):
self.config = config
super().__init__(config)
def forward_transform(self, state: State, genome: Genome):
u_conns = unflatten_conns(genome.nodes, genome.conns)
# remove un-enable connections and remove enable attr
conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False)
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
return genome.nodes, u_conns
def forward(self, state: State, inputs, transformed):
nodes, conns = transformed
batch_act, batch_agg = vmap(act, in_axes=(0, 0, None)), vmap(agg, in_axes=(0, 0, None))
input_idx = state.input_idx
output_idx = state.output_idx
N = nodes.shape[0]
vals = jnp.full((N,), 0.)
weights = conns[0, :]
def body_func(i, values):
values = values.at[input_idx].set(inputs)
nodes_ins = values * weights.T
values = batch_agg(nodes[:, 4], nodes_ins, self.agg_funcs) # z = agg(ins)
values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias
values = batch_act(nodes[:, 3], values, self.act_funcs) # z = act(z)
return values
vals = jax.lax.fori_loop(0, self.config.activate_times, body_func, vals)
return vals[output_idx]

View File

@@ -1,3 +1,5 @@
from typing import Type
import jax
from jax import numpy as jnp
import numpy as np
@@ -10,9 +12,9 @@ from .species import SpeciesInfo, update_species, speciate
class NEAT(Algorithm):
def __init__(self, config: Config, gene: Gene):
def __init__(self, config: Config, gene_type: Type[Gene]):
self.config = config
self.gene = gene
self.gene = gene_type(config.gene)
self.forward_func = None
self.tell_func = None

View File

@@ -92,6 +92,11 @@ class SubstrateConfig:
pass
@dataclass(frozen=True)
class ProblemConfig:
pass
@dataclass(frozen=True)
class Config:
basic: BasicConfig = BasicConfig()
@@ -99,3 +104,4 @@ class Config:
hyper_neat: HyperNeatConfig = HyperNeatConfig()
gene: GeneConfig = GeneConfig()
substrate: SubstrateConfig = SubstrateConfig()
problem: ProblemConfig = ProblemConfig()

View File

@@ -3,3 +3,4 @@ from .state import State
from .genome import Genome
from .gene import Gene
from .substrate import Substrate
from .problem import Problem

View File

@@ -6,6 +6,9 @@ class Gene:
node_attrs = []
conn_attrs = []
def __init__(self, config: GeneConfig):
raise NotImplementedError
def setup(self, state=State()):
raise NotImplementedError

15
core/problem.py Normal file
View File

@@ -0,0 +1,15 @@
from typing import Callable
from config import ProblemConfig
from state import State
class Problem:
def __init__(self, config: ProblemConfig):
raise NotImplementedError
def setup(self, state=State()):
raise NotImplementedError
def evaluate(self, state: State, act_func: Callable, params):
raise NotImplementedError

View File

@@ -6,3 +6,5 @@ class Substrate:
@staticmethod
def setup(state, config: SubstrateConfig):
return state

View File

@@ -1,24 +1,31 @@
from functools import partial
import jax
from utils import unflatten_conns, act, agg, Activation, Aggregation
from algorithm.neat.gene import RecurrentGeneConfig
config = RecurrentGeneConfig(
activation_options=("tanh", "sigmoid"),
activation_default="tanh",
)
class A:
def __init__(self):
self.a = 1
self.b = 2
self.act_funcs = [Activation.name2func[name] for name in config.activation_options]
self.agg_funcs = [Aggregation.name2func[name] for name in config.aggregation_options]
self.isTrue = False
@partial(jax.jit, static_argnums=(0,))
def step(self):
if self.isTrue:
return self.a + 1
else:
return self.b + 1
i = jax.numpy.array([0, 1])
z = jax.numpy.array([
[1, 1],
[2, 2]
])
print(self.act_funcs)
return jax.vmap(act, in_axes=(0, 0, None))(i, z, self.act_funcs)
AA = A()
print(AA.step(), hash(AA))
print(AA.step(), hash(AA))
print(AA.step(), hash(AA))
AA.a = (2, 3, 4)
print(AA.step(), hash(AA))
print(AA.step())

View File

@@ -28,11 +28,13 @@ if __name__ == '__main__':
pop_size=10000
),
neat=NeatConfig(
maximum_nodes=20,
maximum_conns=50,
)
maximum_nodes=50,
maximum_conns=100,
compatibility_threshold=4
),
gene=NormalGeneConfig()
)
normal_gene = NormalGene(NormalGeneConfig())
algorithm = NEAT(config, normal_gene)
algorithm = NEAT(config, NormalGene)
pipeline = Pipeline(config, algorithm)
pipeline.auto_run(evaluate)

49
examples/xor_hyperneat.py Normal file
View File

@@ -0,0 +1,49 @@
import jax
import numpy as np
from config import Config, BasicConfig, NeatConfig
from pipeline import Pipeline
from algorithm import NEAT, HyperNEAT
from algorithm.neat.gene import RecurrentGene, RecurrentGeneConfig
from algorithm.hyperneat.substrate import NormalSubstrate, NormalSubstrateConfig
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
def evaluate(forward_func):
"""
:param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size)
:return:
"""
outs = forward_func(xor_inputs)
outs = jax.device_get(outs)
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
return fitnesses
if __name__ == '__main__':
config = Config(
basic=BasicConfig(
fitness_target=3.99999,
pop_size=10000
),
neat=NeatConfig(
network_type="recurrent",
maximum_nodes=50,
maximum_conns=100,
inputs=4,
outputs=1
),
gene=RecurrentGeneConfig(
activation_default="tanh",
activation_options=("tanh",),
),
substrate=NormalSubstrateConfig(),
)
neat = NEAT(config, RecurrentGene)
hyperNEAT = HyperNEAT(config, neat, NormalSubstrate)
pipeline = Pipeline(config, hyperNEAT)
pipeline.auto_run(evaluate)

42
examples/xor_recurrent.py Normal file
View File

@@ -0,0 +1,42 @@
import jax
import numpy as np
from config import Config, BasicConfig, NeatConfig
from pipeline import Pipeline
from algorithm import NEAT
from algorithm.neat.gene import RecurrentGene, RecurrentGeneConfig
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
def evaluate(forward_func):
"""
:param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size)
:return:
"""
outs = forward_func(xor_inputs)
outs = jax.device_get(outs)
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
return fitnesses
if __name__ == '__main__':
config = Config(
basic=BasicConfig(
fitness_target=3.99999,
pop_size=10000
),
neat=NeatConfig(
network_type="recurrent",
maximum_nodes=50,
maximum_conns=100
),
gene=RecurrentGeneConfig(
activate_times=3
)
)
algorithm = NEAT(config, RecurrentGene)
pipeline = Pipeline(config, algorithm)
pipeline.auto_run(evaluate)

0
problem/__init__.py Normal file
View File

View File

View File

@@ -0,0 +1,21 @@
from dataclasses import dataclass
from typing import Callable
from config import ProblemConfig
from core import Problem, State
@dataclass(frozen=True)
class FuncFitConfig:
pass
class FuncFit(Problem):
def __init__(self, config: ProblemConfig):
self.config = ProblemConfig
def setup(self, state=State()):
pass
def evaluate(self, state: State, act_func: Callable, params):
pass