update functions. Visualize, Interpretable and with evox

This commit is contained in:
root
2024-07-12 04:35:22 +08:00
parent 5fc63fdaf1
commit 0d6e7477bf
32 changed files with 207 additions and 427 deletions

View File

@@ -19,7 +19,7 @@ class HyperNEAT(BaseAlgorithm):
aggregation: Callable = AGG.sum,
activation: Callable = ACT.sigmoid,
activate_time: int = 10,
output_transform: Callable = ACT.standard_sigmoid,
output_transform: Callable = ACT.sigmoid,
):
assert (
substrate.query_coors.shape[1] == neat.num_inputs

View File

@@ -3,4 +3,4 @@ from .graph import *
from .state import State
from .stateful_class import StatefulBaseClass
from .functions import ACT, AGG, apply_activation, apply_aggregation
from .functions import ACT, AGG, apply_activation, apply_aggregation, get_func_name

View File

@@ -0,0 +1,2 @@
from .algorithm_adaptor import EvoXAlgorithmAdaptor
from .tensorneat_monitor import TensorNEATMonitor

View File

@@ -0,0 +1,34 @@
import jax.numpy as jnp
from evox import Algorithm as EvoXAlgorithm, State as EvoXState, jit_class
from tensorneat.algorithm import BaseAlgorithm as TensorNEATAlgorithm
from tensorneat.common import State as TensorNEATState
@jit_class
class EvoXAlgorithmAdaptor(EvoXAlgorithm):
def __init__(self, algorithm: TensorNEATAlgorithm):
self.algorithm = algorithm
self.fixed_state = None
def setup(self, key):
neat_algorithm_state = TensorNEATState(randkey=key)
neat_algorithm_state = self.algorithm.setup(neat_algorithm_state)
self.fixed_state = neat_algorithm_state
return EvoXState(alg_state=neat_algorithm_state)
def ask(self, state: EvoXState):
population = self.algorithm.ask(state.alg_state)
return population, state
def tell(self, state: EvoXState, fitness):
fitness = jnp.where(jnp.isnan(fitness), -jnp.inf, fitness)
neat_algorithm_state = self.algorithm.tell(state.alg_state, fitness)
return state.replace(alg_state=neat_algorithm_state)
def transform(self, individual):
return self.algorithm.transform(self.fixed_state, individual)
def forward(self, transformed, inputs):
return self.algorithm.forward(self.fixed_state, transformed, inputs)

View File

@@ -0,0 +1,110 @@
import warnings
import os
import time
import numpy as np
import jax
from jax.experimental import io_callback
from evox import Monitor
from evox import State as EvoXState
from tensorneat.algorithm import BaseAlgorithm as TensorNEATAlgorithm
from tensorneat.common import State as TensorNEATState
class TensorNEATMonitor(Monitor):
def __init__(
self,
tensorneat_algorithm: TensorNEATAlgorithm,
save_dir: str = None,
is_save: bool = False,
):
super().__init__()
self.tensorneat_algorithm = tensorneat_algorithm
self.generation_timestamp = time.time()
self.alg_state: TensorNEATState = None
self.fitness = None
self.best_fitness = -np.inf
self.best_genome = None
self.is_save = is_save
if is_save:
if save_dir is None:
now = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
self.save_dir = f"./{self.__class__.__name__} {now}"
else:
self.save_dir = save_dir
print(f"save to {self.save_dir}")
if not os.path.exists(self.save_dir):
os.makedirs(self.save_dir)
self.genome_dir = os.path.join(self.save_dir, "genomes")
if not os.path.exists(self.genome_dir):
os.makedirs(self.genome_dir)
def hooks(self):
return ["pre_tell"]
def pre_tell(self, state: EvoXState, cand_sol, transformed_cand_sol, fitness, transformed_fitness):
io_callback(
self.store_info,
None,
state,
transformed_fitness,
)
def store_info(self, state: EvoXState, fitness):
self.alg_state: TensorNEATState = state.query_state("algorithm").alg_state
self.fitness = jax.device_get(fitness)
def show(self):
pop = self.tensorneat_algorithm.ask(self.alg_state)
generation = int(self.alg_state.generation)
valid_fitnesses = self.fitness[~np.isinf(self.fitness)]
max_f, min_f, mean_f, std_f = (
max(valid_fitnesses),
min(valid_fitnesses),
np.mean(valid_fitnesses),
np.std(valid_fitnesses),
)
new_timestamp = time.time()
cost_time = new_timestamp - self.generation_timestamp
self.generation_timestamp = time.time()
max_idx = np.argmax(self.fitness)
if self.fitness[max_idx] > self.best_fitness:
self.best_fitness = self.fitness[max_idx]
self.best_genome = pop[0][max_idx], pop[1][max_idx]
if self.is_save:
# save best
best_genome = jax.device_get((pop[0][max_idx], pop[1][max_idx]))
file_name = os.path.join(
self.genome_dir, f"{generation}.npz"
)
with open(file_name, "wb") as f:
np.savez(
f,
nodes=best_genome[0],
conns=best_genome[1],
fitness=self.best_fitness,
)
# append log
with open(os.path.join(self.save_dir, "log.txt"), "a") as f:
f.write(
f"{generation},{max_f},{min_f},{mean_f},{std_f},{cost_time}\n"
)
print(
f"Generation: {generation}, Cost time: {cost_time * 1000:.2f}ms\n",
f"\tfitness: valid cnt: {len(valid_fitnesses)}, max: {max_f:.4f}, min: {min_f:.4f}, mean: {mean_f:.4f}, std: {std_f:.4f}\n",
)
self.tensorneat_algorithm.show_details(self.alg_state, self.fitness)

View File

@@ -1,3 +1,5 @@
import jax, jax.numpy as jnp
from .act_jnp import *
from .act_sympy import *
from .agg_jnp import *
@@ -32,6 +34,7 @@ act_name2sympy = {
"log": SympyLog,
"exp": SympyExp,
"abs": SympyAbs,
"clip": SympyClip,
}
agg_name2jnp = {
@@ -40,7 +43,6 @@ agg_name2jnp = {
"max": max_,
"min": min_,
"maxabs": maxabs_,
"median": median_,
"mean": mean_,
}
@@ -50,9 +52,42 @@ agg_name2sympy = {
"max": SympyMax,
"min": SympyMin,
"maxabs": SympyMaxabs,
"median": SympyMedian,
"mean": SympyMean,
}
ACT = FunctionManager(act_name2jnp, act_name2sympy)
AGG = FunctionManager(agg_name2jnp, agg_name2sympy)
def apply_activation(idx, z, act_funcs):
"""
calculate activation function for each node
"""
idx = jnp.asarray(idx, dtype=jnp.int32)
# change idx from float to int
# -1 means identity activation
res = jax.lax.cond(
idx == -1,
lambda: z,
lambda: jax.lax.switch(idx, act_funcs, z),
)
return res
def apply_aggregation(idx, z, agg_funcs):
"""
calculate activation function for inputs of node
"""
idx = jnp.asarray(idx, dtype=jnp.int32)
return jax.lax.cond(
jnp.all(jnp.isnan(z)),
lambda: jnp.nan, # all inputs are nan
lambda: jax.lax.switch(idx, agg_funcs, z), # otherwise
)
def get_func_name(func):
name = func.__name__
if name.endswith("_"):
name = name[:-1]
return name

View File

@@ -1,6 +1,6 @@
import jax.numpy as jnp
SCALE = 5
SCALE = 3
def scaled_sigmoid_(z):

View File

@@ -1,7 +1,7 @@
import sympy as sp
import numpy as np
SCALE = 5
SCALE = 3
class SympySigmoid(sp.Function):
@classmethod

View File

@@ -23,18 +23,6 @@ def maxabs_(z):
max_abs_index = jnp.argmax(abs_z)
return z[max_abs_index]
def median_(z):
n = jnp.sum(~jnp.isnan(z), axis=0)
z = jnp.sort(z) # sort
idx1, idx2 = (n - 1) // 2, n // 2
median = (z[idx1] + z[idx2]) / 2
return median
def mean_(z):
sumation = sum_(z)
valid_count = jnp.sum(~jnp.isnan(z), axis=0)

View File

@@ -7,30 +7,18 @@ class SympySum(sp.Function):
def eval(cls, z):
return sp.Add(*z)
@classmethod
def numerical_eval(cls, z, backend=np):
return backend.sum(z)
class SympyProduct(sp.Function):
@classmethod
def eval(cls, z):
return sp.Mul(*z)
@classmethod
def numerical_eval(cls, z, backend=np):
return backend.product(z)
class SympyMax(sp.Function):
@classmethod
def eval(cls, z):
return sp.Max(*z)
@classmethod
def numerical_eval(cls, z, backend=np):
return backend.max(z)
class SympyMin(sp.Function):
@classmethod
@@ -48,26 +36,3 @@ class SympyMean(sp.Function):
@classmethod
def eval(cls, z):
return sp.Add(*z) / len(z)
class SympyMedian(sp.Function):
@classmethod
def eval(cls, args):
if all(arg.is_number for arg in args):
sorted_args = sorted(args)
n = len(sorted_args)
if n % 2 == 1:
return sorted_args[n // 2]
else:
return (sorted_args[n // 2 - 1] + sorted_args[n // 2]) / 2
return None
def _sympystr(self, printer):
return f"median({', '.join(map(str, self.args))})"
def _latex(self, printer):
return (
r"\mathrm{median}\left(" + ", ".join(map(sp.latex, self.args)) + r"\right)"
)

View File

@@ -1,28 +1,32 @@
from functools import partial
import numpy as np
import jax.numpy as jnp
from typing import Union, Callable
import sympy as sp
class FunctionManager:
def __init__(self, name2jnp, name2sympy):
self.name2jnp = name2jnp
self.name2sympy = name2sympy
for name, func in name2jnp.items():
setattr(self, name, func)
def get_all_funcs(self):
all_funcs = []
for name in self.names:
for name in self.name2jnp:
all_funcs.append(getattr(self, name))
return all_funcs
def __getattribute__(self, name: str):
return self.name2jnp[name]
def add_func(self, name, func):
if not callable(func):
raise ValueError("The provided function is not callable")
if name in self.names:
if name in self.name2jnp:
raise ValueError(f"The provided name={name} is already in use")
self.name2jnp[name] = func
setattr(self, name, func)
def update_sympy(self, name, sympy_cls: sp.Function):
self.name2sympy[name] = sympy_cls
@@ -47,3 +51,16 @@ class FunctionManager:
if name not in self.name2sympy:
raise ValueError(f"Func {name} doesn't have a sympy representation.")
return self.name2sympy[name]
def sympy_module(self, backend: str):
assert backend in ["jax", "numpy"]
if backend == "jax":
backend = jnp
elif backend == "numpy":
backend = np
module = {}
for sympy_cls in self.name2sympy.values():
if hasattr(sympy_cls, "numerical_eval"):
module[sympy_cls.__name__] = partial(sympy_cls.numerical_eval, backend)
return module

View File

@@ -15,8 +15,8 @@ from tensorneat.common import (
topological_sort_python,
I_INF,
attach_with_inf,
SYMPY_FUNCS_MODULE_NP,
SYMPY_FUNCS_MODULE_JNP,
ACT,
AGG
)
@@ -92,7 +92,9 @@ class DefaultGenome(BaseGenome):
def otherwise():
# calculate connections
conn_indices = u_conns[:, i]
hit_attrs = attach_with_inf(conns_attrs, conn_indices) # fetch conn attrs
hit_attrs = attach_with_inf(
conns_attrs, conn_indices
) # fetch conn attrs
ins = vmap(self.conn_gene.forward, in_axes=(None, 0, 0))(
state, hit_attrs, values
)
@@ -102,7 +104,9 @@ class DefaultGenome(BaseGenome):
state,
nodes_attrs[i],
ins,
is_output_node=jnp.isin(nodes[i, 0], self.output_idx), # nodes[0] -> the key of nodes
is_output_node=jnp.isin(
nodes[i, 0], self.output_idx
), # nodes[0] -> the key of nodes
)
# set new value
@@ -139,7 +143,6 @@ class DefaultGenome(BaseGenome):
):
assert backend in ["jax", "numpy"], "backend should be 'jax' or 'numpy'"
module = SYMPY_FUNCS_MODULE_JNP if backend == "jax" else SYMPY_FUNCS_MODULE_NP
if sympy_input_transform is None and self.input_transform is not None:
warnings.warn(
@@ -224,7 +227,7 @@ class DefaultGenome(BaseGenome):
sp.lambdify(
input_symbols + list(args_symbols.keys()),
exprs,
modules=[backend, module],
modules=[backend, AGG.sympy_module(backend), ACT.sympy_module(backend)],
)
for exprs in output_exprs
]
@@ -256,7 +259,12 @@ class DefaultGenome(BaseGenome):
rotate=0,
reverse_node_order=False,
size=(300, 300, 300),
color=("blue", "blue", "blue"),
color=("yellow", "white", "blue"),
with_labels=False,
edgecolors="k",
arrowstyle="->",
arrowsize=3,
edge_color=(0.3, 0.3, 0.3),
save_path="network.svg",
save_dpi=800,
**kwargs,
@@ -264,7 +272,6 @@ class DefaultGenome(BaseGenome):
import networkx as nx
from matplotlib import pyplot as plt
nodes_list = list(network["nodes"])
conns_list = list(network["conns"])
input_idx = self.get_input_idx()
output_idx = self.get_output_idx()
@@ -316,6 +323,11 @@ class DefaultGenome(BaseGenome):
pos=rotated_pos,
node_size=node_sizes,
node_color=node_colors,
with_labels=with_labels,
edgecolors=edgecolors,
arrowstyle=arrowstyle,
arrowsize=arrowsize,
edge_color=edge_color,
**kwargs,
)
plt.savefig(save_path, dpi=save_dpi)

View File

@@ -10,7 +10,7 @@ from tensorneat.common import (
apply_aggregation,
mutate_int,
mutate_float,
convert_to_sympy,
get_func_name
)
from . import BaseNode
@@ -141,8 +141,8 @@ class BiasNode(BaseNode):
self.__class__.__name__,
idx,
bias,
self.aggregation_options[agg].__name__,
act_func.__name__,
get_func_name(self.aggregation_options[agg]),
get_func_name(act_func),
idx_width=idx_width,
float_width=precision + 3,
func_width=func_width,
@@ -165,21 +165,19 @@ class BiasNode(BaseNode):
return {
"idx": idx,
"bias": bias,
"agg": self.aggregation_options[int(agg)].__name__,
"act": act_func.__name__,
"agg": get_func_name(self.aggregation_options[agg]),
"act": get_func_name(act_func),
}
def sympy_func(self, state, node_dict, inputs, is_output_node=False):
nd = node_dict
bias = sp.symbols(f"n_{node_dict['idx']}_b")
bias = sp.symbols(f"n_{nd['idx']}_b")
z = convert_to_sympy(nd["agg"])(inputs)
z = AGG.obtain_sympy(node_dict["agg"])(inputs)
z = bias + z
if is_output_node:
pass
else:
z = convert_to_sympy(nd["act"])(z)
z = ACT.obtain_sympy(node_dict["act"])(z)
return z, {bias: nd["bias"]}
return z, {bias: node_dict["bias"]}

View File

@@ -11,7 +11,7 @@ from tensorneat.common import (
apply_aggregation,
mutate_int,
mutate_float,
convert_to_sympy,
get_func_name
)
from .base import BaseNode
@@ -176,8 +176,8 @@ class DefaultNode(BaseNode):
idx,
bias,
res,
self.aggregation_options[agg].__name__,
act_func.__name__,
get_func_name(self.aggregation_options[agg]),
get_func_name(act_func),
idx_width=idx_width,
float_width=precision + 3,
func_width=func_width,
@@ -200,8 +200,8 @@ class DefaultNode(BaseNode):
"idx": idx,
"bias": bias,
"res": res,
"agg": self.aggregation_options[int(agg)].__name__,
"act": act_func.__name__,
"agg": get_func_name(self.aggregation_options[agg]),
"act": get_func_name(act_func),
}
def sympy_func(self, state, node_dict, inputs, is_output_node=False):
@@ -209,12 +209,13 @@ class DefaultNode(BaseNode):
bias = sp.symbols(f"n_{nd['idx']}_b")
res = sp.symbols(f"n_{nd['idx']}_r")
z = convert_to_sympy(nd["agg"])(inputs)
print(nd["agg"])
z = AGG.obtain_sympy(nd["agg"])(inputs)
z = bias + res * z
if is_output_node:
pass
else:
z = convert_to_sympy(nd["act"])(z)
z = ACT.obtain_sympy(nd["act"])(z)
return z, {bias: nd["bias"], res: nd["res"]}