update functions. Visualize, Interpretable and with evox
This commit is contained in:
@@ -19,7 +19,7 @@ class HyperNEAT(BaseAlgorithm):
|
||||
aggregation: Callable = AGG.sum,
|
||||
activation: Callable = ACT.sigmoid,
|
||||
activate_time: int = 10,
|
||||
output_transform: Callable = ACT.standard_sigmoid,
|
||||
output_transform: Callable = ACT.sigmoid,
|
||||
):
|
||||
assert (
|
||||
substrate.query_coors.shape[1] == neat.num_inputs
|
||||
|
||||
@@ -3,4 +3,4 @@ from .graph import *
|
||||
from .state import State
|
||||
from .stateful_class import StatefulBaseClass
|
||||
|
||||
from .functions import ACT, AGG, apply_activation, apply_aggregation
|
||||
from .functions import ACT, AGG, apply_activation, apply_aggregation, get_func_name
|
||||
|
||||
2
src/tensorneat/common/evox_adaptors/__init__.py
Normal file
2
src/tensorneat/common/evox_adaptors/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .algorithm_adaptor import EvoXAlgorithmAdaptor
|
||||
from .tensorneat_monitor import TensorNEATMonitor
|
||||
34
src/tensorneat/common/evox_adaptors/algorithm_adaptor.py
Normal file
34
src/tensorneat/common/evox_adaptors/algorithm_adaptor.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import jax.numpy as jnp
|
||||
|
||||
from evox import Algorithm as EvoXAlgorithm, State as EvoXState, jit_class
|
||||
|
||||
from tensorneat.algorithm import BaseAlgorithm as TensorNEATAlgorithm
|
||||
from tensorneat.common import State as TensorNEATState
|
||||
|
||||
|
||||
@jit_class
|
||||
class EvoXAlgorithmAdaptor(EvoXAlgorithm):
|
||||
def __init__(self, algorithm: TensorNEATAlgorithm):
|
||||
self.algorithm = algorithm
|
||||
self.fixed_state = None
|
||||
|
||||
def setup(self, key):
|
||||
neat_algorithm_state = TensorNEATState(randkey=key)
|
||||
neat_algorithm_state = self.algorithm.setup(neat_algorithm_state)
|
||||
self.fixed_state = neat_algorithm_state
|
||||
return EvoXState(alg_state=neat_algorithm_state)
|
||||
|
||||
def ask(self, state: EvoXState):
|
||||
population = self.algorithm.ask(state.alg_state)
|
||||
return population, state
|
||||
|
||||
def tell(self, state: EvoXState, fitness):
|
||||
fitness = jnp.where(jnp.isnan(fitness), -jnp.inf, fitness)
|
||||
neat_algorithm_state = self.algorithm.tell(state.alg_state, fitness)
|
||||
return state.replace(alg_state=neat_algorithm_state)
|
||||
|
||||
def transform(self, individual):
|
||||
return self.algorithm.transform(self.fixed_state, individual)
|
||||
|
||||
def forward(self, transformed, inputs):
|
||||
return self.algorithm.forward(self.fixed_state, transformed, inputs)
|
||||
110
src/tensorneat/common/evox_adaptors/tensorneat_monitor.py
Normal file
110
src/tensorneat/common/evox_adaptors/tensorneat_monitor.py
Normal file
@@ -0,0 +1,110 @@
|
||||
import warnings
|
||||
import os
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax.experimental import io_callback
|
||||
from evox import Monitor
|
||||
from evox import State as EvoXState
|
||||
|
||||
from tensorneat.algorithm import BaseAlgorithm as TensorNEATAlgorithm
|
||||
from tensorneat.common import State as TensorNEATState
|
||||
|
||||
|
||||
class TensorNEATMonitor(Monitor):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tensorneat_algorithm: TensorNEATAlgorithm,
|
||||
save_dir: str = None,
|
||||
is_save: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.tensorneat_algorithm = tensorneat_algorithm
|
||||
|
||||
self.generation_timestamp = time.time()
|
||||
self.alg_state: TensorNEATState = None
|
||||
self.fitness = None
|
||||
self.best_fitness = -np.inf
|
||||
self.best_genome = None
|
||||
|
||||
self.is_save = is_save
|
||||
|
||||
if is_save:
|
||||
if save_dir is None:
|
||||
now = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
self.save_dir = f"./{self.__class__.__name__} {now}"
|
||||
else:
|
||||
self.save_dir = save_dir
|
||||
print(f"save to {self.save_dir}")
|
||||
if not os.path.exists(self.save_dir):
|
||||
os.makedirs(self.save_dir)
|
||||
self.genome_dir = os.path.join(self.save_dir, "genomes")
|
||||
if not os.path.exists(self.genome_dir):
|
||||
os.makedirs(self.genome_dir)
|
||||
|
||||
def hooks(self):
|
||||
return ["pre_tell"]
|
||||
|
||||
def pre_tell(self, state: EvoXState, cand_sol, transformed_cand_sol, fitness, transformed_fitness):
|
||||
io_callback(
|
||||
self.store_info,
|
||||
None,
|
||||
state,
|
||||
transformed_fitness,
|
||||
)
|
||||
|
||||
def store_info(self, state: EvoXState, fitness):
|
||||
self.alg_state: TensorNEATState = state.query_state("algorithm").alg_state
|
||||
self.fitness = jax.device_get(fitness)
|
||||
|
||||
def show(self):
|
||||
pop = self.tensorneat_algorithm.ask(self.alg_state)
|
||||
generation = int(self.alg_state.generation)
|
||||
|
||||
valid_fitnesses = self.fitness[~np.isinf(self.fitness)]
|
||||
|
||||
max_f, min_f, mean_f, std_f = (
|
||||
max(valid_fitnesses),
|
||||
min(valid_fitnesses),
|
||||
np.mean(valid_fitnesses),
|
||||
np.std(valid_fitnesses),
|
||||
)
|
||||
|
||||
new_timestamp = time.time()
|
||||
|
||||
cost_time = new_timestamp - self.generation_timestamp
|
||||
self.generation_timestamp = time.time()
|
||||
|
||||
max_idx = np.argmax(self.fitness)
|
||||
if self.fitness[max_idx] > self.best_fitness:
|
||||
self.best_fitness = self.fitness[max_idx]
|
||||
self.best_genome = pop[0][max_idx], pop[1][max_idx]
|
||||
|
||||
if self.is_save:
|
||||
# save best
|
||||
best_genome = jax.device_get((pop[0][max_idx], pop[1][max_idx]))
|
||||
file_name = os.path.join(
|
||||
self.genome_dir, f"{generation}.npz"
|
||||
)
|
||||
with open(file_name, "wb") as f:
|
||||
np.savez(
|
||||
f,
|
||||
nodes=best_genome[0],
|
||||
conns=best_genome[1],
|
||||
fitness=self.best_fitness,
|
||||
)
|
||||
|
||||
# append log
|
||||
with open(os.path.join(self.save_dir, "log.txt"), "a") as f:
|
||||
f.write(
|
||||
f"{generation},{max_f},{min_f},{mean_f},{std_f},{cost_time}\n"
|
||||
)
|
||||
|
||||
print(
|
||||
f"Generation: {generation}, Cost time: {cost_time * 1000:.2f}ms\n",
|
||||
f"\tfitness: valid cnt: {len(valid_fitnesses)}, max: {max_f:.4f}, min: {min_f:.4f}, mean: {mean_f:.4f}, std: {std_f:.4f}\n",
|
||||
)
|
||||
|
||||
self.tensorneat_algorithm.show_details(self.alg_state, self.fitness)
|
||||
@@ -1,3 +1,5 @@
|
||||
import jax, jax.numpy as jnp
|
||||
|
||||
from .act_jnp import *
|
||||
from .act_sympy import *
|
||||
from .agg_jnp import *
|
||||
@@ -32,6 +34,7 @@ act_name2sympy = {
|
||||
"log": SympyLog,
|
||||
"exp": SympyExp,
|
||||
"abs": SympyAbs,
|
||||
"clip": SympyClip,
|
||||
}
|
||||
|
||||
agg_name2jnp = {
|
||||
@@ -40,7 +43,6 @@ agg_name2jnp = {
|
||||
"max": max_,
|
||||
"min": min_,
|
||||
"maxabs": maxabs_,
|
||||
"median": median_,
|
||||
"mean": mean_,
|
||||
}
|
||||
|
||||
@@ -50,9 +52,42 @@ agg_name2sympy = {
|
||||
"max": SympyMax,
|
||||
"min": SympyMin,
|
||||
"maxabs": SympyMaxabs,
|
||||
"median": SympyMedian,
|
||||
"mean": SympyMean,
|
||||
}
|
||||
|
||||
ACT = FunctionManager(act_name2jnp, act_name2sympy)
|
||||
AGG = FunctionManager(agg_name2jnp, agg_name2sympy)
|
||||
|
||||
def apply_activation(idx, z, act_funcs):
|
||||
"""
|
||||
calculate activation function for each node
|
||||
"""
|
||||
idx = jnp.asarray(idx, dtype=jnp.int32)
|
||||
# change idx from float to int
|
||||
|
||||
# -1 means identity activation
|
||||
res = jax.lax.cond(
|
||||
idx == -1,
|
||||
lambda: z,
|
||||
lambda: jax.lax.switch(idx, act_funcs, z),
|
||||
)
|
||||
|
||||
return res
|
||||
|
||||
def apply_aggregation(idx, z, agg_funcs):
|
||||
"""
|
||||
calculate activation function for inputs of node
|
||||
"""
|
||||
idx = jnp.asarray(idx, dtype=jnp.int32)
|
||||
|
||||
return jax.lax.cond(
|
||||
jnp.all(jnp.isnan(z)),
|
||||
lambda: jnp.nan, # all inputs are nan
|
||||
lambda: jax.lax.switch(idx, agg_funcs, z), # otherwise
|
||||
)
|
||||
|
||||
def get_func_name(func):
|
||||
name = func.__name__
|
||||
if name.endswith("_"):
|
||||
name = name[:-1]
|
||||
return name
|
||||
@@ -1,6 +1,6 @@
|
||||
import jax.numpy as jnp
|
||||
|
||||
SCALE = 5
|
||||
SCALE = 3
|
||||
|
||||
|
||||
def scaled_sigmoid_(z):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import sympy as sp
|
||||
import numpy as np
|
||||
|
||||
SCALE = 5
|
||||
SCALE = 3
|
||||
|
||||
class SympySigmoid(sp.Function):
|
||||
@classmethod
|
||||
|
||||
@@ -23,18 +23,6 @@ def maxabs_(z):
|
||||
max_abs_index = jnp.argmax(abs_z)
|
||||
return z[max_abs_index]
|
||||
|
||||
|
||||
def median_(z):
|
||||
n = jnp.sum(~jnp.isnan(z), axis=0)
|
||||
|
||||
z = jnp.sort(z) # sort
|
||||
|
||||
idx1, idx2 = (n - 1) // 2, n // 2
|
||||
median = (z[idx1] + z[idx2]) / 2
|
||||
|
||||
return median
|
||||
|
||||
|
||||
def mean_(z):
|
||||
sumation = sum_(z)
|
||||
valid_count = jnp.sum(~jnp.isnan(z), axis=0)
|
||||
|
||||
@@ -7,30 +7,18 @@ class SympySum(sp.Function):
|
||||
def eval(cls, z):
|
||||
return sp.Add(*z)
|
||||
|
||||
@classmethod
|
||||
def numerical_eval(cls, z, backend=np):
|
||||
return backend.sum(z)
|
||||
|
||||
|
||||
class SympyProduct(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return sp.Mul(*z)
|
||||
|
||||
@classmethod
|
||||
def numerical_eval(cls, z, backend=np):
|
||||
return backend.product(z)
|
||||
|
||||
|
||||
class SympyMax(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return sp.Max(*z)
|
||||
|
||||
@classmethod
|
||||
def numerical_eval(cls, z, backend=np):
|
||||
return backend.max(z)
|
||||
|
||||
|
||||
class SympyMin(sp.Function):
|
||||
@classmethod
|
||||
@@ -48,26 +36,3 @@ class SympyMean(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return sp.Add(*z) / len(z)
|
||||
|
||||
|
||||
class SympyMedian(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, args):
|
||||
|
||||
if all(arg.is_number for arg in args):
|
||||
sorted_args = sorted(args)
|
||||
n = len(sorted_args)
|
||||
if n % 2 == 1:
|
||||
return sorted_args[n // 2]
|
||||
else:
|
||||
return (sorted_args[n // 2 - 1] + sorted_args[n // 2]) / 2
|
||||
|
||||
return None
|
||||
|
||||
def _sympystr(self, printer):
|
||||
return f"median({', '.join(map(str, self.args))})"
|
||||
|
||||
def _latex(self, printer):
|
||||
return (
|
||||
r"\mathrm{median}\left(" + ", ".join(map(sp.latex, self.args)) + r"\right)"
|
||||
)
|
||||
|
||||
@@ -1,28 +1,32 @@
|
||||
from functools import partial
|
||||
import numpy as np
|
||||
import jax.numpy as jnp
|
||||
from typing import Union, Callable
|
||||
import sympy as sp
|
||||
|
||||
|
||||
class FunctionManager:
|
||||
|
||||
def __init__(self, name2jnp, name2sympy):
|
||||
self.name2jnp = name2jnp
|
||||
self.name2sympy = name2sympy
|
||||
for name, func in name2jnp.items():
|
||||
setattr(self, name, func)
|
||||
|
||||
def get_all_funcs(self):
|
||||
all_funcs = []
|
||||
for name in self.names:
|
||||
for name in self.name2jnp:
|
||||
all_funcs.append(getattr(self, name))
|
||||
return all_funcs
|
||||
|
||||
def __getattribute__(self, name: str):
|
||||
return self.name2jnp[name]
|
||||
|
||||
def add_func(self, name, func):
|
||||
if not callable(func):
|
||||
raise ValueError("The provided function is not callable")
|
||||
if name in self.names:
|
||||
if name in self.name2jnp:
|
||||
raise ValueError(f"The provided name={name} is already in use")
|
||||
|
||||
self.name2jnp[name] = func
|
||||
setattr(self, name, func)
|
||||
|
||||
def update_sympy(self, name, sympy_cls: sp.Function):
|
||||
self.name2sympy[name] = sympy_cls
|
||||
@@ -47,3 +51,16 @@ class FunctionManager:
|
||||
if name not in self.name2sympy:
|
||||
raise ValueError(f"Func {name} doesn't have a sympy representation.")
|
||||
return self.name2sympy[name]
|
||||
|
||||
def sympy_module(self, backend: str):
|
||||
assert backend in ["jax", "numpy"]
|
||||
if backend == "jax":
|
||||
backend = jnp
|
||||
elif backend == "numpy":
|
||||
backend = np
|
||||
module = {}
|
||||
for sympy_cls in self.name2sympy.values():
|
||||
if hasattr(sympy_cls, "numerical_eval"):
|
||||
module[sympy_cls.__name__] = partial(sympy_cls.numerical_eval, backend)
|
||||
|
||||
return module
|
||||
|
||||
@@ -15,8 +15,8 @@ from tensorneat.common import (
|
||||
topological_sort_python,
|
||||
I_INF,
|
||||
attach_with_inf,
|
||||
SYMPY_FUNCS_MODULE_NP,
|
||||
SYMPY_FUNCS_MODULE_JNP,
|
||||
ACT,
|
||||
AGG
|
||||
)
|
||||
|
||||
|
||||
@@ -92,7 +92,9 @@ class DefaultGenome(BaseGenome):
|
||||
def otherwise():
|
||||
# calculate connections
|
||||
conn_indices = u_conns[:, i]
|
||||
hit_attrs = attach_with_inf(conns_attrs, conn_indices) # fetch conn attrs
|
||||
hit_attrs = attach_with_inf(
|
||||
conns_attrs, conn_indices
|
||||
) # fetch conn attrs
|
||||
ins = vmap(self.conn_gene.forward, in_axes=(None, 0, 0))(
|
||||
state, hit_attrs, values
|
||||
)
|
||||
@@ -102,7 +104,9 @@ class DefaultGenome(BaseGenome):
|
||||
state,
|
||||
nodes_attrs[i],
|
||||
ins,
|
||||
is_output_node=jnp.isin(nodes[i, 0], self.output_idx), # nodes[0] -> the key of nodes
|
||||
is_output_node=jnp.isin(
|
||||
nodes[i, 0], self.output_idx
|
||||
), # nodes[0] -> the key of nodes
|
||||
)
|
||||
|
||||
# set new value
|
||||
@@ -139,7 +143,6 @@ class DefaultGenome(BaseGenome):
|
||||
):
|
||||
|
||||
assert backend in ["jax", "numpy"], "backend should be 'jax' or 'numpy'"
|
||||
module = SYMPY_FUNCS_MODULE_JNP if backend == "jax" else SYMPY_FUNCS_MODULE_NP
|
||||
|
||||
if sympy_input_transform is None and self.input_transform is not None:
|
||||
warnings.warn(
|
||||
@@ -224,7 +227,7 @@ class DefaultGenome(BaseGenome):
|
||||
sp.lambdify(
|
||||
input_symbols + list(args_symbols.keys()),
|
||||
exprs,
|
||||
modules=[backend, module],
|
||||
modules=[backend, AGG.sympy_module(backend), ACT.sympy_module(backend)],
|
||||
)
|
||||
for exprs in output_exprs
|
||||
]
|
||||
@@ -256,7 +259,12 @@ class DefaultGenome(BaseGenome):
|
||||
rotate=0,
|
||||
reverse_node_order=False,
|
||||
size=(300, 300, 300),
|
||||
color=("blue", "blue", "blue"),
|
||||
color=("yellow", "white", "blue"),
|
||||
with_labels=False,
|
||||
edgecolors="k",
|
||||
arrowstyle="->",
|
||||
arrowsize=3,
|
||||
edge_color=(0.3, 0.3, 0.3),
|
||||
save_path="network.svg",
|
||||
save_dpi=800,
|
||||
**kwargs,
|
||||
@@ -264,7 +272,6 @@ class DefaultGenome(BaseGenome):
|
||||
import networkx as nx
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
nodes_list = list(network["nodes"])
|
||||
conns_list = list(network["conns"])
|
||||
input_idx = self.get_input_idx()
|
||||
output_idx = self.get_output_idx()
|
||||
@@ -316,6 +323,11 @@ class DefaultGenome(BaseGenome):
|
||||
pos=rotated_pos,
|
||||
node_size=node_sizes,
|
||||
node_color=node_colors,
|
||||
with_labels=with_labels,
|
||||
edgecolors=edgecolors,
|
||||
arrowstyle=arrowstyle,
|
||||
arrowsize=arrowsize,
|
||||
edge_color=edge_color,
|
||||
**kwargs,
|
||||
)
|
||||
plt.savefig(save_path, dpi=save_dpi)
|
||||
|
||||
@@ -10,7 +10,7 @@ from tensorneat.common import (
|
||||
apply_aggregation,
|
||||
mutate_int,
|
||||
mutate_float,
|
||||
convert_to_sympy,
|
||||
get_func_name
|
||||
)
|
||||
|
||||
from . import BaseNode
|
||||
@@ -141,8 +141,8 @@ class BiasNode(BaseNode):
|
||||
self.__class__.__name__,
|
||||
idx,
|
||||
bias,
|
||||
self.aggregation_options[agg].__name__,
|
||||
act_func.__name__,
|
||||
get_func_name(self.aggregation_options[agg]),
|
||||
get_func_name(act_func),
|
||||
idx_width=idx_width,
|
||||
float_width=precision + 3,
|
||||
func_width=func_width,
|
||||
@@ -165,21 +165,19 @@ class BiasNode(BaseNode):
|
||||
return {
|
||||
"idx": idx,
|
||||
"bias": bias,
|
||||
"agg": self.aggregation_options[int(agg)].__name__,
|
||||
"act": act_func.__name__,
|
||||
"agg": get_func_name(self.aggregation_options[agg]),
|
||||
"act": get_func_name(act_func),
|
||||
}
|
||||
|
||||
def sympy_func(self, state, node_dict, inputs, is_output_node=False):
|
||||
nd = node_dict
|
||||
bias = sp.symbols(f"n_{node_dict['idx']}_b")
|
||||
|
||||
bias = sp.symbols(f"n_{nd['idx']}_b")
|
||||
|
||||
z = convert_to_sympy(nd["agg"])(inputs)
|
||||
z = AGG.obtain_sympy(node_dict["agg"])(inputs)
|
||||
|
||||
z = bias + z
|
||||
if is_output_node:
|
||||
pass
|
||||
else:
|
||||
z = convert_to_sympy(nd["act"])(z)
|
||||
z = ACT.obtain_sympy(node_dict["act"])(z)
|
||||
|
||||
return z, {bias: nd["bias"]}
|
||||
return z, {bias: node_dict["bias"]}
|
||||
|
||||
@@ -11,7 +11,7 @@ from tensorneat.common import (
|
||||
apply_aggregation,
|
||||
mutate_int,
|
||||
mutate_float,
|
||||
convert_to_sympy,
|
||||
get_func_name
|
||||
)
|
||||
|
||||
from .base import BaseNode
|
||||
@@ -176,8 +176,8 @@ class DefaultNode(BaseNode):
|
||||
idx,
|
||||
bias,
|
||||
res,
|
||||
self.aggregation_options[agg].__name__,
|
||||
act_func.__name__,
|
||||
get_func_name(self.aggregation_options[agg]),
|
||||
get_func_name(act_func),
|
||||
idx_width=idx_width,
|
||||
float_width=precision + 3,
|
||||
func_width=func_width,
|
||||
@@ -200,8 +200,8 @@ class DefaultNode(BaseNode):
|
||||
"idx": idx,
|
||||
"bias": bias,
|
||||
"res": res,
|
||||
"agg": self.aggregation_options[int(agg)].__name__,
|
||||
"act": act_func.__name__,
|
||||
"agg": get_func_name(self.aggregation_options[agg]),
|
||||
"act": get_func_name(act_func),
|
||||
}
|
||||
|
||||
def sympy_func(self, state, node_dict, inputs, is_output_node=False):
|
||||
@@ -209,12 +209,13 @@ class DefaultNode(BaseNode):
|
||||
bias = sp.symbols(f"n_{nd['idx']}_b")
|
||||
res = sp.symbols(f"n_{nd['idx']}_r")
|
||||
|
||||
z = convert_to_sympy(nd["agg"])(inputs)
|
||||
print(nd["agg"])
|
||||
z = AGG.obtain_sympy(nd["agg"])(inputs)
|
||||
z = bias + res * z
|
||||
|
||||
if is_output_node:
|
||||
pass
|
||||
else:
|
||||
z = convert_to_sympy(nd["act"])(z)
|
||||
z = ACT.obtain_sympy(nd["act"])(z)
|
||||
|
||||
return z, {bias: nd["bias"], res: nd["res"]}
|
||||
|
||||
Reference in New Issue
Block a user