remove create_func....

This commit is contained in:
wls2002
2023-08-04 17:29:36 +08:00
parent c7fb1ddabe
commit 0e44b13291
29 changed files with 591 additions and 259 deletions

View File

@@ -5,30 +5,30 @@ from jax import numpy as jnp, Array, vmap
import numpy as np
from config import Config, HyperNeatConfig
from core import Algorithm, Substrate, State, Genome
from core import Algorithm, Substrate, State, Genome, Gene
from utils import Activation, Aggregation
from algorithm.neat import NEAT
from .substrate import analysis_substrate
from algorithm import NEAT
class HyperNEAT(Algorithm):
def __init__(self, config: Config, neat: NEAT, substrate: Type[Substrate]):
def __init__(self, config: Config, gene: Type[Gene], substrate: Type[Substrate]):
self.config = config
self.neat = neat
self.neat = NEAT(config, gene)
self.substrate = substrate
def setup(self, randkey, state=State()):
neat_key, randkey = jax.random.split(randkey)
state = state.update(
below_threshold=self.config.hyper_neat.below_threshold,
max_weight=self.config.hyper_neat.max_weight,
below_threshold=self.config.hyperneat.below_threshold,
max_weight=self.config.hyperneat.max_weight,
)
state = self.neat.setup(neat_key, state)
state = self.substrate.setup(self.config.substrate, state)
assert self.config.hyper_neat.inputs + 1 == state.input_coors.shape[0] # +1 for bias
assert self.config.hyper_neat.outputs == state.output_coors.shape[0]
assert self.config.hyperneat.inputs + 1 == state.input_coors.shape[0] # +1 for bias
assert self.config.hyperneat.outputs == state.output_coors.shape[0]
h_input_idx, h_output_idx, h_hidden_idx, query_coors, correspond_keys = analysis_substrate(state)
h_nodes = np.concatenate((h_input_idx, h_output_idx, h_hidden_idx))[..., np.newaxis]
@@ -53,7 +53,7 @@ class HyperNEAT(Algorithm):
return self.neat.tell(state, fitness)
def forward(self, state, inputs: Array, transformed: Array):
return HyperNEATGene.forward(self.config.hyper_neat, state, inputs, transformed)
return HyperNEATGene.forward(self.config.hyperneat, state, inputs, transformed)
def forward_transform(self, state: State, genome: Genome):
t = self.neat.forward_transform(state, genome)
@@ -68,6 +68,7 @@ class HyperNEAT(Algorithm):
query_res = query_res / (1 - state.below_threshold) * state.max_weight
h_conns = state.h_conns.at[:, 2:].set(query_res)
return HyperNEATGene.forward_transform(Genome(state.h_nodes, h_conns))

View File

@@ -9,9 +9,9 @@ from config import SubstrateConfig
@dataclass(frozen=True)
class NormalSubstrateConfig(SubstrateConfig):
input_coors: Tuple[Tuple[float]] = ((-1, -1), (0, -1), (1, -1))
hidden_coors: Tuple[Tuple[float]] = ((-1, 0), (0, 0), (1, 0))
output_coors: Tuple[Tuple[float]] = ((0, 1),)
input_coors: Tuple = ((-1, -1), (0, -1), (1, -1))
hidden_coors: Tuple = ((-1, 0), (0, 0), (1, 0))
output_coors: Tuple = ((0, 1),)
class NormalSubstrate(Substrate):

View File

@@ -1 +1,2 @@
from .neat import NEAT
from .gene import *

View File

@@ -66,7 +66,7 @@ class NormalGene(Gene):
node_attrs = ['bias', 'response', 'aggregation', 'activation']
conn_attrs = ['weight']
def __init__(self, config: NormalGeneConfig):
def __init__(self, config: NormalGeneConfig = NormalGeneConfig()):
self.config = config
self.act_funcs = [Activation.name2func[name] for name in config.activation_options]
self.agg_funcs = [Aggregation.name2func[name] for name in config.aggregation_options]
@@ -101,7 +101,7 @@ class NormalGene(Gene):
)
def update(self, state):
pass
return state
def new_node_attrs(self, state):
return jnp.array([state.bias_init_mean, state.response_init_mean,

View File

@@ -19,7 +19,7 @@ class RecurrentGeneConfig(NormalGeneConfig):
class RecurrentGene(NormalGene):
def __init__(self, config: RecurrentGeneConfig):
def __init__(self, config: RecurrentGeneConfig = RecurrentGeneConfig()):
self.config = config
super().__init__(config)

View File

@@ -28,9 +28,9 @@ class NEAT(Algorithm):
state = state.update(
P=self.config.basic.pop_size,
N=self.config.neat.maximum_nodes,
C=self.config.neat.maximum_conns,
S=self.config.neat.maximum_species,
N=self.config.neat.max_nodes,
C=self.config.neat.max_conns,
S=self.config.neat.max_species,
NL=1 + len(self.gene.node_attrs), # node length = (key) + attributes
CL=3 + len(self.gene.conn_attrs), # conn length = (in, out, key) + attributes
max_stagnation=self.config.neat.max_stagnation,
@@ -80,6 +80,8 @@ class NEAT(Algorithm):
return state.pop_genomes
def tell_algorithm(self, state: State, fitness):
state = self.gene.update(state)
k1, k2, randkey = jax.random.split(state.randkey, 3)
state = state.update(

View File

@@ -17,9 +17,9 @@ class NeatConfig:
network_type: str = "feedforward"
inputs: int = 2
outputs: int = 1
maximum_nodes: int = 50
maximum_conns: int = 100
maximum_species: int = 10
max_nodes: int = 50
max_conns: int = 100
max_species: int = 10
# genome config
compatibility_disjoint: float = 1
@@ -44,9 +44,9 @@ class NeatConfig:
assert self.inputs > 0, "the inputs number of neat must be greater than 0"
assert self.outputs > 0, "the outputs number of neat must be greater than 0"
assert self.maximum_nodes > 0, "the maximum nodes must be greater than 0"
assert self.maximum_conns > 0, "the maximum connections must be greater than 0"
assert self.maximum_species > 0, "the maximum species must be greater than 0"
assert self.max_nodes > 0, "the maximum nodes must be greater than 0"
assert self.max_conns > 0, "the maximum connections must be greater than 0"
assert self.max_species > 0, "the maximum species must be greater than 0"
assert self.compatibility_disjoint > 0, "the compatibility disjoint must be greater than 0"
assert self.compatibility_weight > 0, "the compatibility weight must be greater than 0"
@@ -101,7 +101,7 @@ class ProblemConfig:
class Config:
basic: BasicConfig = BasicConfig()
neat: NeatConfig = NeatConfig()
hyper_neat: HyperNeatConfig = HyperNeatConfig()
hyperneat: HyperNeatConfig = HyperNeatConfig()
gene: GeneConfig = GeneConfig()
substrate: SubstrateConfig = SubstrateConfig()
problem: ProblemConfig = ProblemConfig()

View File

@@ -6,7 +6,7 @@ class Gene:
node_attrs = []
conn_attrs = []
def __init__(self, config: GeneConfig):
def __init__(self, config: GeneConfig = GeneConfig()):
raise NotImplementedError
def setup(self, state=State()):

View File

@@ -19,6 +19,11 @@ class Genome:
def __getitem__(self, idx):
return self.__class__(self.nodes[idx], self.conns[idx])
def __eq__(self, other):
nodes_eq = jnp.alltrue((self.nodes == other.nodes) | (jnp.isnan(self.nodes) & jnp.isnan(other.nodes)))
conns_eq = jnp.alltrue((self.conns == other.conns) | (jnp.isnan(self.conns) & jnp.isnan(other.conns)))
return nodes_eq & conns_eq
def set(self, idx, value: Genome):
return self.__class__(self.nodes.at[idx].set(value.nodes), self.conns.at[idx].set(value.conns))
@@ -83,4 +88,3 @@ class Genome:
@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(*children)

View File

@@ -1,15 +1,27 @@
from typing import Callable
from config import ProblemConfig
from state import State
from .state import State
class Problem:
def __init__(self, config: ProblemConfig):
def __init__(self, problem_config: ProblemConfig = ProblemConfig()):
self.config = problem_config
def evaluate(self, randkey, state: State, act_func: Callable, params):
raise NotImplementedError
def setup(self, state=State()):
@property
def input_shape(self):
raise NotImplementedError
def evaluate(self, state: State, act_func: Callable, params):
@property
def output_shape(self):
raise NotImplementedError
def show(self, randkey, state: State, act_func: Callable, params):
"""
show how a genome perform in this problem
"""
raise NotImplementedError

View File

@@ -4,7 +4,5 @@ from config import SubstrateConfig
class Substrate:
@staticmethod
def setup(state, config: SubstrateConfig):
def setup(state, config: SubstrateConfig = SubstrateConfig()):
return state

View File

36
examples/func_fit/xor.py Normal file
View File

@@ -0,0 +1,36 @@
from config import *
from pipeline import Pipeline
from algorithm import NEAT
from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.func_fit import XOR, FuncFitConfig
if __name__ == '__main__':
config = Config(
basic=BasicConfig(
seed=42,
fitness_target=-1e-2,
pop_size=10000
),
neat=NeatConfig(
max_nodes=50,
max_conns=100,
max_species=30,
conn_add=0.8,
conn_delete=0,
node_add=0.4,
node_delete=0,
inputs=2,
outputs=1
),
gene=NormalGeneConfig(),
problem=FuncFitConfig(
error_method='rmse'
)
)
algorithm = NEAT(config, NormalGene)
pipeline = Pipeline(config, algorithm, XOR)
state = pipeline.setup()
pipeline.pre_compile(state)
state, best = pipeline.auto_run(state)
pipeline.show(state, best)

View File

@@ -0,0 +1,40 @@
from config import *
from pipeline import Pipeline
from algorithm.neat import NormalGene, NormalGeneConfig
from algorithm.hyperneat import HyperNEAT, NormalSubstrate, NormalSubstrateConfig
from problem.func_fit import XOR3d, FuncFitConfig
if __name__ == '__main__':
config = Config(
basic=BasicConfig(
seed=42,
fitness_target=0,
pop_size=1000
),
neat=NeatConfig(
max_nodes=50,
max_conns=100,
max_species=30,
inputs=4,
outputs=1
),
hyperneat=HyperNeatConfig(
inputs=3,
outputs=1
),
substrate=NormalSubstrateConfig(
input_coors=((-1, -1), (-0.5, -1), (0.5, -1), (1, -1)),
),
gene=NormalGeneConfig(
activation_default='tanh',
activation_options=('tanh', ),
),
problem=FuncFitConfig()
)
algorithm = HyperNEAT(config, NormalGene, NormalSubstrate)
pipeline = Pipeline(config, algorithm, XOR3d)
state = pipeline.setup()
state, best = pipeline.auto_run(state)
pipeline.show(state, best)

View File

@@ -0,0 +1,40 @@
from config import *
from pipeline import Pipeline
from algorithm import NEAT
from algorithm.neat.gene import RecurrentGene, RecurrentGeneConfig
from problem.func_fit import XOR3d, FuncFitConfig
if __name__ == '__main__':
config = Config(
basic=BasicConfig(
seed=42,
fitness_target=-1e-2,
generation_limit=300,
pop_size=1000
),
neat=NeatConfig(
network_type="recurrent",
max_nodes=50,
max_conns=100,
max_species=30,
conn_add=0.5,
conn_delete=0.5,
node_add=0.4,
node_delete=0.4,
inputs=3,
outputs=1
),
gene=RecurrentGeneConfig(
activate_times=10
),
problem=FuncFitConfig(
error_method='rmse'
)
)
algorithm = NEAT(config, RecurrentGene)
pipeline = Pipeline(config, algorithm, XOR3d)
state = pipeline.setup()
state, best = pipeline.auto_run(state)
pipeline.show(state, best)

View File

@@ -0,0 +1,84 @@
import jax.numpy as jnp
from config import *
from pipeline import Pipeline
from algorithm import NEAT
from algorithm.neat.gene import NormalGene, NormalGeneConfig
from problem.rl_env import GymNaxConfig, GymNaxEnv
def example_conf1():
return Config(
basic=BasicConfig(
seed=42,
fitness_target=500,
pop_size=10000
),
neat=NeatConfig(
inputs=4,
outputs=1,
),
gene=NormalGeneConfig(
activation_default='sigmoid',
activation_options=('sigmoid',),
),
problem=GymNaxConfig(
env_name='CartPole-v1',
output_transform=lambda out: jnp.where(out[0] > 0.5, 1, 0) # the action of cartpole is {0, 1}
)
)
def example_conf2():
return Config(
basic=BasicConfig(
seed=42,
fitness_target=500,
pop_size=10000
),
neat=NeatConfig(
inputs=4,
outputs=1,
),
gene=NormalGeneConfig(
activation_default='tanh',
activation_options=('tanh',),
),
problem=GymNaxConfig(
env_name='CartPole-v1',
output_transform=lambda out: jnp.where(out[0] > 0, 1, 0) # the action of cartpole is {0, 1}
)
)
def example_conf3():
return Config(
basic=BasicConfig(
seed=42,
fitness_target=500,
pop_size=10000
),
neat=NeatConfig(
inputs=4,
outputs=2,
),
gene=NormalGeneConfig(
activation_default='tanh',
activation_options=('tanh',),
),
problem=GymNaxConfig(
env_name='CartPole-v1',
output_transform=lambda out: jnp.argmax(out) # the action of cartpole is {0, 1}
)
)
if __name__ == '__main__':
# all config files above can solve cartpole
conf = example_conf3()
algorithm = NEAT(conf, NormalGene)
pipeline = Pipeline(conf, algorithm, GymNaxEnv)
state = pipeline.setup()
state, best = pipeline.auto_run(state)
pipeline.show(state, best)

View File

@@ -1,31 +0,0 @@
from functools import partial
import jax
from utils import unflatten_conns, act, agg, Activation, Aggregation
from algorithm.neat.gene import RecurrentGeneConfig
config = RecurrentGeneConfig(
activation_options=("tanh", "sigmoid"),
activation_default="tanh",
)
class A:
def __init__(self):
self.act_funcs = [Activation.name2func[name] for name in config.activation_options]
self.agg_funcs = [Aggregation.name2func[name] for name in config.aggregation_options]
self.isTrue = False
@partial(jax.jit, static_argnums=(0,))
def step(self):
i = jax.numpy.array([0, 1])
z = jax.numpy.array([
[1, 1],
[2, 2]
])
print(self.act_funcs)
return jax.vmap(act, in_axes=(0, 0, None))(i, z, self.act_funcs)
AA = A()
print(AA.step())

View File

@@ -1,40 +0,0 @@
import jax
import numpy as np
from config import Config, BasicConfig, NeatConfig
from pipeline import Pipeline
from algorithm import NEAT
from algorithm.neat.gene import NormalGene, NormalGeneConfig
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
def evaluate(forward_func):
"""
:param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size)
:return:
"""
outs = forward_func(xor_inputs)
outs = jax.device_get(outs)
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
return fitnesses
if __name__ == '__main__':
config = Config(
basic=BasicConfig(
fitness_target=3.9999999,
pop_size=10000
),
neat=NeatConfig(
maximum_nodes=50,
maximum_conns=100,
compatibility_threshold=4
),
gene=NormalGeneConfig()
)
algorithm = NEAT(config, NormalGene)
pipeline = Pipeline(config, algorithm)
pipeline.auto_run(evaluate)

View File

@@ -1,49 +0,0 @@
import jax
import numpy as np
from config import Config, BasicConfig, NeatConfig
from pipeline import Pipeline
from algorithm import NEAT, HyperNEAT
from algorithm.neat.gene import RecurrentGene, RecurrentGeneConfig
from algorithm.hyperneat.substrate import NormalSubstrate, NormalSubstrateConfig
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
def evaluate(forward_func):
"""
:param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size)
:return:
"""
outs = forward_func(xor_inputs)
outs = jax.device_get(outs)
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
return fitnesses
if __name__ == '__main__':
config = Config(
basic=BasicConfig(
fitness_target=3.99999,
pop_size=10000
),
neat=NeatConfig(
network_type="recurrent",
maximum_nodes=50,
maximum_conns=100,
inputs=4,
outputs=1
),
gene=RecurrentGeneConfig(
activation_default="tanh",
activation_options=("tanh",),
),
substrate=NormalSubstrateConfig(),
)
neat = NEAT(config, RecurrentGene)
hyperNEAT = HyperNEAT(config, neat, NormalSubstrate)
pipeline = Pipeline(config, hyperNEAT)
pipeline.auto_run(evaluate)

View File

@@ -1,42 +0,0 @@
import jax
import numpy as np
from config import Config, BasicConfig, NeatConfig
from pipeline import Pipeline
from algorithm import NEAT
from algorithm.neat.gene import RecurrentGene, RecurrentGeneConfig
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
def evaluate(forward_func):
"""
:param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size)
:return:
"""
outs = forward_func(xor_inputs)
outs = jax.device_get(outs)
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
return fitnesses
if __name__ == '__main__':
config = Config(
basic=BasicConfig(
fitness_target=3.99999,
pop_size=10000
),
neat=NeatConfig(
network_type="recurrent",
maximum_nodes=50,
maximum_conns=100
),
gene=RecurrentGeneConfig(
activate_times=3
)
)
algorithm = NEAT(config, RecurrentGene)
pipeline = Pipeline(config, algorithm)
pipeline.auto_run(evaluate)

View File

@@ -1,83 +1,115 @@
import time
from typing import Union, Callable
from functools import partial
from typing import Type
import jax
from jax import vmap, jit
import time
import numpy as np
from algorithm import NEAT, HyperNEAT
from config import Config
from core import Algorithm, Genome
from core import State, Algorithm, Problem
class Pipeline:
"""
Simple pipeline.
"""
def __init__(self, config: Config, algorithm: Algorithm):
def __init__(self, config: Config, algorithm: Algorithm, problem_type: Type[Problem]):
self.config = config
self.algorithm = algorithm
self.problem = problem_type(config.problem)
randkey = jax.random.PRNGKey(config.basic.seed)
self.state = algorithm.setup(randkey)
if isinstance(algorithm, NEAT):
assert config.neat.inputs == self.problem.input_shape[-1]
elif isinstance(algorithm, HyperNEAT):
assert config.hyperneat.inputs == self.problem.input_shape[-1]
else:
raise NotImplementedError
self.act_func = self.algorithm.act
for _ in range(len(self.problem.input_shape) - 1):
self.act_func = jax.vmap(self.act_func, in_axes=(None, 0, None))
self.best_genome = None
self.best_fitness = float('-inf')
self.generation_timestamp = time.time()
self.generation_timestamp = None
self.evaluate_time = 0
def setup(self):
key = jax.random.PRNGKey(self.config.basic.seed)
algorithm_key, evaluate_key = jax.random.split(key, 2)
state = State()
state = self.algorithm.setup(algorithm_key, state)
return state.update(
evaluate_key=evaluate_key
)
self.act_func = jit(self.algorithm.act)
self.batch_act_func = jit(vmap(self.act_func, in_axes=(None, 0, None)))
self.pop_batch_act_func = jit(vmap(self.batch_act_func, in_axes=(None, None, 0)))
self.forward_transform_func = jit(vmap(self.algorithm.forward_transform, in_axes=(None, 0)))
self.tell_func = jit(self.algorithm.tell)
@partial(jax.jit, static_argnums=(0,))
def step(self, state):
def ask(self):
pop_transforms = self.forward_transform_func(self.state, self.state.pop_genomes)
return lambda inputs: self.pop_batch_act_func(self.state, inputs, pop_transforms)
key, sub_key = jax.random.split(state.evaluate_key)
keys = jax.random.split(key, self.config.basic.pop_size)
def tell(self, fitness):
# self.state = self.tell_func(self.state, fitness)
new_state = self.tell_func(self.state, fitness)
self.state = new_state
pop = self.algorithm.ask(state)
def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"):
pop_transformed = jax.vmap(self.algorithm.transform, in_axes=(None, 0))(state, pop)
fitnesses = jax.vmap(self.problem.evaluate, in_axes=(0, None, None, 0))(keys, state, self.act_func,
pop_transformed)
state = self.algorithm.tell(state, fitnesses)
return state.update(evaluate_key=sub_key), fitnesses
def auto_run(self, ini_state):
state = ini_state
for _ in range(self.config.basic.generation_limit):
forward_func = self.ask()
fitnesses = fitness_func(forward_func)
self.generation_timestamp = time.time()
if analysis is not None:
if analysis == "default":
self.default_analysis(fitnesses)
else:
assert callable(analysis), f"What the fuck you passed in? A {analysis}?"
analysis(fitnesses)
previous_pop = self.algorithm.ask(state)
state, fitnesses = self.step(state)
fitnesses = jax.device_get(fitnesses)
self.analysis(state, previous_pop, fitnesses)
if max(fitnesses) >= self.config.basic.fitness_target:
print("Fitness limit reached!")
return self.best_genome
return state, self.best_genome
self.tell(fitnesses)
print("Generation limit reached!")
return self.best_genome
return state, self.best_genome
def analysis(self, state, pop, fitnesses):
def default_analysis(self, fitnesses):
max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses)
new_timestamp = time.time()
cost_time = new_timestamp - self.generation_timestamp
self.generation_timestamp = new_timestamp
max_idx = np.argmax(fitnesses)
if fitnesses[max_idx] > self.best_fitness:
self.best_fitness = fitnesses[max_idx]
self.best_genome = Genome(self.state.pop_genomes.nodes[max_idx], self.state.pop_genomes.conns[max_idx])
self.best_genome = pop[max_idx]
member_count = jax.device_get(self.state.species_info.member_count)
member_count = jax.device_get(state.species_info.member_count)
species_sizes = [int(i) for i in member_count if i > 0]
print(f"Generation: {self.state.generation}",
print(f"Generation: {state.generation}",
f"species: {len(species_sizes)}, {species_sizes}",
f"fitness: {max_f:.6f}, {min_f:.6f}, {mean_f:.6f}, {std_f:.6f}, Cost time: {cost_time * 1000:.6f}ms")
f"fitness: {max_f:.6f}, {min_f:.6f}, {mean_f:.6f}, {std_f:.6f}, Cost time: {cost_time * 1000:.6f}ms")
def show(self, state, genome):
transformed = self.algorithm.transform(state, genome)
self.problem.show(state.evaluate_key, state, self.act_func, transformed)
def pre_compile(self, state):
tic = time.time()
print("start compile")
self.step.lower(self, state).compile()
# compiled_step = jax.jit(self.step, static_argnums=(0,)).lower(state).compile()
# self.__dict__['step'] = compiled_step
print(f"compile finished, cost time: {time.time() - tic}s")

View File

@@ -0,0 +1,3 @@
from .func_fit import FuncFit, FuncFitConfig
from .xor import XOR
from .xor3d import XOR3d

View File

@@ -0,0 +1,69 @@
from typing import Callable
from dataclasses import dataclass
import jax
import jax.numpy as jnp
from config import ProblemConfig
from core import Problem, State
@dataclass(frozen=True)
class FuncFitConfig(ProblemConfig):
error_method: str = 'mse'
def __post_init__(self):
assert self.error_method in {'mse', 'rmse', 'mae', 'mape'}
class FuncFit(Problem):
def __init__(self, config: FuncFitConfig = FuncFitConfig()):
self.config = config
super().__init__(config)
def evaluate(self, randkey, state: State, act_func: Callable, params):
predict = act_func(state, self.inputs, params)
if self.config.error_method == 'mse':
loss = jnp.mean((predict - self.targets) ** 2)
elif self.config.error_method == 'rmse':
loss = jnp.sqrt(jnp.mean((predict - self.targets) ** 2))
elif self.config.error_method == 'mae':
loss = jnp.mean(jnp.abs(predict - self.targets))
elif self.config.error_method == 'mape':
loss = jnp.mean(jnp.abs((predict - self.targets) / self.targets))
else:
raise NotImplementedError
return -loss
def show(self, randkey, state: State, act_func: Callable, params):
predict = act_func(state, self.inputs, params)
inputs, target, predict = jax.device_get([self.inputs, self.targets, predict])
loss = -self.evaluate(randkey, state, act_func, params)
msg = ""
for i in range(inputs.shape[0]):
msg += f"input: {inputs[i]}, target: {target[i]}, predict: {predict[i]}\n"
msg += f"loss: {loss}\n"
print(msg)
@property
def inputs(self):
raise NotImplementedError
@property
def targets(self):
raise NotImplementedError
@property
def input_shape(self):
raise NotImplementedError
@property
def output_shape(self):
raise NotImplementedError

View File

@@ -1,21 +0,0 @@
from dataclasses import dataclass
from typing import Callable
from config import ProblemConfig
from core import Problem, State
@dataclass(frozen=True)
class FuncFitConfig:
pass
class FuncFit(Problem):
def __init__(self, config: ProblemConfig):
self.config = ProblemConfig
def setup(self, state=State()):
pass
def evaluate(self, state: State, act_func: Callable, params):
pass

36
problem/func_fit/xor.py Normal file
View File

@@ -0,0 +1,36 @@
import numpy as np
from .func_fit import FuncFit, FuncFitConfig
class XOR(FuncFit):
def __init__(self, config: FuncFitConfig = FuncFitConfig()):
self.config = config
super().__init__(config)
@property
def inputs(self):
return np.array([
[0, 0],
[0, 1],
[1, 0],
[1, 1]
])
@property
def targets(self):
return np.array([
[0],
[1],
[1],
[0]
])
@property
def input_shape(self):
return (4, 2)
@property
def output_shape(self):
return (4, 1)

44
problem/func_fit/xor3d.py Normal file
View File

@@ -0,0 +1,44 @@
import numpy as np
from .func_fit import FuncFit, FuncFitConfig
class XOR3d(FuncFit):
def __init__(self, config: FuncFitConfig = FuncFitConfig()):
self.config = config
super().__init__(config)
@property
def inputs(self):
return np.array([
[0, 0, 0],
[0, 0, 1],
[0, 1, 0],
[0, 1, 1],
[1, 0, 0],
[1, 0, 1],
[1, 1, 0],
[1, 1, 1],
])
@property
def targets(self):
return np.array([
[0],
[1],
[1],
[0],
[1],
[0],
[0],
[1]
])
@property
def input_shape(self):
return (8, 3)
@property
def output_shape(self):
return (8, 1)

View File

@@ -0,0 +1 @@
from .gymnax_env import GymNaxEnv, GymNaxConfig

View File

@@ -0,0 +1,42 @@
from dataclasses import dataclass
from typing import Callable
import jax
import jax.numpy as jnp
import gymnax
from core import State
from .rl_env import RLEnv, RLEnvConfig
@dataclass(frozen=True)
class GymNaxConfig(RLEnvConfig):
env_name: str = "CartPole-v1"
def __post_init__(self):
assert self.env_name in gymnax.registered_envs, f"Env {self.env_name} not registered"
class GymNaxEnv(RLEnv):
def __init__(self, config: GymNaxConfig = GymNaxConfig()):
super().__init__(config)
self.config = config
self.env, self.env_params = gymnax.make(config.env_name)
def env_step(self, randkey, env_state, action):
return self.env.step(randkey, env_state, action, self.env_params)
def env_reset(self, randkey):
return self.env.reset(randkey, self.env_params)
@property
def input_shape(self):
return self.env.observation_space(self.env_params).shape
@property
def output_shape(self):
return self.env.action_space(self.env_params).shape
def show(self, randkey, state: State, act_func: Callable, params):
raise NotImplementedError("GymNax render must rely on gym 0.19.0(old version).")

70
problem/rl_env/rl_env.py Normal file
View File

@@ -0,0 +1,70 @@
from dataclasses import dataclass
from typing import Callable
from functools import partial
import jax
from config import ProblemConfig
from core import Problem, State
@dataclass(frozen=True)
class RLEnvConfig(ProblemConfig):
output_transform: Callable = lambda x: x
class RLEnv(Problem):
def __init__(self, config: RLEnvConfig = RLEnvConfig()):
super().__init__(config)
self.config = config
def evaluate(self, randkey, state: State, act_func: Callable, params):
rng_reset, rng_episode = jax.random.split(randkey)
init_obs, init_env_state = self.reset(rng_reset)
def cond_func(carry):
_, _, _, done, _ = carry
return ~done
def body_func(carry):
obs, env_state, rng, _, tr = carry # total reward
net_out = act_func(state, obs, params)
action = self.config.output_transform(net_out)
next_obs, next_env_state, reward, done, _ = self.step(rng, env_state, action)
next_rng, _ = jax.random.split(rng)
return next_obs, next_env_state, next_rng, done, tr + reward
_, _, _, _, total_reward = jax.lax.while_loop(
cond_func,
body_func,
(init_obs, init_env_state, rng_episode, False, 0.0)
)
return total_reward
@partial(jax.jit, static_argnums=(0,))
def step(self, randkey, env_state, action):
return self.env_step(randkey, env_state, action)
@partial(jax.jit, static_argnums=(0,))
def reset(self, randkey):
return self.env_reset(randkey)
def env_step(self, randkey, env_state, action):
raise NotImplementedError
def env_reset(self, randkey):
raise NotImplementedError
@property
def input_shape(self):
raise NotImplementedError
@property
def output_shape(self):
raise NotImplementedError
def show(self, randkey, state: State, act_func: Callable, params):
raise NotImplementedError