remove create_func....

This commit is contained in:
wls2002
2023-08-02 13:26:01 +08:00
parent 85318f98f3
commit 1499e062fe
34 changed files with 558 additions and 1022 deletions

View File

@@ -1,2 +1 @@
from .normal import NormalGene, NormalGeneConfig
from .recurrent import RecurrentGene, RecurrentGeneConfig

View File

@@ -6,7 +6,7 @@ from jax import Array, numpy as jnp
from config import GeneConfig
from core import Gene, Genome, State
from utils import Activation, Aggregation, unflatten_conns, topological_sort, I_INT
from utils import Activation, Aggregation, unflatten_conns, topological_sort, I_INT, act, agg
@dataclass(frozen=True)
@@ -66,48 +66,51 @@ class NormalGene(Gene):
node_attrs = ['bias', 'response', 'aggregation', 'activation']
conn_attrs = ['weight']
@staticmethod
def setup(config: NormalGeneConfig, state: State = State()):
def __init__(self, config: NormalGeneConfig):
self.config = config
self.act_funcs = [Activation.name2func[name] for name in config.activation_options]
self.agg_funcs = [Aggregation.name2func[name] for name in config.aggregation_options]
def setup(self, state: State = State()):
return state.update(
bias_init_mean=config.bias_init_mean,
bias_init_std=config.bias_init_std,
bias_mutate_power=config.bias_mutate_power,
bias_mutate_rate=config.bias_mutate_rate,
bias_replace_rate=config.bias_replace_rate,
bias_init_mean=self.config.bias_init_mean,
bias_init_std=self.config.bias_init_std,
bias_mutate_power=self.config.bias_mutate_power,
bias_mutate_rate=self.config.bias_mutate_rate,
bias_replace_rate=self.config.bias_replace_rate,
response_init_mean=config.response_init_mean,
response_init_std=config.response_init_std,
response_mutate_power=config.response_mutate_power,
response_mutate_rate=config.response_mutate_rate,
response_replace_rate=config.response_replace_rate,
response_init_mean=self.config.response_init_mean,
response_init_std=self.config.response_init_std,
response_mutate_power=self.config.response_mutate_power,
response_mutate_rate=self.config.response_mutate_rate,
response_replace_rate=self.config.response_replace_rate,
activation_replace_rate=config.activation_replace_rate,
activation_replace_rate=self.config.activation_replace_rate,
activation_default=0,
activation_options=jnp.arange(len(config.activation_options)),
activation_options=jnp.arange(len(self.config.activation_options)),
aggregation_replace_rate=config.aggregation_replace_rate,
aggregation_replace_rate=self.config.aggregation_replace_rate,
aggregation_default=0,
aggregation_options=jnp.arange(len(config.aggregation_options)),
aggregation_options=jnp.arange(len(self.config.aggregation_options)),
weight_init_mean=config.weight_init_mean,
weight_init_std=config.weight_init_std,
weight_mutate_power=config.weight_mutate_power,
weight_mutate_rate=config.weight_mutate_rate,
weight_replace_rate=config.weight_replace_rate,
weight_init_mean=self.config.weight_init_mean,
weight_init_std=self.config.weight_init_std,
weight_mutate_power=self.config.weight_mutate_power,
weight_mutate_rate=self.config.weight_mutate_rate,
weight_replace_rate=self.config.weight_replace_rate,
)
@staticmethod
def new_node_attrs(state):
def update(self, state):
pass
def new_node_attrs(self, state):
return jnp.array([state.bias_init_mean, state.response_init_mean,
state.activation_default, state.aggregation_default])
@staticmethod
def new_conn_attrs(state):
def new_conn_attrs(self, state):
return jnp.array([state.weight_init_mean])
@staticmethod
def mutate_node(state, attrs: Array, key):
def mutate_node(self, state, key, attrs: Array):
k1, k2, k3, k4 = jax.random.split(key, num=4)
bias = NormalGene._mutate_float(k1, attrs[0], state.bias_init_mean, state.bias_init_std,
@@ -120,26 +123,22 @@ class NormalGene(Gene):
return jnp.array([bias, res, act, agg])
@staticmethod
def mutate_conn(state, attrs: Array, key):
def mutate_conn(self, state, key, attrs: Array):
weight = NormalGene._mutate_float(key, attrs[0], state.weight_init_mean, state.weight_init_std,
state.weight_mutate_power, state.weight_mutate_rate,
state.weight_replace_rate)
return jnp.array([weight])
@staticmethod
def distance_node(state, node1: Array, node2: Array):
def distance_node(self, state, node1: Array, node2: Array):
# bias + response + activation + aggregation
return jnp.abs(node1[1] - node2[1]) + jnp.abs(node1[2] - node2[2]) + \
(node1[3] != node2[3]) + (node1[4] != node2[4])
@staticmethod
def distance_conn(state, con1: Array, con2: Array):
def distance_conn(self, state, con1: Array, con2: Array):
return (con1[2] != con2[2]) + jnp.abs(con1[3] - con2[3]) # enable + weight
@staticmethod
def forward_transform(state: State, genome: Genome):
def forward_transform(self, state: State, genome: Genome):
u_conns = unflatten_conns(genome.nodes, genome.conns)
conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False)
@@ -149,87 +148,46 @@ class NormalGene(Gene):
return seqs, genome.nodes, u_conns
@staticmethod
def create_forward(state: State, config: NormalGeneConfig):
activation_funcs = [Activation.name2func[name] for name in config.activation_options]
aggregation_funcs = [Aggregation.name2func[name] for name in config.aggregation_options]
def forward(self, state: State, inputs, transformed):
cal_seqs, nodes, cons = transformed
def act(idx, z):
"""
calculate activation function for each node
"""
idx = jnp.asarray(idx, dtype=jnp.int32)
# change idx from float to int
res = jax.lax.switch(idx, activation_funcs, z)
return res
input_idx = state.input_idx
output_idx = state.output_idx
def agg(idx, z):
"""
calculate activation function for inputs of node
"""
idx = jnp.asarray(idx, dtype=jnp.int32)
N = nodes.shape[0]
ini_vals = jnp.full((N,), jnp.nan)
ini_vals = ini_vals.at[input_idx].set(inputs)
def all_nan():
return 0.
weights = cons[0, :]
def not_all_nan():
return jax.lax.switch(idx, aggregation_funcs, z)
def cond_fun(carry):
values, idx = carry
return (idx < N) & (cal_seqs[idx] != I_INT)
return jax.lax.cond(jnp.all(jnp.isnan(z)), all_nan, not_all_nan)
def body_func(carry):
values, idx = carry
i = cal_seqs[idx]
def forward(inputs, transformed) -> Array:
"""
forward for single input shaped (input_num, )
def hit():
ins = values * weights[:, i]
z = agg(nodes[i, 4], ins, self.agg_funcs) # z = agg(ins)
z = z * nodes[i, 2] + nodes[i, 1] # z = z * response + bias
z = act(nodes[i, 3], z, self.act_funcs) # z = act(z)
:argument inputs: (input_num, )
:argument cal_seqs: (N, )
:argument nodes: (N, 5)
:argument connections: (2, N, N)
new_values = values.at[i].set(z)
return new_values
:return (output_num, )
"""
def miss():
return values
cal_seqs, nodes, cons = transformed
# the val of input nodes is obtained by the task, not by calculation
values = jax.lax.cond(jnp.isin(i, input_idx), miss, hit)
input_idx = state.input_idx
output_idx = state.output_idx
return values, idx + 1
N = nodes.shape[0]
ini_vals = jnp.full((N,), jnp.nan)
ini_vals = ini_vals.at[input_idx].set(inputs)
vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0))
weights = cons[0, :]
def cond_fun(carry):
values, idx = carry
return (idx < N) & (cal_seqs[idx] != I_INT)
def body_func(carry):
values, idx = carry
i = cal_seqs[idx]
def hit():
ins = values * weights[:, i]
z = agg(nodes[i, 4], ins) # z = agg(ins)
z = z * nodes[i, 2] + nodes[i, 1] # z = z * response + bias
z = act(nodes[i, 3], z) # z = act(z)
new_values = values.at[i].set(z)
return new_values
def miss():
return values
# the val of input nodes is obtained by the task, not by calculation
values = jax.lax.cond(jnp.isin(i, input_idx), miss, hit)
return values, idx + 1
vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0))
return vals[output_idx]
return forward
return vals[output_idx]
@staticmethod
def _mutate_float(key, val, init_mean, init_std, mutate_power, mutate_rate, replace_rate):

View File

@@ -1,84 +0,0 @@
from dataclasses import dataclass
import jax
from jax import Array, numpy as jnp, vmap
from .normal import NormalGene, NormalGeneConfig
from core import State, Genome
from utils import Activation, Aggregation, unflatten_conns
@dataclass(frozen=True)
class RecurrentGeneConfig(NormalGeneConfig):
activate_times: int = 10
def __post_init__(self):
super().__post_init__()
assert self.activate_times > 0
class RecurrentGene(NormalGene):
@staticmethod
def forward_transform(state: State, genome: Genome):
u_conns = unflatten_conns(genome.nodes, genome.conns)
# remove un-enable connections and remove enable attr
conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False)
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
return genome.nodes, u_conns
@staticmethod
def create_forward(state: State, config: RecurrentGeneConfig):
activation_funcs = [Activation.name2func[name] for name in config.activation_options]
aggregation_funcs = [Aggregation.name2func[name] for name in config.aggregation_options]
def act(idx, z):
"""
calculate activation function for each node
"""
idx = jnp.asarray(idx, dtype=jnp.int32)
# change idx from float to int
res = jax.lax.switch(idx, activation_funcs, z)
return res
def agg(idx, z):
"""
calculate activation function for inputs of node
"""
idx = jnp.asarray(idx, dtype=jnp.int32)
def all_nan():
return 0.
def not_all_nan():
return jax.lax.switch(idx, aggregation_funcs, z)
return jax.lax.cond(jnp.all(jnp.isnan(z)), all_nan, not_all_nan)
batch_act, batch_agg = vmap(act), vmap(agg)
def forward(inputs, transform) -> Array:
nodes, cons = transform
input_idx = state.input_idx
output_idx = state.output_idx
N = nodes.shape[0]
vals = jnp.full((N,), 0.)
weights = cons[0, :]
def body_func(i, values):
values = values.at[input_idx].set(inputs)
nodes_ins = values * weights.T
values = batch_agg(nodes[:, 4], nodes_ins) # z = agg(ins)
values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias
values = batch_act(nodes[:, 3], values) # z = act(z)
return values
vals = jax.lax.fori_loop(0, config.activate_times, body_func, vals)
return vals[output_idx]
return forward