284 lines
9.5 KiB
Python
284 lines
9.5 KiB
Python
import warnings
|
|
from typing import Callable
|
|
|
|
import jax, jax.numpy as jnp
|
|
import sympy as sp
|
|
from utils import (
|
|
unflatten_conns,
|
|
topological_sort,
|
|
topological_sort_python,
|
|
I_INF,
|
|
extract_node_attrs,
|
|
extract_conn_attrs,
|
|
set_node_attrs,
|
|
set_conn_attrs,
|
|
attach_with_inf,
|
|
SYMPY_FUNCS_MODULE_NP,
|
|
SYMPY_FUNCS_MODULE_JNP,
|
|
)
|
|
from . import BaseGenome
|
|
from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene
|
|
from ..ga import BaseMutation, BaseCrossover, DefaultMutation, DefaultCrossover
|
|
|
|
|
|
class DefaultGenome(BaseGenome):
|
|
"""Default genome class, with the same behavior as the NEAT-Python"""
|
|
|
|
network_type = "feedforward"
|
|
|
|
def __init__(
|
|
self,
|
|
num_inputs: int,
|
|
num_outputs: int,
|
|
max_nodes=5,
|
|
max_conns=4,
|
|
node_gene: BaseNodeGene = DefaultNodeGene(),
|
|
conn_gene: BaseConnGene = DefaultConnGene(),
|
|
mutation: BaseMutation = DefaultMutation(),
|
|
crossover: BaseCrossover = DefaultCrossover(),
|
|
output_transform: Callable = None,
|
|
):
|
|
super().__init__(
|
|
num_inputs,
|
|
num_outputs,
|
|
max_nodes,
|
|
max_conns,
|
|
node_gene,
|
|
conn_gene,
|
|
mutation,
|
|
crossover,
|
|
)
|
|
|
|
if output_transform is not None:
|
|
try:
|
|
_ = output_transform(jnp.zeros(num_outputs))
|
|
except Exception as e:
|
|
raise ValueError(f"Output transform function failed: {e}")
|
|
self.output_transform = output_transform
|
|
|
|
def transform(self, state, nodes, conns):
|
|
u_conns = unflatten_conns(nodes, conns)
|
|
conn_exist = u_conns != I_INF
|
|
|
|
seqs = topological_sort(nodes, conn_exist)
|
|
|
|
return seqs, nodes, conns, u_conns
|
|
|
|
def restore(self, state, transformed):
|
|
seqs, nodes, conns, u_conns = transformed
|
|
return nodes, conns
|
|
|
|
def forward(self, state, transformed, inputs):
|
|
cal_seqs, nodes, conns, u_conns = transformed
|
|
|
|
ini_vals = jnp.full((self.max_nodes,), jnp.nan)
|
|
ini_vals = ini_vals.at[self.input_idx].set(inputs)
|
|
nodes_attrs = jax.vmap(extract_node_attrs)(nodes)
|
|
conns_attrs = jax.vmap(extract_conn_attrs)(conns)
|
|
|
|
def cond_fun(carry):
|
|
values, idx = carry
|
|
return (idx < self.max_nodes) & (cal_seqs[idx] != I_INF)
|
|
|
|
def body_func(carry):
|
|
values, idx = carry
|
|
i = cal_seqs[idx]
|
|
|
|
def input_node():
|
|
z = self.node_gene.input_transform(state, nodes_attrs[i], values[i])
|
|
new_values = values.at[i].set(z)
|
|
return new_values
|
|
|
|
def otherwise():
|
|
conn_indices = u_conns[:, i]
|
|
hit_attrs = attach_with_inf(conns_attrs, conn_indices)
|
|
ins = jax.vmap(self.conn_gene.forward, in_axes=(None, 0, 0))(
|
|
state, hit_attrs, values
|
|
)
|
|
|
|
z = self.node_gene.forward(
|
|
state,
|
|
nodes_attrs[i],
|
|
ins,
|
|
is_output_node=jnp.isin(i, self.output_idx),
|
|
)
|
|
|
|
new_values = values.at[i].set(z)
|
|
return new_values
|
|
|
|
values = jax.lax.cond(jnp.isin(i, self.input_idx), input_node, otherwise)
|
|
|
|
return values, idx + 1
|
|
|
|
vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0))
|
|
|
|
if self.output_transform is None:
|
|
return vals[self.output_idx]
|
|
else:
|
|
return self.output_transform(vals[self.output_idx])
|
|
|
|
def update_by_batch(self, state, batch_input, transformed):
|
|
cal_seqs, nodes, conns, u_conns = transformed
|
|
|
|
batch_size = batch_input.shape[0]
|
|
batch_ini_vals = jnp.full((batch_size, self.max_nodes), jnp.nan)
|
|
batch_ini_vals = batch_ini_vals.at[:, self.input_idx].set(batch_input)
|
|
nodes_attrs = jax.vmap(extract_node_attrs)(nodes)
|
|
conns_attrs = jax.vmap(extract_conn_attrs)(conns)
|
|
|
|
def cond_fun(carry):
|
|
batch_values, nodes_attrs_, conns_attrs_, idx = carry
|
|
return (idx < self.max_nodes) & (cal_seqs[idx] != I_INF)
|
|
|
|
def body_func(carry):
|
|
batch_values, nodes_attrs_, conns_attrs_, idx = carry
|
|
i = cal_seqs[idx]
|
|
|
|
def input_node():
|
|
batch, new_attrs = self.node_gene.update_input_transform(
|
|
state, nodes_attrs_[i], batch_values[:, i]
|
|
)
|
|
return (
|
|
batch_values.at[:, i].set(batch),
|
|
nodes_attrs_.at[i].set(new_attrs),
|
|
conns_attrs_,
|
|
)
|
|
|
|
def otherwise():
|
|
|
|
conn_indices = u_conns[:, i]
|
|
hit_attrs = attach_with_inf(conns_attrs, conn_indices)
|
|
batch_ins, new_conn_attrs = jax.vmap(
|
|
self.conn_gene.update_by_batch,
|
|
in_axes=(None, 0, 1),
|
|
out_axes=(1, 0),
|
|
)(state, hit_attrs, batch_values)
|
|
|
|
batch_z, new_node_attrs = self.node_gene.update_by_batch(
|
|
state,
|
|
nodes_attrs_[i],
|
|
batch_ins,
|
|
is_output_node=jnp.isin(i, self.output_idx),
|
|
)
|
|
|
|
return (
|
|
batch_values.at[:, i].set(batch_z),
|
|
nodes_attrs_.at[i].set(new_node_attrs),
|
|
conns_attrs_.at[conn_indices].set(new_conn_attrs),
|
|
)
|
|
|
|
# the val of input nodes is obtained by the task, not by calculation
|
|
(batch_values, nodes_attrs_, conns_attrs_) = jax.lax.cond(
|
|
jnp.isin(i, self.input_idx),
|
|
input_node,
|
|
otherwise,
|
|
)
|
|
|
|
return batch_values, nodes_attrs_, conns_attrs_, idx + 1
|
|
|
|
batch_vals, nodes_attrs, conns_attrs, _ = jax.lax.while_loop(
|
|
cond_fun, body_func, (batch_ini_vals, nodes_attrs, conns_attrs, 0)
|
|
)
|
|
|
|
nodes = jax.vmap(set_node_attrs)(nodes, nodes_attrs)
|
|
conns = jax.vmap(set_conn_attrs)(conns, conns_attrs)
|
|
|
|
new_transformed = (cal_seqs, nodes, conns, u_conns)
|
|
|
|
if self.output_transform is None:
|
|
return batch_vals[:, self.output_idx], new_transformed
|
|
else:
|
|
return (
|
|
jax.vmap(self.output_transform)(batch_vals[:, self.output_idx]),
|
|
new_transformed,
|
|
)
|
|
|
|
def sympy_func(self, state, network, sympy_output_transform=None, backend="jax"):
|
|
|
|
assert backend in ["jax", "numpy"], "backend should be 'jax' or 'numpy'"
|
|
module = SYMPY_FUNCS_MODULE_JNP if backend == "jax" else SYMPY_FUNCS_MODULE_NP
|
|
|
|
if sympy_output_transform is None and self.output_transform is not None:
|
|
warnings.warn(
|
|
"genome.output_transform is not None but sympy_output_transform is None!"
|
|
)
|
|
|
|
input_idx = self.get_input_idx()
|
|
output_idx = self.get_output_idx()
|
|
order, _ = topological_sort_python(set(network["nodes"]), set(network["conns"]))
|
|
symbols = {}
|
|
for i in network["nodes"]:
|
|
if i in input_idx:
|
|
symbols[i] = sp.Symbol(f"i{i}")
|
|
elif i in output_idx:
|
|
symbols[i] = sp.Symbol(f"o{i}")
|
|
else: # hidden
|
|
symbols[i] = sp.Symbol(f"h{i}")
|
|
|
|
nodes_exprs = {}
|
|
|
|
args_symbols = {}
|
|
for i in order:
|
|
|
|
if i in input_idx:
|
|
nodes_exprs[symbols[i]] = symbols[i]
|
|
else:
|
|
in_conns = [c for c in network["conns"] if c[1] == i]
|
|
node_inputs = []
|
|
for conn in in_conns:
|
|
val_represent = symbols[conn[0]]
|
|
# a_s -> args_symbols
|
|
val, a_s = self.conn_gene.sympy_func(
|
|
state,
|
|
network["conns"][conn],
|
|
val_represent,
|
|
)
|
|
args_symbols.update(a_s)
|
|
node_inputs.append(val)
|
|
nodes_exprs[symbols[i]], a_s = self.node_gene.sympy_func(
|
|
state,
|
|
network["nodes"][i],
|
|
node_inputs,
|
|
is_output_node=(i in output_idx),
|
|
)
|
|
args_symbols.update(a_s)
|
|
if i in output_idx and sympy_output_transform is not None:
|
|
nodes_exprs[symbols[i]] = sympy_output_transform(
|
|
nodes_exprs[symbols[i]]
|
|
)
|
|
|
|
input_symbols = [v for k, v in symbols.items() if k in input_idx]
|
|
reduced_exprs = nodes_exprs.copy()
|
|
for i in order:
|
|
reduced_exprs[symbols[i]] = reduced_exprs[symbols[i]].subs(reduced_exprs)
|
|
|
|
output_exprs = [reduced_exprs[symbols[i]] for i in output_idx]
|
|
|
|
lambdify_output_funcs = [
|
|
sp.lambdify(
|
|
input_symbols + list(args_symbols.keys()),
|
|
exprs,
|
|
modules=[backend, module],
|
|
)
|
|
for exprs in output_exprs
|
|
]
|
|
|
|
fixed_args_output_funcs = []
|
|
for i in range(len(output_idx)):
|
|
|
|
def f(inputs, i=i):
|
|
return lambdify_output_funcs[i](*inputs, *args_symbols.values())
|
|
|
|
fixed_args_output_funcs.append(f)
|
|
|
|
forward_func = lambda inputs: [f(inputs) for f in fixed_args_output_funcs]
|
|
|
|
return (
|
|
symbols,
|
|
args_symbols,
|
|
input_symbols,
|
|
nodes_exprs,
|
|
output_exprs,
|
|
forward_func,
|
|
)
|