update functions. Visualize, Interpretable and with evox
This commit is contained in:
@@ -3,4 +3,4 @@ from .graph import *
|
||||
from .state import State
|
||||
from .stateful_class import StatefulBaseClass
|
||||
|
||||
from .functions import ACT, AGG, apply_activation, apply_aggregation
|
||||
from .functions import ACT, AGG, apply_activation, apply_aggregation, get_func_name
|
||||
|
||||
2
src/tensorneat/common/evox_adaptors/__init__.py
Normal file
2
src/tensorneat/common/evox_adaptors/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .algorithm_adaptor import EvoXAlgorithmAdaptor
|
||||
from .tensorneat_monitor import TensorNEATMonitor
|
||||
34
src/tensorneat/common/evox_adaptors/algorithm_adaptor.py
Normal file
34
src/tensorneat/common/evox_adaptors/algorithm_adaptor.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import jax.numpy as jnp
|
||||
|
||||
from evox import Algorithm as EvoXAlgorithm, State as EvoXState, jit_class
|
||||
|
||||
from tensorneat.algorithm import BaseAlgorithm as TensorNEATAlgorithm
|
||||
from tensorneat.common import State as TensorNEATState
|
||||
|
||||
|
||||
@jit_class
|
||||
class EvoXAlgorithmAdaptor(EvoXAlgorithm):
|
||||
def __init__(self, algorithm: TensorNEATAlgorithm):
|
||||
self.algorithm = algorithm
|
||||
self.fixed_state = None
|
||||
|
||||
def setup(self, key):
|
||||
neat_algorithm_state = TensorNEATState(randkey=key)
|
||||
neat_algorithm_state = self.algorithm.setup(neat_algorithm_state)
|
||||
self.fixed_state = neat_algorithm_state
|
||||
return EvoXState(alg_state=neat_algorithm_state)
|
||||
|
||||
def ask(self, state: EvoXState):
|
||||
population = self.algorithm.ask(state.alg_state)
|
||||
return population, state
|
||||
|
||||
def tell(self, state: EvoXState, fitness):
|
||||
fitness = jnp.where(jnp.isnan(fitness), -jnp.inf, fitness)
|
||||
neat_algorithm_state = self.algorithm.tell(state.alg_state, fitness)
|
||||
return state.replace(alg_state=neat_algorithm_state)
|
||||
|
||||
def transform(self, individual):
|
||||
return self.algorithm.transform(self.fixed_state, individual)
|
||||
|
||||
def forward(self, transformed, inputs):
|
||||
return self.algorithm.forward(self.fixed_state, transformed, inputs)
|
||||
110
src/tensorneat/common/evox_adaptors/tensorneat_monitor.py
Normal file
110
src/tensorneat/common/evox_adaptors/tensorneat_monitor.py
Normal file
@@ -0,0 +1,110 @@
|
||||
import warnings
|
||||
import os
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax.experimental import io_callback
|
||||
from evox import Monitor
|
||||
from evox import State as EvoXState
|
||||
|
||||
from tensorneat.algorithm import BaseAlgorithm as TensorNEATAlgorithm
|
||||
from tensorneat.common import State as TensorNEATState
|
||||
|
||||
|
||||
class TensorNEATMonitor(Monitor):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tensorneat_algorithm: TensorNEATAlgorithm,
|
||||
save_dir: str = None,
|
||||
is_save: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.tensorneat_algorithm = tensorneat_algorithm
|
||||
|
||||
self.generation_timestamp = time.time()
|
||||
self.alg_state: TensorNEATState = None
|
||||
self.fitness = None
|
||||
self.best_fitness = -np.inf
|
||||
self.best_genome = None
|
||||
|
||||
self.is_save = is_save
|
||||
|
||||
if is_save:
|
||||
if save_dir is None:
|
||||
now = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
self.save_dir = f"./{self.__class__.__name__} {now}"
|
||||
else:
|
||||
self.save_dir = save_dir
|
||||
print(f"save to {self.save_dir}")
|
||||
if not os.path.exists(self.save_dir):
|
||||
os.makedirs(self.save_dir)
|
||||
self.genome_dir = os.path.join(self.save_dir, "genomes")
|
||||
if not os.path.exists(self.genome_dir):
|
||||
os.makedirs(self.genome_dir)
|
||||
|
||||
def hooks(self):
|
||||
return ["pre_tell"]
|
||||
|
||||
def pre_tell(self, state: EvoXState, cand_sol, transformed_cand_sol, fitness, transformed_fitness):
|
||||
io_callback(
|
||||
self.store_info,
|
||||
None,
|
||||
state,
|
||||
transformed_fitness,
|
||||
)
|
||||
|
||||
def store_info(self, state: EvoXState, fitness):
|
||||
self.alg_state: TensorNEATState = state.query_state("algorithm").alg_state
|
||||
self.fitness = jax.device_get(fitness)
|
||||
|
||||
def show(self):
|
||||
pop = self.tensorneat_algorithm.ask(self.alg_state)
|
||||
generation = int(self.alg_state.generation)
|
||||
|
||||
valid_fitnesses = self.fitness[~np.isinf(self.fitness)]
|
||||
|
||||
max_f, min_f, mean_f, std_f = (
|
||||
max(valid_fitnesses),
|
||||
min(valid_fitnesses),
|
||||
np.mean(valid_fitnesses),
|
||||
np.std(valid_fitnesses),
|
||||
)
|
||||
|
||||
new_timestamp = time.time()
|
||||
|
||||
cost_time = new_timestamp - self.generation_timestamp
|
||||
self.generation_timestamp = time.time()
|
||||
|
||||
max_idx = np.argmax(self.fitness)
|
||||
if self.fitness[max_idx] > self.best_fitness:
|
||||
self.best_fitness = self.fitness[max_idx]
|
||||
self.best_genome = pop[0][max_idx], pop[1][max_idx]
|
||||
|
||||
if self.is_save:
|
||||
# save best
|
||||
best_genome = jax.device_get((pop[0][max_idx], pop[1][max_idx]))
|
||||
file_name = os.path.join(
|
||||
self.genome_dir, f"{generation}.npz"
|
||||
)
|
||||
with open(file_name, "wb") as f:
|
||||
np.savez(
|
||||
f,
|
||||
nodes=best_genome[0],
|
||||
conns=best_genome[1],
|
||||
fitness=self.best_fitness,
|
||||
)
|
||||
|
||||
# append log
|
||||
with open(os.path.join(self.save_dir, "log.txt"), "a") as f:
|
||||
f.write(
|
||||
f"{generation},{max_f},{min_f},{mean_f},{std_f},{cost_time}\n"
|
||||
)
|
||||
|
||||
print(
|
||||
f"Generation: {generation}, Cost time: {cost_time * 1000:.2f}ms\n",
|
||||
f"\tfitness: valid cnt: {len(valid_fitnesses)}, max: {max_f:.4f}, min: {min_f:.4f}, mean: {mean_f:.4f}, std: {std_f:.4f}\n",
|
||||
)
|
||||
|
||||
self.tensorneat_algorithm.show_details(self.alg_state, self.fitness)
|
||||
@@ -1,3 +1,5 @@
|
||||
import jax, jax.numpy as jnp
|
||||
|
||||
from .act_jnp import *
|
||||
from .act_sympy import *
|
||||
from .agg_jnp import *
|
||||
@@ -32,6 +34,7 @@ act_name2sympy = {
|
||||
"log": SympyLog,
|
||||
"exp": SympyExp,
|
||||
"abs": SympyAbs,
|
||||
"clip": SympyClip,
|
||||
}
|
||||
|
||||
agg_name2jnp = {
|
||||
@@ -40,7 +43,6 @@ agg_name2jnp = {
|
||||
"max": max_,
|
||||
"min": min_,
|
||||
"maxabs": maxabs_,
|
||||
"median": median_,
|
||||
"mean": mean_,
|
||||
}
|
||||
|
||||
@@ -50,9 +52,42 @@ agg_name2sympy = {
|
||||
"max": SympyMax,
|
||||
"min": SympyMin,
|
||||
"maxabs": SympyMaxabs,
|
||||
"median": SympyMedian,
|
||||
"mean": SympyMean,
|
||||
}
|
||||
|
||||
ACT = FunctionManager(act_name2jnp, act_name2sympy)
|
||||
AGG = FunctionManager(agg_name2jnp, agg_name2sympy)
|
||||
|
||||
def apply_activation(idx, z, act_funcs):
|
||||
"""
|
||||
calculate activation function for each node
|
||||
"""
|
||||
idx = jnp.asarray(idx, dtype=jnp.int32)
|
||||
# change idx from float to int
|
||||
|
||||
# -1 means identity activation
|
||||
res = jax.lax.cond(
|
||||
idx == -1,
|
||||
lambda: z,
|
||||
lambda: jax.lax.switch(idx, act_funcs, z),
|
||||
)
|
||||
|
||||
return res
|
||||
|
||||
def apply_aggregation(idx, z, agg_funcs):
|
||||
"""
|
||||
calculate activation function for inputs of node
|
||||
"""
|
||||
idx = jnp.asarray(idx, dtype=jnp.int32)
|
||||
|
||||
return jax.lax.cond(
|
||||
jnp.all(jnp.isnan(z)),
|
||||
lambda: jnp.nan, # all inputs are nan
|
||||
lambda: jax.lax.switch(idx, agg_funcs, z), # otherwise
|
||||
)
|
||||
|
||||
def get_func_name(func):
|
||||
name = func.__name__
|
||||
if name.endswith("_"):
|
||||
name = name[:-1]
|
||||
return name
|
||||
@@ -1,6 +1,6 @@
|
||||
import jax.numpy as jnp
|
||||
|
||||
SCALE = 5
|
||||
SCALE = 3
|
||||
|
||||
|
||||
def scaled_sigmoid_(z):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import sympy as sp
|
||||
import numpy as np
|
||||
|
||||
SCALE = 5
|
||||
SCALE = 3
|
||||
|
||||
class SympySigmoid(sp.Function):
|
||||
@classmethod
|
||||
|
||||
@@ -23,18 +23,6 @@ def maxabs_(z):
|
||||
max_abs_index = jnp.argmax(abs_z)
|
||||
return z[max_abs_index]
|
||||
|
||||
|
||||
def median_(z):
|
||||
n = jnp.sum(~jnp.isnan(z), axis=0)
|
||||
|
||||
z = jnp.sort(z) # sort
|
||||
|
||||
idx1, idx2 = (n - 1) // 2, n // 2
|
||||
median = (z[idx1] + z[idx2]) / 2
|
||||
|
||||
return median
|
||||
|
||||
|
||||
def mean_(z):
|
||||
sumation = sum_(z)
|
||||
valid_count = jnp.sum(~jnp.isnan(z), axis=0)
|
||||
|
||||
@@ -7,30 +7,18 @@ class SympySum(sp.Function):
|
||||
def eval(cls, z):
|
||||
return sp.Add(*z)
|
||||
|
||||
@classmethod
|
||||
def numerical_eval(cls, z, backend=np):
|
||||
return backend.sum(z)
|
||||
|
||||
|
||||
class SympyProduct(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return sp.Mul(*z)
|
||||
|
||||
@classmethod
|
||||
def numerical_eval(cls, z, backend=np):
|
||||
return backend.product(z)
|
||||
|
||||
|
||||
class SympyMax(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return sp.Max(*z)
|
||||
|
||||
@classmethod
|
||||
def numerical_eval(cls, z, backend=np):
|
||||
return backend.max(z)
|
||||
|
||||
|
||||
class SympyMin(sp.Function):
|
||||
@classmethod
|
||||
@@ -48,26 +36,3 @@ class SympyMean(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, z):
|
||||
return sp.Add(*z) / len(z)
|
||||
|
||||
|
||||
class SympyMedian(sp.Function):
|
||||
@classmethod
|
||||
def eval(cls, args):
|
||||
|
||||
if all(arg.is_number for arg in args):
|
||||
sorted_args = sorted(args)
|
||||
n = len(sorted_args)
|
||||
if n % 2 == 1:
|
||||
return sorted_args[n // 2]
|
||||
else:
|
||||
return (sorted_args[n // 2 - 1] + sorted_args[n // 2]) / 2
|
||||
|
||||
return None
|
||||
|
||||
def _sympystr(self, printer):
|
||||
return f"median({', '.join(map(str, self.args))})"
|
||||
|
||||
def _latex(self, printer):
|
||||
return (
|
||||
r"\mathrm{median}\left(" + ", ".join(map(sp.latex, self.args)) + r"\right)"
|
||||
)
|
||||
|
||||
@@ -1,28 +1,32 @@
|
||||
from functools import partial
|
||||
import numpy as np
|
||||
import jax.numpy as jnp
|
||||
from typing import Union, Callable
|
||||
import sympy as sp
|
||||
|
||||
|
||||
class FunctionManager:
|
||||
|
||||
def __init__(self, name2jnp, name2sympy):
|
||||
self.name2jnp = name2jnp
|
||||
self.name2sympy = name2sympy
|
||||
for name, func in name2jnp.items():
|
||||
setattr(self, name, func)
|
||||
|
||||
def get_all_funcs(self):
|
||||
all_funcs = []
|
||||
for name in self.names:
|
||||
for name in self.name2jnp:
|
||||
all_funcs.append(getattr(self, name))
|
||||
return all_funcs
|
||||
|
||||
def __getattribute__(self, name: str):
|
||||
return self.name2jnp[name]
|
||||
|
||||
def add_func(self, name, func):
|
||||
if not callable(func):
|
||||
raise ValueError("The provided function is not callable")
|
||||
if name in self.names:
|
||||
if name in self.name2jnp:
|
||||
raise ValueError(f"The provided name={name} is already in use")
|
||||
|
||||
self.name2jnp[name] = func
|
||||
setattr(self, name, func)
|
||||
|
||||
def update_sympy(self, name, sympy_cls: sp.Function):
|
||||
self.name2sympy[name] = sympy_cls
|
||||
@@ -47,3 +51,16 @@ class FunctionManager:
|
||||
if name not in self.name2sympy:
|
||||
raise ValueError(f"Func {name} doesn't have a sympy representation.")
|
||||
return self.name2sympy[name]
|
||||
|
||||
def sympy_module(self, backend: str):
|
||||
assert backend in ["jax", "numpy"]
|
||||
if backend == "jax":
|
||||
backend = jnp
|
||||
elif backend == "numpy":
|
||||
backend = np
|
||||
module = {}
|
||||
for sympy_cls in self.name2sympy.values():
|
||||
if hasattr(sympy_cls, "numerical_eval"):
|
||||
module[sympy_cls.__name__] = partial(sympy_cls.numerical_eval, backend)
|
||||
|
||||
return module
|
||||
|
||||
Reference in New Issue
Block a user