update functions. Visualize, Interpretable and with evox

This commit is contained in:
root
2024-07-12 04:35:22 +08:00
parent 5fc63fdaf1
commit 0d6e7477bf
32 changed files with 207 additions and 427 deletions

View File

@@ -3,4 +3,4 @@ from .graph import *
from .state import State
from .stateful_class import StatefulBaseClass
from .functions import ACT, AGG, apply_activation, apply_aggregation
from .functions import ACT, AGG, apply_activation, apply_aggregation, get_func_name

View File

@@ -0,0 +1,2 @@
from .algorithm_adaptor import EvoXAlgorithmAdaptor
from .tensorneat_monitor import TensorNEATMonitor

View File

@@ -0,0 +1,34 @@
import jax.numpy as jnp
from evox import Algorithm as EvoXAlgorithm, State as EvoXState, jit_class
from tensorneat.algorithm import BaseAlgorithm as TensorNEATAlgorithm
from tensorneat.common import State as TensorNEATState
@jit_class
class EvoXAlgorithmAdaptor(EvoXAlgorithm):
def __init__(self, algorithm: TensorNEATAlgorithm):
self.algorithm = algorithm
self.fixed_state = None
def setup(self, key):
neat_algorithm_state = TensorNEATState(randkey=key)
neat_algorithm_state = self.algorithm.setup(neat_algorithm_state)
self.fixed_state = neat_algorithm_state
return EvoXState(alg_state=neat_algorithm_state)
def ask(self, state: EvoXState):
population = self.algorithm.ask(state.alg_state)
return population, state
def tell(self, state: EvoXState, fitness):
fitness = jnp.where(jnp.isnan(fitness), -jnp.inf, fitness)
neat_algorithm_state = self.algorithm.tell(state.alg_state, fitness)
return state.replace(alg_state=neat_algorithm_state)
def transform(self, individual):
return self.algorithm.transform(self.fixed_state, individual)
def forward(self, transformed, inputs):
return self.algorithm.forward(self.fixed_state, transformed, inputs)

View File

@@ -0,0 +1,110 @@
import warnings
import os
import time
import numpy as np
import jax
from jax.experimental import io_callback
from evox import Monitor
from evox import State as EvoXState
from tensorneat.algorithm import BaseAlgorithm as TensorNEATAlgorithm
from tensorneat.common import State as TensorNEATState
class TensorNEATMonitor(Monitor):
def __init__(
self,
tensorneat_algorithm: TensorNEATAlgorithm,
save_dir: str = None,
is_save: bool = False,
):
super().__init__()
self.tensorneat_algorithm = tensorneat_algorithm
self.generation_timestamp = time.time()
self.alg_state: TensorNEATState = None
self.fitness = None
self.best_fitness = -np.inf
self.best_genome = None
self.is_save = is_save
if is_save:
if save_dir is None:
now = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
self.save_dir = f"./{self.__class__.__name__} {now}"
else:
self.save_dir = save_dir
print(f"save to {self.save_dir}")
if not os.path.exists(self.save_dir):
os.makedirs(self.save_dir)
self.genome_dir = os.path.join(self.save_dir, "genomes")
if not os.path.exists(self.genome_dir):
os.makedirs(self.genome_dir)
def hooks(self):
return ["pre_tell"]
def pre_tell(self, state: EvoXState, cand_sol, transformed_cand_sol, fitness, transformed_fitness):
io_callback(
self.store_info,
None,
state,
transformed_fitness,
)
def store_info(self, state: EvoXState, fitness):
self.alg_state: TensorNEATState = state.query_state("algorithm").alg_state
self.fitness = jax.device_get(fitness)
def show(self):
pop = self.tensorneat_algorithm.ask(self.alg_state)
generation = int(self.alg_state.generation)
valid_fitnesses = self.fitness[~np.isinf(self.fitness)]
max_f, min_f, mean_f, std_f = (
max(valid_fitnesses),
min(valid_fitnesses),
np.mean(valid_fitnesses),
np.std(valid_fitnesses),
)
new_timestamp = time.time()
cost_time = new_timestamp - self.generation_timestamp
self.generation_timestamp = time.time()
max_idx = np.argmax(self.fitness)
if self.fitness[max_idx] > self.best_fitness:
self.best_fitness = self.fitness[max_idx]
self.best_genome = pop[0][max_idx], pop[1][max_idx]
if self.is_save:
# save best
best_genome = jax.device_get((pop[0][max_idx], pop[1][max_idx]))
file_name = os.path.join(
self.genome_dir, f"{generation}.npz"
)
with open(file_name, "wb") as f:
np.savez(
f,
nodes=best_genome[0],
conns=best_genome[1],
fitness=self.best_fitness,
)
# append log
with open(os.path.join(self.save_dir, "log.txt"), "a") as f:
f.write(
f"{generation},{max_f},{min_f},{mean_f},{std_f},{cost_time}\n"
)
print(
f"Generation: {generation}, Cost time: {cost_time * 1000:.2f}ms\n",
f"\tfitness: valid cnt: {len(valid_fitnesses)}, max: {max_f:.4f}, min: {min_f:.4f}, mean: {mean_f:.4f}, std: {std_f:.4f}\n",
)
self.tensorneat_algorithm.show_details(self.alg_state, self.fitness)

View File

@@ -1,3 +1,5 @@
import jax, jax.numpy as jnp
from .act_jnp import *
from .act_sympy import *
from .agg_jnp import *
@@ -32,6 +34,7 @@ act_name2sympy = {
"log": SympyLog,
"exp": SympyExp,
"abs": SympyAbs,
"clip": SympyClip,
}
agg_name2jnp = {
@@ -40,7 +43,6 @@ agg_name2jnp = {
"max": max_,
"min": min_,
"maxabs": maxabs_,
"median": median_,
"mean": mean_,
}
@@ -50,9 +52,42 @@ agg_name2sympy = {
"max": SympyMax,
"min": SympyMin,
"maxabs": SympyMaxabs,
"median": SympyMedian,
"mean": SympyMean,
}
ACT = FunctionManager(act_name2jnp, act_name2sympy)
AGG = FunctionManager(agg_name2jnp, agg_name2sympy)
def apply_activation(idx, z, act_funcs):
"""
calculate activation function for each node
"""
idx = jnp.asarray(idx, dtype=jnp.int32)
# change idx from float to int
# -1 means identity activation
res = jax.lax.cond(
idx == -1,
lambda: z,
lambda: jax.lax.switch(idx, act_funcs, z),
)
return res
def apply_aggregation(idx, z, agg_funcs):
"""
calculate activation function for inputs of node
"""
idx = jnp.asarray(idx, dtype=jnp.int32)
return jax.lax.cond(
jnp.all(jnp.isnan(z)),
lambda: jnp.nan, # all inputs are nan
lambda: jax.lax.switch(idx, agg_funcs, z), # otherwise
)
def get_func_name(func):
name = func.__name__
if name.endswith("_"):
name = name[:-1]
return name

View File

@@ -1,6 +1,6 @@
import jax.numpy as jnp
SCALE = 5
SCALE = 3
def scaled_sigmoid_(z):

View File

@@ -1,7 +1,7 @@
import sympy as sp
import numpy as np
SCALE = 5
SCALE = 3
class SympySigmoid(sp.Function):
@classmethod

View File

@@ -23,18 +23,6 @@ def maxabs_(z):
max_abs_index = jnp.argmax(abs_z)
return z[max_abs_index]
def median_(z):
n = jnp.sum(~jnp.isnan(z), axis=0)
z = jnp.sort(z) # sort
idx1, idx2 = (n - 1) // 2, n // 2
median = (z[idx1] + z[idx2]) / 2
return median
def mean_(z):
sumation = sum_(z)
valid_count = jnp.sum(~jnp.isnan(z), axis=0)

View File

@@ -7,30 +7,18 @@ class SympySum(sp.Function):
def eval(cls, z):
return sp.Add(*z)
@classmethod
def numerical_eval(cls, z, backend=np):
return backend.sum(z)
class SympyProduct(sp.Function):
@classmethod
def eval(cls, z):
return sp.Mul(*z)
@classmethod
def numerical_eval(cls, z, backend=np):
return backend.product(z)
class SympyMax(sp.Function):
@classmethod
def eval(cls, z):
return sp.Max(*z)
@classmethod
def numerical_eval(cls, z, backend=np):
return backend.max(z)
class SympyMin(sp.Function):
@classmethod
@@ -48,26 +36,3 @@ class SympyMean(sp.Function):
@classmethod
def eval(cls, z):
return sp.Add(*z) / len(z)
class SympyMedian(sp.Function):
@classmethod
def eval(cls, args):
if all(arg.is_number for arg in args):
sorted_args = sorted(args)
n = len(sorted_args)
if n % 2 == 1:
return sorted_args[n // 2]
else:
return (sorted_args[n // 2 - 1] + sorted_args[n // 2]) / 2
return None
def _sympystr(self, printer):
return f"median({', '.join(map(str, self.args))})"
def _latex(self, printer):
return (
r"\mathrm{median}\left(" + ", ".join(map(sp.latex, self.args)) + r"\right)"
)

View File

@@ -1,28 +1,32 @@
from functools import partial
import numpy as np
import jax.numpy as jnp
from typing import Union, Callable
import sympy as sp
class FunctionManager:
def __init__(self, name2jnp, name2sympy):
self.name2jnp = name2jnp
self.name2sympy = name2sympy
for name, func in name2jnp.items():
setattr(self, name, func)
def get_all_funcs(self):
all_funcs = []
for name in self.names:
for name in self.name2jnp:
all_funcs.append(getattr(self, name))
return all_funcs
def __getattribute__(self, name: str):
return self.name2jnp[name]
def add_func(self, name, func):
if not callable(func):
raise ValueError("The provided function is not callable")
if name in self.names:
if name in self.name2jnp:
raise ValueError(f"The provided name={name} is already in use")
self.name2jnp[name] = func
setattr(self, name, func)
def update_sympy(self, name, sympy_cls: sp.Function):
self.name2sympy[name] = sympy_cls
@@ -47,3 +51,16 @@ class FunctionManager:
if name not in self.name2sympy:
raise ValueError(f"Func {name} doesn't have a sympy representation.")
return self.name2sympy[name]
def sympy_module(self, backend: str):
assert backend in ["jax", "numpy"]
if backend == "jax":
backend = jnp
elif backend == "numpy":
backend = np
module = {}
for sympy_cls in self.name2sympy.values():
if hasattr(sympy_cls, "numerical_eval"):
module[sympy_cls.__name__] = partial(sympy_cls.numerical_eval, backend)
return module