odify genome for the official release

This commit is contained in:
root
2024-07-10 11:24:11 +08:00
parent 075460f896
commit ee8ec84202
83 changed files with 588 additions and 611 deletions

View File

@@ -75,7 +75,7 @@ from pipeline import Pipeline
from algorithm.neat import * from algorithm.neat import *
from problem.rl_env import BraxEnv from problem.rl_env import BraxEnv
from utils import Act from tensorneat.utils import Act
if __name__ == '__main__': if __name__ == '__main__':
pipeline = Pipeline( pipeline = Pipeline(

View File

@@ -2,7 +2,7 @@ from pipeline import Pipeline
from algorithm.neat import * from algorithm.neat import *
from problem.rl_env import BraxEnv from problem.rl_env import BraxEnv
from utils import Act from tensorneat.common import Act
if __name__ == "__main__": if __name__ == "__main__":
pipeline = Pipeline( pipeline = Pipeline(

View File

@@ -4,7 +4,7 @@ from pipeline import Pipeline
from algorithm.neat import * from algorithm.neat import *
from problem.rl_env import BraxEnv from problem.rl_env import BraxEnv
from utils import Act from tensorneat.common import Act
def sample_policy(randkey, obs): def sample_policy(randkey, obs):

View File

@@ -2,7 +2,7 @@ from pipeline import Pipeline
from algorithm.neat import * from algorithm.neat import *
from problem.rl_env import BraxEnv from problem.rl_env import BraxEnv
from utils import Act from tensorneat.common import Act
if __name__ == "__main__": if __name__ == "__main__":
pipeline = Pipeline( pipeline = Pipeline(

View File

@@ -2,7 +2,7 @@ from pipeline import Pipeline
from algorithm.neat import * from algorithm.neat import *
from problem.rl_env import BraxEnv from problem.rl_env import BraxEnv
from utils import Act from tensorneat.common import Act
import jax, jax.numpy as jnp import jax, jax.numpy as jnp

View File

@@ -2,7 +2,7 @@ from pipeline import Pipeline
from algorithm.neat import * from algorithm.neat import *
from problem.func_fit import XOR3d from problem.func_fit import XOR3d
from utils import ACT_ALL, AGG_ALL, Act, Agg from tensorneat.common import ACT_ALL, AGG_ALL, Act, Agg
if __name__ == "__main__": if __name__ == "__main__":
pipeline = Pipeline( pipeline = Pipeline(

View File

@@ -1,7 +1,7 @@
from pipeline import Pipeline from pipeline import Pipeline
from algorithm.neat import * from algorithm.neat import *
from algorithm.hyperneat import * from algorithm.hyperneat import *
from utils import Act from tensorneat.common import Act
from problem.func_fit import XOR3d from problem.func_fit import XOR3d

View File

@@ -3,7 +3,7 @@ import jax
from pipeline import Pipeline from pipeline import Pipeline
from algorithm.neat import * from algorithm.neat import *
from algorithm.hyperneat import * from algorithm.hyperneat import *
from utils import Act from tensorneat.common import Act
from problem.rl_env import GymNaxEnv from problem.rl_env import GymNaxEnv

View File

@@ -2,7 +2,7 @@ from pipeline import Pipeline
from algorithm.neat import * from algorithm.neat import *
from problem.rl_env import GymNaxEnv from problem.rl_env import GymNaxEnv
from utils import Act from tensorneat.common import Act
if __name__ == "__main__": if __name__ == "__main__":
pipeline = Pipeline( pipeline = Pipeline(

View File

@@ -2,7 +2,7 @@ from pipeline import Pipeline
from algorithm.neat import * from algorithm.neat import *
from problem.rl_env import GymNaxEnv from problem.rl_env import GymNaxEnv
from utils import Act from tensorneat.common import Act
if __name__ == "__main__": if __name__ == "__main__":
pipeline = Pipeline( pipeline = Pipeline(

View File

@@ -11,7 +11,7 @@
"from algorithm.neat.genome.advance import AdvanceInitialize\n", "from algorithm.neat.genome.advance import AdvanceInitialize\n",
"from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse\n", "from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse\n",
"from utils.graph import topological_sort_python\n", "from utils.graph import topological_sort_python\n",
"from utils import Act, Agg\n", "from tensorneat.utils import Act, Agg\n",
"\n", "\n",
"import numpy as np" "import numpy as np"
], ],

View File

@@ -3,7 +3,7 @@ import jax, jax.numpy as jnp
from algorithm.neat import * from algorithm.neat import *
from algorithm.neat.genome.dense import DenseInitialize from algorithm.neat.genome.dense import DenseInitialize
from utils.graph import topological_sort_python from utils.graph import topological_sort_python
from utils import * from tensorneat.common import *
if __name__ == "__main__": if __name__ == "__main__":
genome = DenseInitialize( genome = DenseInitialize(

View File

Before

Width:  |  Height:  |  Size: 90 KiB

After

Width:  |  Height:  |  Size: 90 KiB

View File

Before

Width:  |  Height:  |  Size: 89 KiB

After

Width:  |  Height:  |  Size: 89 KiB

View File

@@ -19,7 +19,7 @@
"from algorithm.neat.genome.advance import AdvanceInitialize\n", "from algorithm.neat.genome.advance import AdvanceInitialize\n",
"from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse\n", "from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse\n",
"from utils.graph import topological_sort_python\n", "from utils.graph import topological_sort_python\n",
"from utils import Act, Agg\n", "from tensorneat.utils import Act, Agg\n",
"\n", "\n",
"genome = AdvanceInitialize(\n", "genome = AdvanceInitialize(\n",
" num_inputs=16,\n", " num_inputs=16,\n",

View File

@@ -29,7 +29,7 @@
"from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse\n", "from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse\n",
"\n", "\n",
"from problem.rl_env.jumanji.jumanji_2048 import Jumanji_2048\n", "from problem.rl_env.jumanji.jumanji_2048 import Jumanji_2048\n",
"from utils import Act, Agg\n", "from tensorneat.utils import Act, Agg\n",
"\n", "\n",
"pipeline = Pipeline(\n", "pipeline = Pipeline(\n",
" algorithm=NEAT(\n", " algorithm=NEAT(\n",

View File

@@ -4,7 +4,7 @@ from pipeline import Pipeline
from algorithm.neat import * from algorithm.neat import *
from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse
from problem.rl_env.jumanji.jumanji_2048 import Jumanji_2048 from problem.rl_env.jumanji.jumanji_2048 import Jumanji_2048
from utils import Act, Agg from tensorneat.common import Act, Agg
def rot_li(li): def rot_li(li):

10
examples/tmp.py Normal file
View File

@@ -0,0 +1,10 @@
import jax, jax.numpy as jnp
from tensorneat.algorithm import NEAT
from tensorneat.algorithm.neat import DefaultGenome
key = jax.random.key(0)
genome = DefaultGenome(num_inputs=5, num_outputs=3, init_hidden_layers=(1, ))
state = genome.setup()
nodes, conns = genome.initialize(state, key)
print(genome.repr(state, nodes, conns))

View File

@@ -1,4 +1,4 @@
from utils import State, StatefulBaseClass from tensorneat.common import State, StatefulBaseClass
class BaseAlgorithm(StatefulBaseClass): class BaseAlgorithm(StatefulBaseClass):

View File

@@ -2,7 +2,7 @@ from typing import Callable
import jax, jax.numpy as jnp import jax, jax.numpy as jnp
from utils import State, Act, Agg from tensorneat.common import State, Act, Agg
from .. import BaseAlgorithm, NEAT from .. import BaseAlgorithm, NEAT
from ..neat.gene import BaseNodeGene, BaseConnGene from ..neat.gene import BaseNodeGene, BaseConnGene
from ..neat.genome import RecurrentGenome from ..neat.genome import RecurrentGenome

View File

@@ -1,4 +1,4 @@
from utils import StatefulBaseClass from tensorneat.common import StatefulBaseClass
class BaseSubstrate(StatefulBaseClass): class BaseSubstrate(StatefulBaseClass):

View File

@@ -1,4 +1,3 @@
from .ga import *
from .gene import * from .gene import *
from .genome import * from .genome import *
from .species import * from .species import *

View File

@@ -1,6 +0,0 @@
from utils import StatefulBaseClass
class BaseCrossover(StatefulBaseClass):
def __call__(self, state, randkey, genome, nodes1, nodes2, conns1, conns2):
raise NotImplementedError

View File

@@ -1,6 +0,0 @@
from utils import StatefulBaseClass
class BaseMutation(StatefulBaseClass):
def __call__(self, state, randkey, genome, nodes, conns, new_node_key):
raise NotImplementedError

View File

@@ -1,5 +1,5 @@
import jax, jax.numpy as jnp import jax, jax.numpy as jnp
from utils import State, StatefulBaseClass, hash_array from tensorneat.common import State, StatefulBaseClass, hash_array
class BaseGene(StatefulBaseClass): class BaseGene(StatefulBaseClass):

View File

@@ -2,7 +2,7 @@ import jax.numpy as jnp
import jax.random import jax.random
import numpy as np import numpy as np
import sympy as sp import sympy as sp
from utils import mutate_float from tensorneat.common import mutate_float
from . import BaseConnGene from . import BaseConnGene

View File

@@ -4,7 +4,7 @@ import numpy as np
import jax, jax.numpy as jnp import jax, jax.numpy as jnp
import sympy as sp import sympy as sp
from utils import ( from tensorneat.common import (
Act, Act,
Agg, Agg,
act_func, act_func,

View File

@@ -3,7 +3,7 @@ from typing import Tuple
import jax, jax.numpy as jnp import jax, jax.numpy as jnp
import numpy as np import numpy as np
import sympy as sp import sympy as sp
from utils import ( from tensorneat.common import (
Act, Act,
Agg, Agg,
act_func, act_func,

View File

@@ -1,6 +1,6 @@
import jax.numpy as jnp import jax.numpy as jnp
from . import BaseNodeGene from . import BaseNodeGene
from utils import Agg from tensorneat.common import Agg
class KANNode(BaseNodeGene): class KANNode(BaseNodeGene):

View File

@@ -2,7 +2,7 @@ from typing import Tuple
import jax, jax.numpy as jnp import jax, jax.numpy as jnp
from utils import Act, Agg, act_func, agg_func, mutate_int, mutate_float from tensorneat.common import Act, Agg, act_func, agg_func, mutate_int, mutate_float
from . import BaseNodeGene from . import BaseNodeGene

View File

@@ -2,7 +2,7 @@ from typing import Tuple
import jax, jax.numpy as jnp import jax, jax.numpy as jnp
from utils import Act, Agg, act_func, agg_func, mutate_int, mutate_float from tensorneat.common import Act, Agg, act_func, agg_func, mutate_int, mutate_float
from . import BaseNodeGene from . import BaseNodeGene

View File

@@ -1,5 +1,4 @@
from .base import BaseGenome from .base import BaseGenome
from .default import DefaultGenome from .default import DefaultGenome
from .recurrent import RecurrentGenome from .recurrent import RecurrentGenome
from .hidden import HiddenInitialize
from .dense import DenseInitialize

View File

@@ -1,8 +1,16 @@
from typing import Callable, Sequence
import numpy as np import numpy as np
import jax, jax.numpy as jnp import jax
from jax import vmap, numpy as jnp
from ..gene import BaseNodeGene, BaseConnGene from ..gene import BaseNodeGene, BaseConnGene
from ..ga import BaseMutation, BaseCrossover from .operations import BaseMutation, BaseCrossover, BaseDistance
from utils import State, StatefulBaseClass, topological_sort_python, hash_array from tensorneat.common import (
State,
StatefulBaseClass,
hash_array,
)
from .utils import valid_cnt
class BaseGenome(StatefulBaseClass): class BaseGenome(StatefulBaseClass):
@@ -18,120 +26,159 @@ class BaseGenome(StatefulBaseClass):
conn_gene: BaseConnGene, conn_gene: BaseConnGene,
mutation: BaseMutation, mutation: BaseMutation,
crossover: BaseCrossover, crossover: BaseCrossover,
distance: BaseDistance,
output_transform: Callable = None,
input_transform: Callable = None,
init_hidden_layers: Sequence[int] = (),
): ):
# check transform functions
if input_transform is not None:
try:
_ = input_transform(jnp.zeros(num_inputs))
except Exception as e:
raise ValueError(f"Output transform function failed: {e}")
if output_transform is not None:
try:
_ = output_transform(jnp.zeros(num_outputs))
except Exception as e:
raise ValueError(f"Output transform function failed: {e}")
# prepare for initialization
all_layers = [num_inputs] + list(init_hidden_layers) + [num_outputs]
layer_indices = []
next_index = 0
for layer in all_layers:
layer_indices.append(list(range(next_index, next_index + layer)))
next_index += layer
all_init_nodes = []
all_init_conns_in_idx = []
all_init_conns_out_idx = []
for i in range(len(layer_indices) - 1):
in_layer = layer_indices[i]
out_layer = layer_indices[i + 1]
for in_idx in in_layer:
for out_idx in out_layer:
all_init_conns_in_idx.append(in_idx)
all_init_conns_out_idx.append(out_idx)
all_init_nodes.extend(in_layer)
if max_nodes < len(all_init_nodes):
raise ValueError(
f"max_nodes={max_nodes} must be greater than or equal to the number of initial nodes={len(all_init_nodes)}"
)
if max_conns < len(all_init_conns_in_idx):
raise ValueError(
f"max_conns={max_conns} must be greater than or equal to the number of initial connections={len(all_init_conns_in_idx)}"
)
self.num_inputs = num_inputs self.num_inputs = num_inputs
self.num_outputs = num_outputs self.num_outputs = num_outputs
self.input_idx = np.arange(num_inputs)
self.output_idx = np.arange(num_inputs, num_inputs + num_outputs)
self.max_nodes = max_nodes self.max_nodes = max_nodes
self.max_conns = max_conns self.max_conns = max_conns
self.node_gene = node_gene self.node_gene = node_gene
self.conn_gene = conn_gene self.conn_gene = conn_gene
self.mutation = mutation self.mutation = mutation
self.crossover = crossover self.crossover = crossover
self.distance = distance
self.output_transform = output_transform
self.input_transform = input_transform
self.input_idx = np.array(layer_indices[0])
self.output_idx = np.array(layer_indices[-1])
self.all_init_nodes = np.array(all_init_nodes)
self.all_init_conns = np.c_[all_init_conns_in_idx, all_init_conns_out_idx]
def setup(self, state=State()): def setup(self, state=State()):
state = self.node_gene.setup(state) state = self.node_gene.setup(state)
state = self.conn_gene.setup(state) state = self.conn_gene.setup(state)
state = self.mutation.setup(state) state = self.mutation.setup(state, self)
state = self.crossover.setup(state) state = self.crossover.setup(state, self)
state = self.distance.setup(state, self)
return state return state
def transform(self, state, nodes, conns): def transform(self, state, nodes, conns):
raise NotImplementedError raise NotImplementedError
def restore(self, state, transformed):
raise NotImplementedError
def forward(self, state, transformed, inputs): def forward(self, state, transformed, inputs):
raise NotImplementedError raise NotImplementedError
def sympy_func(self):
raise NotImplementedError
def visualize(self):
raise NotImplementedError
def execute_mutation(self, state, randkey, nodes, conns, new_node_key): def execute_mutation(self, state, randkey, nodes, conns, new_node_key):
return self.mutation(state, randkey, self, nodes, conns, new_node_key) return self.mutation(state, randkey, nodes, conns, new_node_key)
def execute_crossover(self, state, randkey, nodes1, conns1, nodes2, conns2): def execute_crossover(self, state, randkey, nodes1, conns1, nodes2, conns2):
return self.crossover(state, randkey, self, nodes1, conns1, nodes2, conns2) return self.crossover(state, randkey, nodes1, conns1, nodes2, conns2)
def execute_distance(self, state, nodes1, conns1, nodes2, conns2):
return self.distance(state, nodes1, conns1, nodes2, conns2)
def initialize(self, state, randkey): def initialize(self, state, randkey):
"""
Default initialization method for the genome.
Add an extra hidden node.
Make all input nodes and output nodes connected to the hidden node.
All attributes will be initialized randomly using gene.new_random_attrs method.
For example, a network with 2 inputs and 1 output, the structure will be:
nodes:
[
[0, attrs0], # input node 0
[1, attrs1], # input node 1
[2, attrs2], # output node 0
[3, attrs3], # hidden node
[NaN, NaN], # empty node
]
conns:
[
[0, 3, attrs0], # input node 0 -> hidden node
[1, 3, attrs1], # input node 1 -> hidden node
[3, 2, attrs2], # hidden node -> output node 0
[NaN, NaN],
[NaN, NaN],
]
"""
k1, k2 = jax.random.split(randkey) # k1 for nodes, k2 for conns k1, k2 = jax.random.split(randkey) # k1 for nodes, k2 for conns
all_nodes_cnt = len(self.all_init_nodes)
all_conns_cnt = len(self.all_init_conns)
# initialize nodes # initialize nodes
new_node_key = (
max([*self.input_idx, *self.output_idx]) + 1
) # the key for the hidden node
node_keys = jnp.concatenate(
[self.input_idx, self.output_idx, jnp.array([new_node_key])]
) # the list of all node keys
# initialize nodes and connections with NaN
nodes = jnp.full((self.max_nodes, self.node_gene.length), jnp.nan) nodes = jnp.full((self.max_nodes, self.node_gene.length), jnp.nan)
conns = jnp.full((self.max_conns, self.conn_gene.length), jnp.nan) # create node indices
node_indices = self.all_init_nodes
# create node attrs
rand_keys_n = jax.random.split(k1, num=all_nodes_cnt)
node_attr_func = vmap(self.node_gene.new_random_attrs, in_axes=(None, 0))
node_attrs = node_attr_func(state, rand_keys_n)
# set keys for input nodes, output nodes and hidden node nodes = nodes.at[:all_nodes_cnt, 0].set(node_indices) # set node indices
nodes = nodes.at[node_keys, 0].set(node_keys) nodes = nodes.at[:all_nodes_cnt, 1:].set(node_attrs) # set node attrs
# generate random attributes for nodes
node_keys = jax.random.split(k1, len(node_keys))
random_node_attrs = jax.vmap(
self.node_gene.new_random_attrs, in_axes=(None, 0)
)(state, node_keys)
nodes = nodes.at[: len(node_keys), 1:].set(random_node_attrs)
# initialize conns # initialize conns
# input-hidden connections conns = jnp.full((self.max_conns, self.conn_gene.length), jnp.nan)
input_conns = jnp.c_[ # create input and output indices
self.input_idx, jnp.full_like(self.input_idx, new_node_key) conn_indices = self.all_init_conns
] # create conn attrs
conns = conns.at[self.input_idx, :2].set(input_conns) # in-keys, out-keys rand_keys_c = jax.random.split(k2, num=all_conns_cnt)
conns_attr_func = jax.vmap(
self.conn_gene.new_random_attrs,
in_axes=(
None,
0,
),
)
conns_attrs = conns_attr_func(state, rand_keys_c)
# output-hidden connections conns = conns.at[:all_conns_cnt, :2].set(conn_indices) # set conn indices
output_conns = jnp.c_[ conns = conns.at[:all_conns_cnt, 2:].set(conns_attrs) # set conn attrs
jnp.full_like(self.output_idx, new_node_key), self.output_idx
]
conns = conns.at[self.output_idx, :2].set(output_conns) # in-keys, out-keys
conn_keys = jax.random.split(k2, num=len(self.input_idx) + len(self.output_idx))
# generate random attributes for conns
random_conn_attrs = jax.vmap(
self.conn_gene.new_random_attrs, in_axes=(None, 0)
)(state, conn_keys)
conns = conns.at[: len(conn_keys), 2:].set(random_conn_attrs)
return nodes, conns return nodes, conns
def update_by_batch(self, state, batch_input, transformed): def network_dict(self, state, nodes, conns):
""" return {
Update the genome by a batch of data. "nodes": self._get_node_dict(state, nodes),
""" "conns": self._get_conn_dict(state, conns),
raise NotImplementedError }
def get_input_idx(self):
return self.input_idx.tolist()
def get_output_idx(self):
return self.output_idx.tolist()
def hash(self, nodes, conns):
nodes_hashs = vmap(hash_array)(nodes)
conns_hashs = vmap(hash_array)(conns)
return hash_array(jnp.concatenate([nodes_hashs, conns_hashs]))
def repr(self, state, nodes, conns, precision=2): def repr(self, state, nodes, conns, precision=2):
nodes, conns = jax.device_get([nodes, conns]) nodes, conns = jax.device_get([nodes, conns])
nodes_cnt, conns_cnt = self.valid_cnt(nodes), self.valid_cnt(conns) nodes_cnt, conns_cnt = valid_cnt(nodes), valid_cnt(conns)
s = f"{self.__class__.__name__}(nodes={nodes_cnt}, conns={conns_cnt}):\n" s = f"{self.__class__.__name__}(nodes={nodes_cnt}, conns={conns_cnt}):\n"
s += f"\tNodes:\n" s += f"\tNodes:\n"
for node in nodes: for node in nodes:
@@ -152,11 +199,7 @@ class BaseGenome(StatefulBaseClass):
s += f"\t\t{self.conn_gene.repr(state, conn, precision=precision)}\n" s += f"\t\t{self.conn_gene.repr(state, conn, precision=precision)}\n"
return s return s
@classmethod def _get_conn_dict(self, state, conns):
def valid_cnt(cls, arr):
return jnp.sum(~jnp.isnan(arr[:, 0]))
def get_conn_dict(self, state, conns):
conns = jax.device_get(conns) conns = jax.device_get(conns)
conn_dict = {} conn_dict = {}
for conn in conns: for conn in conns:
@@ -167,7 +210,7 @@ class BaseGenome(StatefulBaseClass):
conn_dict[(in_idx, out_idx)] = cd conn_dict[(in_idx, out_idx)] = cd
return conn_dict return conn_dict
def get_node_dict(self, state, nodes): def _get_node_dict(self, state, nodes):
nodes = jax.device_get(nodes) nodes = jax.device_get(nodes)
node_dict = {} node_dict = {}
for node in nodes: for node in nodes:
@@ -177,92 +220,3 @@ class BaseGenome(StatefulBaseClass):
idx = nd["idx"] idx = nd["idx"]
node_dict[idx] = nd node_dict[idx] = nd
return node_dict return node_dict
def network_dict(self, state, nodes, conns):
return {
"nodes": self.get_node_dict(state, nodes),
"conns": self.get_conn_dict(state, conns),
}
def get_input_idx(self):
return self.input_idx.tolist()
def get_output_idx(self):
return self.output_idx.tolist()
def sympy_func(self, state, network, sympy_output_transform=None):
raise NotImplementedError
def visualize(
self,
network,
rotate=0,
reverse_node_order=False,
size=(300, 300, 300),
color=("blue", "blue", "blue"),
save_path="network.svg",
save_dpi=800,
**kwargs,
):
import networkx as nx
from matplotlib import pyplot as plt
nodes_list = list(network["nodes"])
conns_list = list(network["conns"])
input_idx = self.get_input_idx()
output_idx = self.get_output_idx()
topo_order, topo_layers = topological_sort_python(nodes_list, conns_list)
node2layer = {
node: layer for layer, nodes in enumerate(topo_layers) for node in nodes
}
if reverse_node_order:
topo_order = topo_order[::-1]
G = nx.DiGraph()
if not isinstance(size, tuple):
size = (size, size, size)
if not isinstance(color, tuple):
color = (color, color, color)
for node in topo_order:
if node in input_idx:
G.add_node(node, subset=node2layer[node], size=size[0], color=color[0])
elif node in output_idx:
G.add_node(node, subset=node2layer[node], size=size[2], color=color[2])
else:
G.add_node(node, subset=node2layer[node], size=size[1], color=color[1])
for conn in conns_list:
G.add_edge(conn[0], conn[1])
pos = nx.multipartite_layout(G)
def rotate_layout(pos, angle):
angle_rad = np.deg2rad(angle)
cos_angle, sin_angle = np.cos(angle_rad), np.sin(angle_rad)
rotated_pos = {}
for node, (x, y) in pos.items():
rotated_pos[node] = (
cos_angle * x - sin_angle * y,
sin_angle * x + cos_angle * y,
)
return rotated_pos
rotated_pos = rotate_layout(pos, rotate)
node_sizes = [n["size"] for n in G.nodes.values()]
node_colors = [n["color"] for n in G.nodes.values()]
nx.draw(
G,
pos=rotated_pos,
node_size=node_sizes,
node_color=node_colors,
**kwargs,
)
plt.savefig(save_path, dpi=save_dpi)
def hash(self, nodes, conns):
nodes_hashs = jax.vmap(hash_array)(nodes)
conns_hashs = jax.vmap(hash_array)(conns)
return hash_array(jnp.concatenate([nodes_hashs, conns_hashs]))

View File

@@ -1,25 +1,23 @@
import warnings import warnings
from typing import Callable
import jax, jax.numpy as jnp import jax
from jax import vmap, numpy as jnp
import numpy as np import numpy as np
import sympy as sp import sympy as sp
from utils import (
unflatten_conns, from . import BaseGenome
from ..gene import DefaultNodeGene, DefaultConnGene
from .operations import DefaultMutation, DefaultCrossover, DefaultDistance
from .utils import unflatten_conns, extract_node_attrs, extract_conn_attrs
from tensorneat.common import (
topological_sort, topological_sort,
topological_sort_python, topological_sort_python,
I_INF, I_INF,
extract_node_attrs,
extract_conn_attrs,
set_node_attrs,
set_conn_attrs,
attach_with_inf, attach_with_inf,
SYMPY_FUNCS_MODULE_NP, SYMPY_FUNCS_MODULE_NP,
SYMPY_FUNCS_MODULE_JNP, SYMPY_FUNCS_MODULE_JNP,
) )
from . import BaseGenome
from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene
from ..ga import BaseMutation, BaseCrossover, DefaultMutation, DefaultCrossover
class DefaultGenome(BaseGenome): class DefaultGenome(BaseGenome):
@@ -31,15 +29,18 @@ class DefaultGenome(BaseGenome):
self, self,
num_inputs: int, num_inputs: int,
num_outputs: int, num_outputs: int,
max_nodes=5, max_nodes=50,
max_conns=4, max_conns=100,
node_gene: BaseNodeGene = DefaultNodeGene(), node_gene=DefaultNodeGene(),
conn_gene: BaseConnGene = DefaultConnGene(), conn_gene=DefaultConnGene(),
mutation: BaseMutation = DefaultMutation(), mutation=DefaultMutation(),
crossover: BaseCrossover = DefaultCrossover(), crossover=DefaultCrossover(),
output_transform: Callable = None, distance=DefaultDistance(),
input_transform: Callable = None, output_transform=None,
input_transform=None,
init_hidden_layers=(),
): ):
super().__init__( super().__init__(
num_inputs, num_inputs,
num_outputs, num_outputs,
@@ -49,22 +50,12 @@ class DefaultGenome(BaseGenome):
conn_gene, conn_gene,
mutation, mutation,
crossover, crossover,
distance,
output_transform,
input_transform,
init_hidden_layers,
) )
if input_transform is not None:
try:
_ = input_transform(np.zeros(num_inputs))
except Exception as e:
raise ValueError(f"Output transform function failed: {e}")
self.input_transform = input_transform
if output_transform is not None:
try:
_ = output_transform(np.zeros(num_outputs))
except Exception as e:
raise ValueError(f"Output transform function failed: {e}")
self.output_transform = output_transform
def transform(self, state, nodes, conns): def transform(self, state, nodes, conns):
u_conns = unflatten_conns(nodes, conns) u_conns = unflatten_conns(nodes, conns)
conn_exist = u_conns != I_INF conn_exist = u_conns != I_INF
@@ -73,10 +64,6 @@ class DefaultGenome(BaseGenome):
return seqs, nodes, conns, u_conns return seqs, nodes, conns, u_conns
def restore(self, state, transformed):
seqs, nodes, conns, u_conns = transformed
return nodes, conns
def forward(self, state, transformed, inputs): def forward(self, state, transformed, inputs):
if self.input_transform is not None: if self.input_transform is not None:
@@ -86,8 +73,8 @@ class DefaultGenome(BaseGenome):
ini_vals = jnp.full((self.max_nodes,), jnp.nan) ini_vals = jnp.full((self.max_nodes,), jnp.nan)
ini_vals = ini_vals.at[self.input_idx].set(inputs) ini_vals = ini_vals.at[self.input_idx].set(inputs)
nodes_attrs = jax.vmap(extract_node_attrs)(nodes) nodes_attrs = vmap(extract_node_attrs)(nodes)
conns_attrs = jax.vmap(extract_conn_attrs)(conns) conns_attrs = vmap(extract_conn_attrs)(conns)
def cond_fun(carry): def cond_fun(carry):
values, idx = carry values, idx = carry
@@ -105,7 +92,7 @@ class DefaultGenome(BaseGenome):
def otherwise(): def otherwise():
conn_indices = u_conns[:, i] conn_indices = u_conns[:, i]
hit_attrs = attach_with_inf(conns_attrs, conn_indices) hit_attrs = attach_with_inf(conns_attrs, conn_indices)
ins = jax.vmap(self.conn_gene.forward, in_axes=(None, 0, 0))( ins = vmap(self.conn_gene.forward, in_axes=(None, 0, 0))(
state, hit_attrs, values state, hit_attrs, values
) )
@@ -130,85 +117,14 @@ class DefaultGenome(BaseGenome):
else: else:
return self.output_transform(vals[self.output_idx]) return self.output_transform(vals[self.output_idx])
def update_by_batch(self, state, batch_input, transformed): def network_dict(self, state, nodes, conns):
network = super().network_dict(state, nodes, conns)
if self.input_transform is not None: topo_order, topo_layers = topological_sort_python(
batch_input = jax.vmap(self.input_transform)(batch_input) set(network["nodes"]), set(network["conns"])
cal_seqs, nodes, conns, u_conns = transformed
batch_size = batch_input.shape[0]
batch_ini_vals = jnp.full((batch_size, self.max_nodes), jnp.nan)
batch_ini_vals = batch_ini_vals.at[:, self.input_idx].set(batch_input)
nodes_attrs = jax.vmap(extract_node_attrs)(nodes)
conns_attrs = jax.vmap(extract_conn_attrs)(conns)
def cond_fun(carry):
batch_values, nodes_attrs_, conns_attrs_, idx = carry
return (idx < self.max_nodes) & (cal_seqs[idx] != I_INF)
def body_func(carry):
batch_values, nodes_attrs_, conns_attrs_, idx = carry
i = cal_seqs[idx]
def input_node():
batch, new_attrs = self.node_gene.update_input_transform(
state, nodes_attrs_[i], batch_values[:, i]
)
return (
batch_values.at[:, i].set(batch),
nodes_attrs_.at[i].set(new_attrs),
conns_attrs_,
)
def otherwise():
conn_indices = u_conns[:, i]
hit_attrs = attach_with_inf(conns_attrs, conn_indices)
batch_ins, new_conn_attrs = jax.vmap(
self.conn_gene.update_by_batch,
in_axes=(None, 0, 1),
out_axes=(1, 0),
)(state, hit_attrs, batch_values)
batch_z, new_node_attrs = self.node_gene.update_by_batch(
state,
nodes_attrs_[i],
batch_ins,
is_output_node=jnp.isin(i, self.output_idx),
)
return (
batch_values.at[:, i].set(batch_z),
nodes_attrs_.at[i].set(new_node_attrs),
conns_attrs_.at[conn_indices].set(new_conn_attrs),
)
# the val of input nodes is obtained by the task, not by calculation
(batch_values, nodes_attrs_, conns_attrs_) = jax.lax.cond(
jnp.isin(i, self.input_idx),
input_node,
otherwise,
)
return batch_values, nodes_attrs_, conns_attrs_, idx + 1
batch_vals, nodes_attrs, conns_attrs, _ = jax.lax.while_loop(
cond_fun, body_func, (batch_ini_vals, nodes_attrs, conns_attrs, 0)
) )
network["topo_order"] = topo_order
nodes = jax.vmap(set_node_attrs)(nodes, nodes_attrs) network["topo_layers"] = topo_layers
conns = jax.vmap(set_conn_attrs)(conns, conns_attrs) return network
new_transformed = (cal_seqs, nodes, conns, u_conns)
if self.output_transform is None:
return batch_vals[:, self.output_idx], new_transformed
else:
return (
jax.vmap(self.output_transform)(batch_vals[:, self.output_idx]),
new_transformed,
)
def sympy_func( def sympy_func(
self, self,
@@ -241,7 +157,8 @@ class DefaultGenome(BaseGenome):
input_idx = self.get_input_idx() input_idx = self.get_input_idx()
output_idx = self.get_output_idx() output_idx = self.get_output_idx()
order, _ = topological_sort_python(set(network["nodes"]), set(network["conns"])) order = network["topo_order"]
hidden_idx = [ hidden_idx = [
i for i in network["nodes"] if i not in input_idx and i not in output_idx i for i in network["nodes"] if i not in input_idx and i not in output_idx
] ]
@@ -260,8 +177,12 @@ class DefaultGenome(BaseGenome):
for i in order: for i in order:
if i in input_idx: if i in input_idx:
nodes_exprs[symbols[-i - 1]] = symbols[-i - 1] # origin equal to its symbol nodes_exprs[symbols[-i - 1]] = symbols[
nodes_exprs[symbols[i]] = sympy_input_transform[i - min(input_idx)](symbols[-i - 1]) # normed i -i - 1
] # origin equal to its symbol
nodes_exprs[symbols[i]] = sympy_input_transform[i - min(input_idx)](
symbols[-i - 1]
) # normed i
else: else:
in_conns = [c for c in network["conns"] if c[1] == i] in_conns = [c for c in network["conns"] if c[1] == i]
@@ -325,3 +246,73 @@ class DefaultGenome(BaseGenome):
output_exprs, output_exprs,
forward_func, forward_func,
) )
def visualize(
self,
network,
rotate=0,
reverse_node_order=False,
size=(300, 300, 300),
color=("blue", "blue", "blue"),
save_path="network.svg",
save_dpi=800,
**kwargs,
):
import networkx as nx
from matplotlib import pyplot as plt
nodes_list = list(network["nodes"])
conns_list = list(network["conns"])
input_idx = self.get_input_idx()
output_idx = self.get_output_idx()
topo_order, topo_layers = network["topo_order"], network["topo_layers"]
node2layer = {
node: layer for layer, nodes in enumerate(topo_layers) for node in nodes
}
if reverse_node_order:
topo_order = topo_order[::-1]
G = nx.DiGraph()
if not isinstance(size, tuple):
size = (size, size, size)
if not isinstance(color, tuple):
color = (color, color, color)
for node in topo_order:
if node in input_idx:
G.add_node(node, subset=node2layer[node], size=size[0], color=color[0])
elif node in output_idx:
G.add_node(node, subset=node2layer[node], size=size[2], color=color[2])
else:
G.add_node(node, subset=node2layer[node], size=size[1], color=color[1])
for conn in conns_list:
G.add_edge(conn[0], conn[1])
pos = nx.multipartite_layout(G)
def rotate_layout(pos, angle):
angle_rad = np.deg2rad(angle)
cos_angle, sin_angle = np.cos(angle_rad), np.sin(angle_rad)
rotated_pos = {}
for node, (x, y) in pos.items():
rotated_pos[node] = (
cos_angle * x - sin_angle * y,
sin_angle * x + cos_angle * y,
)
return rotated_pos
rotated_pos = rotate_layout(pos, rotate)
node_sizes = [n["size"] for n in G.nodes.values()]
node_colors = [n["color"] for n in G.nodes.values()]
nx.draw(
G,
pos=rotated_pos,
node_size=node_sizes,
node_color=node_colors,
**kwargs,
)
plt.savefig(save_path, dpi=save_dpi)

View File

@@ -1,56 +0,0 @@
import jax, jax.numpy as jnp
from .default import DefaultGenome
class DenseInitialize(DefaultGenome):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert self.max_nodes >= self.num_inputs + self.num_outputs
assert self.max_conns >= self.num_inputs * self.num_outputs
def initialize(self, state, randkey):
k1, k2 = jax.random.split(randkey, num=2)
input_idx, output_idx = self.input_idx, self.output_idx
input_size = len(input_idx)
output_size = len(output_idx)
nodes = jnp.full(
(self.max_nodes, self.node_gene.length), jnp.nan, dtype=jnp.float32
)
nodes = nodes.at[input_idx, 0].set(input_idx)
nodes = nodes.at[output_idx, 0].set(output_idx)
total_idx = input_size + output_size
rand_keys_n = jax.random.split(k1, num=total_idx)
node_attr_func = jax.vmap(self.node_gene.new_random_attrs, in_axes=(None, 0))
node_attrs = node_attr_func(state, rand_keys_n)
nodes = nodes.at[:total_idx, 1:].set(node_attrs)
conns = jnp.full(
(self.max_conns, self.conn_gene.length), jnp.nan, dtype=jnp.float32
)
input_to_output_ids, output_ids = jnp.meshgrid(
input_idx, output_idx, indexing="ij"
)
total_conns = input_size * output_size
conns = conns.at[:total_conns, :2].set(
jnp.column_stack([input_to_output_ids.flatten(), output_ids.flatten()])
)
rand_keys_c = jax.random.split(k2, num=total_conns)
conns_attr_func = jax.vmap(
self.conn_gene.new_random_attrs,
in_axes=(
None,
0,
),
)
conns_attrs = conns_attr_func(state, rand_keys_c)
conns = conns.at[:total_conns, 2:].set(conns_attrs)
return nodes, conns

View File

@@ -1,70 +0,0 @@
import jax, jax.numpy as jnp
from .default import DefaultGenome
class HiddenInitialize(DefaultGenome):
def __init__(self, hidden_cnt=8, *args, **kwargs):
super().__init__(*args, **kwargs)
self.hidden_cnt = hidden_cnt
def initialize(self, state, randkey):
k1, k2 = jax.random.split(randkey, num=2)
input_idx, output_idx = self.input_idx, self.output_idx
input_size = len(input_idx)
output_size = len(output_idx)
hidden_idx = jnp.arange(
input_size + output_size, input_size + output_size + self.hidden_cnt
)
nodes = jnp.full(
(self.max_nodes, self.node_gene.length), jnp.nan, dtype=jnp.float32
)
nodes = nodes.at[input_idx, 0].set(input_idx)
nodes = nodes.at[output_idx, 0].set(output_idx)
nodes = nodes.at[hidden_idx, 0].set(hidden_idx)
total_idx = input_size + output_size + self.hidden_cnt
rand_keys_n = jax.random.split(k1, num=total_idx)
node_attr_func = jax.vmap(self.node_gene.new_random_attrs, in_axes=(None, 0))
node_attrs = node_attr_func(state, rand_keys_n)
nodes = nodes.at[:total_idx, 1:].set(node_attrs)
conns = jnp.full(
(self.max_conns, self.conn_gene.length), jnp.nan, dtype=jnp.float32
)
input_to_hidden_ids, hidden_ids = jnp.meshgrid(
input_idx, hidden_idx, indexing="ij"
)
total_input_to_hidden_conns = input_size * self.hidden_cnt
conns = conns.at[:total_input_to_hidden_conns, :2].set(
jnp.column_stack([input_to_hidden_ids.flatten(), hidden_ids.flatten()])
)
hidden_to_output_ids, output_ids = jnp.meshgrid(
hidden_idx, output_idx, indexing="ij"
)
total_hidden_to_output_conns = self.hidden_cnt * output_size
conns = conns.at[
total_input_to_hidden_conns : total_input_to_hidden_conns
+ total_hidden_to_output_conns,
:2,
].set(jnp.column_stack([hidden_to_output_ids.flatten(), output_ids.flatten()]))
total_conns = total_input_to_hidden_conns + total_hidden_to_output_conns
rand_keys_c = jax.random.split(k2, num=total_conns)
conns_attr_func = jax.vmap(
self.conn_gene.new_random_attrs,
in_axes=(
None,
0,
),
)
conns_attrs = conns_attr_func(state, rand_keys_c)
conns = conns.at[:total_conns, 2:].set(conns_attrs)
return nodes, conns

View File

@@ -1,2 +1,3 @@
from .crossover import BaseCrossover, DefaultCrossover from .crossover import BaseCrossover, DefaultCrossover
from .mutation import BaseMutation, DefaultMutation from .mutation import BaseMutation, DefaultMutation
from .distance import BaseDistance, DefaultDistance

View File

@@ -0,0 +1,12 @@
from tensorneat.common import StatefulBaseClass, State
class BaseCrossover(StatefulBaseClass):
def setup(self, state=State(), genome = None):
assert genome is not None, "genome should not be None"
self.genome = genome
return state
def __call__(self, state, randkey, nodes1, nodes2, conns1, conns2):
raise NotImplementedError

View File

@@ -1,7 +1,8 @@
import jax, jax.numpy as jnp import jax
from jax import vmap, numpy as jnp
from .base import BaseCrossover from .base import BaseCrossover
from utils.tools import ( from ...utils import (
extract_node_attrs, extract_node_attrs,
extract_conn_attrs, extract_conn_attrs,
set_node_attrs, set_node_attrs,
@@ -10,14 +11,14 @@ from utils.tools import (
class DefaultCrossover(BaseCrossover): class DefaultCrossover(BaseCrossover):
def __call__(self, state, randkey, genome, nodes1, conns1, nodes2, conns2): def __call__(self, state, randkey, nodes1, conns1, nodes2, conns2):
""" """
use genome1 and genome2 to generate a new genome use genome1 and genome2 to generate a new genome
notice that genome1 should have higher fitness than genome2 (genome1 is winner!) notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
""" """
randkey1, randkey2 = jax.random.split(randkey, 2) randkey1, randkey2 = jax.random.split(randkey, 2)
randkeys1 = jax.random.split(randkey1, genome.max_nodes) randkeys1 = jax.random.split(randkey1, self.genome.max_nodes)
randkeys2 = jax.random.split(randkey2, genome.max_conns) randkeys2 = jax.random.split(randkey2, self.genome.max_conns)
# crossover nodes # crossover nodes
keys1, keys2 = nodes1[:, 0], nodes2[:, 0] keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
@@ -26,33 +27,33 @@ class DefaultCrossover(BaseCrossover):
# For not homologous genes, use the value of nodes1(winner) # For not homologous genes, use the value of nodes1(winner)
# For homologous genes, use the crossover result between nodes1 and nodes2 # For homologous genes, use the crossover result between nodes1 and nodes2
node_attrs1 = jax.vmap(extract_node_attrs)(nodes1) node_attrs1 = vmap(extract_node_attrs)(nodes1)
node_attrs2 = jax.vmap(extract_node_attrs)(nodes2) node_attrs2 = vmap(extract_node_attrs)(nodes2)
new_node_attrs = jnp.where( new_node_attrs = jnp.where(
jnp.isnan(node_attrs1) | jnp.isnan(node_attrs2), # one of them is nan jnp.isnan(node_attrs1) | jnp.isnan(node_attrs2), # one of them is nan
node_attrs1, # not homologous genes or both nan, use the value of nodes1(winner) node_attrs1, # not homologous genes or both nan, use the value of nodes1(winner)
jax.vmap(genome.node_gene.crossover, in_axes=(None, 0, 0, 0))( vmap(self.genome.node_gene.crossover, in_axes=(None, 0, 0, 0))(
state, randkeys1, node_attrs1, node_attrs2 state, randkeys1, node_attrs1, node_attrs2
), # homologous or both nan ), # homologous or both nan
) )
new_nodes = jax.vmap(set_node_attrs)(nodes1, new_node_attrs) new_nodes = vmap(set_node_attrs)(nodes1, new_node_attrs)
# crossover connections # crossover connections
con_keys1, con_keys2 = conns1[:, :2], conns2[:, :2] con_keys1, con_keys2 = conns1[:, :2], conns2[:, :2]
conns2 = self.align_array(con_keys1, con_keys2, conns2, is_conn=True) conns2 = self.align_array(con_keys1, con_keys2, conns2, is_conn=True)
conns_attrs1 = jax.vmap(extract_conn_attrs)(conns1) conns_attrs1 = vmap(extract_conn_attrs)(conns1)
conns_attrs2 = jax.vmap(extract_conn_attrs)(conns2) conns_attrs2 = vmap(extract_conn_attrs)(conns2)
new_conn_attrs = jnp.where( new_conn_attrs = jnp.where(
jnp.isnan(conns_attrs1) | jnp.isnan(conns_attrs2), jnp.isnan(conns_attrs1) | jnp.isnan(conns_attrs2),
conns_attrs1, # not homologous genes or both nan, use the value of conns1(winner) conns_attrs1, # not homologous genes or both nan, use the value of conns1(winner)
jax.vmap(genome.conn_gene.crossover, in_axes=(None, 0, 0, 0))( vmap(self.genome.conn_gene.crossover, in_axes=(None, 0, 0, 0))(
state, randkeys2, conns_attrs1, conns_attrs2 state, randkeys2, conns_attrs1, conns_attrs2
), # homologous or both nan ), # homologous or both nan
) )
new_conns = jax.vmap(set_conn_attrs)(conns1, new_conn_attrs) new_conns = vmap(set_conn_attrs)(conns1, new_conn_attrs)
return new_nodes, new_conns return new_nodes, new_conns

View File

@@ -0,0 +1,2 @@
from .base import BaseDistance
from .default import DefaultDistance

View File

@@ -0,0 +1,15 @@
from tensorneat.common import StatefulBaseClass, State
class BaseDistance(StatefulBaseClass):
def setup(self, state=State(), genome = None):
assert genome is not None, "genome should not be None"
self.genome = genome
return state
def __call__(self, state, nodes1, nodes2, conns1, conns2):
"""
The distance between two genomes
"""
raise NotImplementedError

View File

@@ -0,0 +1,105 @@
from jax import vmap, numpy as jnp
from .base import BaseDistance
from ...utils import extract_node_attrs, extract_conn_attrs
class DefaultDistance(BaseDistance):
def __init__(
self,
compatibility_disjoint: float = 1.0,
compatibility_weight: float = 0.4,
):
self.compatibility_disjoint = compatibility_disjoint
self.compatibility_weight = compatibility_weight
def __call__(self, state, nodes1, nodes2, conns1, conns2):
"""
The distance between two genomes
"""
d = self.node_distance(state, nodes1, nodes2) + self.conn_distance(
state, conns1, conns2
)
return d
def node_distance(self, state, nodes1, nodes2):
"""
The distance of the nodes part for two genomes
"""
node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0]))
node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0]))
max_cnt = jnp.maximum(node_cnt1, node_cnt2)
# align homologous nodes
# this process is similar to np.intersect1d.
nodes = jnp.concatenate((nodes1, nodes2), axis=0)
keys = nodes[:, 0]
sorted_indices = jnp.argsort(keys, axis=0)
nodes = nodes[sorted_indices]
nodes = jnp.concatenate(
[nodes, jnp.full((1, nodes.shape[1]), jnp.nan)], axis=0
) # add a nan row to the end
fr, sr = nodes[:-1], nodes[1:] # first row, second row
# flag location of homologous nodes
intersect_mask = (fr[:, 0] == sr[:, 0]) & ~jnp.isnan(nodes[:-1, 0])
# calculate the count of non_homologous of two genomes
non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask)
# calculate the distance of homologous nodes
fr_attrs = vmap(extract_node_attrs)(fr)
sr_attrs = vmap(extract_node_attrs)(sr)
hnd = vmap(self.genome.node_gene.distance, in_axes=(None, 0, 0))(
state, fr_attrs, sr_attrs
) # homologous node distance
hnd = jnp.where(jnp.isnan(hnd), 0, hnd)
homologous_distance = jnp.sum(hnd * intersect_mask)
val = (
non_homologous_cnt * self.compatibility_disjoint
+ homologous_distance * self.compatibility_weight
)
val = jnp.where(max_cnt == 0, 0, val / max_cnt) # normalize
return val
def conn_distance(self, state, conns1, conns2):
"""
The distance of the conns part for two genomes
"""
con_cnt1 = jnp.sum(~jnp.isnan(conns1[:, 0]))
con_cnt2 = jnp.sum(~jnp.isnan(conns2[:, 0]))
max_cnt = jnp.maximum(con_cnt1, con_cnt2)
cons = jnp.concatenate((conns1, conns2), axis=0)
keys = cons[:, :2]
sorted_indices = jnp.lexsort(keys.T[::-1])
cons = cons[sorted_indices]
cons = jnp.concatenate(
[cons, jnp.full((1, cons.shape[1]), jnp.nan)], axis=0
) # add a nan row to the end
fr, sr = cons[:-1], cons[1:] # first row, second row
# both genome has such connection
intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0])
non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask)
fr_attrs = vmap(extract_conn_attrs)(fr)
sr_attrs = vmap(extract_conn_attrs)(sr)
hcd = vmap(self.genome.conn_gene.distance, in_axes=(None, 0, 0))(
state, fr_attrs, sr_attrs
) # homologous connection distance
hcd = jnp.where(jnp.isnan(hcd), 0, hcd)
homologous_distance = jnp.sum(hcd * intersect_mask)
val = (
non_homologous_cnt * self.compatibility_disjoint
+ homologous_distance * self.compatibility_weight
)
val = jnp.where(max_cnt == 0, 0, val / max_cnt) # normalize
return val

View File

@@ -0,0 +1,12 @@
from tensorneat.common import StatefulBaseClass, State
class BaseMutation(StatefulBaseClass):
def setup(self, state=State(), genome = None):
assert genome is not None, "genome should not be None"
self.genome = genome
return state
def __call__(self, state, randkey, genome, nodes, conns, new_node_key):
raise NotImplementedError

View File

@@ -1,11 +1,14 @@
import jax, jax.numpy as jnp import jax
from jax import vmap, numpy as jnp
from . import BaseMutation from . import BaseMutation
from utils import ( from tensorneat.common import (
fetch_first, fetch_first,
fetch_random, fetch_random,
I_INF, I_INF,
unflatten_conns,
check_cycles, check_cycles,
)
from ...utils import (
unflatten_conns,
add_node, add_node,
add_conn, add_conn,
delete_node_by_pos, delete_node_by_pos,
@@ -225,17 +228,17 @@ class DefaultMutation(BaseMutation):
nodes_randkeys = jax.random.split(k1, num=genome.max_nodes) nodes_randkeys = jax.random.split(k1, num=genome.max_nodes)
conns_randkeys = jax.random.split(k2, num=genome.max_conns) conns_randkeys = jax.random.split(k2, num=genome.max_conns)
node_attrs = jax.vmap(extract_node_attrs)(nodes) node_attrs = vmap(extract_node_attrs)(nodes)
new_node_attrs = jax.vmap(genome.node_gene.mutate, in_axes=(None, 0, 0))( new_node_attrs = vmap(genome.node_gene.mutate, in_axes=(None, 0, 0))(
state, nodes_randkeys, node_attrs state, nodes_randkeys, node_attrs
) )
new_nodes = jax.vmap(set_node_attrs)(nodes, new_node_attrs) new_nodes = vmap(set_node_attrs)(nodes, new_node_attrs)
conn_attrs = jax.vmap(extract_conn_attrs)(conns) conn_attrs = vmap(extract_conn_attrs)(conns)
new_conn_attrs = jax.vmap(genome.conn_gene.mutate, in_axes=(None, 0, 0))( new_conn_attrs = vmap(genome.conn_gene.mutate, in_axes=(None, 0, 0))(
state, conns_randkeys, conn_attrs state, conns_randkeys, conn_attrs
) )
new_conns = jax.vmap(set_conn_attrs)(conns, new_conn_attrs) new_conns = vmap(set_conn_attrs)(conns, new_conn_attrs)
# nan nodes not changed # nan nodes not changed
new_nodes = jnp.where(jnp.isnan(nodes), jnp.nan, new_nodes) new_nodes = jnp.where(jnp.isnan(nodes), jnp.nan, new_nodes)

View File

@@ -1,11 +1,11 @@
from typing import Callable from typing import Callable
import jax, jax.numpy as jnp import jax, jax.numpy as jnp
from utils import unflatten_conns from .utils import unflatten_conns
from . import BaseGenome from . import BaseGenome
from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene from ..gene import DefaultNodeGene, DefaultConnGene
from ..ga import BaseMutation, BaseCrossover, DefaultMutation, DefaultCrossover from .operations import DefaultMutation, DefaultCrossover
class RecurrentGenome(BaseGenome): class RecurrentGenome(BaseGenome):
@@ -17,13 +17,13 @@ class RecurrentGenome(BaseGenome):
self, self,
num_inputs: int, num_inputs: int,
num_outputs: int, num_outputs: int,
max_nodes: int, max_nodes = 50,
max_conns: int, max_conns = 100,
node_gene: BaseNodeGene = DefaultNodeGene(), node_gene=DefaultNodeGene(),
conn_gene: BaseConnGene = DefaultConnGene(), conn_gene=DefaultConnGene(),
mutation: BaseMutation = DefaultMutation(), mutation=DefaultMutation(),
crossover: BaseCrossover = DefaultCrossover(), crossover=DefaultCrossover(),
activate_time: int = 10, activate_time=10,
output_transform: Callable = None, output_transform: Callable = None,
): ):
super().__init__( super().__init__(

View File

@@ -0,0 +1,109 @@
import jax
from jax import vmap, numpy as jnp
from tensorneat.common import fetch_first, I_INF
def unflatten_conns(nodes, conns):
"""
transform the (C, CL) connections to (N, N), which contains the idx of the connection in conns
connection length, N means the number of nodes, C means the number of connections
returns the unflatten connection indices with shape (N, N)
"""
N = nodes.shape[0] # max_nodes
C = conns.shape[0] # max_conns
node_keys = nodes[:, 0]
i_keys, o_keys = conns[:, 0], conns[:, 1]
def key_to_indices(key, keys):
return fetch_first(key == keys)
i_idxs = vmap(key_to_indices, in_axes=(0, None))(i_keys, node_keys)
o_idxs = vmap(key_to_indices, in_axes=(0, None))(o_keys, node_keys)
# Is interesting that jax use clip when attach data in array
# however, it will do nothing when setting values in an array
# put the index of connections in the unflatten array
unflatten = (
jnp.full((N, N), I_INF, dtype=jnp.int32)
.at[i_idxs, o_idxs]
.set(jnp.arange(C, dtype=jnp.int32))
)
return unflatten
def valid_cnt(nodes_or_conns):
return jnp.sum(~jnp.isnan(nodes_or_conns[:, 0]))
def extract_node_attrs(node):
"""
node: Array(NL, )
extract the attributes of a node
"""
return node[1:] # 0 is for idx
def set_node_attrs(node, attrs):
"""
node: Array(NL, )
attrs: Array(NL-1, )
set the attributes of a node
"""
return node.at[1:].set(attrs) # 0 is for idx
def extract_conn_attrs(conn):
"""
conn: Array(CL, )
extract the attributes of a connection
"""
return conn[2:] # 0, 1 is for in-idx and out-idx
def set_conn_attrs(conn, attrs):
"""
conn: Array(CL, )
attrs: Array(CL-2, )
set the attributes of a connection
"""
return conn.at[2:].set(attrs) # 0, 1 is for in-idx and out-idx
def add_node(nodes, new_key: int, attrs):
"""
Add a new node to the genome.
The new node will place at the first NaN row.
"""
exist_keys = nodes[:, 0]
pos = fetch_first(jnp.isnan(exist_keys))
new_nodes = nodes.at[pos, 0].set(new_key)
return new_nodes.at[pos, 1:].set(attrs)
def delete_node_by_pos(nodes, pos):
"""
Delete a node from the genome.
Delete the node by its pos in nodes.
"""
return nodes.at[pos].set(jnp.nan)
def add_conn(conns, i_key, o_key, attrs):
"""
Add a new connection to the genome.
The new connection will place at the first NaN row.
"""
con_keys = conns[:, 0]
pos = fetch_first(jnp.isnan(con_keys))
new_conns = conns.at[pos, 0:2].set(jnp.array([i_key, o_key]))
return new_conns.at[pos, 2:].set(attrs)
def delete_conn_by_pos(conns, pos):
"""
Delete a connection from the genome.
Delete the connection by its idx.
"""
return conns.at[pos].set(jnp.nan)

View File

@@ -1,5 +1,5 @@
import jax, jax.numpy as jnp import jax, jax.numpy as jnp
from utils import State from tensorneat.common import State
from .. import BaseAlgorithm from .. import BaseAlgorithm
from .species import * from .species import *

View File

@@ -1,4 +1,4 @@
from utils import State, StatefulBaseClass from tensorneat.common import State, StatefulBaseClass
from ..genome import BaseGenome from ..genome import BaseGenome

View File

@@ -1,9 +1,11 @@
import jax, jax.numpy as jnp import jax, jax.numpy as jnp
from utils import ( from tensorneat.common import (
State, State,
rank_elements, rank_elements,
argmin_with_mask, argmin_with_mask,
fetch_first, fetch_first,
)
from ..genome.utils import (
extract_conn_attrs, extract_conn_attrs,
extract_node_attrs, extract_node_attrs,
) )
@@ -635,7 +637,9 @@ class DefaultSpecies(BaseSpecies):
# find next node key # find next node key
all_nodes_keys = state.pop_nodes[:, :, 0] all_nodes_keys = state.pop_nodes[:, :, 0]
max_node_key = jnp.max(all_nodes_keys, where=~jnp.isnan(all_nodes_keys), initial=0) max_node_key = jnp.max(
all_nodes_keys, where=~jnp.isnan(all_nodes_keys), initial=0
)
next_node_key = max_node_key + 1 next_node_key = max_node_key + 1
new_node_keys = jnp.arange(self.pop_size) + next_node_key new_node_keys = jnp.arange(self.pop_size) + next_node_key

View File

@@ -1,4 +1,4 @@
from utils.aggregation.agg_jnp import Agg, agg_func, AGG_ALL from tensorneat.common.aggregation.agg_jnp import Agg, agg_func, AGG_ALL
from .tools import * from .tools import *
from .graph import * from .graph import *
from .state import State from .state import State

View File

@@ -6,36 +6,6 @@ from jax import numpy as jnp, Array, jit, vmap
I_INF = np.iinfo(jnp.int32).max # infinite int I_INF = np.iinfo(jnp.int32).max # infinite int
def unflatten_conns(nodes, conns):
"""
transform the (C, CL) connections to (N, N), which contains the idx of the connection in conns
connection length, N means the number of nodes, C means the number of connections
returns the unflatten connection indices with shape (N, N)
"""
N = nodes.shape[0] # max_nodes
C = conns.shape[0] # max_conns
node_keys = nodes[:, 0]
i_keys, o_keys = conns[:, 0], conns[:, 1]
def key_to_indices(key, keys):
return fetch_first(key == keys)
i_idxs = vmap(key_to_indices, in_axes=(0, None))(i_keys, node_keys)
o_idxs = vmap(key_to_indices, in_axes=(0, None))(o_keys, node_keys)
# Is interesting that jax use clip when attach data in array
# however, it will do nothing when setting values in an array
# put the index of connections in the unflatten array
unflatten = (
jnp.full((N, N), I_INF, dtype=jnp.int32)
.at[i_idxs, o_idxs]
.set(jnp.arange(C, dtype=jnp.int32))
)
return unflatten
# TODO: strange implementation # TODO: strange implementation
def attach_with_inf(arr, idx): def attach_with_inf(arr, idx):
expand_size = arr.ndim - idx.ndim expand_size = arr.ndim - idx.ndim
@@ -45,40 +15,6 @@ def attach_with_inf(arr, idx):
return jnp.where(expand_idx == I_INF, jnp.nan, arr[idx]) return jnp.where(expand_idx == I_INF, jnp.nan, arr[idx])
def extract_node_attrs(node):
"""
node: Array(NL, )
extract the attributes of a node
"""
return node[1:] # 0 is for idx
def set_node_attrs(node, attrs):
"""
node: Array(NL, )
attrs: Array(NL-1, )
set the attributes of a node
"""
return node.at[1:].set(attrs) # 0 is for idx
def extract_conn_attrs(conn):
"""
conn: Array(CL, )
extract the attributes of a connection
"""
return conn[2:] # 0, 1 is for in-idx and out-idx
def set_conn_attrs(conn, attrs):
"""
conn: Array(CL, )
attrs: Array(CL-2, )
set the attributes of a connection
"""
return conn.at[2:].set(attrs) # 0, 1 is for in-idx and out-idx
@jit @jit
def fetch_first(mask, default=I_INF) -> Array: def fetch_first(mask, default=I_INF) -> Array:
""" """
@@ -164,44 +100,6 @@ def argmin_with_mask(arr, mask):
return min_idx return min_idx
def add_node(nodes, new_key: int, attrs):
"""
Add a new node to the genome.
The new node will place at the first NaN row.
"""
exist_keys = nodes[:, 0]
pos = fetch_first(jnp.isnan(exist_keys))
new_nodes = nodes.at[pos, 0].set(new_key)
return new_nodes.at[pos, 1:].set(attrs)
def delete_node_by_pos(nodes, pos):
"""
Delete a node from the genome.
Delete the node by its pos in nodes.
"""
return nodes.at[pos].set(jnp.nan)
def add_conn(conns, i_key, o_key, attrs):
"""
Add a new connection to the genome.
The new connection will place at the first NaN row.
"""
con_keys = conns[:, 0]
pos = fetch_first(jnp.isnan(con_keys))
new_conns = conns.at[pos, 0:2].set(jnp.array([i_key, o_key]))
return new_conns.at[pos, 2:].set(attrs)
def delete_conn_by_pos(conns, pos):
"""
Delete a connection from the genome.
Delete the connection by its idx.
"""
return conns.at[pos].set(jnp.nan)
def hash_array(arr: Array): def hash_array(arr: Array):
arr = jax.lax.bitcast_convert_type(arr, jnp.uint32) arr = jax.lax.bitcast_convert_type(arr, jnp.uint32)

View File

@@ -9,7 +9,7 @@ from algorithm import BaseAlgorithm
from problem import BaseProblem from problem import BaseProblem
from problem.rl_env import RLEnv from problem.rl_env import RLEnv
from problem.func_fit import FuncFit from problem.func_fit import FuncFit
from utils import State, StatefulBaseClass from tensorneat.common import State, StatefulBaseClass
class Pipeline(StatefulBaseClass): class Pipeline(StatefulBaseClass):

View File

@@ -1,6 +1,6 @@
from typing import Callable from typing import Callable
from utils import State, StatefulBaseClass from tensorneat.common import State, StatefulBaseClass
class BaseProblem(StatefulBaseClass): class BaseProblem(StatefulBaseClass):

View File

@@ -1,7 +1,7 @@
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from utils import State from tensorneat.common import State
from .. import BaseProblem from .. import BaseProblem

View File

@@ -1,7 +1,7 @@
import jax, jax.numpy as jnp import jax, jax.numpy as jnp
import jumanji import jumanji
from utils import State from tensorneat.common import State
from ..rl_jit import RLEnv from ..rl_jit import RLEnv

View File

@@ -4,7 +4,7 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
import numpy as np import numpy as np
from utils import State from tensorneat.common import State
from .. import BaseProblem from .. import BaseProblem

View File

@@ -1,5 +1,5 @@
import jax, jax.numpy as jnp import jax, jax.numpy as jnp
from utils import Act from tensorneat.common import Act
from algorithm.neat import * from algorithm.neat import *
import numpy as np import numpy as np

View File

@@ -1,5 +1,5 @@
import jax, jax.numpy as jnp import jax, jax.numpy as jnp
from utils import Act from tensorneat.common import Act
from algorithm.neat import * from algorithm.neat import *
import numpy as np import numpy as np

View File

@@ -27,7 +27,7 @@
"from algorithm.neat.gene.node.kan_node import KANNode\n", "from algorithm.neat.gene.node.kan_node import KANNode\n",
"from algorithm.neat.gene.conn.bspline import BSplineConn\n", "from algorithm.neat.gene.conn.bspline import BSplineConn\n",
"from problem.func_fit import XOR3d\n", "from problem.func_fit import XOR3d\n",
"from utils import Act\n", "from tensorneat.utils import Act\n",
"\n", "\n",
"import jax, jax.numpy as jnp\n", "import jax, jax.numpy as jnp\n",
"\n", "\n",

View File

@@ -1,5 +1,5 @@
import jax, jax.numpy as jnp import jax, jax.numpy as jnp
from utils import Act from tensorneat.common import Act
from algorithm.neat import * from algorithm.neat import *
import numpy as np import numpy as np

View File

@@ -14,7 +14,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"import jax, jax.numpy as jnp\n", "import jax, jax.numpy as jnp\n",
"from utils import State\n", "from tensorneat.utils import State\n",
"from problem.rl_env import BraxEnv\n", "from problem.rl_env import BraxEnv\n",
"\n", "\n",
"\n", "\n",

View File

@@ -145,7 +145,7 @@
"source": [ "source": [
"from algorithm.neat.gene.node.normalized import NormalizedNode\n", "from algorithm.neat.gene.node.normalized import NormalizedNode\n",
"from algorithm.neat.gene.conn import DefaultConnGene\n", "from algorithm.neat.gene.conn import DefaultConnGene\n",
"from utils import Act\n", "from tensorneat.utils import Act\n",
"\n", "\n",
"genome = DefaultGenome(num_inputs=3, num_outputs=2, max_nodes=10, max_conns=10,\n", "genome = DefaultGenome(num_inputs=3, num_outputs=2, max_nodes=10, max_conns=10,\n",
" node_gene=NormalizedNode(activation_default=Act.identity, activation_options=(Act.identity,)),\n", " node_gene=NormalizedNode(activation_default=Act.identity, activation_options=(Act.identity,)),\n",