refactor folder locations

This commit is contained in:
root
2024-07-10 16:40:03 +08:00
parent 3170d2a3d5
commit 4cdac932d3
25 changed files with 0 additions and 1 deletions

View File

@@ -1,3 +0,0 @@
from .base import BaseGene
from .conn import *
from .node import *

View File

@@ -1,45 +0,0 @@
import jax, jax.numpy as jnp
from tensorneat.common import State, StatefulBaseClass, hash_array
class BaseGene(StatefulBaseClass):
"Base class for node genes or connection genes."
fixed_attrs = []
custom_attrs = []
def __init__(self):
pass
def new_identity_attrs(self, state):
# the attrs which do identity transformation, used in mutate add node
raise NotImplementedError
def new_random_attrs(self, state, randkey):
# random attributes of the gene. used in initialization.
raise NotImplementedError
def mutate(self, state, randkey, attrs):
raise NotImplementedError
def crossover(self, state, randkey, attrs1, attrs2):
return jnp.where(
jax.random.normal(randkey, attrs1.shape) > 0,
attrs1,
attrs2,
)
def distance(self, state, attrs1, attrs2):
raise NotImplementedError
def forward(self, state, attrs, inputs):
raise NotImplementedError
@property
def length(self):
return len(self.fixed_attrs) + len(self.custom_attrs)
def repr(self, state, gene, precision=2):
raise NotImplementedError
def hash(self, gene):
return hash_array(gene)

View File

@@ -1,2 +0,0 @@
from .base import BaseConnGene
from .default import DefaultConnGene

View File

@@ -1,36 +0,0 @@
import jax
from .. import BaseGene
class BaseConnGene(BaseGene):
"Base class for connection genes."
fixed_attrs = ["input_index", "output_index"]
def __init__(self):
super().__init__()
def new_zero_attrs(self, state):
# the attrs which make the least influence on the network, used in mutate add conn
raise NotImplementedError
def forward(self, state, attrs, inputs):
raise NotImplementedError
def repr(self, state, conn, precision=2, idx_width=3, func_width=8):
in_idx, out_idx = conn[:2]
in_idx = int(in_idx)
out_idx = int(out_idx)
return "{}(in: {:<{idx_width}}, out: {:<{idx_width}})".format(
self.__class__.__name__, in_idx, out_idx, idx_width=idx_width
)
def to_dict(self, state, conn):
in_idx, out_idx = conn[:2]
return {
"in": int(in_idx),
"out": int(out_idx),
}
def sympy_func(self, state, conn_dict, inputs):
raise NotImplementedError

View File

@@ -1,90 +0,0 @@
import jax.numpy as jnp
import jax.random
import sympy as sp
from tensorneat.common import mutate_float
from .base import BaseConnGene
class DefaultConnGene(BaseConnGene):
"Default connection gene, with the same behavior as in NEAT-python."
custom_attrs = ["weight"]
def __init__(
self,
weight_init_mean: float = 0.0,
weight_init_std: float = 1.0,
weight_mutate_power: float = 0.5,
weight_mutate_rate: float = 0.8,
weight_replace_rate: float = 0.1,
):
super().__init__()
self.weight_init_mean = weight_init_mean
self.weight_init_std = weight_init_std
self.weight_mutate_power = weight_mutate_power
self.weight_mutate_rate = weight_mutate_rate
self.weight_replace_rate = weight_replace_rate
def new_zero_attrs(self, state):
return jnp.array([0.0]) # weight = 0
def new_identity_attrs(self, state):
return jnp.array([1.0]) # weight = 1
def new_random_attrs(self, state, randkey):
weight = (
jax.random.normal(randkey, ()) * self.weight_init_std
+ self.weight_init_mean
)
return jnp.array([weight])
def mutate(self, state, randkey, attrs):
weight = attrs[0]
weight = mutate_float(
randkey,
weight,
self.weight_init_mean,
self.weight_init_std,
self.weight_mutate_power,
self.weight_mutate_rate,
self.weight_replace_rate,
)
return jnp.array([weight])
def distance(self, state, attrs1, attrs2):
weight1 = attrs1[0]
weight2 = attrs2[0]
return jnp.abs(weight1 - weight2)
def forward(self, state, attrs, inputs):
weight = attrs[0]
return inputs * weight
def repr(self, state, conn, precision=2, idx_width=3, func_width=8):
in_idx, out_idx, weight = conn
in_idx = int(in_idx)
out_idx = int(out_idx)
weight = round(float(weight), precision)
return "{}(in: {:<{idx_width}}, out: {:<{idx_width}}, weight: {:<{float_width}})".format(
self.__class__.__name__,
in_idx,
out_idx,
weight,
idx_width=idx_width,
float_width=precision + 3,
)
def to_dict(self, state, conn):
return {
"in": int(conn[0]),
"out": int(conn[1]),
"weight": jnp.float32(conn[2]),
}
def sympy_func(self, state, conn_dict, inputs, precision=None):
weight = sp.symbols(f"c_{conn_dict['in']}_{conn_dict['out']}_w")
return inputs * weight, {weight: conn_dict["weight"]}

View File

@@ -1,3 +0,0 @@
from .base import BaseNodeGene
from .default import DefaultNodeGene
from .bias import BiasNode

View File

@@ -1,30 +0,0 @@
import jax, jax.numpy as jnp
from .. import BaseGene
class BaseNodeGene(BaseGene):
"Base class for node genes."
fixed_attrs = ["index"]
def __init__(self):
super().__init__()
def forward(self, state, attrs, inputs, is_output_node=False):
raise NotImplementedError
def repr(self, state, node, precision=2, idx_width=3, func_width=8):
idx = node[0]
idx = int(idx)
return "{}(idx={:<{idx_width}})".format(
self.__class__.__name__, idx, idx_width=idx_width
)
def to_dict(self, state, node):
idx = node[0]
return {
"idx": int(idx),
}
def sympy_func(self, state, node_dict, inputs, is_output_node=False):
raise NotImplementedError

View File

@@ -1,168 +0,0 @@
from typing import Tuple
import jax, jax.numpy as jnp
import sympy as sp
from tensorneat.common import (
Act,
Agg,
act_func,
agg_func,
mutate_int,
mutate_float,
convert_to_sympy,
)
from . import BaseNodeGene
class BiasNode(BaseNodeGene):
"""
Default node gene, with the same behavior as in NEAT-python.
The attribute response is removed.
"""
custom_attrs = ["bias", "aggregation", "activation"]
def __init__(
self,
bias_init_mean: float = 0.0,
bias_init_std: float = 1.0,
bias_mutate_power: float = 0.5,
bias_mutate_rate: float = 0.7,
bias_replace_rate: float = 0.1,
aggregation_default: callable = Agg.sum,
aggregation_options: Tuple = (Agg.sum,),
aggregation_replace_rate: float = 0.1,
activation_default: callable = Act.sigmoid,
activation_options: Tuple = (Act.sigmoid,),
activation_replace_rate: float = 0.1,
):
super().__init__()
self.bias_init_mean = bias_init_mean
self.bias_init_std = bias_init_std
self.bias_mutate_power = bias_mutate_power
self.bias_mutate_rate = bias_mutate_rate
self.bias_replace_rate = bias_replace_rate
self.aggregation_default = aggregation_options.index(aggregation_default)
self.aggregation_options = aggregation_options
self.aggregation_indices = jnp.arange(len(aggregation_options))
self.aggregation_replace_rate = aggregation_replace_rate
self.activation_default = activation_options.index(activation_default)
self.activation_options = activation_options
self.activation_indices = jnp.arange(len(activation_options))
self.activation_replace_rate = activation_replace_rate
def new_identity_attrs(self, state):
return jnp.array(
[0, self.aggregation_default, -1]
) # activation=-1 means Act.identity
def new_random_attrs(self, state, randkey):
k1, k2, k3 = jax.random.split(randkey, num=3)
bias = jax.random.normal(k1, ()) * self.bias_init_std + self.bias_init_mean
agg = jax.random.choice(k2, self.aggregation_indices)
act = jax.random.choice(k3, self.activation_indices)
return jnp.array([bias, agg, act])
def mutate(self, state, randkey, attrs):
k1, k2, k3 = jax.random.split(randkey, num=3)
bias, agg, act = attrs
bias = mutate_float(
k1,
bias,
self.bias_init_mean,
self.bias_init_std,
self.bias_mutate_power,
self.bias_mutate_rate,
self.bias_replace_rate,
)
agg = mutate_int(
k2, agg, self.aggregation_indices, self.aggregation_replace_rate
)
act = mutate_int(k3, act, self.activation_indices, self.activation_replace_rate)
return jnp.array([bias, agg, act])
def distance(self, state, attrs1, attrs2):
bias1, agg1, act1 = attrs1
bias2, agg2, act2 = attrs2
return jnp.abs(bias1 - bias2) + (agg1 != agg2) + (act1 != act2)
def forward(self, state, attrs, inputs, is_output_node=False):
bias, agg, act = attrs
z = agg_func(agg, inputs, self.aggregation_options)
z = bias + z
# the last output node should not be activated
z = jax.lax.cond(
is_output_node, lambda: z, lambda: act_func(act, z, self.activation_options)
)
return z
def repr(self, state, node, precision=2, idx_width=3, func_width=8):
idx, bias, agg, act = node
idx = int(idx)
bias = round(float(bias), precision)
agg = int(agg)
act = int(act)
if act == -1:
act_func = Act.identity
else:
act_func = self.activation_options[act]
return "{}(idx={:<{idx_width}}, bias={:<{float_width}}, aggregation={:<{func_width}}, activation={:<{func_width}})".format(
self.__class__.__name__,
idx,
bias,
self.aggregation_options[agg].__name__,
act_func.__name__,
idx_width=idx_width,
float_width=precision + 3,
func_width=func_width,
)
def to_dict(self, state, node):
idx, bias, agg, act = node
idx = int(idx)
bias = jnp.float32(bias)
agg = int(agg)
act = int(act)
if act == -1:
act_func = Act.identity
else:
act_func = self.activation_options[act]
return {
"idx": idx,
"bias": bias,
"agg": self.aggregation_options[int(agg)].__name__,
"act": act_func.__name__,
}
def sympy_func(self, state, node_dict, inputs, is_output_node=False):
nd = node_dict
bias = sp.symbols(f"n_{nd['idx']}_b")
z = convert_to_sympy(nd["agg"])(inputs)
z = bias + z
if is_output_node:
pass
else:
z = convert_to_sympy(nd["act"])(z)
return z, {bias: nd["bias"]}

View File

@@ -1,196 +0,0 @@
from typing import Tuple
import numpy as np
import jax, jax.numpy as jnp
import sympy as sp
from tensorneat.common import (
Act,
Agg,
act_func,
agg_func,
mutate_int,
mutate_float,
convert_to_sympy,
)
from . import BaseNodeGene
class DefaultNodeGene(BaseNodeGene):
"Default node gene, with the same behavior as in NEAT-python."
custom_attrs = ["bias", "response", "aggregation", "activation"]
def __init__(
self,
bias_init_mean: float = 0.0,
bias_init_std: float = 1.0,
bias_mutate_power: float = 0.5,
bias_mutate_rate: float = 0.7,
bias_replace_rate: float = 0.1,
response_init_mean: float = 1.0,
response_init_std: float = 0.0,
response_mutate_power: float = 0.5,
response_mutate_rate: float = 0.7,
response_replace_rate: float = 0.1,
aggregation_default: callable = Agg.sum,
aggregation_options: Tuple = (Agg.sum,),
aggregation_replace_rate: float = 0.1,
activation_default: callable = Act.sigmoid,
activation_options: Tuple = (Act.sigmoid,),
activation_replace_rate: float = 0.1,
):
super().__init__()
self.bias_init_mean = bias_init_mean
self.bias_init_std = bias_init_std
self.bias_mutate_power = bias_mutate_power
self.bias_mutate_rate = bias_mutate_rate
self.bias_replace_rate = bias_replace_rate
self.response_init_mean = response_init_mean
self.response_init_std = response_init_std
self.response_mutate_power = response_mutate_power
self.response_mutate_rate = response_mutate_rate
self.response_replace_rate = response_replace_rate
self.aggregation_default = aggregation_options.index(aggregation_default)
self.aggregation_options = aggregation_options
self.aggregation_indices = np.arange(len(aggregation_options))
self.aggregation_replace_rate = aggregation_replace_rate
self.activation_default = activation_options.index(activation_default)
self.activation_options = activation_options
self.activation_indices = np.arange(len(activation_options))
self.activation_replace_rate = activation_replace_rate
def new_identity_attrs(self, state):
return jnp.array(
[0, 1, self.aggregation_default, -1]
) # activation=-1 means Act.identity
def new_random_attrs(self, state, randkey):
k1, k2, k3, k4 = jax.random.split(randkey, num=4)
bias = jax.random.normal(k1, ()) * self.bias_init_std + self.bias_init_mean
res = (
jax.random.normal(k2, ()) * self.response_init_std + self.response_init_mean
)
agg = jax.random.choice(k3, self.aggregation_indices)
act = jax.random.choice(k4, self.activation_indices)
return jnp.array([bias, res, agg, act])
def mutate(self, state, randkey, attrs):
k1, k2, k3, k4 = jax.random.split(randkey, num=4)
bias, res, agg, act = attrs
bias = mutate_float(
k1,
bias,
self.bias_init_mean,
self.bias_init_std,
self.bias_mutate_power,
self.bias_mutate_rate,
self.bias_replace_rate,
)
res = mutate_float(
k2,
res,
self.response_init_mean,
self.response_init_std,
self.response_mutate_power,
self.response_mutate_rate,
self.response_replace_rate,
)
agg = mutate_int(
k4, agg, self.aggregation_indices, self.aggregation_replace_rate
)
act = mutate_int(k3, act, self.activation_indices, self.activation_replace_rate)
return jnp.array([bias, res, agg, act])
def distance(self, state, attrs1, attrs2):
bias1, res1, agg1, act1 = attrs1
bias2, res2, agg2, act2 = attrs2
return (
jnp.abs(bias1 - bias2) # bias
+ jnp.abs(res1 - res2) # response
+ (agg1 != agg2) # aggregation
+ (act1 != act2) # activation
)
def forward(self, state, attrs, inputs, is_output_node=False):
bias, res, agg, act = attrs
z = agg_func(agg, inputs, self.aggregation_options)
z = bias + res * z
# the last output node should not be activated
z = jax.lax.cond(
is_output_node, lambda: z, lambda: act_func(act, z, self.activation_options)
)
return z
def repr(self, state, node, precision=2, idx_width=3, func_width=8):
idx, bias, res, agg, act = node
idx = int(idx)
bias = round(float(bias), precision)
res = round(float(res), precision)
agg = int(agg)
act = int(act)
if act == -1:
act_func = Act.identity
else:
act_func = self.activation_options[act]
return "{}(idx={:<{idx_width}}, bias={:<{float_width}}, response={:<{float_width}}, aggregation={:<{func_width}}, activation={:<{func_width}})".format(
self.__class__.__name__,
idx,
bias,
res,
self.aggregation_options[agg].__name__,
act_func.__name__,
idx_width=idx_width,
float_width=precision + 3,
func_width=func_width,
)
def to_dict(self, state, node):
idx, bias, res, agg, act = node
idx = int(idx)
bias = jnp.float32(bias)
res = jnp.float32(res)
agg = int(agg)
act = int(act)
if act == -1:
act_func = Act.identity
else:
act_func = self.activation_options[act]
return {
"idx": idx,
"bias": bias,
"res": res,
"agg": self.aggregation_options[int(agg)].__name__,
"act": act_func.__name__,
}
def sympy_func(self, state, node_dict, inputs, is_output_node=False):
nd = node_dict
bias = sp.symbols(f"n_{nd['idx']}_b")
res = sp.symbols(f"n_{nd['idx']}_r")
z = convert_to_sympy(nd["agg"])(inputs)
z = bias + res * z
if is_output_node:
pass
else:
z = convert_to_sympy(nd["act"])(z)
return z, {bias: nd["bias"], res: nd["res"]}

View File

@@ -1,4 +0,0 @@
from .base import BaseGenome
from .default import DefaultGenome
from .recurrent import RecurrentGenome

View File

@@ -1,224 +0,0 @@
from typing import Callable, Sequence
import numpy as np
import jax
from jax import vmap, numpy as jnp
from ..gene import BaseNodeGene, BaseConnGene
from .operations import BaseMutation, BaseCrossover, BaseDistance
from tensorneat.common import (
State,
StatefulBaseClass,
hash_array,
)
from .utils import valid_cnt
class BaseGenome(StatefulBaseClass):
network_type = None
def __init__(
self,
num_inputs: int,
num_outputs: int,
max_nodes: int,
max_conns: int,
node_gene: BaseNodeGene,
conn_gene: BaseConnGene,
mutation: BaseMutation,
crossover: BaseCrossover,
distance: BaseDistance,
output_transform: Callable = None,
input_transform: Callable = None,
init_hidden_layers: Sequence[int] = (),
):
# check transform functions
if input_transform is not None:
try:
_ = input_transform(jnp.zeros(num_inputs))
except Exception as e:
raise ValueError(f"Output transform function failed: {e}")
if output_transform is not None:
try:
_ = output_transform(jnp.zeros(num_outputs))
except Exception as e:
raise ValueError(f"Output transform function failed: {e}")
# prepare for initialization
all_layers = [num_inputs] + list(init_hidden_layers) + [num_outputs]
layer_indices = []
next_index = 0
for layer in all_layers:
layer_indices.append(list(range(next_index, next_index + layer)))
next_index += layer
all_init_nodes = []
all_init_conns_in_idx = []
all_init_conns_out_idx = []
for i in range(len(layer_indices) - 1):
in_layer = layer_indices[i]
out_layer = layer_indices[i + 1]
for in_idx in in_layer:
for out_idx in out_layer:
all_init_conns_in_idx.append(in_idx)
all_init_conns_out_idx.append(out_idx)
all_init_nodes.extend(in_layer)
all_init_nodes.extend(layer_indices[-1]) # output layer
if max_nodes < len(all_init_nodes):
raise ValueError(
f"max_nodes={max_nodes} must be greater than or equal to the number of initial nodes={len(all_init_nodes)}"
)
if max_conns < len(all_init_conns_in_idx):
raise ValueError(
f"max_conns={max_conns} must be greater than or equal to the number of initial connections={len(all_init_conns_in_idx)}"
)
self.num_inputs = num_inputs
self.num_outputs = num_outputs
self.max_nodes = max_nodes
self.max_conns = max_conns
self.node_gene = node_gene
self.conn_gene = conn_gene
self.mutation = mutation
self.crossover = crossover
self.distance = distance
self.output_transform = output_transform
self.input_transform = input_transform
self.input_idx = np.array(layer_indices[0])
self.output_idx = np.array(layer_indices[-1])
self.all_init_nodes = np.array(all_init_nodes)
self.all_init_conns = np.c_[all_init_conns_in_idx, all_init_conns_out_idx]
print(self.output_idx)
def setup(self, state=State()):
state = self.node_gene.setup(state)
state = self.conn_gene.setup(state)
state = self.mutation.setup(state, self)
state = self.crossover.setup(state, self)
state = self.distance.setup(state, self)
return state
def transform(self, state, nodes, conns):
raise NotImplementedError
def forward(self, state, transformed, inputs):
raise NotImplementedError
def sympy_func(self):
raise NotImplementedError
def visualize(self):
raise NotImplementedError
def execute_mutation(self, state, randkey, nodes, conns, new_node_key):
return self.mutation(state, randkey, nodes, conns, new_node_key)
def execute_crossover(self, state, randkey, nodes1, conns1, nodes2, conns2):
return self.crossover(state, randkey, nodes1, conns1, nodes2, conns2)
def execute_distance(self, state, nodes1, conns1, nodes2, conns2):
return self.distance(state, nodes1, conns1, nodes2, conns2)
def initialize(self, state, randkey):
k1, k2 = jax.random.split(randkey) # k1 for nodes, k2 for conns
all_nodes_cnt = len(self.all_init_nodes)
all_conns_cnt = len(self.all_init_conns)
# initialize nodes
nodes = jnp.full((self.max_nodes, self.node_gene.length), jnp.nan)
# create node indices
node_indices = self.all_init_nodes
# create node attrs
rand_keys_n = jax.random.split(k1, num=all_nodes_cnt)
node_attr_func = vmap(self.node_gene.new_random_attrs, in_axes=(None, 0))
node_attrs = node_attr_func(state, rand_keys_n)
nodes = nodes.at[:all_nodes_cnt, 0].set(node_indices) # set node indices
nodes = nodes.at[:all_nodes_cnt, 1:].set(node_attrs) # set node attrs
# initialize conns
conns = jnp.full((self.max_conns, self.conn_gene.length), jnp.nan)
# create input and output indices
conn_indices = self.all_init_conns
# create conn attrs
rand_keys_c = jax.random.split(k2, num=all_conns_cnt)
conns_attr_func = jax.vmap(
self.conn_gene.new_random_attrs,
in_axes=(
None,
0,
),
)
conns_attrs = conns_attr_func(state, rand_keys_c)
conns = conns.at[:all_conns_cnt, :2].set(conn_indices) # set conn indices
conns = conns.at[:all_conns_cnt, 2:].set(conns_attrs) # set conn attrs
return nodes, conns
def network_dict(self, state, nodes, conns):
return {
"nodes": self._get_node_dict(state, nodes),
"conns": self._get_conn_dict(state, conns),
}
def get_input_idx(self):
return self.input_idx.tolist()
def get_output_idx(self):
return self.output_idx.tolist()
def hash(self, nodes, conns):
nodes_hashs = vmap(hash_array)(nodes)
conns_hashs = vmap(hash_array)(conns)
return hash_array(jnp.concatenate([nodes_hashs, conns_hashs]))
def repr(self, state, nodes, conns, precision=2):
nodes, conns = jax.device_get([nodes, conns])
nodes_cnt, conns_cnt = valid_cnt(nodes), valid_cnt(conns)
s = f"{self.__class__.__name__}(nodes={nodes_cnt}, conns={conns_cnt}):\n"
s += f"\tNodes:\n"
for node in nodes:
if np.isnan(node[0]):
break
s += f"\t\t{self.node_gene.repr(state, node, precision=precision)}"
node_idx = int(node[0])
if np.isin(node_idx, self.input_idx):
s += " (input)"
elif np.isin(node_idx, self.output_idx):
s += " (output)"
s += "\n"
s += f"\tConns:\n"
for conn in conns:
if np.isnan(conn[0]):
break
s += f"\t\t{self.conn_gene.repr(state, conn, precision=precision)}\n"
return s
def _get_conn_dict(self, state, conns):
conns = jax.device_get(conns)
conn_dict = {}
for conn in conns:
if np.isnan(conn[0]):
continue
cd = self.conn_gene.to_dict(state, conn)
in_idx, out_idx = cd["in"], cd["out"]
conn_dict[(in_idx, out_idx)] = cd
return conn_dict
def _get_node_dict(self, state, nodes):
nodes = jax.device_get(nodes)
node_dict = {}
for node in nodes:
if np.isnan(node[0]):
continue
nd = self.node_gene.to_dict(state, node)
idx = nd["idx"]
node_dict[idx] = nd
return node_dict

View File

@@ -1,321 +0,0 @@
import warnings
import jax
from jax import vmap, numpy as jnp
import numpy as np
import sympy as sp
from . import BaseGenome
from ..gene import DefaultNodeGene, DefaultConnGene
from .operations import DefaultMutation, DefaultCrossover, DefaultDistance
from .utils import unflatten_conns, extract_node_attrs, extract_conn_attrs
from tensorneat.common import (
topological_sort,
topological_sort_python,
I_INF,
attach_with_inf,
SYMPY_FUNCS_MODULE_NP,
SYMPY_FUNCS_MODULE_JNP,
)
class DefaultGenome(BaseGenome):
"""Default genome class, with the same behavior as the NEAT-Python"""
network_type = "feedforward"
def __init__(
self,
num_inputs: int,
num_outputs: int,
max_nodes=50,
max_conns=100,
node_gene=DefaultNodeGene(),
conn_gene=DefaultConnGene(),
mutation=DefaultMutation(),
crossover=DefaultCrossover(),
distance=DefaultDistance(),
output_transform=None,
input_transform=None,
init_hidden_layers=(),
):
super().__init__(
num_inputs,
num_outputs,
max_nodes,
max_conns,
node_gene,
conn_gene,
mutation,
crossover,
distance,
output_transform,
input_transform,
init_hidden_layers,
)
def transform(self, state, nodes, conns):
u_conns = unflatten_conns(nodes, conns)
conn_exist = u_conns != I_INF
seqs = topological_sort(nodes, conn_exist)
return seqs, nodes, conns, u_conns
def forward(self, state, transformed, inputs):
if self.input_transform is not None:
inputs = self.input_transform(inputs)
cal_seqs, nodes, conns, u_conns = transformed
ini_vals = jnp.full((self.max_nodes,), jnp.nan)
ini_vals = ini_vals.at[self.input_idx].set(inputs)
nodes_attrs = vmap(extract_node_attrs)(nodes)
conns_attrs = vmap(extract_conn_attrs)(conns)
def cond_fun(carry):
values, idx = carry
return (idx < self.max_nodes) & (
cal_seqs[idx] != I_INF
) # not out of bounds and next node exists
def body_func(carry):
values, idx = carry
i = cal_seqs[idx]
def input_node():
return values
def otherwise():
# calculate connections
conn_indices = u_conns[:, i]
hit_attrs = attach_with_inf(conns_attrs, conn_indices) # fetch conn attrs
ins = vmap(self.conn_gene.forward, in_axes=(None, 0, 0))(
state, hit_attrs, values
)
# calculate nodes
z = self.node_gene.forward(
state,
nodes_attrs[i],
ins,
is_output_node=jnp.isin(nodes[0], self.output_idx), # nodes[0] -> the key of nodes
)
# set new value
new_values = values.at[i].set(z)
return new_values
values = jax.lax.cond(jnp.isin(i, self.input_idx), input_node, otherwise)
return values, idx + 1
vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0))
if self.output_transform is None:
return vals[self.output_idx]
else:
return self.output_transform(vals[self.output_idx])
def network_dict(self, state, nodes, conns):
network = super().network_dict(state, nodes, conns)
topo_order, topo_layers = topological_sort_python(
set(network["nodes"]), set(network["conns"])
)
network["topo_order"] = topo_order
network["topo_layers"] = topo_layers
return network
def sympy_func(
self,
state,
network,
sympy_input_transform=None,
sympy_output_transform=None,
backend="jax",
):
assert backend in ["jax", "numpy"], "backend should be 'jax' or 'numpy'"
module = SYMPY_FUNCS_MODULE_JNP if backend == "jax" else SYMPY_FUNCS_MODULE_NP
if sympy_input_transform is None and self.input_transform is not None:
warnings.warn(
"genome.input_transform is not None but sympy_input_transform is None!"
)
if sympy_input_transform is None:
sympy_input_transform = lambda x: x
if sympy_input_transform is not None:
if not isinstance(sympy_input_transform, list):
sympy_input_transform = [sympy_input_transform] * self.num_inputs
if sympy_output_transform is None and self.output_transform is not None:
warnings.warn(
"genome.output_transform is not None but sympy_output_transform is None!"
)
input_idx = self.get_input_idx()
output_idx = self.get_output_idx()
order = network["topo_order"]
hidden_idx = [
i for i in network["nodes"] if i not in input_idx and i not in output_idx
]
symbols = {}
for i in network["nodes"]:
if i in input_idx:
symbols[-i - 1] = sp.Symbol(f"i{i - min(input_idx)}") # origin_i
symbols[i] = sp.Symbol(f"norm{i - min(input_idx)}")
elif i in output_idx:
symbols[i] = sp.Symbol(f"o{i - min(output_idx)}")
else: # hidden
symbols[i] = sp.Symbol(f"h{i - min(hidden_idx)}")
nodes_exprs = {}
args_symbols = {}
for i in order:
if i in input_idx:
nodes_exprs[symbols[-i - 1]] = symbols[
-i - 1
] # origin equal to its symbol
nodes_exprs[symbols[i]] = sympy_input_transform[i - min(input_idx)](
symbols[-i - 1]
) # normed i
else:
in_conns = [c for c in network["conns"] if c[1] == i]
node_inputs = []
for conn in in_conns:
val_represent = symbols[conn[0]]
# a_s -> args_symbols
val, a_s = self.conn_gene.sympy_func(
state,
network["conns"][conn],
val_represent,
)
args_symbols.update(a_s)
node_inputs.append(val)
nodes_exprs[symbols[i]], a_s = self.node_gene.sympy_func(
state,
network["nodes"][i],
node_inputs,
is_output_node=(i in output_idx),
)
args_symbols.update(a_s)
if i in output_idx and sympy_output_transform is not None:
nodes_exprs[symbols[i]] = sympy_output_transform(
nodes_exprs[symbols[i]]
)
input_symbols = [symbols[-i - 1] for i in input_idx]
reduced_exprs = nodes_exprs.copy()
for i in order:
reduced_exprs[symbols[i]] = reduced_exprs[symbols[i]].subs(reduced_exprs)
output_exprs = [reduced_exprs[symbols[i]] for i in output_idx]
lambdify_output_funcs = [
sp.lambdify(
input_symbols + list(args_symbols.keys()),
exprs,
modules=[backend, module],
)
for exprs in output_exprs
]
fixed_args_output_funcs = []
for i in range(len(output_idx)):
def f(inputs, i=i):
return lambdify_output_funcs[i](*inputs, *args_symbols.values())
fixed_args_output_funcs.append(f)
forward_func = lambda inputs: jnp.array(
[f(inputs) for f in fixed_args_output_funcs]
)
return (
symbols,
args_symbols,
input_symbols,
nodes_exprs,
output_exprs,
forward_func,
)
def visualize(
self,
network,
rotate=0,
reverse_node_order=False,
size=(300, 300, 300),
color=("blue", "blue", "blue"),
save_path="network.svg",
save_dpi=800,
**kwargs,
):
import networkx as nx
from matplotlib import pyplot as plt
nodes_list = list(network["nodes"])
conns_list = list(network["conns"])
input_idx = self.get_input_idx()
output_idx = self.get_output_idx()
topo_order, topo_layers = network["topo_order"], network["topo_layers"]
node2layer = {
node: layer for layer, nodes in enumerate(topo_layers) for node in nodes
}
if reverse_node_order:
topo_order = topo_order[::-1]
G = nx.DiGraph()
if not isinstance(size, tuple):
size = (size, size, size)
if not isinstance(color, tuple):
color = (color, color, color)
for node in topo_order:
if node in input_idx:
G.add_node(node, subset=node2layer[node], size=size[0], color=color[0])
elif node in output_idx:
G.add_node(node, subset=node2layer[node], size=size[2], color=color[2])
else:
G.add_node(node, subset=node2layer[node], size=size[1], color=color[1])
for conn in conns_list:
G.add_edge(conn[0], conn[1])
pos = nx.multipartite_layout(G)
def rotate_layout(pos, angle):
angle_rad = np.deg2rad(angle)
cos_angle, sin_angle = np.cos(angle_rad), np.sin(angle_rad)
rotated_pos = {}
for node, (x, y) in pos.items():
rotated_pos[node] = (
cos_angle * x - sin_angle * y,
sin_angle * x + cos_angle * y,
)
return rotated_pos
rotated_pos = rotate_layout(pos, rotate)
node_sizes = [n["size"] for n in G.nodes.values()]
node_colors = [n["color"] for n in G.nodes.values()]
nx.draw(
G,
pos=rotated_pos,
node_size=node_sizes,
node_color=node_colors,
**kwargs,
)
plt.savefig(save_path, dpi=save_dpi)

View File

@@ -1,3 +0,0 @@
from .crossover import BaseCrossover, DefaultCrossover
from .mutation import BaseMutation, DefaultMutation
from .distance import BaseDistance, DefaultDistance

View File

@@ -1,2 +0,0 @@
from .base import BaseCrossover
from .default import DefaultCrossover

View File

@@ -1,12 +0,0 @@
from tensorneat.common import StatefulBaseClass, State
class BaseCrossover(StatefulBaseClass):
def setup(self, state=State(), genome = None):
assert genome is not None, "genome should not be None"
self.genome = genome
return state
def __call__(self, state, randkey, nodes1, nodes2, conns1, conns2):
raise NotImplementedError

View File

@@ -1,87 +0,0 @@
import jax
from jax import vmap, numpy as jnp
from .base import BaseCrossover
from ...utils import (
extract_node_attrs,
extract_conn_attrs,
set_node_attrs,
set_conn_attrs,
)
class DefaultCrossover(BaseCrossover):
def __call__(self, state, randkey, nodes1, conns1, nodes2, conns2):
"""
use genome1 and genome2 to generate a new genome
notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
"""
randkey1, randkey2 = jax.random.split(randkey, 2)
randkeys1 = jax.random.split(randkey1, self.genome.max_nodes)
randkeys2 = jax.random.split(randkey2, self.genome.max_conns)
# crossover nodes
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
# make homologous genes align in nodes2 align with nodes1
nodes2 = self.align_array(keys1, keys2, nodes2, is_conn=False)
# For not homologous genes, use the value of nodes1(winner)
# For homologous genes, use the crossover result between nodes1 and nodes2
node_attrs1 = vmap(extract_node_attrs)(nodes1)
node_attrs2 = vmap(extract_node_attrs)(nodes2)
new_node_attrs = jnp.where(
jnp.isnan(node_attrs1) | jnp.isnan(node_attrs2), # one of them is nan
node_attrs1, # not homologous genes or both nan, use the value of nodes1(winner)
vmap(self.genome.node_gene.crossover, in_axes=(None, 0, 0, 0))(
state, randkeys1, node_attrs1, node_attrs2
), # homologous or both nan
)
new_nodes = vmap(set_node_attrs)(nodes1, new_node_attrs)
# crossover connections
con_keys1, con_keys2 = conns1[:, :2], conns2[:, :2]
conns2 = self.align_array(con_keys1, con_keys2, conns2, is_conn=True)
conns_attrs1 = vmap(extract_conn_attrs)(conns1)
conns_attrs2 = vmap(extract_conn_attrs)(conns2)
new_conn_attrs = jnp.where(
jnp.isnan(conns_attrs1) | jnp.isnan(conns_attrs2),
conns_attrs1, # not homologous genes or both nan, use the value of conns1(winner)
vmap(self.genome.conn_gene.crossover, in_axes=(None, 0, 0, 0))(
state, randkeys2, conns_attrs1, conns_attrs2
), # homologous or both nan
)
new_conns = vmap(set_conn_attrs)(conns1, new_conn_attrs)
return new_nodes, new_conns
def align_array(self, seq1, seq2, ar2, is_conn: bool):
"""
After I review this code, I found that it is the most difficult part of the code.
Please consider carefully before change it!
make ar2 align with ar1.
:param seq1:
:param seq2:
:param ar2:
:param is_conn:
:return:
align means to intersect part of ar2 will be at the same position as ar1,
non-intersect part of ar2 will be set to Nan
"""
seq1, seq2 = seq1[:, jnp.newaxis], seq2[jnp.newaxis, :]
mask = (seq1 == seq2) & (~jnp.isnan(seq1))
if is_conn:
mask = jnp.all(mask, axis=2)
intersect_mask = mask.any(axis=1)
idx = jnp.arange(0, len(seq1))
idx_fixed = jnp.dot(mask, idx)
refactor_ar2 = jnp.where(
intersect_mask[:, jnp.newaxis], ar2[idx_fixed], jnp.nan
)
return refactor_ar2

View File

@@ -1,2 +0,0 @@
from .base import BaseDistance
from .default import DefaultDistance

View File

@@ -1,15 +0,0 @@
from tensorneat.common import StatefulBaseClass, State
class BaseDistance(StatefulBaseClass):
def setup(self, state=State(), genome = None):
assert genome is not None, "genome should not be None"
self.genome = genome
return state
def __call__(self, state, nodes1, nodes2, conns1, conns2):
"""
The distance between two genomes
"""
raise NotImplementedError

View File

@@ -1,105 +0,0 @@
from jax import vmap, numpy as jnp
from .base import BaseDistance
from ...utils import extract_node_attrs, extract_conn_attrs
class DefaultDistance(BaseDistance):
def __init__(
self,
compatibility_disjoint: float = 1.0,
compatibility_weight: float = 0.4,
):
self.compatibility_disjoint = compatibility_disjoint
self.compatibility_weight = compatibility_weight
def __call__(self, state, nodes1, nodes2, conns1, conns2):
"""
The distance between two genomes
"""
d = self.node_distance(state, nodes1, nodes2) + self.conn_distance(
state, conns1, conns2
)
return d
def node_distance(self, state, nodes1, nodes2):
"""
The distance of the nodes part for two genomes
"""
node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0]))
node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0]))
max_cnt = jnp.maximum(node_cnt1, node_cnt2)
# align homologous nodes
# this process is similar to np.intersect1d.
nodes = jnp.concatenate((nodes1, nodes2), axis=0)
keys = nodes[:, 0]
sorted_indices = jnp.argsort(keys, axis=0)
nodes = nodes[sorted_indices]
nodes = jnp.concatenate(
[nodes, jnp.full((1, nodes.shape[1]), jnp.nan)], axis=0
) # add a nan row to the end
fr, sr = nodes[:-1], nodes[1:] # first row, second row
# flag location of homologous nodes
intersect_mask = (fr[:, 0] == sr[:, 0]) & ~jnp.isnan(nodes[:-1, 0])
# calculate the count of non_homologous of two genomes
non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask)
# calculate the distance of homologous nodes
fr_attrs = vmap(extract_node_attrs)(fr)
sr_attrs = vmap(extract_node_attrs)(sr)
hnd = vmap(self.genome.node_gene.distance, in_axes=(None, 0, 0))(
state, fr_attrs, sr_attrs
) # homologous node distance
hnd = jnp.where(jnp.isnan(hnd), 0, hnd)
homologous_distance = jnp.sum(hnd * intersect_mask)
val = (
non_homologous_cnt * self.compatibility_disjoint
+ homologous_distance * self.compatibility_weight
)
val = jnp.where(max_cnt == 0, 0, val / max_cnt) # normalize
return val
def conn_distance(self, state, conns1, conns2):
"""
The distance of the conns part for two genomes
"""
con_cnt1 = jnp.sum(~jnp.isnan(conns1[:, 0]))
con_cnt2 = jnp.sum(~jnp.isnan(conns2[:, 0]))
max_cnt = jnp.maximum(con_cnt1, con_cnt2)
cons = jnp.concatenate((conns1, conns2), axis=0)
keys = cons[:, :2]
sorted_indices = jnp.lexsort(keys.T[::-1])
cons = cons[sorted_indices]
cons = jnp.concatenate(
[cons, jnp.full((1, cons.shape[1]), jnp.nan)], axis=0
) # add a nan row to the end
fr, sr = cons[:-1], cons[1:] # first row, second row
# both genome has such connection
intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0])
non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask)
fr_attrs = vmap(extract_conn_attrs)(fr)
sr_attrs = vmap(extract_conn_attrs)(sr)
hcd = vmap(self.genome.conn_gene.distance, in_axes=(None, 0, 0))(
state, fr_attrs, sr_attrs
) # homologous connection distance
hcd = jnp.where(jnp.isnan(hcd), 0, hcd)
homologous_distance = jnp.sum(hcd * intersect_mask)
val = (
non_homologous_cnt * self.compatibility_disjoint
+ homologous_distance * self.compatibility_weight
)
val = jnp.where(max_cnt == 0, 0, val / max_cnt) # normalize
return val

View File

@@ -1,2 +0,0 @@
from .base import BaseMutation
from .default import DefaultMutation

View File

@@ -1,12 +0,0 @@
from tensorneat.common import StatefulBaseClass, State
class BaseMutation(StatefulBaseClass):
def setup(self, state=State(), genome = None):
assert genome is not None, "genome should not be None"
self.genome = genome
return state
def __call__(self, state, randkey, genome, nodes, conns, new_node_key):
raise NotImplementedError

View File

@@ -1,292 +0,0 @@
import jax
from jax import vmap, numpy as jnp
from . import BaseMutation
from tensorneat.common import (
fetch_first,
fetch_random,
I_INF,
check_cycles,
)
from ...utils import (
unflatten_conns,
add_node,
add_conn,
delete_node_by_pos,
delete_conn_by_pos,
extract_node_attrs,
extract_conn_attrs,
set_node_attrs,
set_conn_attrs,
)
class DefaultMutation(BaseMutation):
def __init__(
self,
conn_add: float = 0.2,
conn_delete: float = 0,
node_add: float = 0.2,
node_delete: float = 0,
):
self.conn_add = conn_add
self.conn_delete = conn_delete
self.node_add = node_add
self.node_delete = node_delete
def __call__(self, state, randkey, genome, nodes, conns, new_node_key):
k1, k2 = jax.random.split(randkey)
nodes, conns = self.mutate_structure(
state, k1, genome, nodes, conns, new_node_key
)
nodes, conns = self.mutate_values(state, k2, genome, nodes, conns)
return nodes, conns
def mutate_structure(self, state, randkey, genome, nodes, conns, new_node_key):
def mutate_add_node(key_, nodes_, conns_):
"""
add a node while do not influence the output of the network
"""
remain_node_space = jnp.isnan(nodes_[:, 0]).sum()
remain_conn_space = jnp.isnan(conns_[:, 0]).sum()
i_key, o_key, idx = self.choose_connection_key(
key_, conns_
) # choose a connection
def successful_add_node():
# remove the original connection and record its attrs
original_attrs = extract_conn_attrs(conns_[idx])
new_conns = delete_conn_by_pos(conns_, idx)
# add a new node with identity attrs
new_nodes = add_node(
nodes_, new_node_key, genome.node_gene.new_identity_attrs(state)
)
# add two new connections
# first is with identity attrs
new_conns = add_conn(
new_conns,
i_key,
new_node_key,
genome.conn_gene.new_identity_attrs(state),
)
# second is with the origin attrs
new_conns = add_conn(
new_conns,
new_node_key,
o_key,
original_attrs,
)
return new_nodes, new_conns
return jax.lax.cond(
(idx == I_INF) | (remain_node_space < 1) | (remain_conn_space < 2),
lambda: (nodes_, conns_), # do nothing
successful_add_node,
)
def mutate_delete_node(key_, nodes_, conns_):
"""
delete a node
"""
# randomly choose a node
key, idx = self.choose_node_key(
key_,
nodes_,
genome.input_idx,
genome.output_idx,
allow_input_keys=False,
allow_output_keys=False,
)
def successful_delete_node():
# delete the node
new_nodes = delete_node_by_pos(nodes_, idx)
# delete all connections
new_conns = jnp.where(
((conns_[:, 0] == key) | (conns_[:, 1] == key))[:, None],
jnp.nan,
conns_,
)
return new_nodes, new_conns
return jax.lax.cond(
idx == I_INF, # no available node to delete
lambda: (nodes_, conns_), # do nothing
successful_delete_node,
)
def mutate_add_conn(key_, nodes_, conns_):
"""
add a connection while do not influence the output of the network
"""
remain_conn_space = jnp.isnan(conns_[:, 0]).sum()
# randomly choose two nodes
k1_, k2_ = jax.random.split(key_, num=2)
# input node of the connection can be any node
i_key, from_idx = self.choose_node_key(
k1_,
nodes_,
genome.input_idx,
genome.output_idx,
allow_input_keys=True,
allow_output_keys=True,
)
# output node of the connection can be any node except input node
o_key, to_idx = self.choose_node_key(
k2_,
nodes_,
genome.input_idx,
genome.output_idx,
allow_input_keys=False,
allow_output_keys=True,
)
conn_pos = fetch_first((conns_[:, 0] == i_key) & (conns_[:, 1] == o_key))
is_already_exist = conn_pos != I_INF
def nothing():
return nodes_, conns_
def successful():
# add a connection with zero attrs
return nodes_, add_conn(
conns_, i_key, o_key, genome.conn_gene.new_zero_attrs(state)
)
if genome.network_type == "feedforward":
u_conns = unflatten_conns(nodes_, conns_)
conns_exist = u_conns != I_INF
is_cycle = check_cycles(nodes_, conns_exist, from_idx, to_idx)
return jax.lax.cond(
is_already_exist | is_cycle | (remain_conn_space < 1),
nothing,
successful,
)
elif genome.network_type == "recurrent":
return jax.lax.cond(
is_already_exist | (remain_conn_space < 1),
nothing,
successful,
)
else:
raise ValueError(f"Invalid network type: {genome.network_type}")
def mutate_delete_conn(key_, nodes_, conns_):
# randomly choose a connection
i_key, o_key, idx = self.choose_connection_key(key_, conns_)
return jax.lax.cond(
idx == I_INF,
lambda: (nodes_, conns_), # nothing
lambda: (nodes_, delete_conn_by_pos(conns_, idx)), # success
)
k1, k2, k3, k4 = jax.random.split(randkey, num=4)
r1, r2, r3, r4 = jax.random.uniform(k1, shape=(4,))
def nothing(_, nodes_, conns_):
return nodes_, conns_
if self.node_add > 0:
nodes, conns = jax.lax.cond(
r1 < self.node_add, mutate_add_node, nothing, k1, nodes, conns
)
if self.node_delete > 0:
nodes, conns = jax.lax.cond(
r2 < self.node_delete, mutate_delete_node, nothing, k2, nodes, conns
)
if self.conn_add > 0:
nodes, conns = jax.lax.cond(
r3 < self.conn_add, mutate_add_conn, nothing, k3, nodes, conns
)
if self.conn_delete > 0:
nodes, conns = jax.lax.cond(
r4 < self.conn_delete, mutate_delete_conn, nothing, k4, nodes, conns
)
return nodes, conns
def mutate_values(self, state, randkey, genome, nodes, conns):
k1, k2 = jax.random.split(randkey)
nodes_randkeys = jax.random.split(k1, num=genome.max_nodes)
conns_randkeys = jax.random.split(k2, num=genome.max_conns)
node_attrs = vmap(extract_node_attrs)(nodes)
new_node_attrs = vmap(genome.node_gene.mutate, in_axes=(None, 0, 0))(
state, nodes_randkeys, node_attrs
)
new_nodes = vmap(set_node_attrs)(nodes, new_node_attrs)
conn_attrs = vmap(extract_conn_attrs)(conns)
new_conn_attrs = vmap(genome.conn_gene.mutate, in_axes=(None, 0, 0))(
state, conns_randkeys, conn_attrs
)
new_conns = vmap(set_conn_attrs)(conns, new_conn_attrs)
# nan nodes not changed
new_nodes = jnp.where(jnp.isnan(nodes), jnp.nan, new_nodes)
new_conns = jnp.where(jnp.isnan(conns), jnp.nan, new_conns)
return new_nodes, new_conns
def choose_node_key(
self,
key,
nodes,
input_idx,
output_idx,
allow_input_keys: bool = False,
allow_output_keys: bool = False,
):
"""
Randomly choose a node key from the given nodes. It guarantees that the chosen node not be the input or output node.
:param key:
:param nodes:
:param input_idx:
:param output_idx:
:param allow_input_keys:
:param allow_output_keys:
:return: return its key and position(idx)
"""
node_keys = nodes[:, 0]
mask = ~jnp.isnan(node_keys)
if not allow_input_keys:
mask = jnp.logical_and(mask, ~jnp.isin(node_keys, input_idx))
if not allow_output_keys:
mask = jnp.logical_and(mask, ~jnp.isin(node_keys, output_idx))
idx = fetch_random(key, mask)
key = jnp.where(idx != I_INF, nodes[idx, 0], jnp.nan)
return key, idx
def choose_connection_key(self, key, conns):
"""
Randomly choose a connection key from the given connections.
:return: i_key, o_key, idx
"""
idx = fetch_random(key, ~jnp.isnan(conns[:, 0]))
i_key = jnp.where(idx != I_INF, conns[idx, 0], jnp.nan)
o_key = jnp.where(idx != I_INF, conns[idx, 1], jnp.nan)
return i_key, o_key, idx

View File

@@ -1,92 +0,0 @@
import jax
from jax import vmap, numpy as jnp
from .utils import unflatten_conns
from .base import BaseGenome
from .operations import DefaultMutation, DefaultCrossover, DefaultDistance
from .utils import unflatten_conns, extract_node_attrs, extract_conn_attrs
from ..gene import DefaultNodeGene, DefaultConnGene
from tensorneat.common import attach_with_inf
class RecurrentGenome(BaseGenome):
"""Default genome class, with the same behavior as the NEAT-Python"""
network_type = "recurrent"
def __init__(
self,
num_inputs: int,
num_outputs: int,
max_nodes=50,
max_conns=100,
node_gene=DefaultNodeGene(),
conn_gene=DefaultConnGene(),
mutation=DefaultMutation(),
crossover=DefaultCrossover(),
distance=DefaultDistance(),
output_transform=None,
input_transform=None,
init_hidden_layers=(),
activate_time=10,
):
super().__init__(
num_inputs,
num_outputs,
max_nodes,
max_conns,
node_gene,
conn_gene,
mutation,
crossover,
distance,
output_transform,
input_transform,
init_hidden_layers,
)
self.activate_time = activate_time
def transform(self, state, nodes, conns):
u_conns = unflatten_conns(nodes, conns)
return nodes, conns, u_conns
def forward(self, state, transformed, inputs):
nodes, conns, u_conns = transformed
vals = jnp.full((self.max_nodes,), jnp.nan)
nodes_attrs = vmap(extract_node_attrs)(nodes)
conns_attrs = vmap(extract_conn_attrs)(conns)
expand_conns_attrs = attach_with_inf(conns_attrs, u_conns)
def body_func(_, values):
# set input values
values = values.at[self.input_idx].set(inputs)
# calculate connections
node_ins = vmap(
vmap(self.conn_gene.forward, in_axes=(None, 0, None)),
in_axes=(None, 0, 0),
)(state, expand_conns_attrs, values)
# calculate nodes
is_output_nodes = jnp.isin(nodes[:, 0], self.output_idx)
values = vmap(self.node_gene.forward, in_axes=(None, 0, 0, 0))(
state, nodes_attrs, node_ins.T, is_output_nodes
)
return values
vals = jax.lax.fori_loop(0, self.activate_time, body_func, vals)
if self.output_transform is None:
return vals[self.output_idx]
else:
return self.output_transform(vals[self.output_idx])
def sympy_func(self, state, network, precision=3):
raise ValueError("Sympy function is not supported for Recurrent Network!")
def visualize(self, network):
raise ValueError("Visualize function is not supported for Recurrent Network!")

View File

@@ -1,109 +0,0 @@
import jax
from jax import vmap, numpy as jnp
from tensorneat.common import fetch_first, I_INF
def unflatten_conns(nodes, conns):
"""
transform the (C, CL) connections to (N, N), which contains the idx of the connection in conns
connection length, N means the number of nodes, C means the number of connections
returns the unflatten connection indices with shape (N, N)
"""
N = nodes.shape[0] # max_nodes
C = conns.shape[0] # max_conns
node_keys = nodes[:, 0]
i_keys, o_keys = conns[:, 0], conns[:, 1]
def key_to_indices(key, keys):
return fetch_first(key == keys)
i_idxs = vmap(key_to_indices, in_axes=(0, None))(i_keys, node_keys)
o_idxs = vmap(key_to_indices, in_axes=(0, None))(o_keys, node_keys)
# Is interesting that jax use clip when attach data in array
# however, it will do nothing when setting values in an array
# put the index of connections in the unflatten array
unflatten = (
jnp.full((N, N), I_INF, dtype=jnp.int32)
.at[i_idxs, o_idxs]
.set(jnp.arange(C, dtype=jnp.int32))
)
return unflatten
def valid_cnt(nodes_or_conns):
return jnp.sum(~jnp.isnan(nodes_or_conns[:, 0]))
def extract_node_attrs(node):
"""
node: Array(NL, )
extract the attributes of a node
"""
return node[1:] # 0 is for idx
def set_node_attrs(node, attrs):
"""
node: Array(NL, )
attrs: Array(NL-1, )
set the attributes of a node
"""
return node.at[1:].set(attrs) # 0 is for idx
def extract_conn_attrs(conn):
"""
conn: Array(CL, )
extract the attributes of a connection
"""
return conn[2:] # 0, 1 is for in-idx and out-idx
def set_conn_attrs(conn, attrs):
"""
conn: Array(CL, )
attrs: Array(CL-2, )
set the attributes of a connection
"""
return conn.at[2:].set(attrs) # 0, 1 is for in-idx and out-idx
def add_node(nodes, new_key: int, attrs):
"""
Add a new node to the genome.
The new node will place at the first NaN row.
"""
exist_keys = nodes[:, 0]
pos = fetch_first(jnp.isnan(exist_keys))
new_nodes = nodes.at[pos, 0].set(new_key)
return new_nodes.at[pos, 1:].set(attrs)
def delete_node_by_pos(nodes, pos):
"""
Delete a node from the genome.
Delete the node by its pos in nodes.
"""
return nodes.at[pos].set(jnp.nan)
def add_conn(conns, i_key, o_key, attrs):
"""
Add a new connection to the genome.
The new connection will place at the first NaN row.
"""
con_keys = conns[:, 0]
pos = fetch_first(jnp.isnan(con_keys))
new_conns = conns.at[pos, 0:2].set(jnp.array([i_key, o_key]))
return new_conns.at[pos, 2:].set(attrs)
def delete_conn_by_pos(conns, pos):
"""
Delete a connection from the genome.
Delete the connection by its idx.
"""
return conns.at[pos].set(jnp.nan)

View File

@@ -1,4 +1,3 @@
import jax, jax.numpy as jnp
from tensorneat.common import State
from .. import BaseAlgorithm
from .species import *