odify genome for the official release
This commit is contained in:
@@ -75,7 +75,7 @@ from pipeline import Pipeline
|
||||
from algorithm.neat import *
|
||||
|
||||
from problem.rl_env import BraxEnv
|
||||
from utils import Act
|
||||
from tensorneat.utils import Act
|
||||
|
||||
if __name__ == '__main__':
|
||||
pipeline = Pipeline(
|
||||
|
||||
@@ -2,7 +2,7 @@ from pipeline import Pipeline
|
||||
from algorithm.neat import *
|
||||
|
||||
from problem.rl_env import BraxEnv
|
||||
from utils import Act
|
||||
from tensorneat.common import Act
|
||||
|
||||
if __name__ == "__main__":
|
||||
pipeline = Pipeline(
|
||||
@@ -4,7 +4,7 @@ from pipeline import Pipeline
|
||||
from algorithm.neat import *
|
||||
|
||||
from problem.rl_env import BraxEnv
|
||||
from utils import Act
|
||||
from tensorneat.common import Act
|
||||
|
||||
|
||||
def sample_policy(randkey, obs):
|
||||
@@ -2,7 +2,7 @@ from pipeline import Pipeline
|
||||
from algorithm.neat import *
|
||||
|
||||
from problem.rl_env import BraxEnv
|
||||
from utils import Act
|
||||
from tensorneat.common import Act
|
||||
|
||||
if __name__ == "__main__":
|
||||
pipeline = Pipeline(
|
||||
@@ -2,7 +2,7 @@ from pipeline import Pipeline
|
||||
from algorithm.neat import *
|
||||
|
||||
from problem.rl_env import BraxEnv
|
||||
from utils import Act
|
||||
from tensorneat.common import Act
|
||||
|
||||
import jax, jax.numpy as jnp
|
||||
|
||||
@@ -2,7 +2,7 @@ from pipeline import Pipeline
|
||||
from algorithm.neat import *
|
||||
|
||||
from problem.func_fit import XOR3d
|
||||
from utils import ACT_ALL, AGG_ALL, Act, Agg
|
||||
from tensorneat.common import ACT_ALL, AGG_ALL, Act, Agg
|
||||
|
||||
if __name__ == "__main__":
|
||||
pipeline = Pipeline(
|
||||
@@ -1,7 +1,7 @@
|
||||
from pipeline import Pipeline
|
||||
from algorithm.neat import *
|
||||
from algorithm.hyperneat import *
|
||||
from utils import Act
|
||||
from tensorneat.common import Act
|
||||
|
||||
from problem.func_fit import XOR3d
|
||||
|
||||
@@ -3,7 +3,7 @@ import jax
|
||||
from pipeline import Pipeline
|
||||
from algorithm.neat import *
|
||||
from algorithm.hyperneat import *
|
||||
from utils import Act
|
||||
from tensorneat.common import Act
|
||||
|
||||
from problem.rl_env import GymNaxEnv
|
||||
|
||||
@@ -2,7 +2,7 @@ from pipeline import Pipeline
|
||||
from algorithm.neat import *
|
||||
|
||||
from problem.rl_env import GymNaxEnv
|
||||
from utils import Act
|
||||
from tensorneat.common import Act
|
||||
|
||||
if __name__ == "__main__":
|
||||
pipeline = Pipeline(
|
||||
@@ -2,7 +2,7 @@ from pipeline import Pipeline
|
||||
from algorithm.neat import *
|
||||
|
||||
from problem.rl_env import GymNaxEnv
|
||||
from utils import Act
|
||||
from tensorneat.common import Act
|
||||
|
||||
if __name__ == "__main__":
|
||||
pipeline = Pipeline(
|
||||
@@ -11,7 +11,7 @@
|
||||
"from algorithm.neat.genome.advance import AdvanceInitialize\n",
|
||||
"from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse\n",
|
||||
"from utils.graph import topological_sort_python\n",
|
||||
"from utils import Act, Agg\n",
|
||||
"from tensorneat.utils import Act, Agg\n",
|
||||
"\n",
|
||||
"import numpy as np"
|
||||
],
|
||||
@@ -3,7 +3,7 @@ import jax, jax.numpy as jnp
|
||||
from algorithm.neat import *
|
||||
from algorithm.neat.genome.dense import DenseInitialize
|
||||
from utils.graph import topological_sort_python
|
||||
from utils import *
|
||||
from tensorneat.common import *
|
||||
|
||||
if __name__ == "__main__":
|
||||
genome = DenseInitialize(
|
||||
|
Before Width: | Height: | Size: 90 KiB After Width: | Height: | Size: 90 KiB |
|
Before Width: | Height: | Size: 89 KiB After Width: | Height: | Size: 89 KiB |
@@ -19,7 +19,7 @@
|
||||
"from algorithm.neat.genome.advance import AdvanceInitialize\n",
|
||||
"from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse\n",
|
||||
"from utils.graph import topological_sort_python\n",
|
||||
"from utils import Act, Agg\n",
|
||||
"from tensorneat.utils import Act, Agg\n",
|
||||
"\n",
|
||||
"genome = AdvanceInitialize(\n",
|
||||
" num_inputs=16,\n",
|
||||
@@ -29,7 +29,7 @@
|
||||
"from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse\n",
|
||||
"\n",
|
||||
"from problem.rl_env.jumanji.jumanji_2048 import Jumanji_2048\n",
|
||||
"from utils import Act, Agg\n",
|
||||
"from tensorneat.utils import Act, Agg\n",
|
||||
"\n",
|
||||
"pipeline = Pipeline(\n",
|
||||
" algorithm=NEAT(\n",
|
||||
@@ -4,7 +4,7 @@ from pipeline import Pipeline
|
||||
from algorithm.neat import *
|
||||
from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse
|
||||
from problem.rl_env.jumanji.jumanji_2048 import Jumanji_2048
|
||||
from utils import Act, Agg
|
||||
from tensorneat.common import Act, Agg
|
||||
|
||||
|
||||
def rot_li(li):
|
||||
10
examples/tmp.py
Normal file
10
examples/tmp.py
Normal file
@@ -0,0 +1,10 @@
|
||||
import jax, jax.numpy as jnp
|
||||
|
||||
from tensorneat.algorithm import NEAT
|
||||
from tensorneat.algorithm.neat import DefaultGenome
|
||||
|
||||
key = jax.random.key(0)
|
||||
genome = DefaultGenome(num_inputs=5, num_outputs=3, init_hidden_layers=(1, ))
|
||||
state = genome.setup()
|
||||
nodes, conns = genome.initialize(state, key)
|
||||
print(genome.repr(state, nodes, conns))
|
||||
@@ -1,4 +1,4 @@
|
||||
from utils import State, StatefulBaseClass
|
||||
from tensorneat.common import State, StatefulBaseClass
|
||||
|
||||
|
||||
class BaseAlgorithm(StatefulBaseClass):
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Callable
|
||||
|
||||
import jax, jax.numpy as jnp
|
||||
|
||||
from utils import State, Act, Agg
|
||||
from tensorneat.common import State, Act, Agg
|
||||
from .. import BaseAlgorithm, NEAT
|
||||
from ..neat.gene import BaseNodeGene, BaseConnGene
|
||||
from ..neat.genome import RecurrentGenome
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from utils import StatefulBaseClass
|
||||
from tensorneat.common import StatefulBaseClass
|
||||
|
||||
|
||||
class BaseSubstrate(StatefulBaseClass):
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from .ga import *
|
||||
from .gene import *
|
||||
from .genome import *
|
||||
from .species import *
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
from utils import StatefulBaseClass
|
||||
|
||||
|
||||
class BaseCrossover(StatefulBaseClass):
|
||||
def __call__(self, state, randkey, genome, nodes1, nodes2, conns1, conns2):
|
||||
raise NotImplementedError
|
||||
@@ -1,6 +0,0 @@
|
||||
from utils import StatefulBaseClass
|
||||
|
||||
|
||||
class BaseMutation(StatefulBaseClass):
|
||||
def __call__(self, state, randkey, genome, nodes, conns, new_node_key):
|
||||
raise NotImplementedError
|
||||
@@ -1,5 +1,5 @@
|
||||
import jax, jax.numpy as jnp
|
||||
from utils import State, StatefulBaseClass, hash_array
|
||||
from tensorneat.common import State, StatefulBaseClass, hash_array
|
||||
|
||||
|
||||
class BaseGene(StatefulBaseClass):
|
||||
|
||||
@@ -2,7 +2,7 @@ import jax.numpy as jnp
|
||||
import jax.random
|
||||
import numpy as np
|
||||
import sympy as sp
|
||||
from utils import mutate_float
|
||||
from tensorneat.common import mutate_float
|
||||
from . import BaseConnGene
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import numpy as np
|
||||
import jax, jax.numpy as jnp
|
||||
import sympy as sp
|
||||
|
||||
from utils import (
|
||||
from tensorneat.common import (
|
||||
Act,
|
||||
Agg,
|
||||
act_func,
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Tuple
|
||||
import jax, jax.numpy as jnp
|
||||
import numpy as np
|
||||
import sympy as sp
|
||||
from utils import (
|
||||
from tensorneat.common import (
|
||||
Act,
|
||||
Agg,
|
||||
act_func,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import jax.numpy as jnp
|
||||
from . import BaseNodeGene
|
||||
from utils import Agg
|
||||
from tensorneat.common import Agg
|
||||
|
||||
|
||||
class KANNode(BaseNodeGene):
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Tuple
|
||||
|
||||
import jax, jax.numpy as jnp
|
||||
|
||||
from utils import Act, Agg, act_func, agg_func, mutate_int, mutate_float
|
||||
from tensorneat.common import Act, Agg, act_func, agg_func, mutate_int, mutate_float
|
||||
from . import BaseNodeGene
|
||||
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Tuple
|
||||
|
||||
import jax, jax.numpy as jnp
|
||||
|
||||
from utils import Act, Agg, act_func, agg_func, mutate_int, mutate_float
|
||||
from tensorneat.common import Act, Agg, act_func, agg_func, mutate_int, mutate_float
|
||||
from . import BaseNodeGene
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from .base import BaseGenome
|
||||
from .default import DefaultGenome
|
||||
from .recurrent import RecurrentGenome
|
||||
from .hidden import HiddenInitialize
|
||||
from .dense import DenseInitialize
|
||||
|
||||
|
||||
@@ -1,8 +1,16 @@
|
||||
from typing import Callable, Sequence
|
||||
|
||||
import numpy as np
|
||||
import jax, jax.numpy as jnp
|
||||
import jax
|
||||
from jax import vmap, numpy as jnp
|
||||
from ..gene import BaseNodeGene, BaseConnGene
|
||||
from ..ga import BaseMutation, BaseCrossover
|
||||
from utils import State, StatefulBaseClass, topological_sort_python, hash_array
|
||||
from .operations import BaseMutation, BaseCrossover, BaseDistance
|
||||
from tensorneat.common import (
|
||||
State,
|
||||
StatefulBaseClass,
|
||||
hash_array,
|
||||
)
|
||||
from .utils import valid_cnt
|
||||
|
||||
|
||||
class BaseGenome(StatefulBaseClass):
|
||||
@@ -18,120 +26,159 @@ class BaseGenome(StatefulBaseClass):
|
||||
conn_gene: BaseConnGene,
|
||||
mutation: BaseMutation,
|
||||
crossover: BaseCrossover,
|
||||
distance: BaseDistance,
|
||||
output_transform: Callable = None,
|
||||
input_transform: Callable = None,
|
||||
init_hidden_layers: Sequence[int] = (),
|
||||
):
|
||||
|
||||
# check transform functions
|
||||
if input_transform is not None:
|
||||
try:
|
||||
_ = input_transform(jnp.zeros(num_inputs))
|
||||
except Exception as e:
|
||||
raise ValueError(f"Output transform function failed: {e}")
|
||||
|
||||
if output_transform is not None:
|
||||
try:
|
||||
_ = output_transform(jnp.zeros(num_outputs))
|
||||
except Exception as e:
|
||||
raise ValueError(f"Output transform function failed: {e}")
|
||||
|
||||
# prepare for initialization
|
||||
all_layers = [num_inputs] + list(init_hidden_layers) + [num_outputs]
|
||||
layer_indices = []
|
||||
next_index = 0
|
||||
for layer in all_layers:
|
||||
layer_indices.append(list(range(next_index, next_index + layer)))
|
||||
next_index += layer
|
||||
|
||||
all_init_nodes = []
|
||||
all_init_conns_in_idx = []
|
||||
all_init_conns_out_idx = []
|
||||
for i in range(len(layer_indices) - 1):
|
||||
in_layer = layer_indices[i]
|
||||
out_layer = layer_indices[i + 1]
|
||||
for in_idx in in_layer:
|
||||
for out_idx in out_layer:
|
||||
all_init_conns_in_idx.append(in_idx)
|
||||
all_init_conns_out_idx.append(out_idx)
|
||||
all_init_nodes.extend(in_layer)
|
||||
|
||||
if max_nodes < len(all_init_nodes):
|
||||
raise ValueError(
|
||||
f"max_nodes={max_nodes} must be greater than or equal to the number of initial nodes={len(all_init_nodes)}"
|
||||
)
|
||||
|
||||
if max_conns < len(all_init_conns_in_idx):
|
||||
raise ValueError(
|
||||
f"max_conns={max_conns} must be greater than or equal to the number of initial connections={len(all_init_conns_in_idx)}"
|
||||
)
|
||||
|
||||
self.num_inputs = num_inputs
|
||||
self.num_outputs = num_outputs
|
||||
self.input_idx = np.arange(num_inputs)
|
||||
self.output_idx = np.arange(num_inputs, num_inputs + num_outputs)
|
||||
self.max_nodes = max_nodes
|
||||
self.max_conns = max_conns
|
||||
self.node_gene = node_gene
|
||||
self.conn_gene = conn_gene
|
||||
self.mutation = mutation
|
||||
self.crossover = crossover
|
||||
self.distance = distance
|
||||
self.output_transform = output_transform
|
||||
self.input_transform = input_transform
|
||||
|
||||
self.input_idx = np.array(layer_indices[0])
|
||||
self.output_idx = np.array(layer_indices[-1])
|
||||
self.all_init_nodes = np.array(all_init_nodes)
|
||||
self.all_init_conns = np.c_[all_init_conns_in_idx, all_init_conns_out_idx]
|
||||
|
||||
def setup(self, state=State()):
|
||||
state = self.node_gene.setup(state)
|
||||
state = self.conn_gene.setup(state)
|
||||
state = self.mutation.setup(state)
|
||||
state = self.crossover.setup(state)
|
||||
state = self.mutation.setup(state, self)
|
||||
state = self.crossover.setup(state, self)
|
||||
state = self.distance.setup(state, self)
|
||||
return state
|
||||
|
||||
def transform(self, state, nodes, conns):
|
||||
raise NotImplementedError
|
||||
|
||||
def restore(self, state, transformed):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, state, transformed, inputs):
|
||||
raise NotImplementedError
|
||||
|
||||
def sympy_func(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def visualize(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def execute_mutation(self, state, randkey, nodes, conns, new_node_key):
|
||||
return self.mutation(state, randkey, self, nodes, conns, new_node_key)
|
||||
return self.mutation(state, randkey, nodes, conns, new_node_key)
|
||||
|
||||
def execute_crossover(self, state, randkey, nodes1, conns1, nodes2, conns2):
|
||||
return self.crossover(state, randkey, self, nodes1, conns1, nodes2, conns2)
|
||||
return self.crossover(state, randkey, nodes1, conns1, nodes2, conns2)
|
||||
|
||||
def execute_distance(self, state, nodes1, conns1, nodes2, conns2):
|
||||
return self.distance(state, nodes1, conns1, nodes2, conns2)
|
||||
|
||||
def initialize(self, state, randkey):
|
||||
"""
|
||||
Default initialization method for the genome.
|
||||
Add an extra hidden node.
|
||||
Make all input nodes and output nodes connected to the hidden node.
|
||||
All attributes will be initialized randomly using gene.new_random_attrs method.
|
||||
|
||||
For example, a network with 2 inputs and 1 output, the structure will be:
|
||||
nodes:
|
||||
[
|
||||
[0, attrs0], # input node 0
|
||||
[1, attrs1], # input node 1
|
||||
[2, attrs2], # output node 0
|
||||
[3, attrs3], # hidden node
|
||||
[NaN, NaN], # empty node
|
||||
]
|
||||
conns:
|
||||
[
|
||||
[0, 3, attrs0], # input node 0 -> hidden node
|
||||
[1, 3, attrs1], # input node 1 -> hidden node
|
||||
[3, 2, attrs2], # hidden node -> output node 0
|
||||
[NaN, NaN],
|
||||
[NaN, NaN],
|
||||
]
|
||||
"""
|
||||
|
||||
k1, k2 = jax.random.split(randkey) # k1 for nodes, k2 for conns
|
||||
|
||||
all_nodes_cnt = len(self.all_init_nodes)
|
||||
all_conns_cnt = len(self.all_init_conns)
|
||||
|
||||
# initialize nodes
|
||||
new_node_key = (
|
||||
max([*self.input_idx, *self.output_idx]) + 1
|
||||
) # the key for the hidden node
|
||||
node_keys = jnp.concatenate(
|
||||
[self.input_idx, self.output_idx, jnp.array([new_node_key])]
|
||||
) # the list of all node keys
|
||||
|
||||
# initialize nodes and connections with NaN
|
||||
nodes = jnp.full((self.max_nodes, self.node_gene.length), jnp.nan)
|
||||
conns = jnp.full((self.max_conns, self.conn_gene.length), jnp.nan)
|
||||
# create node indices
|
||||
node_indices = self.all_init_nodes
|
||||
# create node attrs
|
||||
rand_keys_n = jax.random.split(k1, num=all_nodes_cnt)
|
||||
node_attr_func = vmap(self.node_gene.new_random_attrs, in_axes=(None, 0))
|
||||
node_attrs = node_attr_func(state, rand_keys_n)
|
||||
|
||||
# set keys for input nodes, output nodes and hidden node
|
||||
nodes = nodes.at[node_keys, 0].set(node_keys)
|
||||
|
||||
# generate random attributes for nodes
|
||||
node_keys = jax.random.split(k1, len(node_keys))
|
||||
random_node_attrs = jax.vmap(
|
||||
self.node_gene.new_random_attrs, in_axes=(None, 0)
|
||||
)(state, node_keys)
|
||||
nodes = nodes.at[: len(node_keys), 1:].set(random_node_attrs)
|
||||
nodes = nodes.at[:all_nodes_cnt, 0].set(node_indices) # set node indices
|
||||
nodes = nodes.at[:all_nodes_cnt, 1:].set(node_attrs) # set node attrs
|
||||
|
||||
# initialize conns
|
||||
# input-hidden connections
|
||||
input_conns = jnp.c_[
|
||||
self.input_idx, jnp.full_like(self.input_idx, new_node_key)
|
||||
]
|
||||
conns = conns.at[self.input_idx, :2].set(input_conns) # in-keys, out-keys
|
||||
conns = jnp.full((self.max_conns, self.conn_gene.length), jnp.nan)
|
||||
# create input and output indices
|
||||
conn_indices = self.all_init_conns
|
||||
# create conn attrs
|
||||
rand_keys_c = jax.random.split(k2, num=all_conns_cnt)
|
||||
conns_attr_func = jax.vmap(
|
||||
self.conn_gene.new_random_attrs,
|
||||
in_axes=(
|
||||
None,
|
||||
0,
|
||||
),
|
||||
)
|
||||
conns_attrs = conns_attr_func(state, rand_keys_c)
|
||||
|
||||
# output-hidden connections
|
||||
output_conns = jnp.c_[
|
||||
jnp.full_like(self.output_idx, new_node_key), self.output_idx
|
||||
]
|
||||
conns = conns.at[self.output_idx, :2].set(output_conns) # in-keys, out-keys
|
||||
|
||||
conn_keys = jax.random.split(k2, num=len(self.input_idx) + len(self.output_idx))
|
||||
# generate random attributes for conns
|
||||
random_conn_attrs = jax.vmap(
|
||||
self.conn_gene.new_random_attrs, in_axes=(None, 0)
|
||||
)(state, conn_keys)
|
||||
conns = conns.at[: len(conn_keys), 2:].set(random_conn_attrs)
|
||||
conns = conns.at[:all_conns_cnt, :2].set(conn_indices) # set conn indices
|
||||
conns = conns.at[:all_conns_cnt, 2:].set(conns_attrs) # set conn attrs
|
||||
|
||||
return nodes, conns
|
||||
|
||||
def update_by_batch(self, state, batch_input, transformed):
|
||||
"""
|
||||
Update the genome by a batch of data.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
def network_dict(self, state, nodes, conns):
|
||||
return {
|
||||
"nodes": self._get_node_dict(state, nodes),
|
||||
"conns": self._get_conn_dict(state, conns),
|
||||
}
|
||||
|
||||
def get_input_idx(self):
|
||||
return self.input_idx.tolist()
|
||||
|
||||
def get_output_idx(self):
|
||||
return self.output_idx.tolist()
|
||||
|
||||
def hash(self, nodes, conns):
|
||||
nodes_hashs = vmap(hash_array)(nodes)
|
||||
conns_hashs = vmap(hash_array)(conns)
|
||||
return hash_array(jnp.concatenate([nodes_hashs, conns_hashs]))
|
||||
|
||||
def repr(self, state, nodes, conns, precision=2):
|
||||
nodes, conns = jax.device_get([nodes, conns])
|
||||
nodes_cnt, conns_cnt = self.valid_cnt(nodes), self.valid_cnt(conns)
|
||||
nodes_cnt, conns_cnt = valid_cnt(nodes), valid_cnt(conns)
|
||||
s = f"{self.__class__.__name__}(nodes={nodes_cnt}, conns={conns_cnt}):\n"
|
||||
s += f"\tNodes:\n"
|
||||
for node in nodes:
|
||||
@@ -152,11 +199,7 @@ class BaseGenome(StatefulBaseClass):
|
||||
s += f"\t\t{self.conn_gene.repr(state, conn, precision=precision)}\n"
|
||||
return s
|
||||
|
||||
@classmethod
|
||||
def valid_cnt(cls, arr):
|
||||
return jnp.sum(~jnp.isnan(arr[:, 0]))
|
||||
|
||||
def get_conn_dict(self, state, conns):
|
||||
def _get_conn_dict(self, state, conns):
|
||||
conns = jax.device_get(conns)
|
||||
conn_dict = {}
|
||||
for conn in conns:
|
||||
@@ -167,7 +210,7 @@ class BaseGenome(StatefulBaseClass):
|
||||
conn_dict[(in_idx, out_idx)] = cd
|
||||
return conn_dict
|
||||
|
||||
def get_node_dict(self, state, nodes):
|
||||
def _get_node_dict(self, state, nodes):
|
||||
nodes = jax.device_get(nodes)
|
||||
node_dict = {}
|
||||
for node in nodes:
|
||||
@@ -177,92 +220,3 @@ class BaseGenome(StatefulBaseClass):
|
||||
idx = nd["idx"]
|
||||
node_dict[idx] = nd
|
||||
return node_dict
|
||||
|
||||
def network_dict(self, state, nodes, conns):
|
||||
return {
|
||||
"nodes": self.get_node_dict(state, nodes),
|
||||
"conns": self.get_conn_dict(state, conns),
|
||||
}
|
||||
|
||||
def get_input_idx(self):
|
||||
return self.input_idx.tolist()
|
||||
|
||||
def get_output_idx(self):
|
||||
return self.output_idx.tolist()
|
||||
|
||||
def sympy_func(self, state, network, sympy_output_transform=None):
|
||||
raise NotImplementedError
|
||||
|
||||
def visualize(
|
||||
self,
|
||||
network,
|
||||
rotate=0,
|
||||
reverse_node_order=False,
|
||||
size=(300, 300, 300),
|
||||
color=("blue", "blue", "blue"),
|
||||
save_path="network.svg",
|
||||
save_dpi=800,
|
||||
**kwargs,
|
||||
):
|
||||
import networkx as nx
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
nodes_list = list(network["nodes"])
|
||||
conns_list = list(network["conns"])
|
||||
input_idx = self.get_input_idx()
|
||||
output_idx = self.get_output_idx()
|
||||
topo_order, topo_layers = topological_sort_python(nodes_list, conns_list)
|
||||
node2layer = {
|
||||
node: layer for layer, nodes in enumerate(topo_layers) for node in nodes
|
||||
}
|
||||
if reverse_node_order:
|
||||
topo_order = topo_order[::-1]
|
||||
|
||||
G = nx.DiGraph()
|
||||
|
||||
if not isinstance(size, tuple):
|
||||
size = (size, size, size)
|
||||
if not isinstance(color, tuple):
|
||||
color = (color, color, color)
|
||||
|
||||
for node in topo_order:
|
||||
if node in input_idx:
|
||||
G.add_node(node, subset=node2layer[node], size=size[0], color=color[0])
|
||||
elif node in output_idx:
|
||||
G.add_node(node, subset=node2layer[node], size=size[2], color=color[2])
|
||||
else:
|
||||
G.add_node(node, subset=node2layer[node], size=size[1], color=color[1])
|
||||
|
||||
for conn in conns_list:
|
||||
G.add_edge(conn[0], conn[1])
|
||||
pos = nx.multipartite_layout(G)
|
||||
|
||||
def rotate_layout(pos, angle):
|
||||
angle_rad = np.deg2rad(angle)
|
||||
cos_angle, sin_angle = np.cos(angle_rad), np.sin(angle_rad)
|
||||
rotated_pos = {}
|
||||
for node, (x, y) in pos.items():
|
||||
rotated_pos[node] = (
|
||||
cos_angle * x - sin_angle * y,
|
||||
sin_angle * x + cos_angle * y,
|
||||
)
|
||||
return rotated_pos
|
||||
|
||||
rotated_pos = rotate_layout(pos, rotate)
|
||||
|
||||
node_sizes = [n["size"] for n in G.nodes.values()]
|
||||
node_colors = [n["color"] for n in G.nodes.values()]
|
||||
|
||||
nx.draw(
|
||||
G,
|
||||
pos=rotated_pos,
|
||||
node_size=node_sizes,
|
||||
node_color=node_colors,
|
||||
**kwargs,
|
||||
)
|
||||
plt.savefig(save_path, dpi=save_dpi)
|
||||
|
||||
def hash(self, nodes, conns):
|
||||
nodes_hashs = jax.vmap(hash_array)(nodes)
|
||||
conns_hashs = jax.vmap(hash_array)(conns)
|
||||
return hash_array(jnp.concatenate([nodes_hashs, conns_hashs]))
|
||||
|
||||
@@ -1,25 +1,23 @@
|
||||
import warnings
|
||||
from typing import Callable
|
||||
|
||||
import jax, jax.numpy as jnp
|
||||
import jax
|
||||
from jax import vmap, numpy as jnp
|
||||
import numpy as np
|
||||
import sympy as sp
|
||||
from utils import (
|
||||
unflatten_conns,
|
||||
|
||||
from . import BaseGenome
|
||||
from ..gene import DefaultNodeGene, DefaultConnGene
|
||||
from .operations import DefaultMutation, DefaultCrossover, DefaultDistance
|
||||
from .utils import unflatten_conns, extract_node_attrs, extract_conn_attrs
|
||||
|
||||
from tensorneat.common import (
|
||||
topological_sort,
|
||||
topological_sort_python,
|
||||
I_INF,
|
||||
extract_node_attrs,
|
||||
extract_conn_attrs,
|
||||
set_node_attrs,
|
||||
set_conn_attrs,
|
||||
attach_with_inf,
|
||||
SYMPY_FUNCS_MODULE_NP,
|
||||
SYMPY_FUNCS_MODULE_JNP,
|
||||
)
|
||||
from . import BaseGenome
|
||||
from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene
|
||||
from ..ga import BaseMutation, BaseCrossover, DefaultMutation, DefaultCrossover
|
||||
|
||||
|
||||
class DefaultGenome(BaseGenome):
|
||||
@@ -31,15 +29,18 @@ class DefaultGenome(BaseGenome):
|
||||
self,
|
||||
num_inputs: int,
|
||||
num_outputs: int,
|
||||
max_nodes=5,
|
||||
max_conns=4,
|
||||
node_gene: BaseNodeGene = DefaultNodeGene(),
|
||||
conn_gene: BaseConnGene = DefaultConnGene(),
|
||||
mutation: BaseMutation = DefaultMutation(),
|
||||
crossover: BaseCrossover = DefaultCrossover(),
|
||||
output_transform: Callable = None,
|
||||
input_transform: Callable = None,
|
||||
max_nodes=50,
|
||||
max_conns=100,
|
||||
node_gene=DefaultNodeGene(),
|
||||
conn_gene=DefaultConnGene(),
|
||||
mutation=DefaultMutation(),
|
||||
crossover=DefaultCrossover(),
|
||||
distance=DefaultDistance(),
|
||||
output_transform=None,
|
||||
input_transform=None,
|
||||
init_hidden_layers=(),
|
||||
):
|
||||
|
||||
super().__init__(
|
||||
num_inputs,
|
||||
num_outputs,
|
||||
@@ -49,22 +50,12 @@ class DefaultGenome(BaseGenome):
|
||||
conn_gene,
|
||||
mutation,
|
||||
crossover,
|
||||
distance,
|
||||
output_transform,
|
||||
input_transform,
|
||||
init_hidden_layers,
|
||||
)
|
||||
|
||||
if input_transform is not None:
|
||||
try:
|
||||
_ = input_transform(np.zeros(num_inputs))
|
||||
except Exception as e:
|
||||
raise ValueError(f"Output transform function failed: {e}")
|
||||
self.input_transform = input_transform
|
||||
|
||||
if output_transform is not None:
|
||||
try:
|
||||
_ = output_transform(np.zeros(num_outputs))
|
||||
except Exception as e:
|
||||
raise ValueError(f"Output transform function failed: {e}")
|
||||
self.output_transform = output_transform
|
||||
|
||||
def transform(self, state, nodes, conns):
|
||||
u_conns = unflatten_conns(nodes, conns)
|
||||
conn_exist = u_conns != I_INF
|
||||
@@ -73,10 +64,6 @@ class DefaultGenome(BaseGenome):
|
||||
|
||||
return seqs, nodes, conns, u_conns
|
||||
|
||||
def restore(self, state, transformed):
|
||||
seqs, nodes, conns, u_conns = transformed
|
||||
return nodes, conns
|
||||
|
||||
def forward(self, state, transformed, inputs):
|
||||
|
||||
if self.input_transform is not None:
|
||||
@@ -86,8 +73,8 @@ class DefaultGenome(BaseGenome):
|
||||
|
||||
ini_vals = jnp.full((self.max_nodes,), jnp.nan)
|
||||
ini_vals = ini_vals.at[self.input_idx].set(inputs)
|
||||
nodes_attrs = jax.vmap(extract_node_attrs)(nodes)
|
||||
conns_attrs = jax.vmap(extract_conn_attrs)(conns)
|
||||
nodes_attrs = vmap(extract_node_attrs)(nodes)
|
||||
conns_attrs = vmap(extract_conn_attrs)(conns)
|
||||
|
||||
def cond_fun(carry):
|
||||
values, idx = carry
|
||||
@@ -105,7 +92,7 @@ class DefaultGenome(BaseGenome):
|
||||
def otherwise():
|
||||
conn_indices = u_conns[:, i]
|
||||
hit_attrs = attach_with_inf(conns_attrs, conn_indices)
|
||||
ins = jax.vmap(self.conn_gene.forward, in_axes=(None, 0, 0))(
|
||||
ins = vmap(self.conn_gene.forward, in_axes=(None, 0, 0))(
|
||||
state, hit_attrs, values
|
||||
)
|
||||
|
||||
@@ -130,85 +117,14 @@ class DefaultGenome(BaseGenome):
|
||||
else:
|
||||
return self.output_transform(vals[self.output_idx])
|
||||
|
||||
def update_by_batch(self, state, batch_input, transformed):
|
||||
|
||||
if self.input_transform is not None:
|
||||
batch_input = jax.vmap(self.input_transform)(batch_input)
|
||||
|
||||
cal_seqs, nodes, conns, u_conns = transformed
|
||||
|
||||
batch_size = batch_input.shape[0]
|
||||
batch_ini_vals = jnp.full((batch_size, self.max_nodes), jnp.nan)
|
||||
batch_ini_vals = batch_ini_vals.at[:, self.input_idx].set(batch_input)
|
||||
nodes_attrs = jax.vmap(extract_node_attrs)(nodes)
|
||||
conns_attrs = jax.vmap(extract_conn_attrs)(conns)
|
||||
|
||||
def cond_fun(carry):
|
||||
batch_values, nodes_attrs_, conns_attrs_, idx = carry
|
||||
return (idx < self.max_nodes) & (cal_seqs[idx] != I_INF)
|
||||
|
||||
def body_func(carry):
|
||||
batch_values, nodes_attrs_, conns_attrs_, idx = carry
|
||||
i = cal_seqs[idx]
|
||||
|
||||
def input_node():
|
||||
batch, new_attrs = self.node_gene.update_input_transform(
|
||||
state, nodes_attrs_[i], batch_values[:, i]
|
||||
)
|
||||
return (
|
||||
batch_values.at[:, i].set(batch),
|
||||
nodes_attrs_.at[i].set(new_attrs),
|
||||
conns_attrs_,
|
||||
)
|
||||
|
||||
def otherwise():
|
||||
|
||||
conn_indices = u_conns[:, i]
|
||||
hit_attrs = attach_with_inf(conns_attrs, conn_indices)
|
||||
batch_ins, new_conn_attrs = jax.vmap(
|
||||
self.conn_gene.update_by_batch,
|
||||
in_axes=(None, 0, 1),
|
||||
out_axes=(1, 0),
|
||||
)(state, hit_attrs, batch_values)
|
||||
|
||||
batch_z, new_node_attrs = self.node_gene.update_by_batch(
|
||||
state,
|
||||
nodes_attrs_[i],
|
||||
batch_ins,
|
||||
is_output_node=jnp.isin(i, self.output_idx),
|
||||
)
|
||||
|
||||
return (
|
||||
batch_values.at[:, i].set(batch_z),
|
||||
nodes_attrs_.at[i].set(new_node_attrs),
|
||||
conns_attrs_.at[conn_indices].set(new_conn_attrs),
|
||||
)
|
||||
|
||||
# the val of input nodes is obtained by the task, not by calculation
|
||||
(batch_values, nodes_attrs_, conns_attrs_) = jax.lax.cond(
|
||||
jnp.isin(i, self.input_idx),
|
||||
input_node,
|
||||
otherwise,
|
||||
)
|
||||
|
||||
return batch_values, nodes_attrs_, conns_attrs_, idx + 1
|
||||
|
||||
batch_vals, nodes_attrs, conns_attrs, _ = jax.lax.while_loop(
|
||||
cond_fun, body_func, (batch_ini_vals, nodes_attrs, conns_attrs, 0)
|
||||
def network_dict(self, state, nodes, conns):
|
||||
network = super().network_dict(state, nodes, conns)
|
||||
topo_order, topo_layers = topological_sort_python(
|
||||
set(network["nodes"]), set(network["conns"])
|
||||
)
|
||||
|
||||
nodes = jax.vmap(set_node_attrs)(nodes, nodes_attrs)
|
||||
conns = jax.vmap(set_conn_attrs)(conns, conns_attrs)
|
||||
|
||||
new_transformed = (cal_seqs, nodes, conns, u_conns)
|
||||
|
||||
if self.output_transform is None:
|
||||
return batch_vals[:, self.output_idx], new_transformed
|
||||
else:
|
||||
return (
|
||||
jax.vmap(self.output_transform)(batch_vals[:, self.output_idx]),
|
||||
new_transformed,
|
||||
)
|
||||
network["topo_order"] = topo_order
|
||||
network["topo_layers"] = topo_layers
|
||||
return network
|
||||
|
||||
def sympy_func(
|
||||
self,
|
||||
@@ -241,7 +157,8 @@ class DefaultGenome(BaseGenome):
|
||||
|
||||
input_idx = self.get_input_idx()
|
||||
output_idx = self.get_output_idx()
|
||||
order, _ = topological_sort_python(set(network["nodes"]), set(network["conns"]))
|
||||
order = network["topo_order"]
|
||||
|
||||
hidden_idx = [
|
||||
i for i in network["nodes"] if i not in input_idx and i not in output_idx
|
||||
]
|
||||
@@ -260,8 +177,12 @@ class DefaultGenome(BaseGenome):
|
||||
for i in order:
|
||||
|
||||
if i in input_idx:
|
||||
nodes_exprs[symbols[-i - 1]] = symbols[-i - 1] # origin equal to its symbol
|
||||
nodes_exprs[symbols[i]] = sympy_input_transform[i - min(input_idx)](symbols[-i - 1]) # normed i
|
||||
nodes_exprs[symbols[-i - 1]] = symbols[
|
||||
-i - 1
|
||||
] # origin equal to its symbol
|
||||
nodes_exprs[symbols[i]] = sympy_input_transform[i - min(input_idx)](
|
||||
symbols[-i - 1]
|
||||
) # normed i
|
||||
|
||||
else:
|
||||
in_conns = [c for c in network["conns"] if c[1] == i]
|
||||
@@ -325,3 +246,73 @@ class DefaultGenome(BaseGenome):
|
||||
output_exprs,
|
||||
forward_func,
|
||||
)
|
||||
|
||||
def visualize(
|
||||
self,
|
||||
network,
|
||||
rotate=0,
|
||||
reverse_node_order=False,
|
||||
size=(300, 300, 300),
|
||||
color=("blue", "blue", "blue"),
|
||||
save_path="network.svg",
|
||||
save_dpi=800,
|
||||
**kwargs,
|
||||
):
|
||||
import networkx as nx
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
nodes_list = list(network["nodes"])
|
||||
conns_list = list(network["conns"])
|
||||
input_idx = self.get_input_idx()
|
||||
output_idx = self.get_output_idx()
|
||||
|
||||
topo_order, topo_layers = network["topo_order"], network["topo_layers"]
|
||||
node2layer = {
|
||||
node: layer for layer, nodes in enumerate(topo_layers) for node in nodes
|
||||
}
|
||||
if reverse_node_order:
|
||||
topo_order = topo_order[::-1]
|
||||
|
||||
G = nx.DiGraph()
|
||||
|
||||
if not isinstance(size, tuple):
|
||||
size = (size, size, size)
|
||||
if not isinstance(color, tuple):
|
||||
color = (color, color, color)
|
||||
|
||||
for node in topo_order:
|
||||
if node in input_idx:
|
||||
G.add_node(node, subset=node2layer[node], size=size[0], color=color[0])
|
||||
elif node in output_idx:
|
||||
G.add_node(node, subset=node2layer[node], size=size[2], color=color[2])
|
||||
else:
|
||||
G.add_node(node, subset=node2layer[node], size=size[1], color=color[1])
|
||||
|
||||
for conn in conns_list:
|
||||
G.add_edge(conn[0], conn[1])
|
||||
pos = nx.multipartite_layout(G)
|
||||
|
||||
def rotate_layout(pos, angle):
|
||||
angle_rad = np.deg2rad(angle)
|
||||
cos_angle, sin_angle = np.cos(angle_rad), np.sin(angle_rad)
|
||||
rotated_pos = {}
|
||||
for node, (x, y) in pos.items():
|
||||
rotated_pos[node] = (
|
||||
cos_angle * x - sin_angle * y,
|
||||
sin_angle * x + cos_angle * y,
|
||||
)
|
||||
return rotated_pos
|
||||
|
||||
rotated_pos = rotate_layout(pos, rotate)
|
||||
|
||||
node_sizes = [n["size"] for n in G.nodes.values()]
|
||||
node_colors = [n["color"] for n in G.nodes.values()]
|
||||
|
||||
nx.draw(
|
||||
G,
|
||||
pos=rotated_pos,
|
||||
node_size=node_sizes,
|
||||
node_color=node_colors,
|
||||
**kwargs,
|
||||
)
|
||||
plt.savefig(save_path, dpi=save_dpi)
|
||||
|
||||
@@ -1,56 +0,0 @@
|
||||
import jax, jax.numpy as jnp
|
||||
from .default import DefaultGenome
|
||||
|
||||
|
||||
class DenseInitialize(DefaultGenome):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
assert self.max_nodes >= self.num_inputs + self.num_outputs
|
||||
assert self.max_conns >= self.num_inputs * self.num_outputs
|
||||
|
||||
def initialize(self, state, randkey):
|
||||
|
||||
k1, k2 = jax.random.split(randkey, num=2)
|
||||
|
||||
input_idx, output_idx = self.input_idx, self.output_idx
|
||||
input_size = len(input_idx)
|
||||
output_size = len(output_idx)
|
||||
|
||||
nodes = jnp.full(
|
||||
(self.max_nodes, self.node_gene.length), jnp.nan, dtype=jnp.float32
|
||||
)
|
||||
|
||||
nodes = nodes.at[input_idx, 0].set(input_idx)
|
||||
nodes = nodes.at[output_idx, 0].set(output_idx)
|
||||
|
||||
total_idx = input_size + output_size
|
||||
rand_keys_n = jax.random.split(k1, num=total_idx)
|
||||
|
||||
node_attr_func = jax.vmap(self.node_gene.new_random_attrs, in_axes=(None, 0))
|
||||
node_attrs = node_attr_func(state, rand_keys_n)
|
||||
nodes = nodes.at[:total_idx, 1:].set(node_attrs)
|
||||
|
||||
conns = jnp.full(
|
||||
(self.max_conns, self.conn_gene.length), jnp.nan, dtype=jnp.float32
|
||||
)
|
||||
|
||||
input_to_output_ids, output_ids = jnp.meshgrid(
|
||||
input_idx, output_idx, indexing="ij"
|
||||
)
|
||||
total_conns = input_size * output_size
|
||||
conns = conns.at[:total_conns, :2].set(
|
||||
jnp.column_stack([input_to_output_ids.flatten(), output_ids.flatten()])
|
||||
)
|
||||
|
||||
rand_keys_c = jax.random.split(k2, num=total_conns)
|
||||
conns_attr_func = jax.vmap(
|
||||
self.conn_gene.new_random_attrs,
|
||||
in_axes=(
|
||||
None,
|
||||
0,
|
||||
),
|
||||
)
|
||||
conns_attrs = conns_attr_func(state, rand_keys_c)
|
||||
conns = conns.at[:total_conns, 2:].set(conns_attrs)
|
||||
|
||||
return nodes, conns
|
||||
@@ -1,70 +0,0 @@
|
||||
import jax, jax.numpy as jnp
|
||||
from .default import DefaultGenome
|
||||
|
||||
|
||||
class HiddenInitialize(DefaultGenome):
|
||||
def __init__(self, hidden_cnt=8, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.hidden_cnt = hidden_cnt
|
||||
|
||||
def initialize(self, state, randkey):
|
||||
|
||||
k1, k2 = jax.random.split(randkey, num=2)
|
||||
|
||||
input_idx, output_idx = self.input_idx, self.output_idx
|
||||
input_size = len(input_idx)
|
||||
output_size = len(output_idx)
|
||||
|
||||
hidden_idx = jnp.arange(
|
||||
input_size + output_size, input_size + output_size + self.hidden_cnt
|
||||
)
|
||||
nodes = jnp.full(
|
||||
(self.max_nodes, self.node_gene.length), jnp.nan, dtype=jnp.float32
|
||||
)
|
||||
|
||||
nodes = nodes.at[input_idx, 0].set(input_idx)
|
||||
nodes = nodes.at[output_idx, 0].set(output_idx)
|
||||
nodes = nodes.at[hidden_idx, 0].set(hidden_idx)
|
||||
|
||||
total_idx = input_size + output_size + self.hidden_cnt
|
||||
rand_keys_n = jax.random.split(k1, num=total_idx)
|
||||
|
||||
node_attr_func = jax.vmap(self.node_gene.new_random_attrs, in_axes=(None, 0))
|
||||
node_attrs = node_attr_func(state, rand_keys_n)
|
||||
nodes = nodes.at[:total_idx, 1:].set(node_attrs)
|
||||
|
||||
conns = jnp.full(
|
||||
(self.max_conns, self.conn_gene.length), jnp.nan, dtype=jnp.float32
|
||||
)
|
||||
|
||||
input_to_hidden_ids, hidden_ids = jnp.meshgrid(
|
||||
input_idx, hidden_idx, indexing="ij"
|
||||
)
|
||||
total_input_to_hidden_conns = input_size * self.hidden_cnt
|
||||
conns = conns.at[:total_input_to_hidden_conns, :2].set(
|
||||
jnp.column_stack([input_to_hidden_ids.flatten(), hidden_ids.flatten()])
|
||||
)
|
||||
|
||||
hidden_to_output_ids, output_ids = jnp.meshgrid(
|
||||
hidden_idx, output_idx, indexing="ij"
|
||||
)
|
||||
total_hidden_to_output_conns = self.hidden_cnt * output_size
|
||||
conns = conns.at[
|
||||
total_input_to_hidden_conns : total_input_to_hidden_conns
|
||||
+ total_hidden_to_output_conns,
|
||||
:2,
|
||||
].set(jnp.column_stack([hidden_to_output_ids.flatten(), output_ids.flatten()]))
|
||||
|
||||
total_conns = total_input_to_hidden_conns + total_hidden_to_output_conns
|
||||
rand_keys_c = jax.random.split(k2, num=total_conns)
|
||||
conns_attr_func = jax.vmap(
|
||||
self.conn_gene.new_random_attrs,
|
||||
in_axes=(
|
||||
None,
|
||||
0,
|
||||
),
|
||||
)
|
||||
conns_attrs = conns_attr_func(state, rand_keys_c)
|
||||
conns = conns.at[:total_conns, 2:].set(conns_attrs)
|
||||
|
||||
return nodes, conns
|
||||
@@ -1,2 +1,3 @@
|
||||
from .crossover import BaseCrossover, DefaultCrossover
|
||||
from .mutation import BaseMutation, DefaultMutation
|
||||
from .distance import BaseDistance, DefaultDistance
|
||||
@@ -0,0 +1,12 @@
|
||||
from tensorneat.common import StatefulBaseClass, State
|
||||
|
||||
|
||||
class BaseCrossover(StatefulBaseClass):
|
||||
|
||||
def setup(self, state=State(), genome = None):
|
||||
assert genome is not None, "genome should not be None"
|
||||
self.genome = genome
|
||||
return state
|
||||
|
||||
def __call__(self, state, randkey, nodes1, nodes2, conns1, conns2):
|
||||
raise NotImplementedError
|
||||
@@ -1,7 +1,8 @@
|
||||
import jax, jax.numpy as jnp
|
||||
import jax
|
||||
from jax import vmap, numpy as jnp
|
||||
|
||||
from .base import BaseCrossover
|
||||
from utils.tools import (
|
||||
from ...utils import (
|
||||
extract_node_attrs,
|
||||
extract_conn_attrs,
|
||||
set_node_attrs,
|
||||
@@ -10,14 +11,14 @@ from utils.tools import (
|
||||
|
||||
|
||||
class DefaultCrossover(BaseCrossover):
|
||||
def __call__(self, state, randkey, genome, nodes1, conns1, nodes2, conns2):
|
||||
def __call__(self, state, randkey, nodes1, conns1, nodes2, conns2):
|
||||
"""
|
||||
use genome1 and genome2 to generate a new genome
|
||||
notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
|
||||
"""
|
||||
randkey1, randkey2 = jax.random.split(randkey, 2)
|
||||
randkeys1 = jax.random.split(randkey1, genome.max_nodes)
|
||||
randkeys2 = jax.random.split(randkey2, genome.max_conns)
|
||||
randkeys1 = jax.random.split(randkey1, self.genome.max_nodes)
|
||||
randkeys2 = jax.random.split(randkey2, self.genome.max_conns)
|
||||
|
||||
# crossover nodes
|
||||
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
|
||||
@@ -26,33 +27,33 @@ class DefaultCrossover(BaseCrossover):
|
||||
|
||||
# For not homologous genes, use the value of nodes1(winner)
|
||||
# For homologous genes, use the crossover result between nodes1 and nodes2
|
||||
node_attrs1 = jax.vmap(extract_node_attrs)(nodes1)
|
||||
node_attrs2 = jax.vmap(extract_node_attrs)(nodes2)
|
||||
node_attrs1 = vmap(extract_node_attrs)(nodes1)
|
||||
node_attrs2 = vmap(extract_node_attrs)(nodes2)
|
||||
|
||||
new_node_attrs = jnp.where(
|
||||
jnp.isnan(node_attrs1) | jnp.isnan(node_attrs2), # one of them is nan
|
||||
node_attrs1, # not homologous genes or both nan, use the value of nodes1(winner)
|
||||
jax.vmap(genome.node_gene.crossover, in_axes=(None, 0, 0, 0))(
|
||||
vmap(self.genome.node_gene.crossover, in_axes=(None, 0, 0, 0))(
|
||||
state, randkeys1, node_attrs1, node_attrs2
|
||||
), # homologous or both nan
|
||||
)
|
||||
new_nodes = jax.vmap(set_node_attrs)(nodes1, new_node_attrs)
|
||||
new_nodes = vmap(set_node_attrs)(nodes1, new_node_attrs)
|
||||
|
||||
# crossover connections
|
||||
con_keys1, con_keys2 = conns1[:, :2], conns2[:, :2]
|
||||
conns2 = self.align_array(con_keys1, con_keys2, conns2, is_conn=True)
|
||||
|
||||
conns_attrs1 = jax.vmap(extract_conn_attrs)(conns1)
|
||||
conns_attrs2 = jax.vmap(extract_conn_attrs)(conns2)
|
||||
conns_attrs1 = vmap(extract_conn_attrs)(conns1)
|
||||
conns_attrs2 = vmap(extract_conn_attrs)(conns2)
|
||||
|
||||
new_conn_attrs = jnp.where(
|
||||
jnp.isnan(conns_attrs1) | jnp.isnan(conns_attrs2),
|
||||
conns_attrs1, # not homologous genes or both nan, use the value of conns1(winner)
|
||||
jax.vmap(genome.conn_gene.crossover, in_axes=(None, 0, 0, 0))(
|
||||
vmap(self.genome.conn_gene.crossover, in_axes=(None, 0, 0, 0))(
|
||||
state, randkeys2, conns_attrs1, conns_attrs2
|
||||
), # homologous or both nan
|
||||
)
|
||||
new_conns = jax.vmap(set_conn_attrs)(conns1, new_conn_attrs)
|
||||
new_conns = vmap(set_conn_attrs)(conns1, new_conn_attrs)
|
||||
|
||||
return new_nodes, new_conns
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
from .base import BaseDistance
|
||||
from .default import DefaultDistance
|
||||
15
tensorneat/algorithm/neat/genome/operations/distance/base.py
Normal file
15
tensorneat/algorithm/neat/genome/operations/distance/base.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from tensorneat.common import StatefulBaseClass, State
|
||||
|
||||
|
||||
class BaseDistance(StatefulBaseClass):
|
||||
|
||||
def setup(self, state=State(), genome = None):
|
||||
assert genome is not None, "genome should not be None"
|
||||
self.genome = genome
|
||||
return state
|
||||
|
||||
def __call__(self, state, nodes1, nodes2, conns1, conns2):
|
||||
"""
|
||||
The distance between two genomes
|
||||
"""
|
||||
raise NotImplementedError
|
||||
105
tensorneat/algorithm/neat/genome/operations/distance/default.py
Normal file
105
tensorneat/algorithm/neat/genome/operations/distance/default.py
Normal file
@@ -0,0 +1,105 @@
|
||||
from jax import vmap, numpy as jnp
|
||||
|
||||
from .base import BaseDistance
|
||||
from ...utils import extract_node_attrs, extract_conn_attrs
|
||||
|
||||
|
||||
class DefaultDistance(BaseDistance):
|
||||
def __init__(
|
||||
self,
|
||||
compatibility_disjoint: float = 1.0,
|
||||
compatibility_weight: float = 0.4,
|
||||
):
|
||||
self.compatibility_disjoint = compatibility_disjoint
|
||||
self.compatibility_weight = compatibility_weight
|
||||
|
||||
def __call__(self, state, nodes1, nodes2, conns1, conns2):
|
||||
"""
|
||||
The distance between two genomes
|
||||
"""
|
||||
d = self.node_distance(state, nodes1, nodes2) + self.conn_distance(
|
||||
state, conns1, conns2
|
||||
)
|
||||
return d
|
||||
|
||||
def node_distance(self, state, nodes1, nodes2):
|
||||
"""
|
||||
The distance of the nodes part for two genomes
|
||||
"""
|
||||
node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0]))
|
||||
node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0]))
|
||||
max_cnt = jnp.maximum(node_cnt1, node_cnt2)
|
||||
|
||||
# align homologous nodes
|
||||
# this process is similar to np.intersect1d.
|
||||
nodes = jnp.concatenate((nodes1, nodes2), axis=0)
|
||||
keys = nodes[:, 0]
|
||||
sorted_indices = jnp.argsort(keys, axis=0)
|
||||
nodes = nodes[sorted_indices]
|
||||
nodes = jnp.concatenate(
|
||||
[nodes, jnp.full((1, nodes.shape[1]), jnp.nan)], axis=0
|
||||
) # add a nan row to the end
|
||||
fr, sr = nodes[:-1], nodes[1:] # first row, second row
|
||||
|
||||
# flag location of homologous nodes
|
||||
intersect_mask = (fr[:, 0] == sr[:, 0]) & ~jnp.isnan(nodes[:-1, 0])
|
||||
|
||||
# calculate the count of non_homologous of two genomes
|
||||
non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask)
|
||||
|
||||
# calculate the distance of homologous nodes
|
||||
fr_attrs = vmap(extract_node_attrs)(fr)
|
||||
sr_attrs = vmap(extract_node_attrs)(sr)
|
||||
hnd = vmap(self.genome.node_gene.distance, in_axes=(None, 0, 0))(
|
||||
state, fr_attrs, sr_attrs
|
||||
) # homologous node distance
|
||||
hnd = jnp.where(jnp.isnan(hnd), 0, hnd)
|
||||
homologous_distance = jnp.sum(hnd * intersect_mask)
|
||||
|
||||
val = (
|
||||
non_homologous_cnt * self.compatibility_disjoint
|
||||
+ homologous_distance * self.compatibility_weight
|
||||
)
|
||||
|
||||
val = jnp.where(max_cnt == 0, 0, val / max_cnt) # normalize
|
||||
|
||||
return val
|
||||
|
||||
def conn_distance(self, state, conns1, conns2):
|
||||
"""
|
||||
The distance of the conns part for two genomes
|
||||
"""
|
||||
con_cnt1 = jnp.sum(~jnp.isnan(conns1[:, 0]))
|
||||
con_cnt2 = jnp.sum(~jnp.isnan(conns2[:, 0]))
|
||||
max_cnt = jnp.maximum(con_cnt1, con_cnt2)
|
||||
|
||||
cons = jnp.concatenate((conns1, conns2), axis=0)
|
||||
keys = cons[:, :2]
|
||||
sorted_indices = jnp.lexsort(keys.T[::-1])
|
||||
cons = cons[sorted_indices]
|
||||
cons = jnp.concatenate(
|
||||
[cons, jnp.full((1, cons.shape[1]), jnp.nan)], axis=0
|
||||
) # add a nan row to the end
|
||||
fr, sr = cons[:-1], cons[1:] # first row, second row
|
||||
|
||||
# both genome has such connection
|
||||
intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0])
|
||||
|
||||
non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask)
|
||||
|
||||
fr_attrs = vmap(extract_conn_attrs)(fr)
|
||||
sr_attrs = vmap(extract_conn_attrs)(sr)
|
||||
hcd = vmap(self.genome.conn_gene.distance, in_axes=(None, 0, 0))(
|
||||
state, fr_attrs, sr_attrs
|
||||
) # homologous connection distance
|
||||
hcd = jnp.where(jnp.isnan(hcd), 0, hcd)
|
||||
homologous_distance = jnp.sum(hcd * intersect_mask)
|
||||
|
||||
val = (
|
||||
non_homologous_cnt * self.compatibility_disjoint
|
||||
+ homologous_distance * self.compatibility_weight
|
||||
)
|
||||
|
||||
val = jnp.where(max_cnt == 0, 0, val / max_cnt) # normalize
|
||||
|
||||
return val
|
||||
12
tensorneat/algorithm/neat/genome/operations/mutation/base.py
Normal file
12
tensorneat/algorithm/neat/genome/operations/mutation/base.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from tensorneat.common import StatefulBaseClass, State
|
||||
|
||||
|
||||
class BaseMutation(StatefulBaseClass):
|
||||
|
||||
def setup(self, state=State(), genome = None):
|
||||
assert genome is not None, "genome should not be None"
|
||||
self.genome = genome
|
||||
return state
|
||||
|
||||
def __call__(self, state, randkey, genome, nodes, conns, new_node_key):
|
||||
raise NotImplementedError
|
||||
@@ -1,11 +1,14 @@
|
||||
import jax, jax.numpy as jnp
|
||||
import jax
|
||||
from jax import vmap, numpy as jnp
|
||||
from . import BaseMutation
|
||||
from utils import (
|
||||
from tensorneat.common import (
|
||||
fetch_first,
|
||||
fetch_random,
|
||||
I_INF,
|
||||
unflatten_conns,
|
||||
check_cycles,
|
||||
)
|
||||
from ...utils import (
|
||||
unflatten_conns,
|
||||
add_node,
|
||||
add_conn,
|
||||
delete_node_by_pos,
|
||||
@@ -225,17 +228,17 @@ class DefaultMutation(BaseMutation):
|
||||
nodes_randkeys = jax.random.split(k1, num=genome.max_nodes)
|
||||
conns_randkeys = jax.random.split(k2, num=genome.max_conns)
|
||||
|
||||
node_attrs = jax.vmap(extract_node_attrs)(nodes)
|
||||
new_node_attrs = jax.vmap(genome.node_gene.mutate, in_axes=(None, 0, 0))(
|
||||
node_attrs = vmap(extract_node_attrs)(nodes)
|
||||
new_node_attrs = vmap(genome.node_gene.mutate, in_axes=(None, 0, 0))(
|
||||
state, nodes_randkeys, node_attrs
|
||||
)
|
||||
new_nodes = jax.vmap(set_node_attrs)(nodes, new_node_attrs)
|
||||
new_nodes = vmap(set_node_attrs)(nodes, new_node_attrs)
|
||||
|
||||
conn_attrs = jax.vmap(extract_conn_attrs)(conns)
|
||||
new_conn_attrs = jax.vmap(genome.conn_gene.mutate, in_axes=(None, 0, 0))(
|
||||
conn_attrs = vmap(extract_conn_attrs)(conns)
|
||||
new_conn_attrs = vmap(genome.conn_gene.mutate, in_axes=(None, 0, 0))(
|
||||
state, conns_randkeys, conn_attrs
|
||||
)
|
||||
new_conns = jax.vmap(set_conn_attrs)(conns, new_conn_attrs)
|
||||
new_conns = vmap(set_conn_attrs)(conns, new_conn_attrs)
|
||||
|
||||
# nan nodes not changed
|
||||
new_nodes = jnp.where(jnp.isnan(nodes), jnp.nan, new_nodes)
|
||||
@@ -1,11 +1,11 @@
|
||||
from typing import Callable
|
||||
|
||||
import jax, jax.numpy as jnp
|
||||
from utils import unflatten_conns
|
||||
from .utils import unflatten_conns
|
||||
|
||||
from . import BaseGenome
|
||||
from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene
|
||||
from ..ga import BaseMutation, BaseCrossover, DefaultMutation, DefaultCrossover
|
||||
from ..gene import DefaultNodeGene, DefaultConnGene
|
||||
from .operations import DefaultMutation, DefaultCrossover
|
||||
|
||||
|
||||
class RecurrentGenome(BaseGenome):
|
||||
@@ -17,13 +17,13 @@ class RecurrentGenome(BaseGenome):
|
||||
self,
|
||||
num_inputs: int,
|
||||
num_outputs: int,
|
||||
max_nodes: int,
|
||||
max_conns: int,
|
||||
node_gene: BaseNodeGene = DefaultNodeGene(),
|
||||
conn_gene: BaseConnGene = DefaultConnGene(),
|
||||
mutation: BaseMutation = DefaultMutation(),
|
||||
crossover: BaseCrossover = DefaultCrossover(),
|
||||
activate_time: int = 10,
|
||||
max_nodes = 50,
|
||||
max_conns = 100,
|
||||
node_gene=DefaultNodeGene(),
|
||||
conn_gene=DefaultConnGene(),
|
||||
mutation=DefaultMutation(),
|
||||
crossover=DefaultCrossover(),
|
||||
activate_time=10,
|
||||
output_transform: Callable = None,
|
||||
):
|
||||
super().__init__(
|
||||
|
||||
109
tensorneat/algorithm/neat/genome/utils.py
Normal file
109
tensorneat/algorithm/neat/genome/utils.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import jax
|
||||
from jax import vmap, numpy as jnp
|
||||
|
||||
from tensorneat.common import fetch_first, I_INF
|
||||
|
||||
|
||||
def unflatten_conns(nodes, conns):
|
||||
"""
|
||||
transform the (C, CL) connections to (N, N), which contains the idx of the connection in conns
|
||||
connection length, N means the number of nodes, C means the number of connections
|
||||
returns the unflatten connection indices with shape (N, N)
|
||||
"""
|
||||
N = nodes.shape[0] # max_nodes
|
||||
C = conns.shape[0] # max_conns
|
||||
node_keys = nodes[:, 0]
|
||||
i_keys, o_keys = conns[:, 0], conns[:, 1]
|
||||
|
||||
def key_to_indices(key, keys):
|
||||
return fetch_first(key == keys)
|
||||
|
||||
i_idxs = vmap(key_to_indices, in_axes=(0, None))(i_keys, node_keys)
|
||||
o_idxs = vmap(key_to_indices, in_axes=(0, None))(o_keys, node_keys)
|
||||
|
||||
# Is interesting that jax use clip when attach data in array
|
||||
# however, it will do nothing when setting values in an array
|
||||
# put the index of connections in the unflatten array
|
||||
unflatten = (
|
||||
jnp.full((N, N), I_INF, dtype=jnp.int32)
|
||||
.at[i_idxs, o_idxs]
|
||||
.set(jnp.arange(C, dtype=jnp.int32))
|
||||
)
|
||||
|
||||
return unflatten
|
||||
|
||||
|
||||
def valid_cnt(nodes_or_conns):
|
||||
return jnp.sum(~jnp.isnan(nodes_or_conns[:, 0]))
|
||||
|
||||
|
||||
def extract_node_attrs(node):
|
||||
"""
|
||||
node: Array(NL, )
|
||||
extract the attributes of a node
|
||||
"""
|
||||
return node[1:] # 0 is for idx
|
||||
|
||||
|
||||
def set_node_attrs(node, attrs):
|
||||
"""
|
||||
node: Array(NL, )
|
||||
attrs: Array(NL-1, )
|
||||
set the attributes of a node
|
||||
"""
|
||||
return node.at[1:].set(attrs) # 0 is for idx
|
||||
|
||||
|
||||
def extract_conn_attrs(conn):
|
||||
"""
|
||||
conn: Array(CL, )
|
||||
extract the attributes of a connection
|
||||
"""
|
||||
return conn[2:] # 0, 1 is for in-idx and out-idx
|
||||
|
||||
|
||||
def set_conn_attrs(conn, attrs):
|
||||
"""
|
||||
conn: Array(CL, )
|
||||
attrs: Array(CL-2, )
|
||||
set the attributes of a connection
|
||||
"""
|
||||
return conn.at[2:].set(attrs) # 0, 1 is for in-idx and out-idx
|
||||
|
||||
|
||||
def add_node(nodes, new_key: int, attrs):
|
||||
"""
|
||||
Add a new node to the genome.
|
||||
The new node will place at the first NaN row.
|
||||
"""
|
||||
exist_keys = nodes[:, 0]
|
||||
pos = fetch_first(jnp.isnan(exist_keys))
|
||||
new_nodes = nodes.at[pos, 0].set(new_key)
|
||||
return new_nodes.at[pos, 1:].set(attrs)
|
||||
|
||||
|
||||
def delete_node_by_pos(nodes, pos):
|
||||
"""
|
||||
Delete a node from the genome.
|
||||
Delete the node by its pos in nodes.
|
||||
"""
|
||||
return nodes.at[pos].set(jnp.nan)
|
||||
|
||||
|
||||
def add_conn(conns, i_key, o_key, attrs):
|
||||
"""
|
||||
Add a new connection to the genome.
|
||||
The new connection will place at the first NaN row.
|
||||
"""
|
||||
con_keys = conns[:, 0]
|
||||
pos = fetch_first(jnp.isnan(con_keys))
|
||||
new_conns = conns.at[pos, 0:2].set(jnp.array([i_key, o_key]))
|
||||
return new_conns.at[pos, 2:].set(attrs)
|
||||
|
||||
|
||||
def delete_conn_by_pos(conns, pos):
|
||||
"""
|
||||
Delete a connection from the genome.
|
||||
Delete the connection by its idx.
|
||||
"""
|
||||
return conns.at[pos].set(jnp.nan)
|
||||
@@ -1,5 +1,5 @@
|
||||
import jax, jax.numpy as jnp
|
||||
from utils import State
|
||||
from tensorneat.common import State
|
||||
from .. import BaseAlgorithm
|
||||
from .species import *
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from utils import State, StatefulBaseClass
|
||||
from tensorneat.common import State, StatefulBaseClass
|
||||
from ..genome import BaseGenome
|
||||
|
||||
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import jax, jax.numpy as jnp
|
||||
from utils import (
|
||||
from tensorneat.common import (
|
||||
State,
|
||||
rank_elements,
|
||||
argmin_with_mask,
|
||||
fetch_first,
|
||||
)
|
||||
from ..genome.utils import (
|
||||
extract_conn_attrs,
|
||||
extract_node_attrs,
|
||||
)
|
||||
@@ -635,7 +637,9 @@ class DefaultSpecies(BaseSpecies):
|
||||
|
||||
# find next node key
|
||||
all_nodes_keys = state.pop_nodes[:, :, 0]
|
||||
max_node_key = jnp.max(all_nodes_keys, where=~jnp.isnan(all_nodes_keys), initial=0)
|
||||
max_node_key = jnp.max(
|
||||
all_nodes_keys, where=~jnp.isnan(all_nodes_keys), initial=0
|
||||
)
|
||||
next_node_key = max_node_key + 1
|
||||
new_node_keys = jnp.arange(self.pop_size) + next_node_key
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from utils.aggregation.agg_jnp import Agg, agg_func, AGG_ALL
|
||||
from tensorneat.common.aggregation.agg_jnp import Agg, agg_func, AGG_ALL
|
||||
from .tools import *
|
||||
from .graph import *
|
||||
from .state import State
|
||||
@@ -6,36 +6,6 @@ from jax import numpy as jnp, Array, jit, vmap
|
||||
|
||||
I_INF = np.iinfo(jnp.int32).max # infinite int
|
||||
|
||||
|
||||
def unflatten_conns(nodes, conns):
|
||||
"""
|
||||
transform the (C, CL) connections to (N, N), which contains the idx of the connection in conns
|
||||
connection length, N means the number of nodes, C means the number of connections
|
||||
returns the unflatten connection indices with shape (N, N)
|
||||
"""
|
||||
N = nodes.shape[0] # max_nodes
|
||||
C = conns.shape[0] # max_conns
|
||||
node_keys = nodes[:, 0]
|
||||
i_keys, o_keys = conns[:, 0], conns[:, 1]
|
||||
|
||||
def key_to_indices(key, keys):
|
||||
return fetch_first(key == keys)
|
||||
|
||||
i_idxs = vmap(key_to_indices, in_axes=(0, None))(i_keys, node_keys)
|
||||
o_idxs = vmap(key_to_indices, in_axes=(0, None))(o_keys, node_keys)
|
||||
|
||||
# Is interesting that jax use clip when attach data in array
|
||||
# however, it will do nothing when setting values in an array
|
||||
# put the index of connections in the unflatten array
|
||||
unflatten = (
|
||||
jnp.full((N, N), I_INF, dtype=jnp.int32)
|
||||
.at[i_idxs, o_idxs]
|
||||
.set(jnp.arange(C, dtype=jnp.int32))
|
||||
)
|
||||
|
||||
return unflatten
|
||||
|
||||
|
||||
# TODO: strange implementation
|
||||
def attach_with_inf(arr, idx):
|
||||
expand_size = arr.ndim - idx.ndim
|
||||
@@ -45,40 +15,6 @@ def attach_with_inf(arr, idx):
|
||||
return jnp.where(expand_idx == I_INF, jnp.nan, arr[idx])
|
||||
|
||||
|
||||
def extract_node_attrs(node):
|
||||
"""
|
||||
node: Array(NL, )
|
||||
extract the attributes of a node
|
||||
"""
|
||||
return node[1:] # 0 is for idx
|
||||
|
||||
|
||||
def set_node_attrs(node, attrs):
|
||||
"""
|
||||
node: Array(NL, )
|
||||
attrs: Array(NL-1, )
|
||||
set the attributes of a node
|
||||
"""
|
||||
return node.at[1:].set(attrs) # 0 is for idx
|
||||
|
||||
|
||||
def extract_conn_attrs(conn):
|
||||
"""
|
||||
conn: Array(CL, )
|
||||
extract the attributes of a connection
|
||||
"""
|
||||
return conn[2:] # 0, 1 is for in-idx and out-idx
|
||||
|
||||
|
||||
def set_conn_attrs(conn, attrs):
|
||||
"""
|
||||
conn: Array(CL, )
|
||||
attrs: Array(CL-2, )
|
||||
set the attributes of a connection
|
||||
"""
|
||||
return conn.at[2:].set(attrs) # 0, 1 is for in-idx and out-idx
|
||||
|
||||
|
||||
@jit
|
||||
def fetch_first(mask, default=I_INF) -> Array:
|
||||
"""
|
||||
@@ -164,44 +100,6 @@ def argmin_with_mask(arr, mask):
|
||||
return min_idx
|
||||
|
||||
|
||||
def add_node(nodes, new_key: int, attrs):
|
||||
"""
|
||||
Add a new node to the genome.
|
||||
The new node will place at the first NaN row.
|
||||
"""
|
||||
exist_keys = nodes[:, 0]
|
||||
pos = fetch_first(jnp.isnan(exist_keys))
|
||||
new_nodes = nodes.at[pos, 0].set(new_key)
|
||||
return new_nodes.at[pos, 1:].set(attrs)
|
||||
|
||||
|
||||
def delete_node_by_pos(nodes, pos):
|
||||
"""
|
||||
Delete a node from the genome.
|
||||
Delete the node by its pos in nodes.
|
||||
"""
|
||||
return nodes.at[pos].set(jnp.nan)
|
||||
|
||||
|
||||
def add_conn(conns, i_key, o_key, attrs):
|
||||
"""
|
||||
Add a new connection to the genome.
|
||||
The new connection will place at the first NaN row.
|
||||
"""
|
||||
con_keys = conns[:, 0]
|
||||
pos = fetch_first(jnp.isnan(con_keys))
|
||||
new_conns = conns.at[pos, 0:2].set(jnp.array([i_key, o_key]))
|
||||
return new_conns.at[pos, 2:].set(attrs)
|
||||
|
||||
|
||||
def delete_conn_by_pos(conns, pos):
|
||||
"""
|
||||
Delete a connection from the genome.
|
||||
Delete the connection by its idx.
|
||||
"""
|
||||
return conns.at[pos].set(jnp.nan)
|
||||
|
||||
|
||||
def hash_array(arr: Array):
|
||||
arr = jax.lax.bitcast_convert_type(arr, jnp.uint32)
|
||||
|
||||
@@ -9,7 +9,7 @@ from algorithm import BaseAlgorithm
|
||||
from problem import BaseProblem
|
||||
from problem.rl_env import RLEnv
|
||||
from problem.func_fit import FuncFit
|
||||
from utils import State, StatefulBaseClass
|
||||
from tensorneat.common import State, StatefulBaseClass
|
||||
|
||||
|
||||
class Pipeline(StatefulBaseClass):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Callable
|
||||
|
||||
from utils import State, StatefulBaseClass
|
||||
from tensorneat.common import State, StatefulBaseClass
|
||||
|
||||
|
||||
class BaseProblem(StatefulBaseClass):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from utils import State
|
||||
from tensorneat.common import State
|
||||
from .. import BaseProblem
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import jax, jax.numpy as jnp
|
||||
import jumanji
|
||||
|
||||
from utils import State
|
||||
from tensorneat.common import State
|
||||
from ..rl_jit import RLEnv
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
from utils import State
|
||||
from tensorneat.common import State
|
||||
from .. import BaseProblem
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import jax, jax.numpy as jnp
|
||||
from utils import Act
|
||||
from tensorneat.common import Act
|
||||
from algorithm.neat import *
|
||||
import numpy as np
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import jax, jax.numpy as jnp
|
||||
from utils import Act
|
||||
from tensorneat.common import Act
|
||||
from algorithm.neat import *
|
||||
import numpy as np
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@
|
||||
"from algorithm.neat.gene.node.kan_node import KANNode\n",
|
||||
"from algorithm.neat.gene.conn.bspline import BSplineConn\n",
|
||||
"from problem.func_fit import XOR3d\n",
|
||||
"from utils import Act\n",
|
||||
"from tensorneat.utils import Act\n",
|
||||
"\n",
|
||||
"import jax, jax.numpy as jnp\n",
|
||||
"\n",
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import jax, jax.numpy as jnp
|
||||
from utils import Act
|
||||
from tensorneat.common import Act
|
||||
from algorithm.neat import *
|
||||
import numpy as np
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import jax, jax.numpy as jnp\n",
|
||||
"from utils import State\n",
|
||||
"from tensorneat.utils import State\n",
|
||||
"from problem.rl_env import BraxEnv\n",
|
||||
"\n",
|
||||
"\n",
|
||||
|
||||
@@ -145,7 +145,7 @@
|
||||
"source": [
|
||||
"from algorithm.neat.gene.node.normalized import NormalizedNode\n",
|
||||
"from algorithm.neat.gene.conn import DefaultConnGene\n",
|
||||
"from utils import Act\n",
|
||||
"from tensorneat.utils import Act\n",
|
||||
"\n",
|
||||
"genome = DefaultGenome(num_inputs=3, num_outputs=2, max_nodes=10, max_conns=10,\n",
|
||||
" node_gene=NormalizedNode(activation_default=Act.identity, activation_options=(Act.identity,)),\n",
|
||||
|
||||
Reference in New Issue
Block a user