modify act funcs and sympy act funcs;

add dense and advance initialize genome;
add input_transform for genome;
This commit is contained in:
wls2002
2024-06-18 16:01:11 +08:00
parent 907314bc80
commit ce8015d22c
7 changed files with 222 additions and 39 deletions

View File

@@ -2,3 +2,4 @@ from .base import BaseGenome
from .default import DefaultGenome from .default import DefaultGenome
from .recurrent import RecurrentGenome from .recurrent import RecurrentGenome
from .advance import AdvanceInitialize from .advance import AdvanceInitialize
from .dense import DenseInitialize

View File

@@ -0,0 +1,70 @@
import jax, jax.numpy as jnp
from .default import DefaultGenome
class AdvanceInitialize(DefaultGenome):
def __init__(self, hidden_cnt=8, *args, **kwargs):
super().__init__(*args, **kwargs)
self.hidden_cnt = hidden_cnt
def initialize(self, state, randkey):
k1, k2 = jax.random.split(randkey, num=2)
input_idx, output_idx = self.input_idx, self.output_idx
input_size = len(input_idx)
output_size = len(output_idx)
hidden_idx = jnp.arange(
input_size + output_size, input_size + output_size + self.hidden_cnt
)
nodes = jnp.full(
(self.max_nodes, self.node_gene.length), jnp.nan, dtype=jnp.float32
)
nodes = nodes.at[input_idx, 0].set(input_idx)
nodes = nodes.at[output_idx, 0].set(output_idx)
nodes = nodes.at[hidden_idx, 0].set(hidden_idx)
total_idx = input_size + output_size + self.hidden_cnt
rand_keys_n = jax.random.split(k1, num=total_idx)
node_attr_func = jax.vmap(self.node_gene.new_random_attrs, in_axes=(None, 0))
node_attrs = node_attr_func(state, rand_keys_n)
nodes = nodes.at[:total_idx, 1:].set(node_attrs)
conns = jnp.full(
(self.max_conns, self.conn_gene.length), jnp.nan, dtype=jnp.float32
)
input_to_hidden_ids, hidden_ids = jnp.meshgrid(
input_idx, hidden_idx, indexing="ij"
)
total_input_to_hidden_conns = input_size * self.hidden_cnt
conns = conns.at[:total_input_to_hidden_conns, :2].set(
jnp.column_stack([input_to_hidden_ids.flatten(), hidden_ids.flatten()])
)
hidden_to_output_ids, output_ids = jnp.meshgrid(
hidden_idx, output_idx, indexing="ij"
)
total_hidden_to_output_conns = self.hidden_cnt * output_size
conns = conns.at[
total_input_to_hidden_conns : total_input_to_hidden_conns
+ total_hidden_to_output_conns,
:2,
].set(jnp.column_stack([hidden_to_output_ids.flatten(), output_ids.flatten()]))
total_conns = total_input_to_hidden_conns + total_hidden_to_output_conns
rand_keys_c = jax.random.split(k2, num=total_conns)
conns_attr_func = jax.vmap(
self.conn_gene.new_random_attrs,
in_axes=(
None,
0,
),
)
conns_attrs = conns_attr_func(state, rand_keys_c)
conns = conns.at[:total_conns, 2:].set(conns_attrs)
return nodes, conns

View File

@@ -2,6 +2,7 @@ import warnings
from typing import Callable from typing import Callable
import jax, jax.numpy as jnp import jax, jax.numpy as jnp
import numpy as np
import sympy as sp import sympy as sp
from utils import ( from utils import (
unflatten_conns, unflatten_conns,
@@ -37,6 +38,7 @@ class DefaultGenome(BaseGenome):
mutation: BaseMutation = DefaultMutation(), mutation: BaseMutation = DefaultMutation(),
crossover: BaseCrossover = DefaultCrossover(), crossover: BaseCrossover = DefaultCrossover(),
output_transform: Callable = None, output_transform: Callable = None,
input_transform: Callable = None,
): ):
super().__init__( super().__init__(
num_inputs, num_inputs,
@@ -49,9 +51,16 @@ class DefaultGenome(BaseGenome):
crossover, crossover,
) )
if input_transform is not None:
try:
_ = input_transform(np.zeros(num_inputs))
except Exception as e:
raise ValueError(f"Output transform function failed: {e}")
self.input_transform = input_transform
if output_transform is not None: if output_transform is not None:
try: try:
_ = output_transform(jnp.zeros(num_outputs)) _ = output_transform(np.zeros(num_outputs))
except Exception as e: except Exception as e:
raise ValueError(f"Output transform function failed: {e}") raise ValueError(f"Output transform function failed: {e}")
self.output_transform = output_transform self.output_transform = output_transform
@@ -69,6 +78,10 @@ class DefaultGenome(BaseGenome):
return nodes, conns return nodes, conns
def forward(self, state, transformed, inputs): def forward(self, state, transformed, inputs):
if self.input_transform is not None:
inputs = self.input_transform(inputs)
cal_seqs, nodes, conns, u_conns = transformed cal_seqs, nodes, conns, u_conns = transformed
ini_vals = jnp.full((self.max_nodes,), jnp.nan) ini_vals = jnp.full((self.max_nodes,), jnp.nan)
@@ -118,6 +131,10 @@ class DefaultGenome(BaseGenome):
return self.output_transform(vals[self.output_idx]) return self.output_transform(vals[self.output_idx])
def update_by_batch(self, state, batch_input, transformed): def update_by_batch(self, state, batch_input, transformed):
if self.input_transform is not None:
batch_input = jax.vmap(self.input_transform)(batch_input)
cal_seqs, nodes, conns, u_conns = transformed cal_seqs, nodes, conns, u_conns = transformed
batch_size = batch_input.shape[0] batch_size = batch_input.shape[0]
@@ -193,11 +210,19 @@ class DefaultGenome(BaseGenome):
new_transformed, new_transformed,
) )
def sympy_func(self, state, network, sympy_output_transform=None, backend="jax"): def sympy_func(self, state, network, sympy_input_transform=None, sympy_output_transform=None, backend="jax"):
assert backend in ["jax", "numpy"], "backend should be 'jax' or 'numpy'" assert backend in ["jax", "numpy"], "backend should be 'jax' or 'numpy'"
module = SYMPY_FUNCS_MODULE_JNP if backend == "jax" else SYMPY_FUNCS_MODULE_NP module = SYMPY_FUNCS_MODULE_JNP if backend == "jax" else SYMPY_FUNCS_MODULE_NP
if sympy_input_transform is None and self.input_transform is not None:
warnings.warn(
"genome.input_transform is not None but sympy_input_transform is None!"
)
if sympy_input_transform is not None:
if not isinstance(sympy_input_transform, list):
sympy_input_transform = [sympy_input_transform] * self.num_inputs
if sympy_output_transform is None and self.output_transform is not None: if sympy_output_transform is None and self.output_transform is not None:
warnings.warn( warnings.warn(
"genome.output_transform is not None but sympy_output_transform is None!" "genome.output_transform is not None but sympy_output_transform is None!"
@@ -221,6 +246,9 @@ class DefaultGenome(BaseGenome):
for i in order: for i in order:
if i in input_idx: if i in input_idx:
if sympy_input_transform is not None:
nodes_exprs[symbols[i]] = sympy_input_transform[i - min(input_idx)](symbols[i])
else:
nodes_exprs[symbols[i]] = symbols[i] nodes_exprs[symbols[i]] = symbols[i]
else: else:
in_conns = [c for c in network["conns"] if c[1] == i] in_conns = [c for c in network["conns"] if c[1] == i]

View File

@@ -0,0 +1,56 @@
import jax, jax.numpy as jnp
from .default import DefaultGenome
class DenseInitialize(DefaultGenome):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert self.max_nodes >= self.num_inputs + self.num_outputs
assert self.max_conns >= self.num_inputs * self.num_outputs
def initialize(self, state, randkey):
k1, k2 = jax.random.split(randkey, num=2)
input_idx, output_idx = self.input_idx, self.output_idx
input_size = len(input_idx)
output_size = len(output_idx)
nodes = jnp.full(
(self.max_nodes, self.node_gene.length), jnp.nan, dtype=jnp.float32
)
nodes = nodes.at[input_idx, 0].set(input_idx)
nodes = nodes.at[output_idx, 0].set(output_idx)
total_idx = input_size + output_size
rand_keys_n = jax.random.split(k1, num=total_idx)
node_attr_func = jax.vmap(self.node_gene.new_random_attrs, in_axes=(None, 0))
node_attrs = node_attr_func(state, rand_keys_n)
nodes = nodes.at[:total_idx, 1:].set(node_attrs)
conns = jnp.full(
(self.max_conns, self.conn_gene.length), jnp.nan, dtype=jnp.float32
)
input_to_output_ids, output_ids = jnp.meshgrid(
input_idx, output_idx, indexing="ij"
)
total_conns = input_size * output_size
conns = conns.at[:total_conns, :2].set(
jnp.column_stack([input_to_output_ids.flatten(), output_ids.flatten()])
)
rand_keys_c = jax.random.split(k2, num=total_conns)
conns_attr_func = jax.vmap(
self.conn_gene.new_random_attrs,
in_axes=(
None,
0,
),
)
conns_attrs = conns_attr_func(state, rand_keys_c)
conns = conns.at[:total_conns, 2:].set(conns_attrs)
return nodes, conns

View File

@@ -20,7 +20,6 @@ name2sympy = {
"relu": SympyRelu, "relu": SympyRelu,
"lelu": SympyLelu, "lelu": SympyLelu,
"identity": SympyIdentity, "identity": SympyIdentity,
"clamped": SympyClamped,
"inv": SympyInv, "inv": SympyInv,
"log": SympyLog, "log": SympyLog,
"exp": SympyExp, "exp": SympyExp,

View File

@@ -2,6 +2,9 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
sigma_3 = 2.576
class Act: class Act:
@staticmethod @staticmethod
def name2func(name): def name2func(name):
@@ -9,35 +12,42 @@ class Act:
@staticmethod @staticmethod
def sigmoid(z): def sigmoid(z):
z = jnp.clip(5 * z, -10, 10) z = jnp.clip(5 * z / sigma_3, -5, 5)
return 1 / (1 + jnp.exp(-z)) z = 1 / (1 + jnp.exp(-z))
return z * sigma_3 # (0, sigma_3)
@staticmethod @staticmethod
def tanh(z): def tanh(z):
z = jnp.clip(0.6*z, -3, 3) z = jnp.clip(5 * z / sigma_3, -5, 5)
return jnp.tanh(z) return jnp.tanh(z) * sigma_3 # (-sigma_3, sigma_3)
@staticmethod
def standard_tanh(z):
z = jnp.clip(5 * z / sigma_3, -5, 5)
return jnp.tanh(z) # (-1, 1)
@staticmethod @staticmethod
def sin(z): def sin(z):
return jnp.sin(z) z = jnp.clip(jnp.pi / 2 * z / sigma_3, -jnp.pi / 2, jnp.pi / 2)
return jnp.sin(z) * sigma_3 # (-sigma_3, sigma_3)
@staticmethod @staticmethod
def relu(z): def relu(z):
return jnp.maximum(z, 0) z = jnp.clip(z, -sigma_3, sigma_3)
return jnp.maximum(z, 0) # (0, sigma_3)
@staticmethod @staticmethod
def lelu(z): def lelu(z):
leaky = 0.005 leaky = 0.005
z = jnp.clip(z, -sigma_3, sigma_3)
return jnp.where(z > 0, z, leaky * z) return jnp.where(z > 0, z, leaky * z)
@staticmethod @staticmethod
def identity(z): def identity(z):
z = jnp.clip(z, -sigma_3, sigma_3)
return z return z
@staticmethod
def clamped(z):
return jnp.clip(z, -1, 1)
@staticmethod @staticmethod
def inv(z): def inv(z):
z = jnp.where(z > 0, jnp.maximum(z, 1e-7), jnp.minimum(z, -1e-7)) z = jnp.where(z > 0, jnp.maximum(z, 1e-7), jnp.minimum(z, -1e-7))
@@ -55,6 +65,7 @@ class Act:
@staticmethod @staticmethod
def abs(z): def abs(z):
z = jnp.clip(z, -1, 1)
return jnp.abs(z) return jnp.abs(z)
@@ -65,7 +76,6 @@ ACT_ALL = (
Act.relu, Act.relu,
Act.lelu, Act.lelu,
Act.identity, Act.identity,
Act.clamped,
Act.inv, Act.inv,
Act.log, Act.log,
Act.exp, Act.exp,

View File

@@ -2,6 +2,9 @@ import sympy as sp
import numpy as np import numpy as np
sigma_3 = 2.576
class SympyClip(sp.Function): class SympyClip(sp.Function):
@classmethod @classmethod
def eval(cls, val, min_val, max_val): def eval(cls, val, min_val, max_val):
@@ -26,14 +29,17 @@ class SympySigmoid(sp.Function):
@classmethod @classmethod
def eval(cls, z): def eval(cls, z):
if z.is_Number: if z.is_Number:
z = SympyClip(5 * z, -10, 10) z = SympyClip(5 * z / sigma_3, -5, 5)
return 1 / (1 + sp.exp(-z)) z = 1 / (1 + sp.exp(-z))
return z * sigma_3
return None return None
@staticmethod @staticmethod
def numerical_eval(z, backend=np): def numerical_eval(z, backend=np):
z = backend.clip(5 * z, -10, 10) z = backend.clip(5 * z / sigma_3, -5, 5)
return 1 / (1 + backend.exp(-z)) z = 1 / (1 + backend.exp(-z))
return z * sigma_3 # (0, sigma_3)
def _sympystr(self, printer): def _sympystr(self, printer):
return f"sigmoid({self.args[0]})" return f"sigmoid({self.args[0]})"
@@ -46,36 +52,56 @@ class SympyTanh(sp.Function):
@classmethod @classmethod
def eval(cls, z): def eval(cls, z):
if z.is_Number: if z.is_Number:
z = SympyClip(0.6 * z, -3, 3) z = SympyClip(5 * z / sigma_3, -5, 5)
return sp.tanh(z) * sigma_3
return None
@staticmethod
def numerical_eval(z, backend=np):
z = backend.clip(5 * z / sigma_3, -5, 5)
return backend.tanh(z) * sigma_3 # (-sigma_3, sigma_3)
class SympyStandardTanh(sp.Function):
@classmethod
def eval(cls, z):
if z.is_Number:
z = SympyClip(5 * z / sigma_3, -5, 5)
return sp.tanh(z) return sp.tanh(z)
return None return None
@staticmethod @staticmethod
def numerical_eval(z, backend=np): def numerical_eval(z, backend=np):
z = backend.clip(0.6*z, -3, 3) z = backend.clip(5 * z / sigma_3, -5, 5)
return backend.tanh(z) return backend.tanh(z) # (-1, 1)
class SympySin(sp.Function): class SympySin(sp.Function):
@classmethod @classmethod
def eval(cls, z): def eval(cls, z):
return sp.sin(z) if z.is_Number:
z = SympyClip(sp.pi / 2 * z / sigma_3, -sp.pi / 2, sp.pi / 2)
return sp.sin(z) * sigma_3 # (-sigma_3, sigma_3)
return None
@staticmethod @staticmethod
def numerical_eval(z, backend=np): def numerical_eval(z, backend=np):
return backend.sin(z) z = backend.clip(backend.pi / 2 * z / sigma_3, -backend.pi / 2, backend.pi / 2)
return backend.sin(z) * sigma_3 # (-sigma_3, sigma_3)
class SympyRelu(sp.Function): class SympyRelu(sp.Function):
@classmethod @classmethod
def eval(cls, z): def eval(cls, z):
if z.is_Number: if z.is_Number:
return sp.Piecewise((z, z > 0), (0, True)) z = SympyClip(z, -sigma_3, sigma_3)
return sp.Max(z, 0) # (0, sigma_3)
return None return None
@staticmethod @staticmethod
def numerical_eval(z, backend=np): def numerical_eval(z, backend=np):
return backend.maximum(z, 0) z = backend.clip(z, -sigma_3, sigma_3)
return backend.maximum(z, 0) # (0, sigma_3)
def _sympystr(self, printer): def _sympystr(self, printer):
return f"relu({self.args[0]})" return f"relu({self.args[0]})"
@@ -107,21 +133,14 @@ class SympyLelu(sp.Function):
class SympyIdentity(sp.Function): class SympyIdentity(sp.Function):
@classmethod @classmethod
def eval(cls, z): def eval(cls, z):
if z.is_Number:
z = SympyClip(z, -sigma_3, sigma_3)
return z return z
return None
@staticmethod @staticmethod
def numerical_eval(z, backend=np): def numerical_eval(z, backend=np):
return z return backend.clip(z, -sigma_3, sigma_3)
class SympyClamped(sp.Function):
@classmethod
def eval(cls, z):
return SympyClip(z, -1, 1)
@staticmethod
def numerical_eval(z, backend=np):
return backend.clip(z, -1, 1)
class SympyInv(sp.Function): class SympyInv(sp.Function):