modify act funcs and sympy act funcs;
add dense and advance initialize genome; add input_transform for genome;
This commit is contained in:
@@ -2,6 +2,7 @@ import warnings
|
||||
from typing import Callable
|
||||
|
||||
import jax, jax.numpy as jnp
|
||||
import numpy as np
|
||||
import sympy as sp
|
||||
from utils import (
|
||||
unflatten_conns,
|
||||
@@ -37,6 +38,7 @@ class DefaultGenome(BaseGenome):
|
||||
mutation: BaseMutation = DefaultMutation(),
|
||||
crossover: BaseCrossover = DefaultCrossover(),
|
||||
output_transform: Callable = None,
|
||||
input_transform: Callable = None,
|
||||
):
|
||||
super().__init__(
|
||||
num_inputs,
|
||||
@@ -49,9 +51,16 @@ class DefaultGenome(BaseGenome):
|
||||
crossover,
|
||||
)
|
||||
|
||||
if input_transform is not None:
|
||||
try:
|
||||
_ = input_transform(np.zeros(num_inputs))
|
||||
except Exception as e:
|
||||
raise ValueError(f"Output transform function failed: {e}")
|
||||
self.input_transform = input_transform
|
||||
|
||||
if output_transform is not None:
|
||||
try:
|
||||
_ = output_transform(jnp.zeros(num_outputs))
|
||||
_ = output_transform(np.zeros(num_outputs))
|
||||
except Exception as e:
|
||||
raise ValueError(f"Output transform function failed: {e}")
|
||||
self.output_transform = output_transform
|
||||
@@ -69,6 +78,10 @@ class DefaultGenome(BaseGenome):
|
||||
return nodes, conns
|
||||
|
||||
def forward(self, state, transformed, inputs):
|
||||
|
||||
if self.input_transform is not None:
|
||||
inputs = self.input_transform(inputs)
|
||||
|
||||
cal_seqs, nodes, conns, u_conns = transformed
|
||||
|
||||
ini_vals = jnp.full((self.max_nodes,), jnp.nan)
|
||||
@@ -118,6 +131,10 @@ class DefaultGenome(BaseGenome):
|
||||
return self.output_transform(vals[self.output_idx])
|
||||
|
||||
def update_by_batch(self, state, batch_input, transformed):
|
||||
|
||||
if self.input_transform is not None:
|
||||
batch_input = jax.vmap(self.input_transform)(batch_input)
|
||||
|
||||
cal_seqs, nodes, conns, u_conns = transformed
|
||||
|
||||
batch_size = batch_input.shape[0]
|
||||
@@ -193,11 +210,19 @@ class DefaultGenome(BaseGenome):
|
||||
new_transformed,
|
||||
)
|
||||
|
||||
def sympy_func(self, state, network, sympy_output_transform=None, backend="jax"):
|
||||
def sympy_func(self, state, network, sympy_input_transform=None, sympy_output_transform=None, backend="jax"):
|
||||
|
||||
assert backend in ["jax", "numpy"], "backend should be 'jax' or 'numpy'"
|
||||
module = SYMPY_FUNCS_MODULE_JNP if backend == "jax" else SYMPY_FUNCS_MODULE_NP
|
||||
|
||||
if sympy_input_transform is None and self.input_transform is not None:
|
||||
warnings.warn(
|
||||
"genome.input_transform is not None but sympy_input_transform is None!"
|
||||
)
|
||||
if sympy_input_transform is not None:
|
||||
if not isinstance(sympy_input_transform, list):
|
||||
sympy_input_transform = [sympy_input_transform] * self.num_inputs
|
||||
|
||||
if sympy_output_transform is None and self.output_transform is not None:
|
||||
warnings.warn(
|
||||
"genome.output_transform is not None but sympy_output_transform is None!"
|
||||
@@ -221,7 +246,10 @@ class DefaultGenome(BaseGenome):
|
||||
for i in order:
|
||||
|
||||
if i in input_idx:
|
||||
nodes_exprs[symbols[i]] = symbols[i]
|
||||
if sympy_input_transform is not None:
|
||||
nodes_exprs[symbols[i]] = sympy_input_transform[i - min(input_idx)](symbols[i])
|
||||
else:
|
||||
nodes_exprs[symbols[i]] = symbols[i]
|
||||
else:
|
||||
in_conns = [c for c in network["conns"] if c[1] == i]
|
||||
node_inputs = []
|
||||
|
||||
Reference in New Issue
Block a user