modify act funcs and sympy act funcs;

add dense and advance initialize genome;
add input_transform for genome;
This commit is contained in:
wls2002
2024-06-18 16:01:11 +08:00
parent 907314bc80
commit ce8015d22c
7 changed files with 222 additions and 39 deletions

View File

@@ -1,4 +1,5 @@
from .base import BaseGenome
from .default import DefaultGenome
from .recurrent import RecurrentGenome
from .advance import AdvanceInitialize
from .advance import AdvanceInitialize
from .dense import DenseInitialize

View File

@@ -0,0 +1,70 @@
import jax, jax.numpy as jnp
from .default import DefaultGenome
class AdvanceInitialize(DefaultGenome):
def __init__(self, hidden_cnt=8, *args, **kwargs):
super().__init__(*args, **kwargs)
self.hidden_cnt = hidden_cnt
def initialize(self, state, randkey):
k1, k2 = jax.random.split(randkey, num=2)
input_idx, output_idx = self.input_idx, self.output_idx
input_size = len(input_idx)
output_size = len(output_idx)
hidden_idx = jnp.arange(
input_size + output_size, input_size + output_size + self.hidden_cnt
)
nodes = jnp.full(
(self.max_nodes, self.node_gene.length), jnp.nan, dtype=jnp.float32
)
nodes = nodes.at[input_idx, 0].set(input_idx)
nodes = nodes.at[output_idx, 0].set(output_idx)
nodes = nodes.at[hidden_idx, 0].set(hidden_idx)
total_idx = input_size + output_size + self.hidden_cnt
rand_keys_n = jax.random.split(k1, num=total_idx)
node_attr_func = jax.vmap(self.node_gene.new_random_attrs, in_axes=(None, 0))
node_attrs = node_attr_func(state, rand_keys_n)
nodes = nodes.at[:total_idx, 1:].set(node_attrs)
conns = jnp.full(
(self.max_conns, self.conn_gene.length), jnp.nan, dtype=jnp.float32
)
input_to_hidden_ids, hidden_ids = jnp.meshgrid(
input_idx, hidden_idx, indexing="ij"
)
total_input_to_hidden_conns = input_size * self.hidden_cnt
conns = conns.at[:total_input_to_hidden_conns, :2].set(
jnp.column_stack([input_to_hidden_ids.flatten(), hidden_ids.flatten()])
)
hidden_to_output_ids, output_ids = jnp.meshgrid(
hidden_idx, output_idx, indexing="ij"
)
total_hidden_to_output_conns = self.hidden_cnt * output_size
conns = conns.at[
total_input_to_hidden_conns : total_input_to_hidden_conns
+ total_hidden_to_output_conns,
:2,
].set(jnp.column_stack([hidden_to_output_ids.flatten(), output_ids.flatten()]))
total_conns = total_input_to_hidden_conns + total_hidden_to_output_conns
rand_keys_c = jax.random.split(k2, num=total_conns)
conns_attr_func = jax.vmap(
self.conn_gene.new_random_attrs,
in_axes=(
None,
0,
),
)
conns_attrs = conns_attr_func(state, rand_keys_c)
conns = conns.at[:total_conns, 2:].set(conns_attrs)
return nodes, conns

View File

@@ -2,6 +2,7 @@ import warnings
from typing import Callable
import jax, jax.numpy as jnp
import numpy as np
import sympy as sp
from utils import (
unflatten_conns,
@@ -37,6 +38,7 @@ class DefaultGenome(BaseGenome):
mutation: BaseMutation = DefaultMutation(),
crossover: BaseCrossover = DefaultCrossover(),
output_transform: Callable = None,
input_transform: Callable = None,
):
super().__init__(
num_inputs,
@@ -49,9 +51,16 @@ class DefaultGenome(BaseGenome):
crossover,
)
if input_transform is not None:
try:
_ = input_transform(np.zeros(num_inputs))
except Exception as e:
raise ValueError(f"Output transform function failed: {e}")
self.input_transform = input_transform
if output_transform is not None:
try:
_ = output_transform(jnp.zeros(num_outputs))
_ = output_transform(np.zeros(num_outputs))
except Exception as e:
raise ValueError(f"Output transform function failed: {e}")
self.output_transform = output_transform
@@ -69,6 +78,10 @@ class DefaultGenome(BaseGenome):
return nodes, conns
def forward(self, state, transformed, inputs):
if self.input_transform is not None:
inputs = self.input_transform(inputs)
cal_seqs, nodes, conns, u_conns = transformed
ini_vals = jnp.full((self.max_nodes,), jnp.nan)
@@ -118,6 +131,10 @@ class DefaultGenome(BaseGenome):
return self.output_transform(vals[self.output_idx])
def update_by_batch(self, state, batch_input, transformed):
if self.input_transform is not None:
batch_input = jax.vmap(self.input_transform)(batch_input)
cal_seqs, nodes, conns, u_conns = transformed
batch_size = batch_input.shape[0]
@@ -193,11 +210,19 @@ class DefaultGenome(BaseGenome):
new_transformed,
)
def sympy_func(self, state, network, sympy_output_transform=None, backend="jax"):
def sympy_func(self, state, network, sympy_input_transform=None, sympy_output_transform=None, backend="jax"):
assert backend in ["jax", "numpy"], "backend should be 'jax' or 'numpy'"
module = SYMPY_FUNCS_MODULE_JNP if backend == "jax" else SYMPY_FUNCS_MODULE_NP
if sympy_input_transform is None and self.input_transform is not None:
warnings.warn(
"genome.input_transform is not None but sympy_input_transform is None!"
)
if sympy_input_transform is not None:
if not isinstance(sympy_input_transform, list):
sympy_input_transform = [sympy_input_transform] * self.num_inputs
if sympy_output_transform is None and self.output_transform is not None:
warnings.warn(
"genome.output_transform is not None but sympy_output_transform is None!"
@@ -221,7 +246,10 @@ class DefaultGenome(BaseGenome):
for i in order:
if i in input_idx:
nodes_exprs[symbols[i]] = symbols[i]
if sympy_input_transform is not None:
nodes_exprs[symbols[i]] = sympy_input_transform[i - min(input_idx)](symbols[i])
else:
nodes_exprs[symbols[i]] = symbols[i]
else:
in_conns = [c for c in network["conns"] if c[1] == i]
node_inputs = []

View File

@@ -0,0 +1,56 @@
import jax, jax.numpy as jnp
from .default import DefaultGenome
class DenseInitialize(DefaultGenome):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert self.max_nodes >= self.num_inputs + self.num_outputs
assert self.max_conns >= self.num_inputs * self.num_outputs
def initialize(self, state, randkey):
k1, k2 = jax.random.split(randkey, num=2)
input_idx, output_idx = self.input_idx, self.output_idx
input_size = len(input_idx)
output_size = len(output_idx)
nodes = jnp.full(
(self.max_nodes, self.node_gene.length), jnp.nan, dtype=jnp.float32
)
nodes = nodes.at[input_idx, 0].set(input_idx)
nodes = nodes.at[output_idx, 0].set(output_idx)
total_idx = input_size + output_size
rand_keys_n = jax.random.split(k1, num=total_idx)
node_attr_func = jax.vmap(self.node_gene.new_random_attrs, in_axes=(None, 0))
node_attrs = node_attr_func(state, rand_keys_n)
nodes = nodes.at[:total_idx, 1:].set(node_attrs)
conns = jnp.full(
(self.max_conns, self.conn_gene.length), jnp.nan, dtype=jnp.float32
)
input_to_output_ids, output_ids = jnp.meshgrid(
input_idx, output_idx, indexing="ij"
)
total_conns = input_size * output_size
conns = conns.at[:total_conns, :2].set(
jnp.column_stack([input_to_output_ids.flatten(), output_ids.flatten()])
)
rand_keys_c = jax.random.split(k2, num=total_conns)
conns_attr_func = jax.vmap(
self.conn_gene.new_random_attrs,
in_axes=(
None,
0,
),
)
conns_attrs = conns_attr_func(state, rand_keys_c)
conns = conns.at[:total_conns, 2:].set(conns_attrs)
return nodes, conns