remove create_func....

This commit is contained in:
wls2002
2023-08-02 15:02:08 +08:00
parent 1499e062fe
commit c7fb1ddabe
22 changed files with 425 additions and 21 deletions

View File

@@ -1 +1,3 @@
from .normal import NormalGene, NormalGeneConfig
from .recurrent import RecurrentGene, RecurrentGeneConfig

View File

@@ -24,11 +24,11 @@ class NormalGeneConfig(GeneConfig):
response_replace_rate: float = 0.1
activation_default: str = 'sigmoid'
activation_options: Tuple[str] = ('sigmoid',)
activation_options: Tuple = ('sigmoid',)
activation_replace_rate: float = 0.1
aggregation_default: str = 'sum'
aggregation_options: Tuple[str] = ('sum',)
aggregation_options: Tuple = ('sum',)
aggregation_replace_rate: float = 0.1
weight_init_mean: float = 0.0

View File

@@ -0,0 +1,57 @@
from dataclasses import dataclass
import jax
from jax import numpy as jnp, vmap
from .normal import NormalGene, NormalGeneConfig
from core import State, Genome
from utils import unflatten_conns, act, agg
@dataclass(frozen=True)
class RecurrentGeneConfig(NormalGeneConfig):
activate_times: int = 10
def __post_init__(self):
super().__post_init__()
assert self.activate_times > 0
class RecurrentGene(NormalGene):
def __init__(self, config: RecurrentGeneConfig):
self.config = config
super().__init__(config)
def forward_transform(self, state: State, genome: Genome):
u_conns = unflatten_conns(genome.nodes, genome.conns)
# remove un-enable connections and remove enable attr
conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False)
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
return genome.nodes, u_conns
def forward(self, state: State, inputs, transformed):
nodes, conns = transformed
batch_act, batch_agg = vmap(act, in_axes=(0, 0, None)), vmap(agg, in_axes=(0, 0, None))
input_idx = state.input_idx
output_idx = state.output_idx
N = nodes.shape[0]
vals = jnp.full((N,), 0.)
weights = conns[0, :]
def body_func(i, values):
values = values.at[input_idx].set(inputs)
nodes_ins = values * weights.T
values = batch_agg(nodes[:, 4], nodes_ins, self.agg_funcs) # z = agg(ins)
values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias
values = batch_act(nodes[:, 3], values, self.act_funcs) # z = act(z)
return values
vals = jax.lax.fori_loop(0, self.config.activate_times, body_func, vals)
return vals[output_idx]

View File

@@ -1,3 +1,5 @@
from typing import Type
import jax
from jax import numpy as jnp
import numpy as np
@@ -10,9 +12,9 @@ from .species import SpeciesInfo, update_species, speciate
class NEAT(Algorithm):
def __init__(self, config: Config, gene: Gene):
def __init__(self, config: Config, gene_type: Type[Gene]):
self.config = config
self.gene = gene
self.gene = gene_type(config.gene)
self.forward_func = None
self.tell_func = None