add License and pyproject.toml
This commit is contained in:
3
src/tensorneat/algorithm/__init__.py
Normal file
3
src/tensorneat/algorithm/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .base import BaseAlgorithm
|
||||
from .neat import NEAT
|
||||
from .hyperneat import HyperNEAT
|
||||
30
src/tensorneat/algorithm/base.py
Normal file
30
src/tensorneat/algorithm/base.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from tensorneat.common import State, StatefulBaseClass
|
||||
|
||||
|
||||
class BaseAlgorithm(StatefulBaseClass):
|
||||
def ask(self, state: State):
|
||||
"""require the population to be evaluated"""
|
||||
raise NotImplementedError
|
||||
|
||||
def tell(self, state: State, fitness):
|
||||
"""update the state of the algorithm"""
|
||||
raise NotImplementedError
|
||||
|
||||
def transform(self, state, individual):
|
||||
"""transform the genome into a neural network"""
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, state, transformed, inputs):
|
||||
raise NotImplementedError
|
||||
|
||||
def show_details(self, state: State, fitness):
|
||||
"""Visualize the running details of the algorithm"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def num_inputs(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def num_outputs(self):
|
||||
raise NotImplementedError
|
||||
2
src/tensorneat/algorithm/hyperneat/__init__.py
Normal file
2
src/tensorneat/algorithm/hyperneat/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .hyperneat import HyperNEAT
|
||||
from .substrate import BaseSubstrate, DefaultSubstrate, FullSubstrate
|
||||
125
src/tensorneat/algorithm/hyperneat/hyperneat.py
Normal file
125
src/tensorneat/algorithm/hyperneat/hyperneat.py
Normal file
@@ -0,0 +1,125 @@
|
||||
from typing import Callable
|
||||
|
||||
import jax
|
||||
from jax import vmap, numpy as jnp
|
||||
|
||||
from .substrate import *
|
||||
from tensorneat.common import State, Act, Agg
|
||||
from tensorneat.algorithm import BaseAlgorithm, NEAT
|
||||
from tensorneat.genome import BaseNode, BaseConn, RecurrentGenome
|
||||
|
||||
|
||||
class HyperNEAT(BaseAlgorithm):
|
||||
def __init__(
|
||||
self,
|
||||
substrate: BaseSubstrate,
|
||||
neat: NEAT,
|
||||
weight_threshold: float = 0.3,
|
||||
max_weight: float = 5.0,
|
||||
aggregation: Callable = Agg.sum,
|
||||
activation: Callable = Act.sigmoid,
|
||||
activate_time: int = 10,
|
||||
output_transform: Callable = Act.standard_sigmoid,
|
||||
):
|
||||
assert (
|
||||
substrate.query_coors.shape[1] == neat.num_inputs
|
||||
), "Query coors of Substrate should be equal to NEAT input size"
|
||||
|
||||
self.substrate = substrate
|
||||
self.neat = neat
|
||||
self.weight_threshold = weight_threshold
|
||||
self.max_weight = max_weight
|
||||
self.hyper_genome = RecurrentGenome(
|
||||
num_inputs=substrate.num_inputs,
|
||||
num_outputs=substrate.num_outputs,
|
||||
max_nodes=substrate.nodes_cnt,
|
||||
max_conns=substrate.conns_cnt,
|
||||
node_gene=HyperNEATNode(aggregation, activation),
|
||||
conn_gene=HyperNEATConn(),
|
||||
activate_time=activate_time,
|
||||
output_transform=output_transform,
|
||||
)
|
||||
self.pop_size = neat.pop_size
|
||||
|
||||
def setup(self, state=State()):
|
||||
state = self.neat.setup(state)
|
||||
state = self.substrate.setup(state)
|
||||
return self.hyper_genome.setup(state)
|
||||
|
||||
def ask(self, state):
|
||||
return self.neat.ask(state)
|
||||
|
||||
def tell(self, state, fitness):
|
||||
state = self.neat.tell(state, fitness)
|
||||
return state
|
||||
|
||||
def transform(self, state, individual):
|
||||
transformed = self.neat.transform(state, individual)
|
||||
query_res = vmap(self.neat.forward, in_axes=(None, None, 0))(
|
||||
state, transformed, self.substrate.query_coors
|
||||
)
|
||||
# mute the connection with weight weight threshold
|
||||
query_res = jnp.where(
|
||||
(-self.weight_threshold < query_res) & (query_res < self.weight_threshold),
|
||||
0.0,
|
||||
query_res,
|
||||
)
|
||||
|
||||
# make query res in range [-max_weight, max_weight]
|
||||
query_res = jnp.where(
|
||||
query_res > 0, query_res - self.weight_threshold, query_res
|
||||
)
|
||||
query_res = jnp.where(
|
||||
query_res < 0, query_res + self.weight_threshold, query_res
|
||||
)
|
||||
query_res = query_res / (1 - self.weight_threshold) * self.max_weight
|
||||
|
||||
h_nodes, h_conns = self.substrate.make_nodes(
|
||||
query_res
|
||||
), self.substrate.make_conns(query_res)
|
||||
|
||||
return self.hyper_genome.transform(state, h_nodes, h_conns)
|
||||
|
||||
def forward(self, state, transformed, inputs):
|
||||
# add bias
|
||||
inputs_with_bias = jnp.concatenate([inputs, jnp.array([1])])
|
||||
|
||||
res = self.hyper_genome.forward(state, transformed, inputs_with_bias)
|
||||
return res
|
||||
|
||||
@property
|
||||
def num_inputs(self):
|
||||
return self.substrate.num_inputs - 1 # remove bias
|
||||
|
||||
@property
|
||||
def num_outputs(self):
|
||||
return self.substrate.num_outputs
|
||||
|
||||
def show_details(self, state, fitness):
|
||||
return self.neat.show_details(state, fitness)
|
||||
|
||||
|
||||
class HyperNEATNode(BaseNode):
|
||||
def __init__(
|
||||
self,
|
||||
aggregation=Agg.sum,
|
||||
activation=Act.sigmoid,
|
||||
):
|
||||
super().__init__()
|
||||
self.aggregation = aggregation
|
||||
self.activation = activation
|
||||
|
||||
def forward(self, state, attrs, inputs, is_output_node=False):
|
||||
return jax.lax.cond(
|
||||
is_output_node,
|
||||
lambda: self.aggregation(inputs), # output node does not need activation
|
||||
lambda: self.activation(self.aggregation(inputs)),
|
||||
)
|
||||
|
||||
|
||||
class HyperNEATConn(BaseConn):
|
||||
custom_attrs = ["weight"]
|
||||
|
||||
def forward(self, state, attrs, inputs):
|
||||
weight = attrs[0]
|
||||
return inputs * weight
|
||||
3
src/tensorneat/algorithm/hyperneat/substrate/__init__.py
Normal file
3
src/tensorneat/algorithm/hyperneat/substrate/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .base import BaseSubstrate
|
||||
from .default import DefaultSubstrate
|
||||
from .full import FullSubstrate
|
||||
30
src/tensorneat/algorithm/hyperneat/substrate/base.py
Normal file
30
src/tensorneat/algorithm/hyperneat/substrate/base.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from tensorneat.common import StatefulBaseClass
|
||||
|
||||
|
||||
class BaseSubstrate(StatefulBaseClass):
|
||||
|
||||
def make_nodes(self, query_res):
|
||||
raise NotImplementedError
|
||||
|
||||
def make_conns(self, query_res):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def query_coors(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def num_inputs(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def num_outputs(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def nodes_cnt(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def conns_cnt(self):
|
||||
raise NotImplementedError
|
||||
40
src/tensorneat/algorithm/hyperneat/substrate/default.py
Normal file
40
src/tensorneat/algorithm/hyperneat/substrate/default.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from jax import vmap, numpy as jnp
|
||||
|
||||
from .base import BaseSubstrate
|
||||
from tensorneat.genome.utils import set_conn_attrs
|
||||
|
||||
|
||||
class DefaultSubstrate(BaseSubstrate):
|
||||
def __init__(self, num_inputs, num_outputs, coors, nodes, conns):
|
||||
self.inputs = num_inputs
|
||||
self.outputs = num_outputs
|
||||
self.coors = jnp.array(coors)
|
||||
self.nodes = jnp.array(nodes)
|
||||
self.conns = jnp.array(conns)
|
||||
|
||||
def make_nodes(self, query_res):
|
||||
return self.nodes
|
||||
|
||||
def make_conns(self, query_res):
|
||||
# change weight of conns
|
||||
return vmap(set_conn_attrs)(self.conns, query_res)
|
||||
|
||||
@property
|
||||
def query_coors(self):
|
||||
return self.coors
|
||||
|
||||
@property
|
||||
def num_inputs(self):
|
||||
return self.inputs
|
||||
|
||||
@property
|
||||
def num_outputs(self):
|
||||
return self.outputs
|
||||
|
||||
@property
|
||||
def nodes_cnt(self):
|
||||
return self.nodes.shape[0]
|
||||
|
||||
@property
|
||||
def conns_cnt(self):
|
||||
return self.conns.shape[0]
|
||||
79
src/tensorneat/algorithm/hyperneat/substrate/full.py
Normal file
79
src/tensorneat/algorithm/hyperneat/substrate/full.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import numpy as np
|
||||
from .default import DefaultSubstrate
|
||||
|
||||
|
||||
class FullSubstrate(DefaultSubstrate):
|
||||
def __init__(
|
||||
self,
|
||||
input_coors=((-1, -1), (0, -1), (1, -1)),
|
||||
hidden_coors=((-1, 0), (0, 0), (1, 0)),
|
||||
output_coors=((0, 1),),
|
||||
):
|
||||
query_coors, nodes, conns = analysis_substrate(
|
||||
input_coors, output_coors, hidden_coors
|
||||
)
|
||||
super().__init__(len(input_coors), len(output_coors), query_coors, nodes, conns)
|
||||
|
||||
|
||||
def analysis_substrate(input_coors, output_coors, hidden_coors):
|
||||
input_coors = np.array(input_coors)
|
||||
output_coors = np.array(output_coors)
|
||||
hidden_coors = np.array(hidden_coors)
|
||||
|
||||
cd = input_coors.shape[1] # coordinate dimensions
|
||||
si = input_coors.shape[0] # input coordinate size
|
||||
so = output_coors.shape[0] # output coordinate size
|
||||
sh = hidden_coors.shape[0] # hidden coordinate size
|
||||
|
||||
input_idx = np.arange(si)
|
||||
output_idx = np.arange(si, si + so)
|
||||
hidden_idx = np.arange(si + so, si + so + sh)
|
||||
|
||||
total_conns = si * sh + sh * sh + sh * so
|
||||
query_coors = np.zeros((total_conns, cd * 2))
|
||||
correspond_keys = np.zeros((total_conns, 2))
|
||||
|
||||
# connect input to hidden
|
||||
aux_coors, aux_keys = cartesian_product(
|
||||
input_idx, hidden_idx, input_coors, hidden_coors
|
||||
)
|
||||
query_coors[0 : si * sh, :] = aux_coors
|
||||
correspond_keys[0 : si * sh, :] = aux_keys
|
||||
|
||||
# connect hidden to hidden
|
||||
aux_coors, aux_keys = cartesian_product(
|
||||
hidden_idx, hidden_idx, hidden_coors, hidden_coors
|
||||
)
|
||||
query_coors[si * sh : si * sh + sh * sh, :] = aux_coors
|
||||
correspond_keys[si * sh : si * sh + sh * sh, :] = aux_keys
|
||||
|
||||
# connect hidden to output
|
||||
aux_coors, aux_keys = cartesian_product(
|
||||
hidden_idx, output_idx, hidden_coors, output_coors
|
||||
)
|
||||
query_coors[si * sh + sh * sh :, :] = aux_coors
|
||||
correspond_keys[si * sh + sh * sh :, :] = aux_keys
|
||||
|
||||
nodes = np.concatenate((input_idx, output_idx, hidden_idx))[..., np.newaxis]
|
||||
conns = np.zeros(
|
||||
(correspond_keys.shape[0], 3), dtype=np.float32
|
||||
) # input_idx, output_idx, weight
|
||||
conns[:, :2] = correspond_keys
|
||||
|
||||
return query_coors, nodes, conns
|
||||
|
||||
|
||||
def cartesian_product(keys1, keys2, coors1, coors2):
|
||||
len1 = keys1.shape[0]
|
||||
len2 = keys2.shape[0]
|
||||
|
||||
repeated_coors1 = np.repeat(coors1, len2, axis=0)
|
||||
repeated_keys1 = np.repeat(keys1, len2)
|
||||
|
||||
tiled_coors2 = np.tile(coors2, (len1, 1))
|
||||
tiled_keys2 = np.tile(keys2, len1)
|
||||
|
||||
new_coors = np.concatenate((repeated_coors1, tiled_coors2), axis=1)
|
||||
correspond_keys = np.column_stack((repeated_keys1, tiled_keys2))
|
||||
|
||||
return new_coors, correspond_keys
|
||||
2
src/tensorneat/algorithm/neat/__init__.py
Normal file
2
src/tensorneat/algorithm/neat/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .species import *
|
||||
from .neat import NEAT
|
||||
167
src/tensorneat/algorithm/neat/neat.py
Normal file
167
src/tensorneat/algorithm/neat/neat.py
Normal file
@@ -0,0 +1,167 @@
|
||||
from typing import Callable
|
||||
|
||||
import jax
|
||||
from jax import vmap, numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
from .species import SpeciesController
|
||||
from .. import BaseAlgorithm
|
||||
from tensorneat.common import State
|
||||
from tensorneat.genome import BaseGenome
|
||||
|
||||
|
||||
class NEAT(BaseAlgorithm):
|
||||
def __init__(
|
||||
self,
|
||||
genome: BaseGenome,
|
||||
pop_size: int,
|
||||
species_size: int = 10,
|
||||
max_stagnation: int = 15,
|
||||
species_elitism: int = 2,
|
||||
spawn_number_change_rate: float = 0.5,
|
||||
genome_elitism: int = 2,
|
||||
survival_threshold: float = 0.1,
|
||||
min_species_size: int = 1,
|
||||
compatibility_threshold: float = 2.0,
|
||||
species_fitness_func: Callable = jnp.max,
|
||||
):
|
||||
self.genome = genome
|
||||
self.pop_size = pop_size
|
||||
self.species_controller = SpeciesController(
|
||||
pop_size,
|
||||
species_size,
|
||||
max_stagnation,
|
||||
species_elitism,
|
||||
spawn_number_change_rate,
|
||||
genome_elitism,
|
||||
survival_threshold,
|
||||
min_species_size,
|
||||
compatibility_threshold,
|
||||
species_fitness_func,
|
||||
)
|
||||
|
||||
def setup(self, state=State()):
|
||||
# setup state
|
||||
state = self.genome.setup(state)
|
||||
|
||||
k1, randkey = jax.random.split(state.randkey, 2)
|
||||
|
||||
# initialize the population
|
||||
initialize_keys = jax.random.split(k1, self.pop_size)
|
||||
pop_nodes, pop_conns = vmap(self.genome.initialize, in_axes=(None, 0))(
|
||||
state, initialize_keys
|
||||
)
|
||||
|
||||
state = state.register(
|
||||
pop_nodes=pop_nodes,
|
||||
pop_conns=pop_conns,
|
||||
generation=jnp.float32(0),
|
||||
)
|
||||
|
||||
# initialize species state
|
||||
state = self.species_controller.setup(state, pop_nodes[0], pop_conns[0])
|
||||
|
||||
return state.update(randkey=randkey)
|
||||
|
||||
def ask(self, state):
|
||||
return state.pop_nodes, state.pop_conns
|
||||
|
||||
def tell(self, state, fitness):
|
||||
state = state.update(generation=state.generation + 1)
|
||||
|
||||
# tell fitness to species controller
|
||||
state, winner, loser, elite_mask = self.species_controller.update_species(
|
||||
state,
|
||||
fitness,
|
||||
)
|
||||
|
||||
# create next population
|
||||
state = self._create_next_generation(state, winner, loser, elite_mask)
|
||||
|
||||
# speciate the next population
|
||||
state = self.species_controller.speciate(state, self.genome.execute_distance)
|
||||
|
||||
return state
|
||||
|
||||
def transform(self, state, individual):
|
||||
nodes, conns = individual
|
||||
return self.genome.transform(state, nodes, conns)
|
||||
|
||||
def forward(self, state, transformed, inputs):
|
||||
return self.genome.forward(state, transformed, inputs)
|
||||
|
||||
@property
|
||||
def num_inputs(self):
|
||||
return self.genome.num_inputs
|
||||
|
||||
@property
|
||||
def num_outputs(self):
|
||||
return self.genome.num_outputs
|
||||
|
||||
def _create_next_generation(self, state, winner, loser, elite_mask):
|
||||
|
||||
# find next node key for mutation
|
||||
all_nodes_keys = state.pop_nodes[:, :, 0]
|
||||
max_node_key = jnp.max(
|
||||
all_nodes_keys, where=~jnp.isnan(all_nodes_keys), initial=0
|
||||
)
|
||||
next_node_key = max_node_key + 1
|
||||
new_node_keys = jnp.arange(self.pop_size) + next_node_key
|
||||
|
||||
# prepare random keys
|
||||
k1, k2, randkey = jax.random.split(state.randkey, 3)
|
||||
crossover_randkeys = jax.random.split(k1, self.pop_size)
|
||||
mutate_randkeys = jax.random.split(k2, self.pop_size)
|
||||
|
||||
wpn, wpc = state.pop_nodes[winner], state.pop_conns[winner]
|
||||
lpn, lpc = state.pop_nodes[loser], state.pop_conns[loser]
|
||||
|
||||
# batch crossover
|
||||
n_nodes, n_conns = vmap(
|
||||
self.genome.execute_crossover, in_axes=(None, 0, 0, 0, 0, 0)
|
||||
)(
|
||||
state, crossover_randkeys, wpn, wpc, lpn, lpc
|
||||
) # new_nodes, new_conns
|
||||
|
||||
# batch mutation
|
||||
m_n_nodes, m_n_conns = vmap(
|
||||
self.genome.execute_mutation, in_axes=(None, 0, 0, 0, 0)
|
||||
)(
|
||||
state, mutate_randkeys, n_nodes, n_conns, new_node_keys
|
||||
) # mutated_new_nodes, mutated_new_conns
|
||||
|
||||
# elitism don't mutate
|
||||
pop_nodes = jnp.where(elite_mask[:, None, None], n_nodes, m_n_nodes)
|
||||
pop_conns = jnp.where(elite_mask[:, None, None], n_conns, m_n_conns)
|
||||
|
||||
return state.update(
|
||||
randkey=randkey,
|
||||
pop_nodes=pop_nodes,
|
||||
pop_conns=pop_conns,
|
||||
)
|
||||
|
||||
def show_details(self, state, fitness):
|
||||
member_count = jax.device_get(state.species.member_count)
|
||||
species_sizes = [int(i) for i in member_count if i > 0]
|
||||
|
||||
pop_nodes, pop_conns = jax.device_get([state.pop_nodes, state.pop_conns])
|
||||
nodes_cnt = (~np.isnan(pop_nodes[:, :, 0])).sum(axis=1) # (P,)
|
||||
conns_cnt = (~np.isnan(pop_conns[:, :, 0])).sum(axis=1) # (P,)
|
||||
|
||||
max_node_cnt, min_node_cnt, mean_node_cnt = (
|
||||
max(nodes_cnt),
|
||||
min(nodes_cnt),
|
||||
np.mean(nodes_cnt),
|
||||
)
|
||||
|
||||
max_conn_cnt, min_conn_cnt, mean_conn_cnt = (
|
||||
max(conns_cnt),
|
||||
min(conns_cnt),
|
||||
np.mean(conns_cnt),
|
||||
)
|
||||
|
||||
print(
|
||||
f"\tnode counts: max: {max_node_cnt}, min: {min_node_cnt}, mean: {mean_node_cnt:.2f}\n",
|
||||
f"\tconn counts: max: {max_conn_cnt}, min: {min_conn_cnt}, mean: {mean_conn_cnt:.2f}\n",
|
||||
f"\tspecies: {len(species_sizes)}, {species_sizes}\n",
|
||||
)
|
||||
537
src/tensorneat/algorithm/neat/species.py
Normal file
537
src/tensorneat/algorithm/neat/species.py
Normal file
@@ -0,0 +1,537 @@
|
||||
from typing import Callable
|
||||
|
||||
import jax
|
||||
from jax import vmap, numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
from tensorneat.common import (
|
||||
State,
|
||||
StatefulBaseClass,
|
||||
rank_elements,
|
||||
argmin_with_mask,
|
||||
fetch_first,
|
||||
)
|
||||
|
||||
|
||||
class SpeciesController(StatefulBaseClass):
|
||||
def __init__(
|
||||
self,
|
||||
pop_size,
|
||||
species_size,
|
||||
max_stagnation,
|
||||
species_elitism,
|
||||
spawn_number_change_rate,
|
||||
genome_elitism,
|
||||
survival_threshold,
|
||||
min_species_size,
|
||||
compatibility_threshold,
|
||||
species_fitness_func,
|
||||
):
|
||||
self.pop_size = pop_size
|
||||
self.species_size = species_size
|
||||
self.species_arange = np.arange(self.species_size)
|
||||
self.max_stagnation = max_stagnation
|
||||
self.species_elitism = species_elitism
|
||||
self.spawn_number_change_rate = spawn_number_change_rate
|
||||
self.genome_elitism = genome_elitism
|
||||
self.survival_threshold = survival_threshold
|
||||
self.min_species_size = min_species_size
|
||||
self.compatibility_threshold = compatibility_threshold
|
||||
self.species_fitness_func = species_fitness_func
|
||||
|
||||
def setup(self, state, first_nodes, first_conns):
|
||||
# the unique index (primary key) for each species
|
||||
species_keys = jnp.full((self.species_size,), jnp.nan)
|
||||
|
||||
# the best fitness of each species
|
||||
best_fitness = jnp.full((self.species_size,), jnp.nan)
|
||||
|
||||
# the last 1 that the species improved
|
||||
last_improved = jnp.full((self.species_size,), jnp.nan)
|
||||
|
||||
# the number of members of each species
|
||||
member_count = jnp.full((self.species_size,), jnp.nan)
|
||||
|
||||
# the species index of each individual
|
||||
idx2species = jnp.zeros(self.pop_size)
|
||||
|
||||
# nodes for each center genome of each species
|
||||
center_nodes = jnp.full(
|
||||
(self.species_size, *first_nodes.shape),
|
||||
jnp.nan,
|
||||
)
|
||||
|
||||
# connections for each center genome of each species
|
||||
center_conns = jnp.full(
|
||||
(self.species_size, *first_conns.shape),
|
||||
jnp.nan,
|
||||
)
|
||||
|
||||
species_keys = species_keys.at[0].set(0)
|
||||
best_fitness = best_fitness.at[0].set(-jnp.inf)
|
||||
last_improved = last_improved.at[0].set(0)
|
||||
member_count = member_count.at[0].set(self.pop_size)
|
||||
center_nodes = center_nodes.at[0].set(first_nodes)
|
||||
center_conns = center_conns.at[0].set(first_conns)
|
||||
|
||||
species_state = State(
|
||||
species_keys=species_keys,
|
||||
best_fitness=best_fitness,
|
||||
last_improved=last_improved,
|
||||
member_count=member_count,
|
||||
idx2species=idx2species,
|
||||
center_nodes=center_nodes,
|
||||
center_conns=center_conns,
|
||||
next_species_key=jnp.float32(1), # 0 is reserved for the first species
|
||||
)
|
||||
|
||||
return state.register(species=species_state)
|
||||
|
||||
def update_species(self, state, fitness):
|
||||
species_state = state.species
|
||||
|
||||
# update the fitness of each species
|
||||
species_fitness = self._update_species_fitness(species_state, fitness)
|
||||
|
||||
# stagnation species
|
||||
species_state, species_fitness = self._stagnation(
|
||||
species_state, species_fitness, state.generation
|
||||
)
|
||||
|
||||
# sort species_info by their fitness. (also push nan to the end)
|
||||
sort_indices = jnp.argsort(species_fitness)[::-1] # fitness from high to low
|
||||
|
||||
species_state = species_state.update(
|
||||
species_keys=species_state.species_keys[sort_indices],
|
||||
best_fitness=species_state.best_fitness[sort_indices],
|
||||
last_improved=species_state.last_improved[sort_indices],
|
||||
member_count=species_state.member_count[sort_indices],
|
||||
center_nodes=species_state.center_nodes[sort_indices],
|
||||
center_conns=species_state.center_conns[sort_indices],
|
||||
)
|
||||
|
||||
# decide the number of members of each species by their fitness
|
||||
spawn_number = self._cal_spawn_numbers(species_state)
|
||||
|
||||
k1, k2 = jax.random.split(state.randkey)
|
||||
# crossover info
|
||||
winner, loser, elite_mask = self._create_crossover_pair(
|
||||
species_state, k1, spawn_number, fitness
|
||||
)
|
||||
|
||||
return (
|
||||
state.update(randkey=k2, species=species_state),
|
||||
winner,
|
||||
loser,
|
||||
elite_mask,
|
||||
)
|
||||
|
||||
def _update_species_fitness(self, species_state, fitness):
|
||||
"""
|
||||
obtain the fitness of the species by the fitness of each individual.
|
||||
use max criterion.
|
||||
"""
|
||||
|
||||
def aux_func(idx):
|
||||
s_fitness = jnp.where(
|
||||
species_state.idx2species == species_state.species_keys[idx],
|
||||
fitness,
|
||||
-jnp.inf,
|
||||
)
|
||||
val = self.species_fitness_func(s_fitness)
|
||||
return val
|
||||
|
||||
return vmap(aux_func)(self.species_arange)
|
||||
|
||||
def _stagnation(self, species_state, species_fitness, generation):
|
||||
"""
|
||||
stagnation species.
|
||||
those species whose fitness is not better than the best fitness of the species for a long time will be stagnation.
|
||||
elitism species never stagnation
|
||||
"""
|
||||
|
||||
def check_stagnation(idx):
|
||||
# determine whether the species stagnation
|
||||
|
||||
# not better than the best fitness of the species
|
||||
# for a long time
|
||||
st = (species_fitness[idx] <= species_state.best_fitness[idx]) & (
|
||||
generation - species_state.last_improved[idx] > self.max_stagnation
|
||||
)
|
||||
|
||||
# update last_improved and best_fitness
|
||||
# whether better than the best fitness of the species
|
||||
li, bf = jax.lax.cond(
|
||||
species_fitness[idx] > species_state.best_fitness[idx],
|
||||
lambda: (generation, species_fitness[idx]), # update
|
||||
lambda: (
|
||||
species_state.last_improved[idx],
|
||||
species_state.best_fitness[idx],
|
||||
), # not update
|
||||
)
|
||||
|
||||
return st, bf, li
|
||||
|
||||
spe_st, best_fitness, last_improved = vmap(check_stagnation)(
|
||||
self.species_arange
|
||||
)
|
||||
|
||||
# update species state
|
||||
species_state = species_state.update(
|
||||
best_fitness=best_fitness,
|
||||
last_improved=last_improved,
|
||||
)
|
||||
|
||||
# elite species will not be stagnation
|
||||
species_rank = rank_elements(species_fitness)
|
||||
spe_st = jnp.where(
|
||||
species_rank < self.species_elitism, False, spe_st
|
||||
) # elitism never stagnation
|
||||
|
||||
# set stagnation species to nan
|
||||
def update_func(idx):
|
||||
return jax.lax.cond(
|
||||
spe_st[idx],
|
||||
lambda: (
|
||||
jnp.nan, # species_key
|
||||
jnp.nan, # best_fitness
|
||||
jnp.nan, # last_improved
|
||||
jnp.nan, # member_count
|
||||
jnp.full_like(species_state.center_nodes[idx], jnp.nan),
|
||||
jnp.full_like(species_state.center_conns[idx], jnp.nan),
|
||||
-jnp.inf, # species_fitness
|
||||
), # stagnation species
|
||||
lambda: (
|
||||
species_state.species_keys[idx],
|
||||
species_state.best_fitness[idx],
|
||||
species_state.last_improved[idx],
|
||||
species_state.member_count[idx],
|
||||
species_state.center_nodes[idx],
|
||||
species_state.center_conns[idx],
|
||||
species_fitness[idx],
|
||||
), # not stagnation species
|
||||
)
|
||||
|
||||
(
|
||||
species_keys,
|
||||
best_fitness,
|
||||
last_improved,
|
||||
member_count,
|
||||
center_nodes,
|
||||
center_conns,
|
||||
species_fitness,
|
||||
) = vmap(update_func)(self.species_arange)
|
||||
|
||||
return (
|
||||
species_state.update(
|
||||
species_keys=species_keys,
|
||||
best_fitness=best_fitness,
|
||||
last_improved=last_improved,
|
||||
member_count=member_count,
|
||||
center_nodes=center_nodes,
|
||||
center_conns=center_conns,
|
||||
),
|
||||
species_fitness,
|
||||
)
|
||||
|
||||
def _cal_spawn_numbers(self, species_state):
|
||||
"""
|
||||
decide the number of members of each species by their fitness rank.
|
||||
the species with higher fitness will have more members
|
||||
Linear ranking selection
|
||||
e.g. N = 3, P=10 -> probability = [0.5, 0.33, 0.17], spawn_number = [5, 3, 2]
|
||||
"""
|
||||
|
||||
species_keys = species_state.species_keys
|
||||
|
||||
is_species_valid = ~jnp.isnan(species_keys)
|
||||
valid_species_num = jnp.sum(is_species_valid)
|
||||
denominator = (
|
||||
(valid_species_num + 1) * valid_species_num / 2
|
||||
) # obtain 3 + 2 + 1 = 6
|
||||
|
||||
rank_score = valid_species_num - self.species_arange # obtain [3, 2, 1]
|
||||
spawn_number_rate = rank_score / denominator # obtain [0.5, 0.33, 0.17]
|
||||
spawn_number_rate = jnp.where(
|
||||
is_species_valid, spawn_number_rate, 0
|
||||
) # set invalid species to 0
|
||||
|
||||
target_spawn_number = jnp.floor(
|
||||
spawn_number_rate * self.pop_size
|
||||
) # calculate member
|
||||
|
||||
# Avoid too much variation of numbers for a species
|
||||
previous_size = species_state.member_count
|
||||
spawn_number = (
|
||||
previous_size
|
||||
+ (target_spawn_number - previous_size) * self.spawn_number_change_rate
|
||||
)
|
||||
spawn_number = spawn_number.astype(jnp.int32)
|
||||
|
||||
# must control the sum of spawn_number to be equal to pop_size
|
||||
error = self.pop_size - jnp.sum(spawn_number)
|
||||
|
||||
# add error to the first species to control the sum of spawn_number
|
||||
spawn_number = spawn_number.at[0].add(error)
|
||||
|
||||
return spawn_number
|
||||
|
||||
def _create_crossover_pair(self, species_state, randkey, spawn_number, fitness):
|
||||
s_idx = self.species_arange
|
||||
p_idx = jnp.arange(self.pop_size)
|
||||
|
||||
def aux_func(key, idx):
|
||||
# choose parents from the in the same species
|
||||
# key -> randkey, idx -> the idx of current species
|
||||
|
||||
members = species_state.idx2species == species_state.species_keys[idx]
|
||||
members_num = jnp.sum(members)
|
||||
|
||||
members_fitness = jnp.where(members, fitness, -jnp.inf)
|
||||
sorted_member_indices = jnp.argsort(members_fitness)[::-1]
|
||||
|
||||
survive_size = jnp.floor(self.survival_threshold * members_num).astype(
|
||||
jnp.int32
|
||||
)
|
||||
|
||||
select_pro = (p_idx < survive_size) / survive_size
|
||||
fa, ma = jax.random.choice(
|
||||
key,
|
||||
sorted_member_indices,
|
||||
shape=(2, self.pop_size),
|
||||
replace=True,
|
||||
p=select_pro,
|
||||
)
|
||||
|
||||
# elite
|
||||
fa = jnp.where(p_idx < self.genome_elitism, sorted_member_indices, fa)
|
||||
ma = jnp.where(p_idx < self.genome_elitism, sorted_member_indices, ma)
|
||||
elite = jnp.where(p_idx < self.genome_elitism, True, False)
|
||||
return fa, ma, elite
|
||||
|
||||
# choose parents to crossover in each species
|
||||
# fas, mas, elites: (self.species_size, self.pop_size)
|
||||
# fas -> father indices, mas -> mother indices, elites -> whether elite or not
|
||||
fas, mas, elites = vmap(aux_func)(
|
||||
jax.random.split(randkey, self.species_size), s_idx
|
||||
)
|
||||
|
||||
# merge choosen parents from each species into one array
|
||||
# winner, loser, elite_mask: (self.pop_size)
|
||||
# winner -> winner indices, loser -> loser indices, elite_mask -> whether elite or not
|
||||
spawn_number_cum = jnp.cumsum(spawn_number)
|
||||
|
||||
def aux_func(idx):
|
||||
loc = jnp.argmax(idx < spawn_number_cum)
|
||||
|
||||
# elite genomes are at the beginning of the species
|
||||
idx_in_species = jnp.where(loc > 0, idx - spawn_number_cum[loc - 1], idx)
|
||||
return (
|
||||
fas[loc, idx_in_species],
|
||||
mas[loc, idx_in_species],
|
||||
elites[loc, idx_in_species],
|
||||
)
|
||||
|
||||
part1, part2, elite_mask = vmap(aux_func)(p_idx)
|
||||
|
||||
is_part1_win = fitness[part1] >= fitness[part2]
|
||||
winner = jnp.where(is_part1_win, part1, part2)
|
||||
loser = jnp.where(is_part1_win, part2, part1)
|
||||
|
||||
return winner, loser, elite_mask
|
||||
|
||||
def speciate(self, state, genome_distance_func: Callable):
|
||||
# prepare distance functions
|
||||
o2p_distance_func = vmap(
|
||||
genome_distance_func, in_axes=(None, None, None, 0, 0)
|
||||
) # one to population
|
||||
|
||||
# idx to specie key
|
||||
idx2species = jnp.full(
|
||||
(self.pop_size,), jnp.nan
|
||||
) # NaN means not assigned to any species
|
||||
|
||||
# the distance between genomes to its center genomes
|
||||
o2c_distances = jnp.full((self.pop_size,), jnp.inf)
|
||||
|
||||
# step 1: find new centers
|
||||
def cond_func(carry):
|
||||
# i, idx2species, center_nodes, center_conns, o2c_distances
|
||||
i, i2s, cns, ccs, o2c = carry
|
||||
|
||||
return (i < self.species_size) & (
|
||||
~jnp.isnan(state.species.species_keys[i])
|
||||
) # current species is existing
|
||||
|
||||
def body_func(carry):
|
||||
i, i2s, cns, ccs, o2c = carry
|
||||
|
||||
distances = o2p_distance_func(
|
||||
state, cns[i], ccs[i], state.pop_nodes, state.pop_conns
|
||||
)
|
||||
|
||||
# find the closest one
|
||||
closest_idx = argmin_with_mask(distances, mask=jnp.isnan(i2s))
|
||||
|
||||
i2s = i2s.at[closest_idx].set(state.species.species_keys[i])
|
||||
cns = cns.at[i].set(state.pop_nodes[closest_idx])
|
||||
ccs = ccs.at[i].set(state.pop_conns[closest_idx])
|
||||
|
||||
# the genome with closest_idx will become the new center, thus its distance to center is 0.
|
||||
o2c = o2c.at[closest_idx].set(0)
|
||||
|
||||
return i + 1, i2s, cns, ccs, o2c
|
||||
|
||||
_, idx2species, center_nodes, center_conns, o2c_distances = jax.lax.while_loop(
|
||||
cond_func,
|
||||
body_func,
|
||||
(
|
||||
0,
|
||||
idx2species,
|
||||
state.species.center_nodes,
|
||||
state.species.center_conns,
|
||||
o2c_distances,
|
||||
),
|
||||
)
|
||||
|
||||
state = state.update(
|
||||
species=state.species.update(
|
||||
idx2species=idx2species,
|
||||
center_nodes=center_nodes,
|
||||
center_conns=center_conns,
|
||||
),
|
||||
)
|
||||
|
||||
# part 2: assign members to each species
|
||||
def cond_func(carry):
|
||||
# i, idx2species, center_nodes, center_conns, species_keys, o2c_distances, next_species_key
|
||||
i, i2s, cns, ccs, sk, o2c, nsk = carry
|
||||
|
||||
current_species_existed = ~jnp.isnan(sk[i])
|
||||
not_all_assigned = jnp.any(jnp.isnan(i2s))
|
||||
not_reach_species_upper_bounds = i < self.species_size
|
||||
return not_reach_species_upper_bounds & (
|
||||
current_species_existed | not_all_assigned
|
||||
)
|
||||
|
||||
def body_func(carry):
|
||||
i, i2s, cns, ccs, sk, o2c, nsk = carry
|
||||
|
||||
_, i2s, cns, ccs, sk, o2c, nsk = jax.lax.cond(
|
||||
jnp.isnan(sk[i]), # whether the current species is existing or not
|
||||
create_new_species, # if not existing, create a new specie
|
||||
update_exist_specie, # if existing, update the specie
|
||||
(i, i2s, cns, ccs, sk, o2c, nsk),
|
||||
)
|
||||
|
||||
return i + 1, i2s, cns, ccs, sk, o2c, nsk
|
||||
|
||||
def create_new_species(carry):
|
||||
i, i2s, cns, ccs, sk, o2c, nsk = carry
|
||||
|
||||
# pick the first one who has not been assigned to any species
|
||||
idx = fetch_first(jnp.isnan(i2s))
|
||||
|
||||
# assign it to the new species
|
||||
# [key, best score, last update generation, member_count]
|
||||
sk = sk.at[i].set(nsk) # nsk -> next species key
|
||||
i2s = i2s.at[idx].set(nsk)
|
||||
o2c = o2c.at[idx].set(0)
|
||||
|
||||
# update center genomes
|
||||
cns = cns.at[i].set(state.pop_nodes[idx])
|
||||
ccs = ccs.at[i].set(state.pop_conns[idx])
|
||||
|
||||
# find the members for the new species
|
||||
i2s, o2c = speciate_by_threshold(i, i2s, cns, ccs, sk, o2c)
|
||||
|
||||
return i, i2s, cns, ccs, sk, o2c, nsk + 1 # change to next new speciate key
|
||||
|
||||
def update_exist_specie(carry):
|
||||
i, i2s, cns, ccs, sk, o2c, nsk = carry
|
||||
|
||||
i2s, o2c = speciate_by_threshold(i, i2s, cns, ccs, sk, o2c)
|
||||
|
||||
# turn to next species
|
||||
return i + 1, i2s, cns, ccs, sk, o2c, nsk
|
||||
|
||||
def speciate_by_threshold(i, i2s, cns, ccs, sk, o2c):
|
||||
# distance between such center genome and ppo genomes
|
||||
o2p_distance = o2p_distance_func(
|
||||
state, cns[i], ccs[i], state.pop_nodes, state.pop_conns
|
||||
)
|
||||
|
||||
close_enough_mask = o2p_distance < self.compatibility_threshold
|
||||
# when a genome is not assigned or the distance between its current center is bigger than this center
|
||||
catchable_mask = jnp.isnan(i2s) | (o2p_distance < o2c)
|
||||
|
||||
mask = close_enough_mask & catchable_mask
|
||||
|
||||
# update species info
|
||||
i2s = jnp.where(mask, sk[i], i2s)
|
||||
|
||||
# update distance between centers
|
||||
o2c = jnp.where(mask, o2p_distance, o2c)
|
||||
|
||||
return i2s, o2c
|
||||
|
||||
# update idx2species
|
||||
(
|
||||
_,
|
||||
idx2species,
|
||||
center_nodes,
|
||||
center_conns,
|
||||
species_keys,
|
||||
_,
|
||||
next_species_key,
|
||||
) = jax.lax.while_loop(
|
||||
cond_func,
|
||||
body_func,
|
||||
(
|
||||
0,
|
||||
state.species.idx2species,
|
||||
center_nodes,
|
||||
center_conns,
|
||||
state.species.species_keys,
|
||||
o2c_distances,
|
||||
state.species.next_species_key,
|
||||
),
|
||||
)
|
||||
|
||||
# if there are still some pop genomes not assigned to any species, add them to the last genome
|
||||
# this condition can only happen when the number of species is reached species upper bounds
|
||||
idx2species = jnp.where(jnp.isnan(idx2species), species_keys[-1], idx2species)
|
||||
|
||||
# complete info of species which is created in this generation
|
||||
new_created_mask = (~jnp.isnan(species_keys)) & jnp.isnan(state.species.best_fitness)
|
||||
best_fitness = jnp.where(new_created_mask, -jnp.inf, state.species.best_fitness)
|
||||
last_improved = jnp.where(
|
||||
new_created_mask, state.generation, state.species.last_improved
|
||||
)
|
||||
|
||||
# update members count
|
||||
def count_members(idx):
|
||||
return jax.lax.cond(
|
||||
jnp.isnan(species_keys[idx]), # if the species is not existing
|
||||
lambda: jnp.nan, # nan
|
||||
lambda: jnp.sum(
|
||||
idx2species == species_keys[idx], dtype=jnp.float32
|
||||
), # count members
|
||||
)
|
||||
|
||||
member_count = vmap(count_members)(self.species_arange)
|
||||
|
||||
species_state = state.species.update(
|
||||
species_keys=species_keys,
|
||||
best_fitness=best_fitness,
|
||||
last_improved=last_improved,
|
||||
member_count=member_count,
|
||||
idx2species=idx2species,
|
||||
center_nodes=center_nodes,
|
||||
center_conns=center_conns,
|
||||
next_species_key=next_species_key,
|
||||
)
|
||||
|
||||
return state.update(
|
||||
species=species_state,
|
||||
)
|
||||
56
src/tensorneat/common/__init__.py
Normal file
56
src/tensorneat/common/__init__.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from tensorneat.common.aggregation.agg_jnp import Agg, agg_func, AGG_ALL
|
||||
from .tools import *
|
||||
from .graph import *
|
||||
from .state import State
|
||||
from .stateful_class import StatefulBaseClass
|
||||
|
||||
from .aggregation.agg_jnp import Agg, AGG_ALL, agg_func
|
||||
from .activation.act_jnp import Act, ACT_ALL, act_func
|
||||
from .aggregation.agg_sympy import *
|
||||
from .activation.act_sympy import *
|
||||
|
||||
from typing import Callable, Union
|
||||
|
||||
name2sympy = {
|
||||
"sigmoid": SympySigmoid,
|
||||
"standard_sigmoid": SympyStandardSigmoid,
|
||||
"tanh": SympyTanh,
|
||||
"standard_tanh": SympyStandardTanh,
|
||||
"sin": SympySin,
|
||||
"relu": SympyRelu,
|
||||
"lelu": SympyLelu,
|
||||
"identity": SympyIdentity,
|
||||
"inv": SympyInv,
|
||||
"log": SympyLog,
|
||||
"exp": SympyExp,
|
||||
"abs": SympyAbs,
|
||||
"sum": SympySum,
|
||||
"product": SympyProduct,
|
||||
"max": SympyMax,
|
||||
"min": SympyMin,
|
||||
"maxabs": SympyMaxabs,
|
||||
"mean": SympyMean,
|
||||
"clip": SympyClip,
|
||||
"square": SympySquare,
|
||||
}
|
||||
|
||||
|
||||
def convert_to_sympy(func: Union[str, Callable]):
|
||||
if isinstance(func, str):
|
||||
name = func
|
||||
else:
|
||||
name = func.__name__
|
||||
if name in name2sympy:
|
||||
return name2sympy[name]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Can not convert to sympy! Function {name} not found in name2sympy"
|
||||
)
|
||||
|
||||
|
||||
SYMPY_FUNCS_MODULE_NP = {}
|
||||
SYMPY_FUNCS_MODULE_JNP = {}
|
||||
for cls in name2sympy.values():
|
||||
if hasattr(cls, "numerical_eval"):
|
||||
SYMPY_FUNCS_MODULE_NP[cls.__name__] = cls.numerical_eval
|
||||
SYMPY_FUNCS_MODULE_JNP[cls.__name__] = partial(cls.numerical_eval, backend=jnp)
|
||||
0
src/tensorneat/common/activation/__init__.py
Normal file
0
src/tensorneat/common/activation/__init__.py
Normal file
110
src/tensorneat/common/activation/act_jnp.py
Normal file
110
src/tensorneat/common/activation/act_jnp.py
Normal file
@@ -0,0 +1,110 @@
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
sigma_3 = 2.576
|
||||
|
||||
|
||||
class Act:
|
||||
@staticmethod
|
||||
def name2func(name):
|
||||
return getattr(Act, name)
|
||||
|
||||
@staticmethod
|
||||
def sigmoid(z):
|
||||
z = 5 * z / sigma_3
|
||||
z = 1 / (1 + jnp.exp(-z))
|
||||
|
||||
return z * sigma_3 # (0, sigma_3)
|
||||
|
||||
@staticmethod
|
||||
def standard_sigmoid(z):
|
||||
z = 5 * z / sigma_3
|
||||
z = 1 / (1 + jnp.exp(-z))
|
||||
|
||||
return z # (0, 1)
|
||||
|
||||
@staticmethod
|
||||
def tanh(z):
|
||||
z = 5 * z / sigma_3
|
||||
return jnp.tanh(z) * sigma_3 # (-sigma_3, sigma_3)
|
||||
|
||||
@staticmethod
|
||||
def standard_tanh(z):
|
||||
z = 5 * z / sigma_3
|
||||
return jnp.tanh(z) # (-1, 1)
|
||||
|
||||
@staticmethod
|
||||
def sin(z):
|
||||
z = jnp.clip(jnp.pi / 2 * z / sigma_3, -jnp.pi / 2, jnp.pi / 2)
|
||||
return jnp.sin(z) * sigma_3 # (-sigma_3, sigma_3)
|
||||
|
||||
@staticmethod
|
||||
def relu(z):
|
||||
z = jnp.clip(z, -sigma_3, sigma_3)
|
||||
return jnp.maximum(z, 0) # (0, sigma_3)
|
||||
|
||||
@staticmethod
|
||||
def lelu(z):
|
||||
leaky = 0.005
|
||||
z = jnp.clip(z, -sigma_3, sigma_3)
|
||||
return jnp.where(z > 0, z, leaky * z)
|
||||
|
||||
@staticmethod
|
||||
def identity(z):
|
||||
return z
|
||||
|
||||
@staticmethod
|
||||
def inv(z):
|
||||
z = jnp.where(z > 0, jnp.maximum(z, 1e-7), jnp.minimum(z, -1e-7))
|
||||
return 1 / z
|
||||
|
||||
@staticmethod
|
||||
def log(z):
|
||||
z = jnp.maximum(z, 1e-7)
|
||||
return jnp.log(z)
|
||||
|
||||
@staticmethod
|
||||
def exp(z):
|
||||
z = jnp.clip(z, -10, 10)
|
||||
return jnp.exp(z)
|
||||
|
||||
@staticmethod
|
||||
def square(z):
|
||||
return jnp.pow(z, 2)
|
||||
|
||||
@staticmethod
|
||||
def abs(z):
|
||||
z = jnp.clip(z, -1, 1)
|
||||
return jnp.abs(z)
|
||||
|
||||
|
||||
ACT_ALL = (
|
||||
Act.sigmoid,
|
||||
Act.tanh,
|
||||
Act.sin,
|
||||
Act.relu,
|
||||
Act.lelu,
|
||||
Act.identity,
|
||||
Act.inv,
|
||||
Act.log,
|
||||
Act.exp,
|
||||
Act.abs,
|
||||
)
|
||||
|
||||
|
||||
def act_func(idx, z, act_funcs):
|
||||
"""
|
||||
calculate activation function for each node
|
||||
"""
|
||||
idx = jnp.asarray(idx, dtype=jnp.int32)
|
||||
# change idx from float to int
|
||||
|
||||
# -1 means identity activation
|
||||
res = jax.lax.cond(
|
||||
idx == -1,
|
||||
lambda: z,
|
||||
lambda: jax.lax.switch(idx, act_funcs, z),
|
||||
)
|
||||
|
||||
return res
|
||||
196
src/tensorneat/common/activation/act_sympy.py
Normal file
196
src/tensorneat/common/activation/act_sympy.py
Normal file
@@ -0,0 +1,196 @@
|
||||
import sympy as sp
|
||||
import numpy as np
|
||||
|
||||
|
||||
sigma_3 = 2.576
|
||||
|
||||
|
||||
class SympyClip(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, val, min_val, max_val):
|
||||
if val.is_Number and min_val.is_Number and max_val.is_Number:
|
||||
return sp.Piecewise(
|
||||
(min_val, val < min_val), (max_val, val > max_val), (val, True)
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(val, min_val, max_val, backend=np):
|
||||
return backend.clip(val, min_val, max_val)
|
||||
|
||||
def _sympystr(self, printer):
|
||||
return f"clip({self.args[0]}, {self.args[1]}, {self.args[2]})"
|
||||
|
||||
def _latex(self, printer):
|
||||
return rf"\mathrm{{clip}}\left({sp.latex(self.args[0])}, {self.args[1]}, {self.args[2]}\right)"
|
||||
|
||||
|
||||
class SympySigmoid_(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
z = 1 / (1 + sp.exp(-z))
|
||||
return z
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(z, backend=np):
|
||||
z = 1 / (1 + backend.exp(-z))
|
||||
return z
|
||||
|
||||
def _sympystr(self, printer):
|
||||
return f"sigmoid({self.args[0]})"
|
||||
|
||||
def _latex(self, printer):
|
||||
return rf"\mathrm{{sigmoid}}\left({sp.latex(self.args[0])}\right)"
|
||||
|
||||
|
||||
class SympySigmoid(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return SympySigmoid_(5 * z / sigma_3) * sigma_3
|
||||
|
||||
|
||||
class SympyStandardSigmoid(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return SympySigmoid_(5 * z / sigma_3)
|
||||
|
||||
|
||||
class SympyTanh(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
z = 5 * z / sigma_3
|
||||
return sp.tanh(z) * sigma_3
|
||||
|
||||
|
||||
class SympyStandardTanh(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
z = 5 * z / sigma_3
|
||||
return sp.tanh(z)
|
||||
|
||||
|
||||
class SympySin(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
if z.is_Number:
|
||||
z = SympyClip(sp.pi / 2 * z / sigma_3, -sp.pi / 2, sp.pi / 2)
|
||||
return sp.sin(z) * sigma_3 # (-sigma_3, sigma_3)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(z, backend=np):
|
||||
z = backend.clip(backend.pi / 2 * z / sigma_3, -backend.pi / 2, backend.pi / 2)
|
||||
return backend.sin(z) * sigma_3 # (-sigma_3, sigma_3)
|
||||
|
||||
|
||||
class SympyRelu(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
if z.is_Number:
|
||||
z = SympyClip(z, -sigma_3, sigma_3)
|
||||
return sp.Max(z, 0) # (0, sigma_3)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(z, backend=np):
|
||||
z = backend.clip(z, -sigma_3, sigma_3)
|
||||
return backend.maximum(z, 0) # (0, sigma_3)
|
||||
|
||||
def _sympystr(self, printer):
|
||||
return f"relu({self.args[0]})"
|
||||
|
||||
def _latex(self, printer):
|
||||
return rf"\mathrm{{relu}}\left({sp.latex(self.args[0])}\right)"
|
||||
|
||||
|
||||
class SympyLelu(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
if z.is_Number:
|
||||
leaky = 0.005
|
||||
return sp.Piecewise((z, z > 0), (leaky * z, True))
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(z, backend=np):
|
||||
leaky = 0.005
|
||||
return backend.maximum(z, leaky * z)
|
||||
|
||||
def _sympystr(self, printer):
|
||||
return f"lelu({self.args[0]})"
|
||||
|
||||
def _latex(self, printer):
|
||||
return rf"\mathrm{{lelu}}\left({sp.latex(self.args[0])}\right)"
|
||||
|
||||
|
||||
class SympyIdentity(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return z
|
||||
|
||||
|
||||
class SympyInv(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
if z.is_Number:
|
||||
z = sp.Piecewise((sp.Max(z, 1e-7), z > 0), (sp.Min(z, -1e-7), True))
|
||||
return 1 / z
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(z, backend=np):
|
||||
z = backend.maximum(z, 1e-7)
|
||||
return 1 / z
|
||||
|
||||
def _sympystr(self, printer):
|
||||
return f"1 / {self.args[0]}"
|
||||
|
||||
def _latex(self, printer):
|
||||
return rf"\frac{{1}}{{{sp.latex(self.args[0])}}}"
|
||||
|
||||
|
||||
class SympyLog(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
if z.is_Number:
|
||||
z = sp.Max(z, 1e-7)
|
||||
return sp.log(z)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def numerical_eval(z, backend=np):
|
||||
z = backend.maximum(z, 1e-7)
|
||||
return backend.log(z)
|
||||
|
||||
def _sympystr(self, printer):
|
||||
return f"log({self.args[0]})"
|
||||
|
||||
def _latex(self, printer):
|
||||
return rf"\mathrm{{log}}\left({sp.latex(self.args[0])}\right)"
|
||||
|
||||
|
||||
class SympyExp(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
if z.is_Number:
|
||||
z = SympyClip(z, -10, 10)
|
||||
return sp.exp(z)
|
||||
return None
|
||||
|
||||
def _sympystr(self, printer):
|
||||
return f"exp({self.args[0]})"
|
||||
|
||||
def _latex(self, printer):
|
||||
return rf"\mathrm{{exp}}\left({sp.latex(self.args[0])}\right)"
|
||||
|
||||
|
||||
class SympySquare(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return sp.Pow(z, 2)
|
||||
|
||||
|
||||
class SympyAbs(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return sp.Abs(z)
|
||||
0
src/tensorneat/common/aggregation/__init__.py
Normal file
0
src/tensorneat/common/aggregation/__init__.py
Normal file
66
src/tensorneat/common/aggregation/agg_jnp.py
Normal file
66
src/tensorneat/common/aggregation/agg_jnp.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
class Agg:
|
||||
@staticmethod
|
||||
def name2func(name):
|
||||
return getattr(Agg, name)
|
||||
|
||||
@staticmethod
|
||||
def sum(z):
|
||||
return jnp.sum(z, axis=0, where=~jnp.isnan(z), initial=0)
|
||||
|
||||
@staticmethod
|
||||
def product(z):
|
||||
return jnp.prod(z, axis=0, where=~jnp.isnan(z), initial=1)
|
||||
|
||||
@staticmethod
|
||||
def max(z):
|
||||
return jnp.max(z, axis=0, where=~jnp.isnan(z), initial=-jnp.inf)
|
||||
|
||||
@staticmethod
|
||||
def min(z):
|
||||
return jnp.min(z, axis=0, where=~jnp.isnan(z), initial=jnp.inf)
|
||||
|
||||
@staticmethod
|
||||
def maxabs(z):
|
||||
z = jnp.where(jnp.isnan(z), 0, z)
|
||||
abs_z = jnp.abs(z)
|
||||
max_abs_index = jnp.argmax(abs_z)
|
||||
return z[max_abs_index]
|
||||
|
||||
@staticmethod
|
||||
def median(z):
|
||||
n = jnp.sum(~jnp.isnan(z), axis=0)
|
||||
|
||||
z = jnp.sort(z) # sort
|
||||
|
||||
idx1, idx2 = (n - 1) // 2, n // 2
|
||||
median = (z[idx1] + z[idx2]) / 2
|
||||
|
||||
return median
|
||||
|
||||
@staticmethod
|
||||
def mean(z):
|
||||
aux = jnp.where(jnp.isnan(z), 0, z)
|
||||
valid_values_sum = jnp.sum(aux, axis=0)
|
||||
valid_values_count = jnp.sum(~jnp.isnan(z), axis=0)
|
||||
mean_without_zeros = valid_values_sum / valid_values_count
|
||||
return mean_without_zeros
|
||||
|
||||
|
||||
AGG_ALL = (Agg.sum, Agg.product, Agg.max, Agg.min, Agg.maxabs, Agg.median, Agg.mean)
|
||||
|
||||
|
||||
def agg_func(idx, z, agg_funcs):
|
||||
"""
|
||||
calculate activation function for inputs of node
|
||||
"""
|
||||
idx = jnp.asarray(idx, dtype=jnp.int32)
|
||||
|
||||
return jax.lax.cond(
|
||||
jnp.all(jnp.isnan(z)),
|
||||
lambda: jnp.nan, # all inputs are nan
|
||||
lambda: jax.lax.switch(idx, agg_funcs, z), # otherwise
|
||||
)
|
||||
65
src/tensorneat/common/aggregation/agg_sympy.py
Normal file
65
src/tensorneat/common/aggregation/agg_sympy.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import numpy as np
|
||||
import sympy as sp
|
||||
|
||||
|
||||
class SympySum(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return sp.Add(*z)
|
||||
|
||||
@classmethod
|
||||
def numerical_eval(cls, z, backend=np):
|
||||
return backend.sum(z)
|
||||
|
||||
|
||||
class SympyProduct(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return sp.Mul(*z)
|
||||
|
||||
|
||||
class SympyMax(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return sp.Max(*z)
|
||||
|
||||
|
||||
class SympyMin(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return sp.Min(*z)
|
||||
|
||||
|
||||
class SympyMaxabs(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return sp.Max(*z, key=sp.Abs)
|
||||
|
||||
|
||||
class SympyMean(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return sp.Add(*z) / len(z)
|
||||
|
||||
|
||||
class SympyMedian(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, args):
|
||||
|
||||
if all(arg.is_number for arg in args):
|
||||
sorted_args = sorted(args)
|
||||
n = len(sorted_args)
|
||||
if n % 2 == 1:
|
||||
return sorted_args[n // 2]
|
||||
else:
|
||||
return (sorted_args[n // 2 - 1] + sorted_args[n // 2]) / 2
|
||||
|
||||
return None
|
||||
|
||||
def _sympystr(self, printer):
|
||||
return f"median({', '.join(map(str, self.args))})"
|
||||
|
||||
def _latex(self, printer):
|
||||
return (
|
||||
r"\mathrm{median}\left(" + ", ".join(map(sp.latex, self.args)) + r"\right)"
|
||||
)
|
||||
123
src/tensorneat/common/graph.py
Normal file
123
src/tensorneat/common/graph.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""
|
||||
Some graph algorithm implemented in jax.
|
||||
Only used in feed-forward networks.
|
||||
"""
|
||||
|
||||
import jax
|
||||
from jax import jit, Array, numpy as jnp
|
||||
from typing import Tuple, Set, List, Union
|
||||
|
||||
from .tools import fetch_first, I_INF
|
||||
|
||||
|
||||
@jit
|
||||
def topological_sort(nodes: Array, conns: Array) -> Array:
|
||||
"""
|
||||
a jit-able version of topological_sort!
|
||||
conns: Array[N, N]
|
||||
"""
|
||||
|
||||
in_degree = jnp.where(jnp.isnan(nodes[:, 0]), jnp.nan, jnp.sum(conns, axis=0))
|
||||
res = jnp.full(in_degree.shape, I_INF)
|
||||
|
||||
def cond_fun(carry):
|
||||
res_, idx_, in_degree_ = carry
|
||||
i = fetch_first(in_degree_ == 0.0)
|
||||
return i != I_INF
|
||||
|
||||
def body_func(carry):
|
||||
res_, idx_, in_degree_ = carry
|
||||
i = fetch_first(in_degree_ == 0.0)
|
||||
|
||||
# add to res and flag it is already in it
|
||||
res_ = res_.at[idx_].set(i)
|
||||
in_degree_ = in_degree_.at[i].set(-1)
|
||||
|
||||
# decrease in_degree of all its children
|
||||
children = conns[i, :]
|
||||
in_degree_ = jnp.where(children, in_degree_ - 1, in_degree_)
|
||||
return res_, idx_ + 1, in_degree_
|
||||
|
||||
res, _, _ = jax.lax.while_loop(cond_fun, body_func, (res, 0, in_degree))
|
||||
return res
|
||||
|
||||
|
||||
def topological_sort_python(
|
||||
nodes: Union[Set[int], List[int]],
|
||||
conns: Union[Set[Tuple[int, int]], List[Tuple[int, int]]],
|
||||
) -> Tuple[List[int], List[List[int]]]:
|
||||
# a python version of topological_sort, use python set to store nodes and conns
|
||||
# returns the topological order of the nodes and the topological layers
|
||||
# written by gpt4 :)
|
||||
|
||||
# Make a copy of the input nodes and connections
|
||||
nodes = nodes.copy()
|
||||
conns = conns.copy()
|
||||
|
||||
# Initialize the in-degree of each node to 0
|
||||
in_degree = {node: 0 for node in nodes}
|
||||
|
||||
# Compute the in-degree for each node
|
||||
for conn in conns:
|
||||
in_degree[conn[1]] += 1
|
||||
|
||||
topo_order = []
|
||||
topo_layer = []
|
||||
|
||||
# Find all nodes with in-degree 0
|
||||
zero_in_degree_nodes = [node for node in nodes if in_degree[node] == 0]
|
||||
|
||||
while zero_in_degree_nodes:
|
||||
|
||||
for node in zero_in_degree_nodes:
|
||||
nodes.remove(node)
|
||||
|
||||
zero_in_degree_nodes = sorted(
|
||||
zero_in_degree_nodes
|
||||
) # make sure the topo_order is from small to large
|
||||
|
||||
topo_layer.append(zero_in_degree_nodes.copy())
|
||||
|
||||
for node in zero_in_degree_nodes:
|
||||
topo_order.append(node)
|
||||
|
||||
# Iterate over all connections and reduce the in-degree of connected nodes
|
||||
for conn in list(conns):
|
||||
if conn[0] == node:
|
||||
in_degree[conn[1]] -= 1
|
||||
conns.remove(conn)
|
||||
|
||||
zero_in_degree_nodes = [node for node in nodes if in_degree[node] == 0]
|
||||
|
||||
# Check if there are still connections left indicating a cycle
|
||||
if conns or nodes:
|
||||
raise ValueError("Graph has at least one cycle, topological sort not possible")
|
||||
|
||||
return topo_order, topo_layer
|
||||
|
||||
|
||||
@jit
|
||||
def check_cycles(nodes: Array, conns: Array, from_idx, to_idx) -> Array:
|
||||
"""
|
||||
Check whether a new connection (from_idx -> to_idx) will cause a cycle.
|
||||
"""
|
||||
|
||||
conns = conns.at[from_idx, to_idx].set(True)
|
||||
|
||||
visited = jnp.full(nodes.shape[0], False)
|
||||
new_visited = visited.at[to_idx].set(True)
|
||||
|
||||
def cond_func(carry):
|
||||
visited_, new_visited_ = carry
|
||||
end_cond1 = jnp.all(visited_ == new_visited_) # no new nodes been visited
|
||||
end_cond2 = new_visited_[from_idx] # the starting node has been visited
|
||||
return jnp.logical_not(end_cond1 | end_cond2)
|
||||
|
||||
def body_func(carry):
|
||||
_, visited_ = carry
|
||||
new_visited_ = jnp.dot(visited_, conns)
|
||||
new_visited_ = jnp.logical_or(visited_, new_visited_)
|
||||
return visited_, new_visited_
|
||||
|
||||
_, visited = jax.lax.while_loop(cond_func, body_func, (visited, new_visited))
|
||||
return visited[from_idx]
|
||||
49
src/tensorneat/common/state.py
Normal file
49
src/tensorneat/common/state.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from jax.tree_util import register_pytree_node_class
|
||||
|
||||
|
||||
@register_pytree_node_class
|
||||
class State:
|
||||
def __init__(self, **kwargs):
|
||||
self.__dict__["state_dict"] = kwargs
|
||||
|
||||
def registered_keys(self):
|
||||
return self.state_dict.keys()
|
||||
|
||||
def register(self, **kwargs):
|
||||
for key in kwargs:
|
||||
if key in self.registered_keys():
|
||||
raise ValueError(f"Key {key} already exists in state")
|
||||
return State(**{**self.state_dict, **kwargs})
|
||||
|
||||
def update(self, **kwargs):
|
||||
for key in kwargs:
|
||||
if key not in self.registered_keys():
|
||||
raise ValueError(f"Key {key} does not exist in state")
|
||||
return State(**{**self.state_dict, **kwargs})
|
||||
|
||||
def __getattr__(self, name):
|
||||
return self.state_dict[name]
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
raise AttributeError("State is immutable")
|
||||
|
||||
def __repr__(self):
|
||||
return f"State ({self.state_dict})"
|
||||
|
||||
def __getstate__(self):
|
||||
return self.state_dict.copy()
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__dict__["state_dict"] = state
|
||||
|
||||
def __contains__(self, item):
|
||||
return item in self.state_dict
|
||||
|
||||
def tree_flatten(self):
|
||||
children = list(self.state_dict.values())
|
||||
aux_data = list(self.state_dict.keys())
|
||||
return children, aux_data
|
||||
|
||||
@classmethod
|
||||
def tree_unflatten(cls, aux_data, children):
|
||||
return cls(**dict(zip(aux_data, children)))
|
||||
69
src/tensorneat/common/stateful_class.py
Normal file
69
src/tensorneat/common/stateful_class.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import json
|
||||
from typing import Optional
|
||||
from . import State
|
||||
import pickle
|
||||
import datetime
|
||||
import warnings
|
||||
|
||||
|
||||
class StatefulBaseClass:
|
||||
def setup(self, state=State()):
|
||||
return state
|
||||
|
||||
def save(self, state: Optional[State] = None, path: Optional[str] = None):
|
||||
if path is None:
|
||||
time = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
path = f"./{self.__class__.__name__} {time}.pkl"
|
||||
if state is not None:
|
||||
self.__dict__["aux_for_state"] = state
|
||||
with open(path, "wb") as f:
|
||||
pickle.dump(self, f)
|
||||
|
||||
def __getstate__(self):
|
||||
# only pickle the picklable attributes
|
||||
state = self.__dict__.copy()
|
||||
non_picklable_keys = []
|
||||
for key, value in state.items():
|
||||
try:
|
||||
pickle.dumps(value)
|
||||
except Exception:
|
||||
non_picklable_keys.append(key)
|
||||
|
||||
for key in non_picklable_keys:
|
||||
state.pop(key)
|
||||
|
||||
return state
|
||||
|
||||
def show_config(self):
|
||||
config = {}
|
||||
for key, value in self.__dict__.items():
|
||||
if isinstance(value, StatefulBaseClass):
|
||||
config[str(key)] = value.show_config()
|
||||
else:
|
||||
config[str(key)] = str(value)
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: str, with_state: bool = False, warning: bool = True):
|
||||
with open(path, "rb") as f:
|
||||
obj = pickle.load(f)
|
||||
if with_state:
|
||||
if "aux_for_state" not in obj.__dict__:
|
||||
if warning:
|
||||
warnings.warn(
|
||||
"This object does not have state to load, return empty state",
|
||||
category=UserWarning,
|
||||
)
|
||||
return obj, State()
|
||||
state = obj.__dict__["aux_for_state"]
|
||||
del obj.__dict__["aux_for_state"]
|
||||
return obj, state
|
||||
else:
|
||||
if "aux_for_state" in obj.__dict__:
|
||||
if warning:
|
||||
warnings.warn(
|
||||
"This object has state to load, ignore it",
|
||||
category=UserWarning,
|
||||
)
|
||||
del obj.__dict__["aux_for_state"]
|
||||
return obj
|
||||
110
src/tensorneat/common/tools.py
Normal file
110
src/tensorneat/common/tools.py
Normal file
@@ -0,0 +1,110 @@
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import jax
|
||||
from jax import numpy as jnp, Array, jit, vmap
|
||||
|
||||
I_INF = np.iinfo(jnp.int32).max # infinite int
|
||||
|
||||
|
||||
def attach_with_inf(arr, idx):
|
||||
target_dim = arr.ndim + idx.ndim - 1
|
||||
expand_idx = jnp.expand_dims(idx, axis=tuple(range(idx.ndim, target_dim)))
|
||||
|
||||
return jnp.where(expand_idx == I_INF, jnp.nan, arr[idx])
|
||||
|
||||
|
||||
@jit
|
||||
def fetch_first(mask, default=I_INF) -> Array:
|
||||
"""
|
||||
fetch the first True index
|
||||
:param mask: array of bool
|
||||
:param default: the default value if no element satisfying the condition
|
||||
:return: the index of the first element satisfying the condition. if no element satisfying the condition, return default value
|
||||
"""
|
||||
idx = jnp.argmax(mask)
|
||||
return jnp.where(mask[idx], idx, default)
|
||||
|
||||
|
||||
@jit
|
||||
def fetch_random(randkey, mask, default=I_INF) -> Array:
|
||||
"""
|
||||
similar to fetch_first, but fetch a random True index
|
||||
"""
|
||||
true_cnt = jnp.sum(mask)
|
||||
cumsum = jnp.cumsum(mask)
|
||||
target = jax.random.randint(randkey, shape=(), minval=1, maxval=true_cnt + 1)
|
||||
mask = jnp.where(true_cnt == 0, False, cumsum >= target)
|
||||
return fetch_first(mask, default)
|
||||
|
||||
|
||||
@partial(jit, static_argnames=["reverse"])
|
||||
def rank_elements(array, reverse=False):
|
||||
"""
|
||||
rank the element in the array.
|
||||
if reverse is True, the rank is from small to large. default large to small
|
||||
"""
|
||||
if not reverse:
|
||||
array = -array
|
||||
return jnp.argsort(jnp.argsort(array))
|
||||
|
||||
|
||||
@jit
|
||||
def mutate_float(
|
||||
randkey, val, init_mean, init_std, mutate_power, mutate_rate, replace_rate
|
||||
):
|
||||
"""
|
||||
mutate a float value
|
||||
uniformly pick r from [0, 1]
|
||||
r in [0, mutate_rate) -> add noise
|
||||
r in [mutate_rate, mutate_rate + replace_rate) -> create a new value to replace the original value
|
||||
otherwise -> keep the original value
|
||||
"""
|
||||
k1, k2, k3 = jax.random.split(randkey, num=3)
|
||||
noise = jax.random.normal(k1, ()) * mutate_power
|
||||
replace = jax.random.normal(k2, ()) * init_std + init_mean
|
||||
r = jax.random.uniform(k3, ())
|
||||
|
||||
val = jnp.where(
|
||||
r < mutate_rate,
|
||||
val + noise,
|
||||
jnp.where((mutate_rate < r) & (r < mutate_rate + replace_rate), replace, val),
|
||||
)
|
||||
|
||||
return val
|
||||
|
||||
|
||||
@jit
|
||||
def mutate_int(randkey, val, options, replace_rate):
|
||||
"""
|
||||
mutate an int value
|
||||
uniformly pick r from [0, 1]
|
||||
r in [0, replace_rate) -> create a new value to replace the original value
|
||||
otherwise -> keep the original value
|
||||
"""
|
||||
k1, k2 = jax.random.split(randkey, num=2)
|
||||
r = jax.random.uniform(k1, ())
|
||||
|
||||
val = jnp.where(r < replace_rate, jax.random.choice(k2, options), val)
|
||||
|
||||
return val
|
||||
|
||||
|
||||
def argmin_with_mask(arr, mask):
|
||||
"""
|
||||
find the index of the minimum element in the array, but only consider the element with True mask
|
||||
"""
|
||||
masked_arr = jnp.where(mask, arr, jnp.inf)
|
||||
min_idx = jnp.argmin(masked_arr)
|
||||
return min_idx
|
||||
|
||||
|
||||
def hash_array(arr: Array):
|
||||
arr = jax.lax.bitcast_convert_type(arr, jnp.uint32)
|
||||
|
||||
def update(i, hash_val):
|
||||
return hash_val ^ (
|
||||
arr[i] + jnp.uint32(0x9E3779B9) + (hash_val << 6) + (hash_val >> 2)
|
||||
)
|
||||
|
||||
return jax.lax.fori_loop(0, arr.size, update, jnp.uint32(0))
|
||||
6
src/tensorneat/genome/__init__.py
Normal file
6
src/tensorneat/genome/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from .gene import *
|
||||
from .operations import *
|
||||
from .base import BaseGenome
|
||||
from .default import DefaultGenome
|
||||
from .recurrent import RecurrentGenome
|
||||
|
||||
223
src/tensorneat/genome/base.py
Normal file
223
src/tensorneat/genome/base.py
Normal file
@@ -0,0 +1,223 @@
|
||||
from typing import Callable, Sequence
|
||||
|
||||
import numpy as np
|
||||
import jax
|
||||
from jax import vmap, numpy as jnp
|
||||
from .gene import BaseNode, BaseConn
|
||||
from .operations import BaseMutation, BaseCrossover, BaseDistance
|
||||
from tensorneat.common import (
|
||||
State,
|
||||
StatefulBaseClass,
|
||||
hash_array,
|
||||
)
|
||||
from .utils import valid_cnt
|
||||
|
||||
|
||||
class BaseGenome(StatefulBaseClass):
|
||||
network_type = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_inputs: int,
|
||||
num_outputs: int,
|
||||
max_nodes: int,
|
||||
max_conns: int,
|
||||
node_gene: BaseNode,
|
||||
conn_gene: BaseConn,
|
||||
mutation: BaseMutation,
|
||||
crossover: BaseCrossover,
|
||||
distance: BaseDistance,
|
||||
output_transform: Callable = None,
|
||||
input_transform: Callable = None,
|
||||
init_hidden_layers: Sequence[int] = (),
|
||||
):
|
||||
|
||||
# check transform functions
|
||||
if input_transform is not None:
|
||||
try:
|
||||
_ = input_transform(jnp.zeros(num_inputs))
|
||||
except Exception as e:
|
||||
raise ValueError(f"Output transform function failed: {e}")
|
||||
|
||||
if output_transform is not None:
|
||||
try:
|
||||
_ = output_transform(jnp.zeros(num_outputs))
|
||||
except Exception as e:
|
||||
raise ValueError(f"Output transform function failed: {e}")
|
||||
|
||||
# prepare for initialization
|
||||
all_layers = [num_inputs] + list(init_hidden_layers) + [num_outputs]
|
||||
layer_indices = []
|
||||
next_index = 0
|
||||
for layer in all_layers:
|
||||
layer_indices.append(list(range(next_index, next_index + layer)))
|
||||
next_index += layer
|
||||
|
||||
all_init_nodes = []
|
||||
all_init_conns_in_idx = []
|
||||
all_init_conns_out_idx = []
|
||||
for i in range(len(layer_indices) - 1):
|
||||
in_layer = layer_indices[i]
|
||||
out_layer = layer_indices[i + 1]
|
||||
for in_idx in in_layer:
|
||||
for out_idx in out_layer:
|
||||
all_init_conns_in_idx.append(in_idx)
|
||||
all_init_conns_out_idx.append(out_idx)
|
||||
all_init_nodes.extend(in_layer)
|
||||
all_init_nodes.extend(layer_indices[-1]) # output layer
|
||||
|
||||
if max_nodes < len(all_init_nodes):
|
||||
raise ValueError(
|
||||
f"max_nodes={max_nodes} must be greater than or equal to the number of initial nodes={len(all_init_nodes)}"
|
||||
)
|
||||
|
||||
if max_conns < len(all_init_conns_in_idx):
|
||||
raise ValueError(
|
||||
f"max_conns={max_conns} must be greater than or equal to the number of initial connections={len(all_init_conns_in_idx)}"
|
||||
)
|
||||
|
||||
self.num_inputs = num_inputs
|
||||
self.num_outputs = num_outputs
|
||||
self.max_nodes = max_nodes
|
||||
self.max_conns = max_conns
|
||||
self.node_gene = node_gene
|
||||
self.conn_gene = conn_gene
|
||||
self.mutation = mutation
|
||||
self.crossover = crossover
|
||||
self.distance = distance
|
||||
self.output_transform = output_transform
|
||||
self.input_transform = input_transform
|
||||
|
||||
self.input_idx = np.array(layer_indices[0])
|
||||
self.output_idx = np.array(layer_indices[-1])
|
||||
self.all_init_nodes = np.array(all_init_nodes)
|
||||
self.all_init_conns = np.c_[all_init_conns_in_idx, all_init_conns_out_idx]
|
||||
|
||||
def setup(self, state=State()):
|
||||
state = self.node_gene.setup(state)
|
||||
state = self.conn_gene.setup(state)
|
||||
state = self.mutation.setup(state, self)
|
||||
state = self.crossover.setup(state, self)
|
||||
state = self.distance.setup(state, self)
|
||||
return state
|
||||
|
||||
def transform(self, state, nodes, conns):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, state, transformed, inputs):
|
||||
raise NotImplementedError
|
||||
|
||||
def sympy_func(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def visualize(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def execute_mutation(self, state, randkey, nodes, conns, new_node_key):
|
||||
return self.mutation(state, randkey, nodes, conns, new_node_key)
|
||||
|
||||
def execute_crossover(self, state, randkey, nodes1, conns1, nodes2, conns2):
|
||||
return self.crossover(state, randkey, nodes1, conns1, nodes2, conns2)
|
||||
|
||||
def execute_distance(self, state, nodes1, conns1, nodes2, conns2):
|
||||
return self.distance(state, nodes1, conns1, nodes2, conns2)
|
||||
|
||||
def initialize(self, state, randkey):
|
||||
k1, k2 = jax.random.split(randkey) # k1 for nodes, k2 for conns
|
||||
|
||||
all_nodes_cnt = len(self.all_init_nodes)
|
||||
all_conns_cnt = len(self.all_init_conns)
|
||||
|
||||
# initialize nodes
|
||||
nodes = jnp.full((self.max_nodes, self.node_gene.length), jnp.nan)
|
||||
# create node indices
|
||||
node_indices = self.all_init_nodes
|
||||
# create node attrs
|
||||
rand_keys_n = jax.random.split(k1, num=all_nodes_cnt)
|
||||
node_attr_func = vmap(self.node_gene.new_random_attrs, in_axes=(None, 0))
|
||||
node_attrs = node_attr_func(state, rand_keys_n)
|
||||
|
||||
nodes = nodes.at[:all_nodes_cnt, 0].set(node_indices) # set node indices
|
||||
nodes = nodes.at[:all_nodes_cnt, 1:].set(node_attrs) # set node attrs
|
||||
|
||||
# initialize conns
|
||||
conns = jnp.full((self.max_conns, self.conn_gene.length), jnp.nan)
|
||||
# create input and output indices
|
||||
conn_indices = self.all_init_conns
|
||||
# create conn attrs
|
||||
rand_keys_c = jax.random.split(k2, num=all_conns_cnt)
|
||||
conns_attr_func = jax.vmap(
|
||||
self.conn_gene.new_random_attrs,
|
||||
in_axes=(
|
||||
None,
|
||||
0,
|
||||
),
|
||||
)
|
||||
conns_attrs = conns_attr_func(state, rand_keys_c)
|
||||
|
||||
conns = conns.at[:all_conns_cnt, :2].set(conn_indices) # set conn indices
|
||||
conns = conns.at[:all_conns_cnt, 2:].set(conns_attrs) # set conn attrs
|
||||
|
||||
return nodes, conns
|
||||
|
||||
def network_dict(self, state, nodes, conns):
|
||||
return {
|
||||
"nodes": self._get_node_dict(state, nodes),
|
||||
"conns": self._get_conn_dict(state, conns),
|
||||
}
|
||||
|
||||
def get_input_idx(self):
|
||||
return self.input_idx.tolist()
|
||||
|
||||
def get_output_idx(self):
|
||||
return self.output_idx.tolist()
|
||||
|
||||
def hash(self, nodes, conns):
|
||||
nodes_hashs = vmap(hash_array)(nodes)
|
||||
conns_hashs = vmap(hash_array)(conns)
|
||||
return hash_array(jnp.concatenate([nodes_hashs, conns_hashs]))
|
||||
|
||||
def repr(self, state, nodes, conns, precision=2):
|
||||
nodes, conns = jax.device_get([nodes, conns])
|
||||
nodes_cnt, conns_cnt = valid_cnt(nodes), valid_cnt(conns)
|
||||
s = f"{self.__class__.__name__}(nodes={nodes_cnt}, conns={conns_cnt}):\n"
|
||||
s += f"\tNodes:\n"
|
||||
for node in nodes:
|
||||
if np.isnan(node[0]):
|
||||
break
|
||||
s += f"\t\t{self.node_gene.repr(state, node, precision=precision)}"
|
||||
node_idx = int(node[0])
|
||||
if np.isin(node_idx, self.input_idx):
|
||||
s += " (input)"
|
||||
elif np.isin(node_idx, self.output_idx):
|
||||
s += " (output)"
|
||||
s += "\n"
|
||||
|
||||
s += f"\tConns:\n"
|
||||
for conn in conns:
|
||||
if np.isnan(conn[0]):
|
||||
break
|
||||
s += f"\t\t{self.conn_gene.repr(state, conn, precision=precision)}\n"
|
||||
return s
|
||||
|
||||
def _get_conn_dict(self, state, conns):
|
||||
conns = jax.device_get(conns)
|
||||
conn_dict = {}
|
||||
for conn in conns:
|
||||
if np.isnan(conn[0]):
|
||||
continue
|
||||
cd = self.conn_gene.to_dict(state, conn)
|
||||
in_idx, out_idx = cd["in"], cd["out"]
|
||||
conn_dict[(in_idx, out_idx)] = cd
|
||||
return conn_dict
|
||||
|
||||
def _get_node_dict(self, state, nodes):
|
||||
nodes = jax.device_get(nodes)
|
||||
node_dict = {}
|
||||
for node in nodes:
|
||||
if np.isnan(node[0]):
|
||||
continue
|
||||
nd = self.node_gene.to_dict(state, node)
|
||||
idx = nd["idx"]
|
||||
node_dict[idx] = nd
|
||||
return node_dict
|
||||
321
src/tensorneat/genome/default.py
Normal file
321
src/tensorneat/genome/default.py
Normal file
@@ -0,0 +1,321 @@
|
||||
import warnings
|
||||
|
||||
import jax
|
||||
from jax import vmap, numpy as jnp
|
||||
import numpy as np
|
||||
import sympy as sp
|
||||
|
||||
from .base import BaseGenome
|
||||
from .gene import DefaultNode, DefaultConn
|
||||
from .operations import DefaultMutation, DefaultCrossover, DefaultDistance
|
||||
from .utils import unflatten_conns, extract_node_attrs, extract_conn_attrs
|
||||
|
||||
from tensorneat.common import (
|
||||
topological_sort,
|
||||
topological_sort_python,
|
||||
I_INF,
|
||||
attach_with_inf,
|
||||
SYMPY_FUNCS_MODULE_NP,
|
||||
SYMPY_FUNCS_MODULE_JNP,
|
||||
)
|
||||
|
||||
|
||||
class DefaultGenome(BaseGenome):
|
||||
"""Default genome class, with the same behavior as the NEAT-Python"""
|
||||
|
||||
network_type = "feedforward"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_inputs: int,
|
||||
num_outputs: int,
|
||||
max_nodes=50,
|
||||
max_conns=100,
|
||||
node_gene=DefaultNode(),
|
||||
conn_gene=DefaultConn(),
|
||||
mutation=DefaultMutation(),
|
||||
crossover=DefaultCrossover(),
|
||||
distance=DefaultDistance(),
|
||||
output_transform=None,
|
||||
input_transform=None,
|
||||
init_hidden_layers=(),
|
||||
):
|
||||
|
||||
super().__init__(
|
||||
num_inputs,
|
||||
num_outputs,
|
||||
max_nodes,
|
||||
max_conns,
|
||||
node_gene,
|
||||
conn_gene,
|
||||
mutation,
|
||||
crossover,
|
||||
distance,
|
||||
output_transform,
|
||||
input_transform,
|
||||
init_hidden_layers,
|
||||
)
|
||||
|
||||
def transform(self, state, nodes, conns):
|
||||
u_conns = unflatten_conns(nodes, conns)
|
||||
conn_exist = u_conns != I_INF
|
||||
|
||||
seqs = topological_sort(nodes, conn_exist)
|
||||
|
||||
return seqs, nodes, conns, u_conns
|
||||
|
||||
def forward(self, state, transformed, inputs):
|
||||
|
||||
if self.input_transform is not None:
|
||||
inputs = self.input_transform(inputs)
|
||||
|
||||
cal_seqs, nodes, conns, u_conns = transformed
|
||||
|
||||
ini_vals = jnp.full((self.max_nodes,), jnp.nan)
|
||||
ini_vals = ini_vals.at[self.input_idx].set(inputs)
|
||||
nodes_attrs = vmap(extract_node_attrs)(nodes)
|
||||
conns_attrs = vmap(extract_conn_attrs)(conns)
|
||||
|
||||
def cond_fun(carry):
|
||||
values, idx = carry
|
||||
return (idx < self.max_nodes) & (
|
||||
cal_seqs[idx] != I_INF
|
||||
) # not out of bounds and next node exists
|
||||
|
||||
def body_func(carry):
|
||||
values, idx = carry
|
||||
i = cal_seqs[idx]
|
||||
|
||||
def input_node():
|
||||
return values
|
||||
|
||||
def otherwise():
|
||||
# calculate connections
|
||||
conn_indices = u_conns[:, i]
|
||||
hit_attrs = attach_with_inf(conns_attrs, conn_indices) # fetch conn attrs
|
||||
ins = vmap(self.conn_gene.forward, in_axes=(None, 0, 0))(
|
||||
state, hit_attrs, values
|
||||
)
|
||||
|
||||
# calculate nodes
|
||||
z = self.node_gene.forward(
|
||||
state,
|
||||
nodes_attrs[i],
|
||||
ins,
|
||||
is_output_node=jnp.isin(nodes[i, 0], self.output_idx), # nodes[0] -> the key of nodes
|
||||
)
|
||||
|
||||
# set new value
|
||||
new_values = values.at[i].set(z)
|
||||
return new_values
|
||||
|
||||
values = jax.lax.cond(jnp.isin(i, self.input_idx), input_node, otherwise)
|
||||
|
||||
return values, idx + 1
|
||||
|
||||
vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0))
|
||||
|
||||
if self.output_transform is None:
|
||||
return vals[self.output_idx]
|
||||
else:
|
||||
return self.output_transform(vals[self.output_idx])
|
||||
|
||||
def network_dict(self, state, nodes, conns):
|
||||
network = super().network_dict(state, nodes, conns)
|
||||
topo_order, topo_layers = topological_sort_python(
|
||||
set(network["nodes"]), set(network["conns"])
|
||||
)
|
||||
network["topo_order"] = topo_order
|
||||
network["topo_layers"] = topo_layers
|
||||
return network
|
||||
|
||||
def sympy_func(
|
||||
self,
|
||||
state,
|
||||
network,
|
||||
sympy_input_transform=None,
|
||||
sympy_output_transform=None,
|
||||
backend="jax",
|
||||
):
|
||||
|
||||
assert backend in ["jax", "numpy"], "backend should be 'jax' or 'numpy'"
|
||||
module = SYMPY_FUNCS_MODULE_JNP if backend == "jax" else SYMPY_FUNCS_MODULE_NP
|
||||
|
||||
if sympy_input_transform is None and self.input_transform is not None:
|
||||
warnings.warn(
|
||||
"genome.input_transform is not None but sympy_input_transform is None!"
|
||||
)
|
||||
|
||||
if sympy_input_transform is None:
|
||||
sympy_input_transform = lambda x: x
|
||||
|
||||
if sympy_input_transform is not None:
|
||||
if not isinstance(sympy_input_transform, list):
|
||||
sympy_input_transform = [sympy_input_transform] * self.num_inputs
|
||||
|
||||
if sympy_output_transform is None and self.output_transform is not None:
|
||||
warnings.warn(
|
||||
"genome.output_transform is not None but sympy_output_transform is None!"
|
||||
)
|
||||
|
||||
input_idx = self.get_input_idx()
|
||||
output_idx = self.get_output_idx()
|
||||
order = network["topo_order"]
|
||||
|
||||
hidden_idx = [
|
||||
i for i in network["nodes"] if i not in input_idx and i not in output_idx
|
||||
]
|
||||
symbols = {}
|
||||
for i in network["nodes"]:
|
||||
if i in input_idx:
|
||||
symbols[-i - 1] = sp.Symbol(f"i{i - min(input_idx)}") # origin_i
|
||||
symbols[i] = sp.Symbol(f"norm{i - min(input_idx)}")
|
||||
elif i in output_idx:
|
||||
symbols[i] = sp.Symbol(f"o{i - min(output_idx)}")
|
||||
else: # hidden
|
||||
symbols[i] = sp.Symbol(f"h{i - min(hidden_idx)}")
|
||||
|
||||
nodes_exprs = {}
|
||||
args_symbols = {}
|
||||
for i in order:
|
||||
|
||||
if i in input_idx:
|
||||
nodes_exprs[symbols[-i - 1]] = symbols[
|
||||
-i - 1
|
||||
] # origin equal to its symbol
|
||||
nodes_exprs[symbols[i]] = sympy_input_transform[i - min(input_idx)](
|
||||
symbols[-i - 1]
|
||||
) # normed i
|
||||
|
||||
else:
|
||||
in_conns = [c for c in network["conns"] if c[1] == i]
|
||||
node_inputs = []
|
||||
for conn in in_conns:
|
||||
val_represent = symbols[conn[0]]
|
||||
# a_s -> args_symbols
|
||||
val, a_s = self.conn_gene.sympy_func(
|
||||
state,
|
||||
network["conns"][conn],
|
||||
val_represent,
|
||||
)
|
||||
args_symbols.update(a_s)
|
||||
node_inputs.append(val)
|
||||
nodes_exprs[symbols[i]], a_s = self.node_gene.sympy_func(
|
||||
state,
|
||||
network["nodes"][i],
|
||||
node_inputs,
|
||||
is_output_node=(i in output_idx),
|
||||
)
|
||||
args_symbols.update(a_s)
|
||||
|
||||
if i in output_idx and sympy_output_transform is not None:
|
||||
nodes_exprs[symbols[i]] = sympy_output_transform(
|
||||
nodes_exprs[symbols[i]]
|
||||
)
|
||||
|
||||
input_symbols = [symbols[-i - 1] for i in input_idx]
|
||||
reduced_exprs = nodes_exprs.copy()
|
||||
for i in order:
|
||||
reduced_exprs[symbols[i]] = reduced_exprs[symbols[i]].subs(reduced_exprs)
|
||||
|
||||
output_exprs = [reduced_exprs[symbols[i]] for i in output_idx]
|
||||
|
||||
lambdify_output_funcs = [
|
||||
sp.lambdify(
|
||||
input_symbols + list(args_symbols.keys()),
|
||||
exprs,
|
||||
modules=[backend, module],
|
||||
)
|
||||
for exprs in output_exprs
|
||||
]
|
||||
|
||||
fixed_args_output_funcs = []
|
||||
for i in range(len(output_idx)):
|
||||
|
||||
def f(inputs, i=i):
|
||||
return lambdify_output_funcs[i](*inputs, *args_symbols.values())
|
||||
|
||||
fixed_args_output_funcs.append(f)
|
||||
|
||||
forward_func = lambda inputs: jnp.array(
|
||||
[f(inputs) for f in fixed_args_output_funcs]
|
||||
)
|
||||
|
||||
return (
|
||||
symbols,
|
||||
args_symbols,
|
||||
input_symbols,
|
||||
nodes_exprs,
|
||||
output_exprs,
|
||||
forward_func,
|
||||
)
|
||||
|
||||
def visualize(
|
||||
self,
|
||||
network,
|
||||
rotate=0,
|
||||
reverse_node_order=False,
|
||||
size=(300, 300, 300),
|
||||
color=("blue", "blue", "blue"),
|
||||
save_path="network.svg",
|
||||
save_dpi=800,
|
||||
**kwargs,
|
||||
):
|
||||
import networkx as nx
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
nodes_list = list(network["nodes"])
|
||||
conns_list = list(network["conns"])
|
||||
input_idx = self.get_input_idx()
|
||||
output_idx = self.get_output_idx()
|
||||
|
||||
topo_order, topo_layers = network["topo_order"], network["topo_layers"]
|
||||
node2layer = {
|
||||
node: layer for layer, nodes in enumerate(topo_layers) for node in nodes
|
||||
}
|
||||
if reverse_node_order:
|
||||
topo_order = topo_order[::-1]
|
||||
|
||||
G = nx.DiGraph()
|
||||
|
||||
if not isinstance(size, tuple):
|
||||
size = (size, size, size)
|
||||
if not isinstance(color, tuple):
|
||||
color = (color, color, color)
|
||||
|
||||
for node in topo_order:
|
||||
if node in input_idx:
|
||||
G.add_node(node, subset=node2layer[node], size=size[0], color=color[0])
|
||||
elif node in output_idx:
|
||||
G.add_node(node, subset=node2layer[node], size=size[2], color=color[2])
|
||||
else:
|
||||
G.add_node(node, subset=node2layer[node], size=size[1], color=color[1])
|
||||
|
||||
for conn in conns_list:
|
||||
G.add_edge(conn[0], conn[1])
|
||||
pos = nx.multipartite_layout(G)
|
||||
|
||||
def rotate_layout(pos, angle):
|
||||
angle_rad = np.deg2rad(angle)
|
||||
cos_angle, sin_angle = np.cos(angle_rad), np.sin(angle_rad)
|
||||
rotated_pos = {}
|
||||
for node, (x, y) in pos.items():
|
||||
rotated_pos[node] = (
|
||||
cos_angle * x - sin_angle * y,
|
||||
sin_angle * x + cos_angle * y,
|
||||
)
|
||||
return rotated_pos
|
||||
|
||||
rotated_pos = rotate_layout(pos, rotate)
|
||||
|
||||
node_sizes = [n["size"] for n in G.nodes.values()]
|
||||
node_colors = [n["color"] for n in G.nodes.values()]
|
||||
|
||||
nx.draw(
|
||||
G,
|
||||
pos=rotated_pos,
|
||||
node_size=node_sizes,
|
||||
node_color=node_colors,
|
||||
**kwargs,
|
||||
)
|
||||
plt.savefig(save_path, dpi=save_dpi)
|
||||
3
src/tensorneat/genome/gene/__init__.py
Normal file
3
src/tensorneat/genome/gene/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .base import BaseGene
|
||||
from .conn import *
|
||||
from .node import *
|
||||
45
src/tensorneat/genome/gene/base.py
Normal file
45
src/tensorneat/genome/gene/base.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import jax, jax.numpy as jnp
|
||||
from tensorneat.common import State, StatefulBaseClass, hash_array
|
||||
|
||||
|
||||
class BaseGene(StatefulBaseClass):
|
||||
"Base class for node genes or connection genes."
|
||||
fixed_attrs = []
|
||||
custom_attrs = []
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def new_identity_attrs(self, state):
|
||||
# the attrs which do identity transformation, used in mutate add node
|
||||
raise NotImplementedError
|
||||
|
||||
def new_random_attrs(self, state, randkey):
|
||||
# random attributes of the gene. used in initialization.
|
||||
raise NotImplementedError
|
||||
|
||||
def mutate(self, state, randkey, attrs):
|
||||
raise NotImplementedError
|
||||
|
||||
def crossover(self, state, randkey, attrs1, attrs2):
|
||||
return jnp.where(
|
||||
jax.random.normal(randkey, attrs1.shape) > 0,
|
||||
attrs1,
|
||||
attrs2,
|
||||
)
|
||||
|
||||
def distance(self, state, attrs1, attrs2):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, state, attrs, inputs):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def length(self):
|
||||
return len(self.fixed_attrs) + len(self.custom_attrs)
|
||||
|
||||
def repr(self, state, gene, precision=2):
|
||||
raise NotImplementedError
|
||||
|
||||
def hash(self, gene):
|
||||
return hash_array(gene)
|
||||
2
src/tensorneat/genome/gene/conn/__init__.py
Normal file
2
src/tensorneat/genome/gene/conn/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .base import BaseConn
|
||||
from .default import DefaultConn
|
||||
35
src/tensorneat/genome/gene/conn/base.py
Normal file
35
src/tensorneat/genome/gene/conn/base.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from ..base import BaseGene
|
||||
|
||||
|
||||
class BaseConn(BaseGene):
|
||||
"Base class for connection genes."
|
||||
fixed_attrs = ["input_index", "output_index"]
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def new_zero_attrs(self, state):
|
||||
# the attrs which make the least influence on the network, used in mutate add conn
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, state, attrs, inputs):
|
||||
raise NotImplementedError
|
||||
|
||||
def repr(self, state, conn, precision=2, idx_width=3, func_width=8):
|
||||
in_idx, out_idx = conn[:2]
|
||||
in_idx = int(in_idx)
|
||||
out_idx = int(out_idx)
|
||||
|
||||
return "{}(in: {:<{idx_width}}, out: {:<{idx_width}})".format(
|
||||
self.__class__.__name__, in_idx, out_idx, idx_width=idx_width
|
||||
)
|
||||
|
||||
def to_dict(self, state, conn):
|
||||
in_idx, out_idx = conn[:2]
|
||||
return {
|
||||
"in": int(in_idx),
|
||||
"out": int(out_idx),
|
||||
}
|
||||
|
||||
def sympy_func(self, state, conn_dict, inputs):
|
||||
raise NotImplementedError
|
||||
96
src/tensorneat/genome/gene/conn/default.py
Normal file
96
src/tensorneat/genome/gene/conn/default.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import jax.numpy as jnp
|
||||
import jax.random
|
||||
import sympy as sp
|
||||
from tensorneat.common import mutate_float
|
||||
from .base import BaseConn
|
||||
|
||||
|
||||
class DefaultConn(BaseConn):
|
||||
"Default connection gene, with the same behavior as in NEAT-python."
|
||||
|
||||
custom_attrs = ["weight"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_init_mean: float = 0.0,
|
||||
weight_init_std: float = 1.0,
|
||||
weight_mutate_power: float = 0.15,
|
||||
weight_mutate_rate: float = 0.2,
|
||||
weight_replace_rate: float = 0.015,
|
||||
weight_lower_bound: float = -5.0,
|
||||
weight_upper_bound: float = 5.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.weight_init_mean = weight_init_mean
|
||||
self.weight_init_std = weight_init_std
|
||||
self.weight_mutate_power = weight_mutate_power
|
||||
self.weight_mutate_rate = weight_mutate_rate
|
||||
self.weight_replace_rate = weight_replace_rate
|
||||
self.weight_lower_bound = weight_lower_bound
|
||||
self.weight_upper_bound = weight_upper_bound
|
||||
|
||||
|
||||
def new_zero_attrs(self, state):
|
||||
return jnp.array([0.0]) # weight = 0
|
||||
|
||||
def new_identity_attrs(self, state):
|
||||
return jnp.array([1.0]) # weight = 1
|
||||
|
||||
def new_random_attrs(self, state, randkey):
|
||||
weight = (
|
||||
jax.random.normal(randkey, ()) * self.weight_init_std
|
||||
+ self.weight_init_mean
|
||||
)
|
||||
weight = jnp.clip(weight, self.weight_lower_bound, self.weight_upper_bound)
|
||||
return jnp.array([weight])
|
||||
|
||||
def mutate(self, state, randkey, attrs):
|
||||
weight = attrs[0]
|
||||
weight = mutate_float(
|
||||
randkey,
|
||||
weight,
|
||||
self.weight_init_mean,
|
||||
self.weight_init_std,
|
||||
self.weight_mutate_power,
|
||||
self.weight_mutate_rate,
|
||||
self.weight_replace_rate,
|
||||
)
|
||||
weight = jnp.clip(weight, self.weight_lower_bound, self.weight_upper_bound)
|
||||
return jnp.array([weight])
|
||||
|
||||
def distance(self, state, attrs1, attrs2):
|
||||
weight1 = attrs1[0]
|
||||
weight2 = attrs2[0]
|
||||
return jnp.abs(weight1 - weight2)
|
||||
|
||||
def forward(self, state, attrs, inputs):
|
||||
weight = attrs[0]
|
||||
return inputs * weight
|
||||
|
||||
def repr(self, state, conn, precision=2, idx_width=3, func_width=8):
|
||||
in_idx, out_idx, weight = conn
|
||||
|
||||
in_idx = int(in_idx)
|
||||
out_idx = int(out_idx)
|
||||
weight = round(float(weight), precision)
|
||||
|
||||
return "{}(in: {:<{idx_width}}, out: {:<{idx_width}}, weight: {:<{float_width}})".format(
|
||||
self.__class__.__name__,
|
||||
in_idx,
|
||||
out_idx,
|
||||
weight,
|
||||
idx_width=idx_width,
|
||||
float_width=precision + 3,
|
||||
)
|
||||
|
||||
def to_dict(self, state, conn):
|
||||
return {
|
||||
"in": int(conn[0]),
|
||||
"out": int(conn[1]),
|
||||
"weight": jnp.float32(conn[2]),
|
||||
}
|
||||
|
||||
def sympy_func(self, state, conn_dict, inputs, precision=None):
|
||||
weight = sp.symbols(f"c_{conn_dict['in']}_{conn_dict['out']}_w")
|
||||
|
||||
return inputs * weight, {weight: conn_dict["weight"]}
|
||||
3
src/tensorneat/genome/gene/node/__init__.py
Normal file
3
src/tensorneat/genome/gene/node/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .base import BaseNode
|
||||
from .default import DefaultNode
|
||||
from .bias import BiasNode
|
||||
30
src/tensorneat/genome/gene/node/base.py
Normal file
30
src/tensorneat/genome/gene/node/base.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import jax, jax.numpy as jnp
|
||||
from .. import BaseGene
|
||||
|
||||
|
||||
class BaseNode(BaseGene):
|
||||
"Base class for node genes."
|
||||
fixed_attrs = ["index"]
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, state, attrs, inputs, is_output_node=False):
|
||||
raise NotImplementedError
|
||||
|
||||
def repr(self, state, node, precision=2, idx_width=3, func_width=8):
|
||||
idx = node[0]
|
||||
|
||||
idx = int(idx)
|
||||
return "{}(idx={:<{idx_width}})".format(
|
||||
self.__class__.__name__, idx, idx_width=idx_width
|
||||
)
|
||||
|
||||
def to_dict(self, state, node):
|
||||
idx = node[0]
|
||||
return {
|
||||
"idx": int(idx),
|
||||
}
|
||||
|
||||
def sympy_func(self, state, node_dict, inputs, is_output_node=False):
|
||||
raise NotImplementedError
|
||||
185
src/tensorneat/genome/gene/node/bias.py
Normal file
185
src/tensorneat/genome/gene/node/bias.py
Normal file
@@ -0,0 +1,185 @@
|
||||
from typing import Union, Sequence, Callable, Optional
|
||||
|
||||
import numpy as np
|
||||
import jax, jax.numpy as jnp
|
||||
import sympy as sp
|
||||
from tensorneat.common import (
|
||||
Act,
|
||||
Agg,
|
||||
act_func,
|
||||
agg_func,
|
||||
mutate_int,
|
||||
mutate_float,
|
||||
convert_to_sympy,
|
||||
)
|
||||
|
||||
from . import BaseNode
|
||||
|
||||
|
||||
class BiasNode(BaseNode):
|
||||
"""
|
||||
Default node gene, with the same behavior as in NEAT-python.
|
||||
The attribute response is removed.
|
||||
"""
|
||||
|
||||
custom_attrs = ["bias", "aggregation", "activation"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bias_init_mean: float = 0.0,
|
||||
bias_init_std: float = 1.0,
|
||||
bias_mutate_power: float = 0.15,
|
||||
bias_mutate_rate: float = 0.2,
|
||||
bias_replace_rate: float = 0.015,
|
||||
bias_lower_bound: float = -5,
|
||||
bias_upper_bound: float = 5,
|
||||
aggregation_default: Optional[Callable] = None,
|
||||
aggregation_options: Union[Callable, Sequence[Callable]] = Agg.sum,
|
||||
aggregation_replace_rate: float = 0.1,
|
||||
activation_default: Optional[Callable] = None,
|
||||
activation_options: Union[Callable, Sequence[Callable]] = Act.sigmoid,
|
||||
activation_replace_rate: float = 0.1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if isinstance(aggregation_options, Callable):
|
||||
aggregation_options = [aggregation_options]
|
||||
if isinstance(activation_options, Callable):
|
||||
activation_options = [activation_options]
|
||||
|
||||
if aggregation_default is None:
|
||||
aggregation_default = aggregation_options[0]
|
||||
if activation_default is None:
|
||||
activation_default = activation_options[0]
|
||||
|
||||
self.bias_init_mean = bias_init_mean
|
||||
self.bias_init_std = bias_init_std
|
||||
self.bias_mutate_power = bias_mutate_power
|
||||
self.bias_mutate_rate = bias_mutate_rate
|
||||
self.bias_replace_rate = bias_replace_rate
|
||||
self.bias_lower_bound = bias_lower_bound
|
||||
self.bias_upper_bound = bias_upper_bound
|
||||
|
||||
self.aggregation_default = aggregation_options.index(aggregation_default)
|
||||
self.aggregation_options = aggregation_options
|
||||
self.aggregation_indices = np.arange(len(aggregation_options))
|
||||
self.aggregation_replace_rate = aggregation_replace_rate
|
||||
|
||||
self.activation_default = activation_options.index(activation_default)
|
||||
self.activation_options = activation_options
|
||||
self.activation_indices = np.arange(len(activation_options))
|
||||
self.activation_replace_rate = activation_replace_rate
|
||||
|
||||
def new_identity_attrs(self, state):
|
||||
return jnp.array(
|
||||
[0, self.aggregation_default, -1]
|
||||
) # activation=-1 means Act.identity
|
||||
|
||||
def new_random_attrs(self, state, randkey):
|
||||
k1, k2, k3 = jax.random.split(randkey, num=3)
|
||||
bias = jax.random.normal(k1, ()) * self.bias_init_std + self.bias_init_mean
|
||||
bias = jnp.clip(bias, self.bias_lower_bound, self.bias_upper_bound)
|
||||
agg = jax.random.choice(k2, self.aggregation_indices)
|
||||
act = jax.random.choice(k3, self.activation_indices)
|
||||
|
||||
return jnp.array([bias, agg, act])
|
||||
|
||||
def mutate(self, state, randkey, attrs):
|
||||
k1, k2, k3 = jax.random.split(randkey, num=3)
|
||||
bias, agg, act = attrs
|
||||
|
||||
bias = mutate_float(
|
||||
k1,
|
||||
bias,
|
||||
self.bias_init_mean,
|
||||
self.bias_init_std,
|
||||
self.bias_mutate_power,
|
||||
self.bias_mutate_rate,
|
||||
self.bias_replace_rate,
|
||||
)
|
||||
bias = jnp.clip(bias, self.bias_lower_bound, self.bias_upper_bound)
|
||||
agg = mutate_int(
|
||||
k2, agg, self.aggregation_indices, self.aggregation_replace_rate
|
||||
)
|
||||
|
||||
act = mutate_int(k3, act, self.activation_indices, self.activation_replace_rate)
|
||||
|
||||
return jnp.array([bias, agg, act])
|
||||
|
||||
def distance(self, state, attrs1, attrs2):
|
||||
bias1, agg1, act1 = attrs1
|
||||
bias2, agg2, act2 = attrs2
|
||||
|
||||
return jnp.abs(bias1 - bias2) + (agg1 != agg2) + (act1 != act2)
|
||||
|
||||
def forward(self, state, attrs, inputs, is_output_node=False):
|
||||
bias, agg, act = attrs
|
||||
|
||||
z = agg_func(agg, inputs, self.aggregation_options)
|
||||
z = bias + z
|
||||
|
||||
# the last output node should not be activated
|
||||
z = jax.lax.cond(
|
||||
is_output_node, lambda: z, lambda: act_func(act, z, self.activation_options)
|
||||
)
|
||||
|
||||
return z
|
||||
|
||||
def repr(self, state, node, precision=2, idx_width=3, func_width=8):
|
||||
idx, bias, agg, act = node
|
||||
|
||||
idx = int(idx)
|
||||
bias = round(float(bias), precision)
|
||||
agg = int(agg)
|
||||
act = int(act)
|
||||
|
||||
if act == -1:
|
||||
act_func = Act.identity
|
||||
else:
|
||||
act_func = self.activation_options[act]
|
||||
return "{}(idx={:<{idx_width}}, bias={:<{float_width}}, aggregation={:<{func_width}}, activation={:<{func_width}})".format(
|
||||
self.__class__.__name__,
|
||||
idx,
|
||||
bias,
|
||||
self.aggregation_options[agg].__name__,
|
||||
act_func.__name__,
|
||||
idx_width=idx_width,
|
||||
float_width=precision + 3,
|
||||
func_width=func_width,
|
||||
)
|
||||
|
||||
def to_dict(self, state, node):
|
||||
idx, bias, agg, act = node
|
||||
|
||||
idx = int(idx)
|
||||
|
||||
bias = jnp.float32(bias)
|
||||
agg = int(agg)
|
||||
act = int(act)
|
||||
|
||||
if act == -1:
|
||||
act_func = Act.identity
|
||||
else:
|
||||
act_func = self.activation_options[act]
|
||||
|
||||
return {
|
||||
"idx": idx,
|
||||
"bias": bias,
|
||||
"agg": self.aggregation_options[int(agg)].__name__,
|
||||
"act": act_func.__name__,
|
||||
}
|
||||
|
||||
def sympy_func(self, state, node_dict, inputs, is_output_node=False):
|
||||
nd = node_dict
|
||||
|
||||
bias = sp.symbols(f"n_{nd['idx']}_b")
|
||||
|
||||
z = convert_to_sympy(nd["agg"])(inputs)
|
||||
|
||||
z = bias + z
|
||||
if is_output_node:
|
||||
pass
|
||||
else:
|
||||
z = convert_to_sympy(nd["act"])(z)
|
||||
|
||||
return z, {bias: nd["bias"]}
|
||||
220
src/tensorneat/genome/gene/node/default.py
Normal file
220
src/tensorneat/genome/gene/node/default.py
Normal file
@@ -0,0 +1,220 @@
|
||||
from typing import Optional, Union, Sequence, Callable
|
||||
|
||||
import numpy as np
|
||||
import jax, jax.numpy as jnp
|
||||
import sympy as sp
|
||||
|
||||
from tensorneat.common import (
|
||||
Act,
|
||||
Agg,
|
||||
act_func,
|
||||
agg_func,
|
||||
mutate_int,
|
||||
mutate_float,
|
||||
convert_to_sympy,
|
||||
)
|
||||
|
||||
from .base import BaseNode
|
||||
|
||||
|
||||
class DefaultNode(BaseNode):
|
||||
"Default node gene, with the same behavior as in NEAT-python."
|
||||
|
||||
custom_attrs = ["bias", "response", "aggregation", "activation"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bias_init_mean: float = 0.0,
|
||||
bias_init_std: float = 1.0,
|
||||
bias_mutate_power: float = 0.15,
|
||||
bias_mutate_rate: float = 0.2,
|
||||
bias_replace_rate: float = 0.015,
|
||||
bias_lower_bound: float = -5,
|
||||
bias_upper_bound: float = 5,
|
||||
response_init_mean: float = 1.0,
|
||||
response_init_std: float = 0.0,
|
||||
response_mutate_power: float = 0.15,
|
||||
response_mutate_rate: float = 0.2,
|
||||
response_replace_rate: float = 0.015,
|
||||
response_lower_bound: float = -5,
|
||||
response_upper_bound: float = 5,
|
||||
aggregation_default: Optional[Callable] = None,
|
||||
aggregation_options: Union[Callable, Sequence[Callable]] = Agg.sum,
|
||||
aggregation_replace_rate: float = 0.1,
|
||||
activation_default: Optional[Callable] = None,
|
||||
activation_options: Union[Callable, Sequence[Callable]] = Act.sigmoid,
|
||||
activation_replace_rate: float = 0.1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if isinstance(aggregation_options, Callable):
|
||||
aggregation_options = [aggregation_options]
|
||||
if isinstance(activation_options, Callable):
|
||||
activation_options = [activation_options]
|
||||
|
||||
if aggregation_default is None:
|
||||
aggregation_default = aggregation_options[0]
|
||||
if activation_default is None:
|
||||
activation_default = activation_options[0]
|
||||
|
||||
self.bias_init_mean = bias_init_mean
|
||||
self.bias_init_std = bias_init_std
|
||||
self.bias_mutate_power = bias_mutate_power
|
||||
self.bias_mutate_rate = bias_mutate_rate
|
||||
self.bias_replace_rate = bias_replace_rate
|
||||
self.bias_lower_bound = bias_lower_bound
|
||||
self.bias_upper_bound = bias_upper_bound
|
||||
|
||||
self.response_init_mean = response_init_mean
|
||||
self.response_init_std = response_init_std
|
||||
self.response_mutate_power = response_mutate_power
|
||||
self.response_mutate_rate = response_mutate_rate
|
||||
self.response_replace_rate = response_replace_rate
|
||||
self.reponse_lower_bound = response_lower_bound
|
||||
self.response_upper_bound = response_upper_bound
|
||||
|
||||
self.aggregation_default = aggregation_options.index(aggregation_default)
|
||||
self.aggregation_options = aggregation_options
|
||||
self.aggregation_indices = np.arange(len(aggregation_options))
|
||||
self.aggregation_replace_rate = aggregation_replace_rate
|
||||
|
||||
self.activation_default = activation_options.index(activation_default)
|
||||
self.activation_options = activation_options
|
||||
self.activation_indices = np.arange(len(activation_options))
|
||||
self.activation_replace_rate = activation_replace_rate
|
||||
|
||||
def new_identity_attrs(self, state):
|
||||
bias = 0
|
||||
res = 1
|
||||
agg = self.aggregation_default
|
||||
act = self.activation_default
|
||||
|
||||
return jnp.array([bias, res, agg, act]) # activation=-1 means Act.identity
|
||||
|
||||
def new_random_attrs(self, state, randkey):
|
||||
k1, k2, k3, k4 = jax.random.split(randkey, num=4)
|
||||
bias = jax.random.normal(k1, ()) * self.bias_init_std + self.bias_init_mean
|
||||
bias = jnp.clip(bias, self.bias_lower_bound, self.bias_upper_bound)
|
||||
res = (
|
||||
jax.random.normal(k2, ()) * self.response_init_std + self.response_init_mean
|
||||
)
|
||||
res = jnp.clip(res, self.reponse_lower_bound, self.response_upper_bound)
|
||||
agg = jax.random.choice(k3, self.aggregation_indices)
|
||||
act = jax.random.choice(k4, self.activation_indices)
|
||||
|
||||
return jnp.array([bias, res, agg, act])
|
||||
|
||||
def mutate(self, state, randkey, attrs):
|
||||
k1, k2, k3, k4 = jax.random.split(randkey, num=4)
|
||||
bias, res, agg, act = attrs
|
||||
bias = mutate_float(
|
||||
k1,
|
||||
bias,
|
||||
self.bias_init_mean,
|
||||
self.bias_init_std,
|
||||
self.bias_mutate_power,
|
||||
self.bias_mutate_rate,
|
||||
self.bias_replace_rate,
|
||||
)
|
||||
bias = jnp.clip(bias, self.bias_lower_bound, self.bias_upper_bound)
|
||||
res = mutate_float(
|
||||
k2,
|
||||
res,
|
||||
self.response_init_mean,
|
||||
self.response_init_std,
|
||||
self.response_mutate_power,
|
||||
self.response_mutate_rate,
|
||||
self.response_replace_rate,
|
||||
)
|
||||
res = jnp.clip(res, self.reponse_lower_bound, self.response_upper_bound)
|
||||
agg = mutate_int(
|
||||
k4, agg, self.aggregation_indices, self.aggregation_replace_rate
|
||||
)
|
||||
|
||||
act = mutate_int(k3, act, self.activation_indices, self.activation_replace_rate)
|
||||
|
||||
return jnp.array([bias, res, agg, act])
|
||||
|
||||
def distance(self, state, attrs1, attrs2):
|
||||
bias1, res1, agg1, act1 = attrs1
|
||||
bias2, res2, agg2, act2 = attrs2
|
||||
return (
|
||||
jnp.abs(bias1 - bias2) # bias
|
||||
+ jnp.abs(res1 - res2) # response
|
||||
+ (agg1 != agg2) # aggregation
|
||||
+ (act1 != act2) # activation
|
||||
)
|
||||
|
||||
def forward(self, state, attrs, inputs, is_output_node=False):
|
||||
bias, res, agg, act = attrs
|
||||
|
||||
z = agg_func(agg, inputs, self.aggregation_options)
|
||||
z = bias + res * z
|
||||
|
||||
# the last output node should not be activated
|
||||
z = jax.lax.cond(
|
||||
is_output_node, lambda: z, lambda: act_func(act, z, self.activation_options)
|
||||
)
|
||||
|
||||
return z
|
||||
|
||||
def repr(self, state, node, precision=2, idx_width=3, func_width=8):
|
||||
idx, bias, res, agg, act = node
|
||||
|
||||
idx = int(idx)
|
||||
bias = round(float(bias), precision)
|
||||
res = round(float(res), precision)
|
||||
agg = int(agg)
|
||||
act = int(act)
|
||||
|
||||
if act == -1:
|
||||
act_func = Act.identity
|
||||
else:
|
||||
act_func = self.activation_options[act]
|
||||
return "{}(idx={:<{idx_width}}, bias={:<{float_width}}, response={:<{float_width}}, aggregation={:<{func_width}}, activation={:<{func_width}})".format(
|
||||
self.__class__.__name__,
|
||||
idx,
|
||||
bias,
|
||||
res,
|
||||
self.aggregation_options[agg].__name__,
|
||||
act_func.__name__,
|
||||
idx_width=idx_width,
|
||||
float_width=precision + 3,
|
||||
func_width=func_width,
|
||||
)
|
||||
|
||||
def to_dict(self, state, node):
|
||||
idx, bias, res, agg, act = node
|
||||
|
||||
idx = int(idx)
|
||||
bias = jnp.float32(bias)
|
||||
res = jnp.float32(res)
|
||||
agg = int(agg)
|
||||
act = int(act)
|
||||
|
||||
if act == -1:
|
||||
act_func = Act.identity
|
||||
else:
|
||||
act_func = self.activation_options[act]
|
||||
return {
|
||||
"idx": idx,
|
||||
"bias": bias,
|
||||
"res": res,
|
||||
"agg": self.aggregation_options[int(agg)].__name__,
|
||||
"act": act_func.__name__,
|
||||
}
|
||||
|
||||
def sympy_func(self, state, node_dict, inputs, is_output_node=False):
|
||||
nd = node_dict
|
||||
bias = sp.symbols(f"n_{nd['idx']}_b")
|
||||
res = sp.symbols(f"n_{nd['idx']}_r")
|
||||
|
||||
z = convert_to_sympy(nd["agg"])(inputs)
|
||||
z = bias + res * z
|
||||
|
||||
if is_output_node:
|
||||
pass
|
||||
else:
|
||||
z = convert_to_sympy(nd["act"])(z)
|
||||
|
||||
return z, {bias: nd["bias"], res: nd["res"]}
|
||||
3
src/tensorneat/genome/operations/__init__.py
Normal file
3
src/tensorneat/genome/operations/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .crossover import BaseCrossover, DefaultCrossover
|
||||
from .mutation import BaseMutation, DefaultMutation
|
||||
from .distance import BaseDistance, DefaultDistance
|
||||
2
src/tensorneat/genome/operations/crossover/__init__.py
Normal file
2
src/tensorneat/genome/operations/crossover/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .base import BaseCrossover
|
||||
from .default import DefaultCrossover
|
||||
12
src/tensorneat/genome/operations/crossover/base.py
Normal file
12
src/tensorneat/genome/operations/crossover/base.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from tensorneat.common import StatefulBaseClass, State
|
||||
|
||||
|
||||
class BaseCrossover(StatefulBaseClass):
|
||||
|
||||
def setup(self, state=State(), genome = None):
|
||||
assert genome is not None, "genome should not be None"
|
||||
self.genome = genome
|
||||
return state
|
||||
|
||||
def __call__(self, state, randkey, nodes1, nodes2, conns1, conns2):
|
||||
raise NotImplementedError
|
||||
87
src/tensorneat/genome/operations/crossover/default.py
Normal file
87
src/tensorneat/genome/operations/crossover/default.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import jax
|
||||
from jax import vmap, numpy as jnp
|
||||
|
||||
from .base import BaseCrossover
|
||||
from ...utils import (
|
||||
extract_node_attrs,
|
||||
extract_conn_attrs,
|
||||
set_node_attrs,
|
||||
set_conn_attrs,
|
||||
)
|
||||
|
||||
|
||||
class DefaultCrossover(BaseCrossover):
|
||||
def __call__(self, state, randkey, nodes1, conns1, nodes2, conns2):
|
||||
"""
|
||||
use genome1 and genome2 to generate a new genome
|
||||
notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
|
||||
"""
|
||||
randkey1, randkey2 = jax.random.split(randkey, 2)
|
||||
randkeys1 = jax.random.split(randkey1, self.genome.max_nodes)
|
||||
randkeys2 = jax.random.split(randkey2, self.genome.max_conns)
|
||||
|
||||
# crossover nodes
|
||||
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
|
||||
# make homologous genes align in nodes2 align with nodes1
|
||||
nodes2 = self.align_array(keys1, keys2, nodes2, is_conn=False)
|
||||
|
||||
# For not homologous genes, use the value of nodes1(winner)
|
||||
# For homologous genes, use the crossover result between nodes1 and nodes2
|
||||
node_attrs1 = vmap(extract_node_attrs)(nodes1)
|
||||
node_attrs2 = vmap(extract_node_attrs)(nodes2)
|
||||
|
||||
new_node_attrs = jnp.where(
|
||||
jnp.isnan(node_attrs1) | jnp.isnan(node_attrs2), # one of them is nan
|
||||
node_attrs1, # not homologous genes or both nan, use the value of nodes1(winner)
|
||||
vmap(self.genome.node_gene.crossover, in_axes=(None, 0, 0, 0))(
|
||||
state, randkeys1, node_attrs1, node_attrs2
|
||||
), # homologous or both nan
|
||||
)
|
||||
new_nodes = vmap(set_node_attrs)(nodes1, new_node_attrs)
|
||||
|
||||
# crossover connections
|
||||
con_keys1, con_keys2 = conns1[:, :2], conns2[:, :2]
|
||||
conns2 = self.align_array(con_keys1, con_keys2, conns2, is_conn=True)
|
||||
|
||||
conns_attrs1 = vmap(extract_conn_attrs)(conns1)
|
||||
conns_attrs2 = vmap(extract_conn_attrs)(conns2)
|
||||
|
||||
new_conn_attrs = jnp.where(
|
||||
jnp.isnan(conns_attrs1) | jnp.isnan(conns_attrs2),
|
||||
conns_attrs1, # not homologous genes or both nan, use the value of conns1(winner)
|
||||
vmap(self.genome.conn_gene.crossover, in_axes=(None, 0, 0, 0))(
|
||||
state, randkeys2, conns_attrs1, conns_attrs2
|
||||
), # homologous or both nan
|
||||
)
|
||||
new_conns = vmap(set_conn_attrs)(conns1, new_conn_attrs)
|
||||
|
||||
return new_nodes, new_conns
|
||||
|
||||
def align_array(self, seq1, seq2, ar2, is_conn: bool):
|
||||
"""
|
||||
After I review this code, I found that it is the most difficult part of the code.
|
||||
Please consider carefully before change it!
|
||||
make ar2 align with ar1.
|
||||
:param seq1:
|
||||
:param seq2:
|
||||
:param ar2:
|
||||
:param is_conn:
|
||||
:return:
|
||||
align means to intersect part of ar2 will be at the same position as ar1,
|
||||
non-intersect part of ar2 will be set to Nan
|
||||
"""
|
||||
seq1, seq2 = seq1[:, jnp.newaxis], seq2[jnp.newaxis, :]
|
||||
mask = (seq1 == seq2) & (~jnp.isnan(seq1))
|
||||
|
||||
if is_conn:
|
||||
mask = jnp.all(mask, axis=2)
|
||||
|
||||
intersect_mask = mask.any(axis=1)
|
||||
idx = jnp.arange(0, len(seq1))
|
||||
idx_fixed = jnp.dot(mask, idx)
|
||||
|
||||
refactor_ar2 = jnp.where(
|
||||
intersect_mask[:, jnp.newaxis], ar2[idx_fixed], jnp.nan
|
||||
)
|
||||
|
||||
return refactor_ar2
|
||||
2
src/tensorneat/genome/operations/distance/__init__.py
Normal file
2
src/tensorneat/genome/operations/distance/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .base import BaseDistance
|
||||
from .default import DefaultDistance
|
||||
15
src/tensorneat/genome/operations/distance/base.py
Normal file
15
src/tensorneat/genome/operations/distance/base.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from tensorneat.common import StatefulBaseClass, State
|
||||
|
||||
|
||||
class BaseDistance(StatefulBaseClass):
|
||||
|
||||
def setup(self, state=State(), genome = None):
|
||||
assert genome is not None, "genome should not be None"
|
||||
self.genome = genome
|
||||
return state
|
||||
|
||||
def __call__(self, state, nodes1, nodes2, conns1, conns2):
|
||||
"""
|
||||
The distance between two genomes
|
||||
"""
|
||||
raise NotImplementedError
|
||||
105
src/tensorneat/genome/operations/distance/default.py
Normal file
105
src/tensorneat/genome/operations/distance/default.py
Normal file
@@ -0,0 +1,105 @@
|
||||
from jax import vmap, numpy as jnp
|
||||
|
||||
from .base import BaseDistance
|
||||
from ...utils import extract_node_attrs, extract_conn_attrs
|
||||
|
||||
|
||||
class DefaultDistance(BaseDistance):
|
||||
def __init__(
|
||||
self,
|
||||
compatibility_disjoint: float = 1.0,
|
||||
compatibility_weight: float = 0.4,
|
||||
):
|
||||
self.compatibility_disjoint = compatibility_disjoint
|
||||
self.compatibility_weight = compatibility_weight
|
||||
|
||||
def __call__(self, state, nodes1, conns1, nodes2, conns2):
|
||||
"""
|
||||
The distance between two genomes
|
||||
"""
|
||||
d = self.node_distance(state, nodes1, nodes2) + self.conn_distance(
|
||||
state, conns1, conns2
|
||||
)
|
||||
return d
|
||||
|
||||
def node_distance(self, state, nodes1, nodes2):
|
||||
"""
|
||||
The distance of the nodes part for two genomes
|
||||
"""
|
||||
node_cnt1 = jnp.sum(~jnp.isnan(nodes1[:, 0]))
|
||||
node_cnt2 = jnp.sum(~jnp.isnan(nodes2[:, 0]))
|
||||
max_cnt = jnp.maximum(node_cnt1, node_cnt2)
|
||||
|
||||
# align homologous nodes
|
||||
# this process is similar to np.intersect1d.
|
||||
nodes = jnp.concatenate((nodes1, nodes2), axis=0)
|
||||
keys = nodes[:, 0]
|
||||
sorted_indices = jnp.argsort(keys, axis=0)
|
||||
nodes = nodes[sorted_indices]
|
||||
nodes = jnp.concatenate(
|
||||
[nodes, jnp.full((1, nodes.shape[1]), jnp.nan)], axis=0
|
||||
) # add a nan row to the end
|
||||
fr, sr = nodes[:-1], nodes[1:] # first row, second row
|
||||
|
||||
# flag location of homologous nodes
|
||||
intersect_mask = (fr[:, 0] == sr[:, 0]) & ~jnp.isnan(nodes[:-1, 0])
|
||||
|
||||
# calculate the count of non_homologous of two genomes
|
||||
non_homologous_cnt = node_cnt1 + node_cnt2 - 2 * jnp.sum(intersect_mask)
|
||||
|
||||
# calculate the distance of homologous nodes
|
||||
fr_attrs = vmap(extract_node_attrs)(fr)
|
||||
sr_attrs = vmap(extract_node_attrs)(sr)
|
||||
hnd = vmap(self.genome.node_gene.distance, in_axes=(None, 0, 0))(
|
||||
state, fr_attrs, sr_attrs
|
||||
) # homologous node distance
|
||||
hnd = jnp.where(jnp.isnan(hnd), 0, hnd)
|
||||
homologous_distance = jnp.sum(hnd * intersect_mask)
|
||||
|
||||
val = (
|
||||
non_homologous_cnt * self.compatibility_disjoint
|
||||
+ homologous_distance * self.compatibility_weight
|
||||
)
|
||||
|
||||
val = jnp.where(max_cnt == 0, 0, val / max_cnt) # normalize
|
||||
|
||||
return val
|
||||
|
||||
def conn_distance(self, state, conns1, conns2):
|
||||
"""
|
||||
The distance of the conns part for two genomes
|
||||
"""
|
||||
con_cnt1 = jnp.sum(~jnp.isnan(conns1[:, 0]))
|
||||
con_cnt2 = jnp.sum(~jnp.isnan(conns2[:, 0]))
|
||||
max_cnt = jnp.maximum(con_cnt1, con_cnt2)
|
||||
|
||||
cons = jnp.concatenate((conns1, conns2), axis=0)
|
||||
keys = cons[:, :2]
|
||||
sorted_indices = jnp.lexsort(keys.T[::-1])
|
||||
cons = cons[sorted_indices]
|
||||
cons = jnp.concatenate(
|
||||
[cons, jnp.full((1, cons.shape[1]), jnp.nan)], axis=0
|
||||
) # add a nan row to the end
|
||||
fr, sr = cons[:-1], cons[1:] # first row, second row
|
||||
|
||||
# both genome has such connection
|
||||
intersect_mask = jnp.all(fr[:, :2] == sr[:, :2], axis=1) & ~jnp.isnan(fr[:, 0])
|
||||
|
||||
non_homologous_cnt = con_cnt1 + con_cnt2 - 2 * jnp.sum(intersect_mask)
|
||||
|
||||
fr_attrs = vmap(extract_conn_attrs)(fr)
|
||||
sr_attrs = vmap(extract_conn_attrs)(sr)
|
||||
hcd = vmap(self.genome.conn_gene.distance, in_axes=(None, 0, 0))(
|
||||
state, fr_attrs, sr_attrs
|
||||
) # homologous connection distance
|
||||
hcd = jnp.where(jnp.isnan(hcd), 0, hcd)
|
||||
homologous_distance = jnp.sum(hcd * intersect_mask)
|
||||
|
||||
val = (
|
||||
non_homologous_cnt * self.compatibility_disjoint
|
||||
+ homologous_distance * self.compatibility_weight
|
||||
)
|
||||
|
||||
val = jnp.where(max_cnt == 0, 0, val / max_cnt) # normalize
|
||||
|
||||
return val
|
||||
2
src/tensorneat/genome/operations/mutation/__init__.py
Normal file
2
src/tensorneat/genome/operations/mutation/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .base import BaseMutation
|
||||
from .default import DefaultMutation
|
||||
12
src/tensorneat/genome/operations/mutation/base.py
Normal file
12
src/tensorneat/genome/operations/mutation/base.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from tensorneat.common import StatefulBaseClass, State
|
||||
|
||||
|
||||
class BaseMutation(StatefulBaseClass):
|
||||
|
||||
def setup(self, state=State(), genome = None):
|
||||
assert genome is not None, "genome should not be None"
|
||||
self.genome = genome
|
||||
return state
|
||||
|
||||
def __call__(self, state, randkey, nodes, conns, new_node_key):
|
||||
raise NotImplementedError
|
||||
292
src/tensorneat/genome/operations/mutation/default.py
Normal file
292
src/tensorneat/genome/operations/mutation/default.py
Normal file
@@ -0,0 +1,292 @@
|
||||
import jax
|
||||
from jax import vmap, numpy as jnp
|
||||
from . import BaseMutation
|
||||
from tensorneat.common import (
|
||||
fetch_first,
|
||||
fetch_random,
|
||||
I_INF,
|
||||
check_cycles,
|
||||
)
|
||||
from ...utils import (
|
||||
unflatten_conns,
|
||||
add_node,
|
||||
add_conn,
|
||||
delete_node_by_pos,
|
||||
delete_conn_by_pos,
|
||||
extract_node_attrs,
|
||||
extract_conn_attrs,
|
||||
set_node_attrs,
|
||||
set_conn_attrs,
|
||||
)
|
||||
|
||||
|
||||
class DefaultMutation(BaseMutation):
|
||||
def __init__(
|
||||
self,
|
||||
conn_add: float = 0.2,
|
||||
conn_delete: float = 0.2,
|
||||
node_add: float = 0.1,
|
||||
node_delete: float = 0.1,
|
||||
):
|
||||
self.conn_add = conn_add
|
||||
self.conn_delete = conn_delete
|
||||
self.node_add = node_add
|
||||
self.node_delete = node_delete
|
||||
|
||||
def __call__(self, state, randkey, nodes, conns, new_node_key):
|
||||
k1, k2 = jax.random.split(randkey)
|
||||
|
||||
nodes, conns = self.mutate_structure(
|
||||
state, k1, nodes, conns, new_node_key
|
||||
)
|
||||
nodes, conns = self.mutate_values(state, k2, nodes, conns)
|
||||
|
||||
return nodes, conns
|
||||
|
||||
def mutate_structure(self, state, randkey, nodes, conns, new_node_key):
|
||||
def mutate_add_node(key_, nodes_, conns_):
|
||||
"""
|
||||
add a node while do not influence the output of the network
|
||||
"""
|
||||
|
||||
remain_node_space = jnp.isnan(nodes_[:, 0]).sum()
|
||||
remain_conn_space = jnp.isnan(conns_[:, 0]).sum()
|
||||
i_key, o_key, idx = self.choose_connection_key(
|
||||
key_, conns_
|
||||
) # choose a connection
|
||||
|
||||
def successful_add_node():
|
||||
# remove the original connection and record its attrs
|
||||
original_attrs = extract_conn_attrs(conns_[idx])
|
||||
new_conns = delete_conn_by_pos(conns_, idx)
|
||||
|
||||
# add a new node with identity attrs
|
||||
new_nodes = add_node(
|
||||
nodes_, new_node_key, self.genome.node_gene.new_identity_attrs(state)
|
||||
)
|
||||
|
||||
# add two new connections
|
||||
# first is with identity attrs
|
||||
new_conns = add_conn(
|
||||
new_conns,
|
||||
i_key,
|
||||
new_node_key,
|
||||
self.genome.conn_gene.new_identity_attrs(state),
|
||||
)
|
||||
# second is with the origin attrs
|
||||
new_conns = add_conn(
|
||||
new_conns,
|
||||
new_node_key,
|
||||
o_key,
|
||||
original_attrs,
|
||||
)
|
||||
|
||||
return new_nodes, new_conns
|
||||
|
||||
return jax.lax.cond(
|
||||
(idx == I_INF) | (remain_node_space < 1) | (remain_conn_space < 2),
|
||||
lambda: (nodes_, conns_), # do nothing
|
||||
successful_add_node,
|
||||
)
|
||||
|
||||
def mutate_delete_node(key_, nodes_, conns_):
|
||||
"""
|
||||
delete a node
|
||||
"""
|
||||
# randomly choose a node
|
||||
key, idx = self.choose_node_key(
|
||||
key_,
|
||||
nodes_,
|
||||
self.genome.input_idx,
|
||||
self.genome.output_idx,
|
||||
allow_input_keys=False,
|
||||
allow_output_keys=False,
|
||||
)
|
||||
|
||||
def successful_delete_node():
|
||||
# delete the node
|
||||
new_nodes = delete_node_by_pos(nodes_, idx)
|
||||
|
||||
# delete all connections
|
||||
new_conns = jnp.where(
|
||||
((conns_[:, 0] == key) | (conns_[:, 1] == key))[:, None],
|
||||
jnp.nan,
|
||||
conns_,
|
||||
)
|
||||
|
||||
return new_nodes, new_conns
|
||||
|
||||
return jax.lax.cond(
|
||||
idx == I_INF, # no available node to delete
|
||||
lambda: (nodes_, conns_), # do nothing
|
||||
successful_delete_node,
|
||||
)
|
||||
|
||||
def mutate_add_conn(key_, nodes_, conns_):
|
||||
"""
|
||||
add a connection while do not influence the output of the network
|
||||
"""
|
||||
|
||||
remain_conn_space = jnp.isnan(conns_[:, 0]).sum()
|
||||
|
||||
# randomly choose two nodes
|
||||
k1_, k2_ = jax.random.split(key_, num=2)
|
||||
|
||||
# input node of the connection can be any node
|
||||
i_key, from_idx = self.choose_node_key(
|
||||
k1_,
|
||||
nodes_,
|
||||
self.genome.input_idx,
|
||||
self.genome.output_idx,
|
||||
allow_input_keys=True,
|
||||
allow_output_keys=True,
|
||||
)
|
||||
|
||||
# output node of the connection can be any node except input node
|
||||
o_key, to_idx = self.choose_node_key(
|
||||
k2_,
|
||||
nodes_,
|
||||
self.genome.input_idx,
|
||||
self.genome.output_idx,
|
||||
allow_input_keys=False,
|
||||
allow_output_keys=True,
|
||||
)
|
||||
|
||||
conn_pos = fetch_first((conns_[:, 0] == i_key) & (conns_[:, 1] == o_key))
|
||||
is_already_exist = conn_pos != I_INF
|
||||
|
||||
def nothing():
|
||||
return nodes_, conns_
|
||||
|
||||
def successful():
|
||||
# add a connection with zero attrs
|
||||
return nodes_, add_conn(
|
||||
conns_, i_key, o_key, self.genome.conn_gene.new_zero_attrs(state)
|
||||
)
|
||||
|
||||
if self.genome.network_type == "feedforward":
|
||||
u_conns = unflatten_conns(nodes_, conns_)
|
||||
conns_exist = u_conns != I_INF
|
||||
is_cycle = check_cycles(nodes_, conns_exist, from_idx, to_idx)
|
||||
|
||||
return jax.lax.cond(
|
||||
is_already_exist | is_cycle | (remain_conn_space < 1),
|
||||
nothing,
|
||||
successful,
|
||||
)
|
||||
|
||||
elif self.genome.network_type == "recurrent":
|
||||
return jax.lax.cond(
|
||||
is_already_exist | (remain_conn_space < 1),
|
||||
nothing,
|
||||
successful,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid network type: {self.genome.network_type}")
|
||||
|
||||
def mutate_delete_conn(key_, nodes_, conns_):
|
||||
# randomly choose a connection
|
||||
i_key, o_key, idx = self.choose_connection_key(key_, conns_)
|
||||
|
||||
return jax.lax.cond(
|
||||
idx == I_INF,
|
||||
lambda: (nodes_, conns_), # nothing
|
||||
lambda: (nodes_, delete_conn_by_pos(conns_, idx)), # success
|
||||
)
|
||||
|
||||
k1, k2, k3, k4 = jax.random.split(randkey, num=4)
|
||||
r1, r2, r3, r4 = jax.random.uniform(k1, shape=(4,))
|
||||
|
||||
def nothing(_, nodes_, conns_):
|
||||
return nodes_, conns_
|
||||
|
||||
if self.node_add > 0:
|
||||
nodes, conns = jax.lax.cond(
|
||||
r1 < self.node_add, mutate_add_node, nothing, k1, nodes, conns
|
||||
)
|
||||
|
||||
if self.node_delete > 0:
|
||||
nodes, conns = jax.lax.cond(
|
||||
r2 < self.node_delete, mutate_delete_node, nothing, k2, nodes, conns
|
||||
)
|
||||
|
||||
if self.conn_add > 0:
|
||||
nodes, conns = jax.lax.cond(
|
||||
r3 < self.conn_add, mutate_add_conn, nothing, k3, nodes, conns
|
||||
)
|
||||
|
||||
if self.conn_delete > 0:
|
||||
nodes, conns = jax.lax.cond(
|
||||
r4 < self.conn_delete, mutate_delete_conn, nothing, k4, nodes, conns
|
||||
)
|
||||
|
||||
return nodes, conns
|
||||
|
||||
def mutate_values(self, state, randkey, nodes, conns):
|
||||
k1, k2 = jax.random.split(randkey)
|
||||
nodes_randkeys = jax.random.split(k1, num=self.genome.max_nodes)
|
||||
conns_randkeys = jax.random.split(k2, num=self.genome.max_conns)
|
||||
|
||||
node_attrs = vmap(extract_node_attrs)(nodes)
|
||||
new_node_attrs = vmap(self.genome.node_gene.mutate, in_axes=(None, 0, 0))(
|
||||
state, nodes_randkeys, node_attrs
|
||||
)
|
||||
new_nodes = vmap(set_node_attrs)(nodes, new_node_attrs)
|
||||
|
||||
conn_attrs = vmap(extract_conn_attrs)(conns)
|
||||
new_conn_attrs = vmap(self.genome.conn_gene.mutate, in_axes=(None, 0, 0))(
|
||||
state, conns_randkeys, conn_attrs
|
||||
)
|
||||
new_conns = vmap(set_conn_attrs)(conns, new_conn_attrs)
|
||||
|
||||
# nan nodes not changed
|
||||
new_nodes = jnp.where(jnp.isnan(nodes), jnp.nan, new_nodes)
|
||||
new_conns = jnp.where(jnp.isnan(conns), jnp.nan, new_conns)
|
||||
|
||||
return new_nodes, new_conns
|
||||
|
||||
def choose_node_key(
|
||||
self,
|
||||
key,
|
||||
nodes,
|
||||
input_idx,
|
||||
output_idx,
|
||||
allow_input_keys: bool = False,
|
||||
allow_output_keys: bool = False,
|
||||
):
|
||||
"""
|
||||
Randomly choose a node key from the given nodes. It guarantees that the chosen node not be the input or output node.
|
||||
:param key:
|
||||
:param nodes:
|
||||
:param input_idx:
|
||||
:param output_idx:
|
||||
:param allow_input_keys:
|
||||
:param allow_output_keys:
|
||||
:return: return its key and position(idx)
|
||||
"""
|
||||
|
||||
node_keys = nodes[:, 0]
|
||||
mask = ~jnp.isnan(node_keys)
|
||||
|
||||
if not allow_input_keys:
|
||||
mask = jnp.logical_and(mask, ~jnp.isin(node_keys, input_idx))
|
||||
|
||||
if not allow_output_keys:
|
||||
mask = jnp.logical_and(mask, ~jnp.isin(node_keys, output_idx))
|
||||
|
||||
idx = fetch_random(key, mask)
|
||||
key = jnp.where(idx != I_INF, nodes[idx, 0], jnp.nan)
|
||||
return key, idx
|
||||
|
||||
def choose_connection_key(self, key, conns):
|
||||
"""
|
||||
Randomly choose a connection key from the given connections.
|
||||
:return: i_key, o_key, idx
|
||||
"""
|
||||
|
||||
idx = fetch_random(key, ~jnp.isnan(conns[:, 0]))
|
||||
i_key = jnp.where(idx != I_INF, conns[idx, 0], jnp.nan)
|
||||
o_key = jnp.where(idx != I_INF, conns[idx, 1], jnp.nan)
|
||||
|
||||
return i_key, o_key, idx
|
||||
92
src/tensorneat/genome/recurrent.py
Normal file
92
src/tensorneat/genome/recurrent.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import jax
|
||||
from jax import vmap, numpy as jnp
|
||||
from .utils import unflatten_conns
|
||||
|
||||
from .base import BaseGenome
|
||||
from .gene import DefaultNode, DefaultConn
|
||||
from .operations import DefaultMutation, DefaultCrossover, DefaultDistance
|
||||
from .utils import unflatten_conns, extract_node_attrs, extract_conn_attrs
|
||||
|
||||
from tensorneat.common import attach_with_inf
|
||||
|
||||
class RecurrentGenome(BaseGenome):
|
||||
"""Default genome class, with the same behavior as the NEAT-Python"""
|
||||
|
||||
network_type = "recurrent"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_inputs: int,
|
||||
num_outputs: int,
|
||||
max_nodes=50,
|
||||
max_conns=100,
|
||||
node_gene=DefaultNode(),
|
||||
conn_gene=DefaultConn(),
|
||||
mutation=DefaultMutation(),
|
||||
crossover=DefaultCrossover(),
|
||||
distance=DefaultDistance(),
|
||||
output_transform=None,
|
||||
input_transform=None,
|
||||
init_hidden_layers=(),
|
||||
activate_time=10,
|
||||
):
|
||||
super().__init__(
|
||||
num_inputs,
|
||||
num_outputs,
|
||||
max_nodes,
|
||||
max_conns,
|
||||
node_gene,
|
||||
conn_gene,
|
||||
mutation,
|
||||
crossover,
|
||||
distance,
|
||||
output_transform,
|
||||
input_transform,
|
||||
init_hidden_layers,
|
||||
)
|
||||
self.activate_time = activate_time
|
||||
|
||||
def transform(self, state, nodes, conns):
|
||||
u_conns = unflatten_conns(nodes, conns)
|
||||
return nodes, conns, u_conns
|
||||
|
||||
def forward(self, state, transformed, inputs):
|
||||
nodes, conns, u_conns = transformed
|
||||
|
||||
vals = jnp.full((self.max_nodes,), jnp.nan)
|
||||
|
||||
nodes_attrs = vmap(extract_node_attrs)(nodes)
|
||||
conns_attrs = vmap(extract_conn_attrs)(conns)
|
||||
expand_conns_attrs = attach_with_inf(conns_attrs, u_conns)
|
||||
|
||||
def body_func(_, values):
|
||||
|
||||
# set input values
|
||||
values = values.at[self.input_idx].set(inputs)
|
||||
|
||||
# calculate connections
|
||||
node_ins = vmap(
|
||||
vmap(self.conn_gene.forward, in_axes=(None, 0, None)),
|
||||
in_axes=(None, 0, 0),
|
||||
)(state, expand_conns_attrs, values)
|
||||
|
||||
# calculate nodes
|
||||
is_output_nodes = jnp.isin(nodes[:, 0], self.output_idx)
|
||||
values = vmap(self.node_gene.forward, in_axes=(None, 0, 0, 0))(
|
||||
state, nodes_attrs, node_ins.T, is_output_nodes
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
vals = jax.lax.fori_loop(0, self.activate_time, body_func, vals)
|
||||
|
||||
if self.output_transform is None:
|
||||
return vals[self.output_idx]
|
||||
else:
|
||||
return self.output_transform(vals[self.output_idx])
|
||||
|
||||
def sympy_func(self, state, network, precision=3):
|
||||
raise ValueError("Sympy function is not supported for Recurrent Network!")
|
||||
|
||||
def visualize(self, network):
|
||||
raise ValueError("Visualize function is not supported for Recurrent Network!")
|
||||
109
src/tensorneat/genome/utils.py
Normal file
109
src/tensorneat/genome/utils.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import jax
|
||||
from jax import vmap, numpy as jnp
|
||||
|
||||
from tensorneat.common import fetch_first, I_INF
|
||||
|
||||
|
||||
def unflatten_conns(nodes, conns):
|
||||
"""
|
||||
transform the (C, CL) connections to (N, N), which contains the idx of the connection in conns
|
||||
connection length, N means the number of nodes, C means the number of connections
|
||||
returns the unflatten connection indices with shape (N, N)
|
||||
"""
|
||||
N = nodes.shape[0] # max_nodes
|
||||
C = conns.shape[0] # max_conns
|
||||
node_keys = nodes[:, 0]
|
||||
i_keys, o_keys = conns[:, 0], conns[:, 1]
|
||||
|
||||
def key_to_indices(key, keys):
|
||||
return fetch_first(key == keys)
|
||||
|
||||
i_idxs = vmap(key_to_indices, in_axes=(0, None))(i_keys, node_keys)
|
||||
o_idxs = vmap(key_to_indices, in_axes=(0, None))(o_keys, node_keys)
|
||||
|
||||
# Is interesting that jax use clip when attach data in array
|
||||
# however, it will do nothing when setting values in an array
|
||||
# put the index of connections in the unflatten array
|
||||
unflatten = (
|
||||
jnp.full((N, N), I_INF, dtype=jnp.int32)
|
||||
.at[i_idxs, o_idxs]
|
||||
.set(jnp.arange(C, dtype=jnp.int32))
|
||||
)
|
||||
|
||||
return unflatten
|
||||
|
||||
|
||||
def valid_cnt(nodes_or_conns):
|
||||
return jnp.sum(~jnp.isnan(nodes_or_conns[:, 0]))
|
||||
|
||||
|
||||
def extract_node_attrs(node):
|
||||
"""
|
||||
node: Array(NL, )
|
||||
extract the attributes of a node
|
||||
"""
|
||||
return node[1:] # 0 is for idx
|
||||
|
||||
|
||||
def set_node_attrs(node, attrs):
|
||||
"""
|
||||
node: Array(NL, )
|
||||
attrs: Array(NL-1, )
|
||||
set the attributes of a node
|
||||
"""
|
||||
return node.at[1:].set(attrs) # 0 is for idx
|
||||
|
||||
|
||||
def extract_conn_attrs(conn):
|
||||
"""
|
||||
conn: Array(CL, )
|
||||
extract the attributes of a connection
|
||||
"""
|
||||
return conn[2:] # 0, 1 is for in-idx and out-idx
|
||||
|
||||
|
||||
def set_conn_attrs(conn, attrs):
|
||||
"""
|
||||
conn: Array(CL, )
|
||||
attrs: Array(CL-2, )
|
||||
set the attributes of a connection
|
||||
"""
|
||||
return conn.at[2:].set(attrs) # 0, 1 is for in-idx and out-idx
|
||||
|
||||
|
||||
def add_node(nodes, new_key: int, attrs):
|
||||
"""
|
||||
Add a new node to the genome.
|
||||
The new node will place at the first NaN row.
|
||||
"""
|
||||
exist_keys = nodes[:, 0]
|
||||
pos = fetch_first(jnp.isnan(exist_keys))
|
||||
new_nodes = nodes.at[pos, 0].set(new_key)
|
||||
return new_nodes.at[pos, 1:].set(attrs)
|
||||
|
||||
|
||||
def delete_node_by_pos(nodes, pos):
|
||||
"""
|
||||
Delete a node from the genome.
|
||||
Delete the node by its pos in nodes.
|
||||
"""
|
||||
return nodes.at[pos].set(jnp.nan)
|
||||
|
||||
|
||||
def add_conn(conns, i_key, o_key, attrs):
|
||||
"""
|
||||
Add a new connection to the genome.
|
||||
The new connection will place at the first NaN row.
|
||||
"""
|
||||
con_keys = conns[:, 0]
|
||||
pos = fetch_first(jnp.isnan(con_keys))
|
||||
new_conns = conns.at[pos, 0:2].set(jnp.array([i_key, o_key]))
|
||||
return new_conns.at[pos, 2:].set(attrs)
|
||||
|
||||
|
||||
def delete_conn_by_pos(conns, pos):
|
||||
"""
|
||||
Delete a connection from the genome.
|
||||
Delete the connection by its idx.
|
||||
"""
|
||||
return conns.at[pos].set(jnp.nan)
|
||||
189
src/tensorneat/pipeline.py
Normal file
189
src/tensorneat/pipeline.py
Normal file
@@ -0,0 +1,189 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
import jax, jax.numpy as jnp
|
||||
import datetime, time
|
||||
import numpy as np
|
||||
|
||||
from tensorneat.algorithm import BaseAlgorithm
|
||||
from tensorneat.problem import BaseProblem
|
||||
from tensorneat.common import State, StatefulBaseClass
|
||||
|
||||
|
||||
class Pipeline(StatefulBaseClass):
|
||||
def __init__(
|
||||
self,
|
||||
algorithm: BaseAlgorithm,
|
||||
problem: BaseProblem,
|
||||
seed: int = 42,
|
||||
fitness_target: float = 1,
|
||||
generation_limit: int = 1000,
|
||||
is_save: bool = False,
|
||||
save_dir=None,
|
||||
):
|
||||
assert problem.jitable, "Currently, problem must be jitable"
|
||||
|
||||
self.algorithm = algorithm
|
||||
self.problem = problem
|
||||
self.seed = seed
|
||||
self.fitness_target = fitness_target
|
||||
self.generation_limit = generation_limit
|
||||
self.pop_size = self.algorithm.pop_size
|
||||
|
||||
np.random.seed(self.seed)
|
||||
|
||||
assert (
|
||||
algorithm.num_inputs == self.problem.input_shape[-1]
|
||||
), f"algorithm input shape is {algorithm.num_inputs} but problem input shape is {self.problem.input_shape}"
|
||||
|
||||
self.best_genome = None
|
||||
self.best_fitness = float("-inf")
|
||||
self.generation_timestamp = None
|
||||
self.is_save = is_save
|
||||
|
||||
if is_save:
|
||||
if save_dir is None:
|
||||
now = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
self.save_dir = f"./{self.__class__.__name__} {now}"
|
||||
else:
|
||||
self.save_dir = save_dir
|
||||
print(f"save to {self.save_dir}")
|
||||
if not os.path.exists(self.save_dir):
|
||||
os.makedirs(self.save_dir)
|
||||
self.genome_dir = os.path.join(self.save_dir, "genomes")
|
||||
if not os.path.exists(self.genome_dir):
|
||||
os.makedirs(self.genome_dir)
|
||||
|
||||
def setup(self, state=State()):
|
||||
print("initializing")
|
||||
state = state.register(randkey=jax.random.PRNGKey(self.seed))
|
||||
|
||||
state = self.algorithm.setup(state)
|
||||
state = self.problem.setup(state)
|
||||
|
||||
if self.is_save:
|
||||
# self.save(state=state, path=os.path.join(self.save_dir, "pipeline.pkl"))
|
||||
with open(os.path.join(self.save_dir, "config.txt"), "w") as f:
|
||||
f.write(json.dumps(self.show_config(), indent=4))
|
||||
# create log file
|
||||
with open(os.path.join(self.save_dir, "log.txt"), "w") as f:
|
||||
f.write("Generation,Max,Min,Mean,Std,Cost Time\n")
|
||||
|
||||
print("initializing finished")
|
||||
return state
|
||||
|
||||
def step(self, state):
|
||||
|
||||
randkey_, randkey = jax.random.split(state.randkey)
|
||||
keys = jax.random.split(randkey_, self.pop_size)
|
||||
|
||||
pop = self.algorithm.ask(state)
|
||||
|
||||
pop_transformed = jax.vmap(self.algorithm.transform, in_axes=(None, 0))(
|
||||
state, pop
|
||||
)
|
||||
|
||||
fitnesses = jax.vmap(self.problem.evaluate, in_axes=(None, 0, None, 0))(
|
||||
state, keys, self.algorithm.forward, pop_transformed
|
||||
)
|
||||
|
||||
# replace nan with -inf
|
||||
fitnesses = jnp.where(jnp.isnan(fitnesses), -jnp.inf, fitnesses)
|
||||
|
||||
previous_pop = self.algorithm.ask(state)
|
||||
state = self.algorithm.tell(state, fitnesses)
|
||||
|
||||
return state.update(randkey=randkey), previous_pop, fitnesses
|
||||
|
||||
def auto_run(self, state):
|
||||
print("start compile")
|
||||
tic = time.time()
|
||||
compiled_step = jax.jit(self.step).lower(state).compile()
|
||||
# compiled_step = self.step
|
||||
print(
|
||||
f"compile finished, cost time: {time.time() - tic:.6f}s",
|
||||
)
|
||||
|
||||
for _ in range(self.generation_limit):
|
||||
|
||||
self.generation_timestamp = time.time()
|
||||
|
||||
state, previous_pop, fitnesses = compiled_step(state)
|
||||
|
||||
fitnesses = jax.device_get(fitnesses)
|
||||
|
||||
self.analysis(state, previous_pop, fitnesses)
|
||||
|
||||
if max(fitnesses) >= self.fitness_target:
|
||||
print("Fitness limit reached!")
|
||||
break
|
||||
|
||||
if int(state.generation) >= self.generation_limit:
|
||||
print("Generation limit reached!")
|
||||
|
||||
if self.is_save:
|
||||
best_genome = jax.device_get(self.best_genome)
|
||||
with open(os.path.join(self.genome_dir, f"best_genome.npz"), "wb") as f:
|
||||
np.savez(
|
||||
f,
|
||||
nodes=best_genome[0],
|
||||
conns=best_genome[1],
|
||||
fitness=self.best_fitness,
|
||||
)
|
||||
|
||||
return state, self.best_genome
|
||||
|
||||
def analysis(self, state, pop, fitnesses):
|
||||
|
||||
generation = int(state.generation)
|
||||
|
||||
valid_fitnesses = fitnesses[~np.isinf(fitnesses)]
|
||||
|
||||
max_f, min_f, mean_f, std_f = (
|
||||
max(valid_fitnesses),
|
||||
min(valid_fitnesses),
|
||||
np.mean(valid_fitnesses),
|
||||
np.std(valid_fitnesses),
|
||||
)
|
||||
|
||||
new_timestamp = time.time()
|
||||
|
||||
cost_time = new_timestamp - self.generation_timestamp
|
||||
|
||||
max_idx = np.argmax(fitnesses)
|
||||
if fitnesses[max_idx] > self.best_fitness:
|
||||
self.best_fitness = fitnesses[max_idx]
|
||||
self.best_genome = pop[0][max_idx], pop[1][max_idx]
|
||||
|
||||
if self.is_save:
|
||||
# save best
|
||||
best_genome = jax.device_get((pop[0][max_idx], pop[1][max_idx]))
|
||||
file_name = os.path.join(
|
||||
self.genome_dir, f"{generation}.npz"
|
||||
)
|
||||
with open(file_name, "wb") as f:
|
||||
np.savez(
|
||||
f,
|
||||
nodes=best_genome[0],
|
||||
conns=best_genome[1],
|
||||
fitness=self.best_fitness,
|
||||
)
|
||||
|
||||
# append log
|
||||
with open(os.path.join(self.save_dir, "log.txt"), "a") as f:
|
||||
f.write(
|
||||
f"{generation},{max_f},{min_f},{mean_f},{std_f},{cost_time}\n"
|
||||
)
|
||||
|
||||
print(
|
||||
f"Generation: {generation}, Cost time: {cost_time * 1000:.2f}ms\n",
|
||||
f"\tfitness: valid cnt: {len(valid_fitnesses)}, max: {max_f:.4f}, min: {min_f:.4f}, mean: {mean_f:.4f}, std: {std_f:.4f}\n",
|
||||
)
|
||||
|
||||
self.algorithm.show_details(state, fitnesses)
|
||||
|
||||
def show(self, state, best, *args, **kwargs):
|
||||
transformed = self.algorithm.transform(state, best)
|
||||
self.problem.show(
|
||||
state, state.randkey, self.algorithm.forward, transformed, *args, **kwargs
|
||||
)
|
||||
3
src/tensorneat/problem/__init__.py
Normal file
3
src/tensorneat/problem/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .base import BaseProblem
|
||||
from .rl import *
|
||||
from .func_fit import *
|
||||
35
src/tensorneat/problem/base.py
Normal file
35
src/tensorneat/problem/base.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from typing import Callable
|
||||
|
||||
from tensorneat.common import State, StatefulBaseClass
|
||||
|
||||
|
||||
class BaseProblem(StatefulBaseClass):
|
||||
jitable = None
|
||||
|
||||
def evaluate(self, state: State, randkey, act_func: Callable, params):
|
||||
"""evaluate one individual"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
"""
|
||||
The input shape for the problem to evaluate
|
||||
In RL problem, it is the observation space
|
||||
In function fitting problem, it is the input shape of the function
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
"""
|
||||
The output shape for the problem to evaluate
|
||||
In RL problem, it is the action space
|
||||
In function fitting problem, it is the output shape of the function
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def show(self, state: State, randkey, act_func: Callable, params, *args, **kwargs):
|
||||
"""
|
||||
show how a genome perform in this problem
|
||||
"""
|
||||
raise NotImplementedError
|
||||
4
src/tensorneat/problem/func_fit/__init__.py
Normal file
4
src/tensorneat/problem/func_fit/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .xor import XOR
|
||||
from .xor3d import XOR3d
|
||||
from .custom import CustomFuncFit
|
||||
from .func_fit import FuncFit
|
||||
117
src/tensorneat/problem/func_fit/custom.py
Normal file
117
src/tensorneat/problem/func_fit/custom.py
Normal file
@@ -0,0 +1,117 @@
|
||||
from typing import Callable, Union, List, Tuple
|
||||
from jax import vmap, Array, numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
from .func_fit import FuncFit
|
||||
|
||||
|
||||
class CustomFuncFit(FuncFit):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
func: Callable,
|
||||
low_bounds: Union[List, Tuple, Array],
|
||||
upper_bounds: Union[List, Tuple, Array],
|
||||
method: str = "sample",
|
||||
num_samples: int = 100,
|
||||
step_size: Array = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
if isinstance(low_bounds, list) or isinstance(low_bounds, tuple):
|
||||
low_bounds = np.array(low_bounds, dtype=np.float32)
|
||||
if isinstance(upper_bounds, list) or isinstance(upper_bounds, tuple):
|
||||
upper_bounds = np.array(upper_bounds, dtype=np.float32)
|
||||
|
||||
try:
|
||||
out = func(low_bounds)
|
||||
except Exception as e:
|
||||
raise ValueError(f"func(low_bounds) raise an exception: {e}")
|
||||
assert low_bounds.shape == upper_bounds.shape
|
||||
|
||||
assert method in {"sample", "grid"}
|
||||
|
||||
self.func = func
|
||||
self.low_bounds = low_bounds
|
||||
self.upper_bounds = upper_bounds
|
||||
|
||||
self.method = method
|
||||
self.num_samples = num_samples
|
||||
self.step_size = step_size
|
||||
|
||||
self.generate_dataset()
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def generate_dataset(self):
|
||||
|
||||
if self.method == "sample":
|
||||
assert (
|
||||
self.num_samples > 0
|
||||
), f"num_samples must be positive, got {self.num_samples}"
|
||||
|
||||
inputs = np.zeros(
|
||||
(self.num_samples, self.low_bounds.shape[0]), dtype=np.float32
|
||||
)
|
||||
for i in range(self.low_bounds.shape[0]):
|
||||
inputs[:, i] = np.random.uniform(
|
||||
low=self.low_bounds[i],
|
||||
high=self.upper_bounds[i],
|
||||
size=(self.num_samples,),
|
||||
)
|
||||
elif self.method == "grid":
|
||||
assert (
|
||||
self.step_size is not None
|
||||
), "step_size must be provided when method is 'grid'"
|
||||
assert (
|
||||
self.step_size.shape == self.low_bounds.shape
|
||||
), "step_size must have the same shape as low_bounds"
|
||||
assert np.all(self.step_size > 0), "step_size must be positive"
|
||||
|
||||
inputs = np.zeros((1, 1))
|
||||
for i in range(self.low_bounds.shape[0]):
|
||||
new_col = np.arange(
|
||||
self.low_bounds[i], self.upper_bounds[i], self.step_size[i]
|
||||
)
|
||||
inputs = cartesian_product(inputs, new_col[:, None])
|
||||
inputs = inputs[:, 1:]
|
||||
else:
|
||||
raise ValueError(f"Unknown method: {self.method}")
|
||||
|
||||
outputs = vmap(self.func)(inputs)
|
||||
|
||||
self.data_inputs = jnp.array(inputs)
|
||||
self.data_outputs = jnp.array(outputs)
|
||||
|
||||
@property
|
||||
def inputs(self):
|
||||
return self.data_inputs
|
||||
|
||||
@property
|
||||
def targets(self):
|
||||
return self.data_outputs
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return self.data_inputs.shape
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return self.data_outputs.shape
|
||||
|
||||
|
||||
def cartesian_product(arr1, arr2):
|
||||
assert (
|
||||
arr1.ndim == arr2.ndim
|
||||
), "arr1 and arr2 must have the same number of dimensions"
|
||||
assert arr1.ndim <= 2, "arr1 and arr2 must have at most 2 dimensions"
|
||||
|
||||
len1 = arr1.shape[0]
|
||||
len2 = arr2.shape[0]
|
||||
|
||||
repeated_arr1 = np.repeat(arr1, len2, axis=0)
|
||||
tiled_arr2 = np.tile(arr2, (len1, 1))
|
||||
|
||||
new_arr = np.concatenate((repeated_arr1, tiled_arr2), axis=1)
|
||||
return new_arr
|
||||
72
src/tensorneat/problem/func_fit/func_fit.py
Normal file
72
src/tensorneat/problem/func_fit/func_fit.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from ..base import BaseProblem
|
||||
from tensorneat.common import State
|
||||
|
||||
|
||||
class FuncFit(BaseProblem):
|
||||
jitable = True
|
||||
|
||||
def __init__(self, error_method: str = "mse"):
|
||||
super().__init__()
|
||||
|
||||
assert error_method in {"mse", "rmse", "mae", "mape"}
|
||||
self.error_method = error_method
|
||||
|
||||
def setup(self, state: State = State()):
|
||||
return state
|
||||
|
||||
def evaluate(self, state, randkey, act_func, params):
|
||||
|
||||
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
|
||||
state, params, self.inputs
|
||||
)
|
||||
|
||||
if self.error_method == "mse":
|
||||
loss = jnp.mean((predict - self.targets) ** 2)
|
||||
|
||||
elif self.error_method == "rmse":
|
||||
loss = jnp.sqrt(jnp.mean((predict - self.targets) ** 2))
|
||||
|
||||
elif self.error_method == "mae":
|
||||
loss = jnp.mean(jnp.abs(predict - self.targets))
|
||||
|
||||
elif self.error_method == "mape":
|
||||
loss = jnp.mean(jnp.abs((predict - self.targets) / self.targets))
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return -loss
|
||||
|
||||
def show(self, state, randkey, act_func, params, *args, **kwargs):
|
||||
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
|
||||
state, params, self.inputs
|
||||
)
|
||||
inputs, target, predict = jax.device_get([self.inputs, self.targets, predict])
|
||||
fitness = self.evaluate(state, randkey, act_func, params)
|
||||
|
||||
loss = -fitness
|
||||
|
||||
msg = ""
|
||||
for i in range(inputs.shape[0]):
|
||||
msg += f"input: {inputs[i]}, target: {target[i]}, predict: {predict[i]}\n"
|
||||
msg += f"loss: {loss}\n"
|
||||
print(msg)
|
||||
|
||||
@property
|
||||
def inputs(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def targets(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
raise NotImplementedError
|
||||
27
src/tensorneat/problem/func_fit/xor.py
Normal file
27
src/tensorneat/problem/func_fit/xor.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import numpy as np
|
||||
|
||||
from .func_fit import FuncFit
|
||||
|
||||
|
||||
class XOR(FuncFit):
|
||||
@property
|
||||
def inputs(self):
|
||||
return np.array(
|
||||
[[0, 0], [0, 1], [1, 0], [1, 1]],
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
@property
|
||||
def targets(self):
|
||||
return np.array(
|
||||
[[0], [1], [1], [0]],
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return 4, 2
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return 4, 1
|
||||
36
src/tensorneat/problem/func_fit/xor3d.py
Normal file
36
src/tensorneat/problem/func_fit/xor3d.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import numpy as np
|
||||
|
||||
from .func_fit import FuncFit
|
||||
|
||||
|
||||
class XOR3d(FuncFit):
|
||||
@property
|
||||
def inputs(self):
|
||||
return np.array(
|
||||
[
|
||||
[0, 0, 0],
|
||||
[0, 0, 1],
|
||||
[0, 1, 0],
|
||||
[0, 1, 1],
|
||||
[1, 0, 0],
|
||||
[1, 0, 1],
|
||||
[1, 1, 0],
|
||||
[1, 1, 1],
|
||||
],
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
@property
|
||||
def targets(self):
|
||||
return np.array(
|
||||
[[0], [1], [1], [0], [1], [0], [0], [1]],
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return 8, 3
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return 8, 1
|
||||
3
src/tensorneat/problem/rl/__init__.py
Normal file
3
src/tensorneat/problem/rl/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .gymnax import GymNaxEnv
|
||||
from .brax import BraxEnv
|
||||
from .rl_jit import RLEnv
|
||||
83
src/tensorneat/problem/rl/brax.py
Normal file
83
src/tensorneat/problem/rl/brax.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import jax.numpy as jnp
|
||||
from brax import envs
|
||||
|
||||
from .rl_jit import RLEnv
|
||||
|
||||
|
||||
class BraxEnv(RLEnv):
|
||||
def __init__(
|
||||
self, env_name: str = "ant", backend: str = "generalized", *args, **kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.env_name = env_name
|
||||
self.env = envs.create(env_name=env_name, backend=backend)
|
||||
|
||||
def env_step(self, randkey, env_state, action):
|
||||
state = self.env.step(env_state, action)
|
||||
return state.obs, state, state.reward, state.done.astype(jnp.bool_), state.info
|
||||
|
||||
def env_reset(self, randkey):
|
||||
init_state = self.env.reset(randkey)
|
||||
return init_state.obs, init_state
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (self.env.observation_size,)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (self.env.action_size,)
|
||||
|
||||
def show(
|
||||
self,
|
||||
state,
|
||||
randkey,
|
||||
act_func,
|
||||
params,
|
||||
save_path=None,
|
||||
height=480,
|
||||
width=480,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
import jax
|
||||
import imageio
|
||||
from brax.io import image
|
||||
|
||||
obs, env_state = self.reset(randkey)
|
||||
reward, done = 0.0, False
|
||||
state_histories = [env_state.pipeline_state]
|
||||
|
||||
def step(key, env_state, obs):
|
||||
key, _ = jax.random.split(key)
|
||||
|
||||
if self.action_policy is not None:
|
||||
forward_func = lambda obs: act_func(state, params, obs)
|
||||
action = self.action_policy(key, forward_func, obs)
|
||||
else:
|
||||
action = act_func(state, params, obs)
|
||||
|
||||
obs, env_state, r, done, _ = self.step(randkey, env_state, action)
|
||||
return key, env_state, obs, r, done
|
||||
|
||||
jit_step = jax.jit(step)
|
||||
|
||||
for _ in range(self.max_step):
|
||||
key, env_state, obs, r, done = jit_step(randkey, env_state, obs)
|
||||
state_histories.append(env_state.pipeline_state)
|
||||
reward += r
|
||||
if done:
|
||||
break
|
||||
|
||||
imgs = image.render_array(
|
||||
sys=self.env.sys, trajectory=state_histories, height=height, width=width
|
||||
)
|
||||
|
||||
if save_path is None:
|
||||
save_path = f"{self.env_name}.gif"
|
||||
|
||||
imageio.mimsave(save_path, imgs, *args, **kwargs)
|
||||
|
||||
print("Gif saved to: ", save_path)
|
||||
print("Total reward: ", reward)
|
||||
27
src/tensorneat/problem/rl/gymnax.py
Normal file
27
src/tensorneat/problem/rl/gymnax.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import gymnax
|
||||
|
||||
from .rl_jit import RLEnv
|
||||
|
||||
|
||||
class GymNaxEnv(RLEnv):
|
||||
def __init__(self, env_name, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
assert env_name in gymnax.registered_envs, f"Env {env_name} not registered in gymnax."
|
||||
self.env, self.env_params = gymnax.make(env_name)
|
||||
|
||||
def env_step(self, randkey, env_state, action):
|
||||
return self.env.step(randkey, env_state, action, self.env_params)
|
||||
|
||||
def env_reset(self, randkey):
|
||||
return self.env.reset(randkey, self.env_params)
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return self.env.observation_space(self.env_params).shape
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return self.env.action_space(self.env_params).shape
|
||||
|
||||
def show(self, state, randkey, act_func, params, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
209
src/tensorneat/problem/rl/rl_jit.py
Normal file
209
src/tensorneat/problem/rl/rl_jit.py
Normal file
@@ -0,0 +1,209 @@
|
||||
from typing import Callable
|
||||
|
||||
import jax
|
||||
from jax import vmap, numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
from ..base import BaseProblem
|
||||
from tensorneat.common import State
|
||||
|
||||
|
||||
class RLEnv(BaseProblem):
|
||||
jitable = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_step=1000,
|
||||
repeat_times=1,
|
||||
action_policy: Callable = None,
|
||||
obs_normalization: bool = False,
|
||||
sample_policy: Callable = None,
|
||||
sample_episodes: int = 0,
|
||||
):
|
||||
"""
|
||||
action_policy take three args:
|
||||
randkey, forward_func, obs
|
||||
randkey is a random key for jax.random
|
||||
forward_func is a function which receive obs and return action forward_func(obs) - > action
|
||||
obs is the observation of the environment
|
||||
|
||||
sample_policy take two args:
|
||||
randkey, obs -> action
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self.max_step = max_step
|
||||
self.repeat_times = repeat_times
|
||||
self.action_policy = action_policy
|
||||
|
||||
if obs_normalization:
|
||||
assert sample_policy is not None, "sample_policy must be provided"
|
||||
assert sample_episodes > 0, "sample_size must be greater than 0"
|
||||
self.sample_policy = sample_policy
|
||||
self.sample_episodes = sample_episodes
|
||||
self.obs_normalization = obs_normalization
|
||||
|
||||
def setup(self, state=State()):
|
||||
if self.obs_normalization:
|
||||
print("Sampling episodes for normalization")
|
||||
keys = jax.random.split(state.randkey, self.sample_episodes)
|
||||
dummy_act_func = (
|
||||
lambda s, p, o: o
|
||||
) # receive state, params, obs and return the original obs
|
||||
dummy_sample_func = lambda rk, act_func, obs: self.sample_policy(
|
||||
rk, obs
|
||||
) # ignore act_func
|
||||
|
||||
def sample(rk):
|
||||
return self._evaluate_once(
|
||||
state, rk, dummy_act_func, None, dummy_sample_func, True
|
||||
)
|
||||
|
||||
rewards, episodes = jax.jit(vmap(sample))(keys)
|
||||
|
||||
obs = jax.device_get(episodes["obs"]) # shape: (sample_episodes, max_step, *input_shape)
|
||||
obs = obs.reshape(
|
||||
-1, *self.input_shape
|
||||
) # shape: (sample_episodes * max_step, *input_shape)
|
||||
|
||||
obs_axis = tuple(range(obs.ndim))
|
||||
valid_data_flag = np.all(~jnp.isnan(obs), axis=obs_axis[1:])
|
||||
obs = obs[valid_data_flag]
|
||||
|
||||
obs_mean = np.mean(obs, axis=0)
|
||||
obs_std = np.std(obs, axis=0)
|
||||
|
||||
state = state.register(
|
||||
problem_obs_mean=obs_mean,
|
||||
problem_obs_std=obs_std,
|
||||
)
|
||||
|
||||
print("Sampling episodes for normalization finished.")
|
||||
print("valid data count: ", obs.shape[0])
|
||||
print("obs_mean: ", obs_mean)
|
||||
print("obs_std: ", obs_std)
|
||||
return state
|
||||
|
||||
def evaluate(self, state: State, randkey, act_func: Callable, params):
|
||||
keys = jax.random.split(randkey, self.repeat_times)
|
||||
rewards = vmap(
|
||||
self._evaluate_once, in_axes=(None, 0, None, None, None, None, None)
|
||||
)(
|
||||
state,
|
||||
keys,
|
||||
act_func,
|
||||
params,
|
||||
self.action_policy,
|
||||
False,
|
||||
self.obs_normalization,
|
||||
)
|
||||
|
||||
return rewards.mean()
|
||||
|
||||
def _evaluate_once(
|
||||
self,
|
||||
state,
|
||||
randkey,
|
||||
act_func,
|
||||
params,
|
||||
action_policy,
|
||||
record_episode,
|
||||
normalize_obs=False,
|
||||
):
|
||||
rng_reset, rng_episode = jax.random.split(randkey)
|
||||
init_obs, init_env_state = self.reset(rng_reset)
|
||||
|
||||
if record_episode:
|
||||
obs_array = jnp.full((self.max_step, *self.input_shape), jnp.nan)
|
||||
action_array = jnp.full((self.max_step, *self.output_shape), jnp.nan)
|
||||
reward_array = jnp.full((self.max_step,), jnp.nan)
|
||||
episode = {
|
||||
"obs": obs_array,
|
||||
"action": action_array,
|
||||
"reward": reward_array,
|
||||
}
|
||||
else:
|
||||
episode = None
|
||||
|
||||
def cond_func(carry):
|
||||
_, _, _, done, _, count, _, rk = carry
|
||||
return ~done & (count < self.max_step)
|
||||
|
||||
def body_func(carry):
|
||||
(
|
||||
obs,
|
||||
env_state,
|
||||
rng,
|
||||
done,
|
||||
tr,
|
||||
count,
|
||||
epis,
|
||||
rk,
|
||||
) = carry # tr -> total reward; rk -> randkey
|
||||
|
||||
if normalize_obs:
|
||||
obs = norm_obs(state, obs)
|
||||
|
||||
if action_policy is not None:
|
||||
forward_func = lambda obs: act_func(state, params, obs)
|
||||
action = action_policy(rk, forward_func, obs)
|
||||
else:
|
||||
action = act_func(state, params, obs)
|
||||
next_obs, next_env_state, reward, done, _ = self.step(
|
||||
rng, env_state, action
|
||||
)
|
||||
next_rng, _ = jax.random.split(rng)
|
||||
|
||||
if record_episode:
|
||||
epis["obs"] = epis["obs"].at[count].set(obs)
|
||||
epis["action"] = epis["action"].at[count].set(action)
|
||||
epis["reward"] = epis["reward"].at[count].set(reward)
|
||||
|
||||
return (
|
||||
next_obs,
|
||||
next_env_state,
|
||||
next_rng,
|
||||
done,
|
||||
tr + reward,
|
||||
count + 1,
|
||||
epis,
|
||||
jax.random.split(rk)[0],
|
||||
)
|
||||
|
||||
_, _, _, _, total_reward, _, episode, _ = jax.lax.while_loop(
|
||||
cond_func,
|
||||
body_func,
|
||||
(init_obs, init_env_state, rng_episode, False, 0.0, 0, episode, randkey),
|
||||
)
|
||||
|
||||
if record_episode:
|
||||
return total_reward, episode
|
||||
else:
|
||||
return total_reward
|
||||
|
||||
def step(self, randkey, env_state, action):
|
||||
return self.env_step(randkey, env_state, action)
|
||||
|
||||
def reset(self, randkey):
|
||||
return self.env_reset(randkey)
|
||||
|
||||
def env_step(self, randkey, env_state, action):
|
||||
raise NotImplementedError
|
||||
|
||||
def env_reset(self, randkey):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def show(self, state, randkey, act_func, params, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def norm_obs(state, obs):
|
||||
return (obs - state.problem_obs_mean) / (state.problem_obs_std + 1e-6)
|
||||
Reference in New Issue
Block a user