remove create_func....
This commit is contained in:
@@ -5,30 +5,30 @@ from jax import numpy as jnp, Array, vmap
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from config import Config, HyperNeatConfig
|
from config import Config, HyperNeatConfig
|
||||||
from core import Algorithm, Substrate, State, Genome
|
from core import Algorithm, Substrate, State, Genome, Gene
|
||||||
from utils import Activation, Aggregation
|
from utils import Activation, Aggregation
|
||||||
from algorithm.neat import NEAT
|
|
||||||
from .substrate import analysis_substrate
|
from .substrate import analysis_substrate
|
||||||
|
from algorithm import NEAT
|
||||||
|
|
||||||
|
|
||||||
class HyperNEAT(Algorithm):
|
class HyperNEAT(Algorithm):
|
||||||
|
|
||||||
def __init__(self, config: Config, neat: NEAT, substrate: Type[Substrate]):
|
def __init__(self, config: Config, gene: Type[Gene], substrate: Type[Substrate]):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.neat = neat
|
self.neat = NEAT(config, gene)
|
||||||
self.substrate = substrate
|
self.substrate = substrate
|
||||||
|
|
||||||
def setup(self, randkey, state=State()):
|
def setup(self, randkey, state=State()):
|
||||||
neat_key, randkey = jax.random.split(randkey)
|
neat_key, randkey = jax.random.split(randkey)
|
||||||
state = state.update(
|
state = state.update(
|
||||||
below_threshold=self.config.hyper_neat.below_threshold,
|
below_threshold=self.config.hyperneat.below_threshold,
|
||||||
max_weight=self.config.hyper_neat.max_weight,
|
max_weight=self.config.hyperneat.max_weight,
|
||||||
)
|
)
|
||||||
state = self.neat.setup(neat_key, state)
|
state = self.neat.setup(neat_key, state)
|
||||||
state = self.substrate.setup(self.config.substrate, state)
|
state = self.substrate.setup(self.config.substrate, state)
|
||||||
|
|
||||||
assert self.config.hyper_neat.inputs + 1 == state.input_coors.shape[0] # +1 for bias
|
assert self.config.hyperneat.inputs + 1 == state.input_coors.shape[0] # +1 for bias
|
||||||
assert self.config.hyper_neat.outputs == state.output_coors.shape[0]
|
assert self.config.hyperneat.outputs == state.output_coors.shape[0]
|
||||||
|
|
||||||
h_input_idx, h_output_idx, h_hidden_idx, query_coors, correspond_keys = analysis_substrate(state)
|
h_input_idx, h_output_idx, h_hidden_idx, query_coors, correspond_keys = analysis_substrate(state)
|
||||||
h_nodes = np.concatenate((h_input_idx, h_output_idx, h_hidden_idx))[..., np.newaxis]
|
h_nodes = np.concatenate((h_input_idx, h_output_idx, h_hidden_idx))[..., np.newaxis]
|
||||||
@@ -53,7 +53,7 @@ class HyperNEAT(Algorithm):
|
|||||||
return self.neat.tell(state, fitness)
|
return self.neat.tell(state, fitness)
|
||||||
|
|
||||||
def forward(self, state, inputs: Array, transformed: Array):
|
def forward(self, state, inputs: Array, transformed: Array):
|
||||||
return HyperNEATGene.forward(self.config.hyper_neat, state, inputs, transformed)
|
return HyperNEATGene.forward(self.config.hyperneat, state, inputs, transformed)
|
||||||
|
|
||||||
def forward_transform(self, state: State, genome: Genome):
|
def forward_transform(self, state: State, genome: Genome):
|
||||||
t = self.neat.forward_transform(state, genome)
|
t = self.neat.forward_transform(state, genome)
|
||||||
@@ -68,6 +68,7 @@ class HyperNEAT(Algorithm):
|
|||||||
query_res = query_res / (1 - state.below_threshold) * state.max_weight
|
query_res = query_res / (1 - state.below_threshold) * state.max_weight
|
||||||
|
|
||||||
h_conns = state.h_conns.at[:, 2:].set(query_res)
|
h_conns = state.h_conns.at[:, 2:].set(query_res)
|
||||||
|
|
||||||
return HyperNEATGene.forward_transform(Genome(state.h_nodes, h_conns))
|
return HyperNEATGene.forward_transform(Genome(state.h_nodes, h_conns))
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -9,9 +9,9 @@ from config import SubstrateConfig
|
|||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class NormalSubstrateConfig(SubstrateConfig):
|
class NormalSubstrateConfig(SubstrateConfig):
|
||||||
input_coors: Tuple[Tuple[float]] = ((-1, -1), (0, -1), (1, -1))
|
input_coors: Tuple = ((-1, -1), (0, -1), (1, -1))
|
||||||
hidden_coors: Tuple[Tuple[float]] = ((-1, 0), (0, 0), (1, 0))
|
hidden_coors: Tuple = ((-1, 0), (0, 0), (1, 0))
|
||||||
output_coors: Tuple[Tuple[float]] = ((0, 1),)
|
output_coors: Tuple = ((0, 1),)
|
||||||
|
|
||||||
|
|
||||||
class NormalSubstrate(Substrate):
|
class NormalSubstrate(Substrate):
|
||||||
|
|||||||
@@ -1 +1,2 @@
|
|||||||
from .neat import NEAT
|
from .neat import NEAT
|
||||||
|
from .gene import *
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ class NormalGene(Gene):
|
|||||||
node_attrs = ['bias', 'response', 'aggregation', 'activation']
|
node_attrs = ['bias', 'response', 'aggregation', 'activation']
|
||||||
conn_attrs = ['weight']
|
conn_attrs = ['weight']
|
||||||
|
|
||||||
def __init__(self, config: NormalGeneConfig):
|
def __init__(self, config: NormalGeneConfig = NormalGeneConfig()):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.act_funcs = [Activation.name2func[name] for name in config.activation_options]
|
self.act_funcs = [Activation.name2func[name] for name in config.activation_options]
|
||||||
self.agg_funcs = [Aggregation.name2func[name] for name in config.aggregation_options]
|
self.agg_funcs = [Aggregation.name2func[name] for name in config.aggregation_options]
|
||||||
@@ -101,7 +101,7 @@ class NormalGene(Gene):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def update(self, state):
|
def update(self, state):
|
||||||
pass
|
return state
|
||||||
|
|
||||||
def new_node_attrs(self, state):
|
def new_node_attrs(self, state):
|
||||||
return jnp.array([state.bias_init_mean, state.response_init_mean,
|
return jnp.array([state.bias_init_mean, state.response_init_mean,
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ class RecurrentGeneConfig(NormalGeneConfig):
|
|||||||
|
|
||||||
class RecurrentGene(NormalGene):
|
class RecurrentGene(NormalGene):
|
||||||
|
|
||||||
def __init__(self, config: RecurrentGeneConfig):
|
def __init__(self, config: RecurrentGeneConfig = RecurrentGeneConfig()):
|
||||||
self.config = config
|
self.config = config
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
|
|||||||
@@ -28,9 +28,9 @@ class NEAT(Algorithm):
|
|||||||
|
|
||||||
state = state.update(
|
state = state.update(
|
||||||
P=self.config.basic.pop_size,
|
P=self.config.basic.pop_size,
|
||||||
N=self.config.neat.maximum_nodes,
|
N=self.config.neat.max_nodes,
|
||||||
C=self.config.neat.maximum_conns,
|
C=self.config.neat.max_conns,
|
||||||
S=self.config.neat.maximum_species,
|
S=self.config.neat.max_species,
|
||||||
NL=1 + len(self.gene.node_attrs), # node length = (key) + attributes
|
NL=1 + len(self.gene.node_attrs), # node length = (key) + attributes
|
||||||
CL=3 + len(self.gene.conn_attrs), # conn length = (in, out, key) + attributes
|
CL=3 + len(self.gene.conn_attrs), # conn length = (in, out, key) + attributes
|
||||||
max_stagnation=self.config.neat.max_stagnation,
|
max_stagnation=self.config.neat.max_stagnation,
|
||||||
@@ -80,6 +80,8 @@ class NEAT(Algorithm):
|
|||||||
return state.pop_genomes
|
return state.pop_genomes
|
||||||
|
|
||||||
def tell_algorithm(self, state: State, fitness):
|
def tell_algorithm(self, state: State, fitness):
|
||||||
|
state = self.gene.update(state)
|
||||||
|
|
||||||
k1, k2, randkey = jax.random.split(state.randkey, 3)
|
k1, k2, randkey = jax.random.split(state.randkey, 3)
|
||||||
|
|
||||||
state = state.update(
|
state = state.update(
|
||||||
|
|||||||
@@ -17,9 +17,9 @@ class NeatConfig:
|
|||||||
network_type: str = "feedforward"
|
network_type: str = "feedforward"
|
||||||
inputs: int = 2
|
inputs: int = 2
|
||||||
outputs: int = 1
|
outputs: int = 1
|
||||||
maximum_nodes: int = 50
|
max_nodes: int = 50
|
||||||
maximum_conns: int = 100
|
max_conns: int = 100
|
||||||
maximum_species: int = 10
|
max_species: int = 10
|
||||||
|
|
||||||
# genome config
|
# genome config
|
||||||
compatibility_disjoint: float = 1
|
compatibility_disjoint: float = 1
|
||||||
@@ -44,9 +44,9 @@ class NeatConfig:
|
|||||||
assert self.inputs > 0, "the inputs number of neat must be greater than 0"
|
assert self.inputs > 0, "the inputs number of neat must be greater than 0"
|
||||||
assert self.outputs > 0, "the outputs number of neat must be greater than 0"
|
assert self.outputs > 0, "the outputs number of neat must be greater than 0"
|
||||||
|
|
||||||
assert self.maximum_nodes > 0, "the maximum nodes must be greater than 0"
|
assert self.max_nodes > 0, "the maximum nodes must be greater than 0"
|
||||||
assert self.maximum_conns > 0, "the maximum connections must be greater than 0"
|
assert self.max_conns > 0, "the maximum connections must be greater than 0"
|
||||||
assert self.maximum_species > 0, "the maximum species must be greater than 0"
|
assert self.max_species > 0, "the maximum species must be greater than 0"
|
||||||
|
|
||||||
assert self.compatibility_disjoint > 0, "the compatibility disjoint must be greater than 0"
|
assert self.compatibility_disjoint > 0, "the compatibility disjoint must be greater than 0"
|
||||||
assert self.compatibility_weight > 0, "the compatibility weight must be greater than 0"
|
assert self.compatibility_weight > 0, "the compatibility weight must be greater than 0"
|
||||||
@@ -101,7 +101,7 @@ class ProblemConfig:
|
|||||||
class Config:
|
class Config:
|
||||||
basic: BasicConfig = BasicConfig()
|
basic: BasicConfig = BasicConfig()
|
||||||
neat: NeatConfig = NeatConfig()
|
neat: NeatConfig = NeatConfig()
|
||||||
hyper_neat: HyperNeatConfig = HyperNeatConfig()
|
hyperneat: HyperNeatConfig = HyperNeatConfig()
|
||||||
gene: GeneConfig = GeneConfig()
|
gene: GeneConfig = GeneConfig()
|
||||||
substrate: SubstrateConfig = SubstrateConfig()
|
substrate: SubstrateConfig = SubstrateConfig()
|
||||||
problem: ProblemConfig = ProblemConfig()
|
problem: ProblemConfig = ProblemConfig()
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ class Gene:
|
|||||||
node_attrs = []
|
node_attrs = []
|
||||||
conn_attrs = []
|
conn_attrs = []
|
||||||
|
|
||||||
def __init__(self, config: GeneConfig):
|
def __init__(self, config: GeneConfig = GeneConfig()):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def setup(self, state=State()):
|
def setup(self, state=State()):
|
||||||
|
|||||||
@@ -19,6 +19,11 @@ class Genome:
|
|||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
return self.__class__(self.nodes[idx], self.conns[idx])
|
return self.__class__(self.nodes[idx], self.conns[idx])
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
nodes_eq = jnp.alltrue((self.nodes == other.nodes) | (jnp.isnan(self.nodes) & jnp.isnan(other.nodes)))
|
||||||
|
conns_eq = jnp.alltrue((self.conns == other.conns) | (jnp.isnan(self.conns) & jnp.isnan(other.conns)))
|
||||||
|
return nodes_eq & conns_eq
|
||||||
|
|
||||||
def set(self, idx, value: Genome):
|
def set(self, idx, value: Genome):
|
||||||
return self.__class__(self.nodes.at[idx].set(value.nodes), self.conns.at[idx].set(value.conns))
|
return self.__class__(self.nodes.at[idx].set(value.nodes), self.conns.at[idx].set(value.conns))
|
||||||
|
|
||||||
@@ -83,4 +88,3 @@ class Genome:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def tree_unflatten(cls, aux_data, children):
|
def tree_unflatten(cls, aux_data, children):
|
||||||
return cls(*children)
|
return cls(*children)
|
||||||
|
|
||||||
|
|||||||
@@ -1,15 +1,27 @@
|
|||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
from config import ProblemConfig
|
from config import ProblemConfig
|
||||||
from state import State
|
from .state import State
|
||||||
|
|
||||||
|
|
||||||
class Problem:
|
class Problem:
|
||||||
|
|
||||||
def __init__(self, config: ProblemConfig):
|
def __init__(self, problem_config: ProblemConfig = ProblemConfig()):
|
||||||
|
self.config = problem_config
|
||||||
|
|
||||||
|
def evaluate(self, randkey, state: State, act_func: Callable, params):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def setup(self, state=State()):
|
@property
|
||||||
|
def input_shape(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def evaluate(self, state: State, act_func: Callable, params):
|
@property
|
||||||
|
def output_shape(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def show(self, randkey, state: State, act_func: Callable, params):
|
||||||
|
"""
|
||||||
|
show how a genome perform in this problem
|
||||||
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
@@ -4,7 +4,5 @@ from config import SubstrateConfig
|
|||||||
class Substrate:
|
class Substrate:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def setup(state, config: SubstrateConfig):
|
def setup(state, config: SubstrateConfig = SubstrateConfig()):
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
36
examples/func_fit/xor.py
Normal file
36
examples/func_fit/xor.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
from config import *
|
||||||
|
from pipeline import Pipeline
|
||||||
|
from algorithm import NEAT
|
||||||
|
from algorithm.neat.gene import NormalGene, NormalGeneConfig
|
||||||
|
from problem.func_fit import XOR, FuncFitConfig
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
config = Config(
|
||||||
|
basic=BasicConfig(
|
||||||
|
seed=42,
|
||||||
|
fitness_target=-1e-2,
|
||||||
|
pop_size=10000
|
||||||
|
),
|
||||||
|
neat=NeatConfig(
|
||||||
|
max_nodes=50,
|
||||||
|
max_conns=100,
|
||||||
|
max_species=30,
|
||||||
|
conn_add=0.8,
|
||||||
|
conn_delete=0,
|
||||||
|
node_add=0.4,
|
||||||
|
node_delete=0,
|
||||||
|
inputs=2,
|
||||||
|
outputs=1
|
||||||
|
),
|
||||||
|
gene=NormalGeneConfig(),
|
||||||
|
problem=FuncFitConfig(
|
||||||
|
error_method='rmse'
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
algorithm = NEAT(config, NormalGene)
|
||||||
|
pipeline = Pipeline(config, algorithm, XOR)
|
||||||
|
state = pipeline.setup()
|
||||||
|
pipeline.pre_compile(state)
|
||||||
|
state, best = pipeline.auto_run(state)
|
||||||
|
pipeline.show(state, best)
|
||||||
40
examples/func_fit/xor_hyperneat.py
Normal file
40
examples/func_fit/xor_hyperneat.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
from config import *
|
||||||
|
from pipeline import Pipeline
|
||||||
|
from algorithm.neat import NormalGene, NormalGeneConfig
|
||||||
|
from algorithm.hyperneat import HyperNEAT, NormalSubstrate, NormalSubstrateConfig
|
||||||
|
from problem.func_fit import XOR3d, FuncFitConfig
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
config = Config(
|
||||||
|
basic=BasicConfig(
|
||||||
|
seed=42,
|
||||||
|
fitness_target=0,
|
||||||
|
pop_size=1000
|
||||||
|
),
|
||||||
|
neat=NeatConfig(
|
||||||
|
max_nodes=50,
|
||||||
|
max_conns=100,
|
||||||
|
max_species=30,
|
||||||
|
inputs=4,
|
||||||
|
outputs=1
|
||||||
|
),
|
||||||
|
hyperneat=HyperNeatConfig(
|
||||||
|
inputs=3,
|
||||||
|
outputs=1
|
||||||
|
),
|
||||||
|
substrate=NormalSubstrateConfig(
|
||||||
|
input_coors=((-1, -1), (-0.5, -1), (0.5, -1), (1, -1)),
|
||||||
|
),
|
||||||
|
gene=NormalGeneConfig(
|
||||||
|
activation_default='tanh',
|
||||||
|
activation_options=('tanh', ),
|
||||||
|
),
|
||||||
|
problem=FuncFitConfig()
|
||||||
|
)
|
||||||
|
|
||||||
|
algorithm = HyperNEAT(config, NormalGene, NormalSubstrate)
|
||||||
|
pipeline = Pipeline(config, algorithm, XOR3d)
|
||||||
|
state = pipeline.setup()
|
||||||
|
state, best = pipeline.auto_run(state)
|
||||||
|
pipeline.show(state, best)
|
||||||
40
examples/func_fit/xor_recurrent.py
Normal file
40
examples/func_fit/xor_recurrent.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
from config import *
|
||||||
|
from pipeline import Pipeline
|
||||||
|
from algorithm import NEAT
|
||||||
|
from algorithm.neat.gene import RecurrentGene, RecurrentGeneConfig
|
||||||
|
from problem.func_fit import XOR3d, FuncFitConfig
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
config = Config(
|
||||||
|
basic=BasicConfig(
|
||||||
|
seed=42,
|
||||||
|
fitness_target=-1e-2,
|
||||||
|
generation_limit=300,
|
||||||
|
pop_size=1000
|
||||||
|
),
|
||||||
|
neat=NeatConfig(
|
||||||
|
network_type="recurrent",
|
||||||
|
max_nodes=50,
|
||||||
|
max_conns=100,
|
||||||
|
max_species=30,
|
||||||
|
conn_add=0.5,
|
||||||
|
conn_delete=0.5,
|
||||||
|
node_add=0.4,
|
||||||
|
node_delete=0.4,
|
||||||
|
inputs=3,
|
||||||
|
outputs=1
|
||||||
|
),
|
||||||
|
gene=RecurrentGeneConfig(
|
||||||
|
activate_times=10
|
||||||
|
),
|
||||||
|
problem=FuncFitConfig(
|
||||||
|
error_method='rmse'
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
algorithm = NEAT(config, RecurrentGene)
|
||||||
|
pipeline = Pipeline(config, algorithm, XOR3d)
|
||||||
|
state = pipeline.setup()
|
||||||
|
state, best = pipeline.auto_run(state)
|
||||||
|
pipeline.show(state, best)
|
||||||
84
examples/gymnax/cartpole.py
Normal file
84
examples/gymnax/cartpole.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
import jax.numpy as jnp
|
||||||
|
|
||||||
|
from config import *
|
||||||
|
from pipeline import Pipeline
|
||||||
|
from algorithm import NEAT
|
||||||
|
from algorithm.neat.gene import NormalGene, NormalGeneConfig
|
||||||
|
from problem.rl_env import GymNaxConfig, GymNaxEnv
|
||||||
|
|
||||||
|
|
||||||
|
def example_conf1():
|
||||||
|
return Config(
|
||||||
|
basic=BasicConfig(
|
||||||
|
seed=42,
|
||||||
|
fitness_target=500,
|
||||||
|
pop_size=10000
|
||||||
|
),
|
||||||
|
neat=NeatConfig(
|
||||||
|
inputs=4,
|
||||||
|
outputs=1,
|
||||||
|
),
|
||||||
|
gene=NormalGeneConfig(
|
||||||
|
activation_default='sigmoid',
|
||||||
|
activation_options=('sigmoid',),
|
||||||
|
),
|
||||||
|
problem=GymNaxConfig(
|
||||||
|
env_name='CartPole-v1',
|
||||||
|
output_transform=lambda out: jnp.where(out[0] > 0.5, 1, 0) # the action of cartpole is {0, 1}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def example_conf2():
|
||||||
|
return Config(
|
||||||
|
basic=BasicConfig(
|
||||||
|
seed=42,
|
||||||
|
fitness_target=500,
|
||||||
|
pop_size=10000
|
||||||
|
),
|
||||||
|
neat=NeatConfig(
|
||||||
|
inputs=4,
|
||||||
|
outputs=1,
|
||||||
|
),
|
||||||
|
gene=NormalGeneConfig(
|
||||||
|
activation_default='tanh',
|
||||||
|
activation_options=('tanh',),
|
||||||
|
),
|
||||||
|
problem=GymNaxConfig(
|
||||||
|
env_name='CartPole-v1',
|
||||||
|
output_transform=lambda out: jnp.where(out[0] > 0, 1, 0) # the action of cartpole is {0, 1}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def example_conf3():
|
||||||
|
return Config(
|
||||||
|
basic=BasicConfig(
|
||||||
|
seed=42,
|
||||||
|
fitness_target=500,
|
||||||
|
pop_size=10000
|
||||||
|
),
|
||||||
|
neat=NeatConfig(
|
||||||
|
inputs=4,
|
||||||
|
outputs=2,
|
||||||
|
),
|
||||||
|
gene=NormalGeneConfig(
|
||||||
|
activation_default='tanh',
|
||||||
|
activation_options=('tanh',),
|
||||||
|
),
|
||||||
|
problem=GymNaxConfig(
|
||||||
|
env_name='CartPole-v1',
|
||||||
|
output_transform=lambda out: jnp.argmax(out) # the action of cartpole is {0, 1}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# all config files above can solve cartpole
|
||||||
|
conf = example_conf3()
|
||||||
|
|
||||||
|
algorithm = NEAT(conf, NormalGene)
|
||||||
|
pipeline = Pipeline(conf, algorithm, GymNaxEnv)
|
||||||
|
state = pipeline.setup()
|
||||||
|
state, best = pipeline.auto_run(state)
|
||||||
|
pipeline.show(state, best)
|
||||||
@@ -1,31 +0,0 @@
|
|||||||
from functools import partial
|
|
||||||
import jax
|
|
||||||
|
|
||||||
from utils import unflatten_conns, act, agg, Activation, Aggregation
|
|
||||||
from algorithm.neat.gene import RecurrentGeneConfig
|
|
||||||
|
|
||||||
config = RecurrentGeneConfig(
|
|
||||||
activation_options=("tanh", "sigmoid"),
|
|
||||||
activation_default="tanh",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class A:
|
|
||||||
def __init__(self):
|
|
||||||
self.act_funcs = [Activation.name2func[name] for name in config.activation_options]
|
|
||||||
self.agg_funcs = [Aggregation.name2func[name] for name in config.aggregation_options]
|
|
||||||
self.isTrue = False
|
|
||||||
|
|
||||||
@partial(jax.jit, static_argnums=(0,))
|
|
||||||
def step(self):
|
|
||||||
i = jax.numpy.array([0, 1])
|
|
||||||
z = jax.numpy.array([
|
|
||||||
[1, 1],
|
|
||||||
[2, 2]
|
|
||||||
])
|
|
||||||
print(self.act_funcs)
|
|
||||||
return jax.vmap(act, in_axes=(0, 0, None))(i, z, self.act_funcs)
|
|
||||||
|
|
||||||
|
|
||||||
AA = A()
|
|
||||||
print(AA.step())
|
|
||||||
@@ -1,40 +0,0 @@
|
|||||||
import jax
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from config import Config, BasicConfig, NeatConfig
|
|
||||||
from pipeline import Pipeline
|
|
||||||
from algorithm import NEAT
|
|
||||||
from algorithm.neat.gene import NormalGene, NormalGeneConfig
|
|
||||||
|
|
||||||
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
|
|
||||||
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
|
|
||||||
|
|
||||||
|
|
||||||
def evaluate(forward_func):
|
|
||||||
"""
|
|
||||||
:param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size)
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
outs = forward_func(xor_inputs)
|
|
||||||
outs = jax.device_get(outs)
|
|
||||||
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
|
|
||||||
return fitnesses
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
config = Config(
|
|
||||||
basic=BasicConfig(
|
|
||||||
fitness_target=3.9999999,
|
|
||||||
pop_size=10000
|
|
||||||
),
|
|
||||||
neat=NeatConfig(
|
|
||||||
maximum_nodes=50,
|
|
||||||
maximum_conns=100,
|
|
||||||
compatibility_threshold=4
|
|
||||||
),
|
|
||||||
gene=NormalGeneConfig()
|
|
||||||
)
|
|
||||||
|
|
||||||
algorithm = NEAT(config, NormalGene)
|
|
||||||
pipeline = Pipeline(config, algorithm)
|
|
||||||
pipeline.auto_run(evaluate)
|
|
||||||
@@ -1,49 +0,0 @@
|
|||||||
import jax
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from config import Config, BasicConfig, NeatConfig
|
|
||||||
from pipeline import Pipeline
|
|
||||||
from algorithm import NEAT, HyperNEAT
|
|
||||||
from algorithm.neat.gene import RecurrentGene, RecurrentGeneConfig
|
|
||||||
from algorithm.hyperneat.substrate import NormalSubstrate, NormalSubstrateConfig
|
|
||||||
|
|
||||||
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
|
|
||||||
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
|
|
||||||
|
|
||||||
|
|
||||||
def evaluate(forward_func):
|
|
||||||
"""
|
|
||||||
:param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size)
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
outs = forward_func(xor_inputs)
|
|
||||||
outs = jax.device_get(outs)
|
|
||||||
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
|
|
||||||
return fitnesses
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
config = Config(
|
|
||||||
basic=BasicConfig(
|
|
||||||
fitness_target=3.99999,
|
|
||||||
pop_size=10000
|
|
||||||
),
|
|
||||||
neat=NeatConfig(
|
|
||||||
network_type="recurrent",
|
|
||||||
maximum_nodes=50,
|
|
||||||
maximum_conns=100,
|
|
||||||
inputs=4,
|
|
||||||
outputs=1
|
|
||||||
|
|
||||||
),
|
|
||||||
gene=RecurrentGeneConfig(
|
|
||||||
activation_default="tanh",
|
|
||||||
activation_options=("tanh",),
|
|
||||||
),
|
|
||||||
substrate=NormalSubstrateConfig(),
|
|
||||||
)
|
|
||||||
neat = NEAT(config, RecurrentGene)
|
|
||||||
hyperNEAT = HyperNEAT(config, neat, NormalSubstrate)
|
|
||||||
|
|
||||||
pipeline = Pipeline(config, hyperNEAT)
|
|
||||||
pipeline.auto_run(evaluate)
|
|
||||||
@@ -1,42 +0,0 @@
|
|||||||
import jax
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from config import Config, BasicConfig, NeatConfig
|
|
||||||
from pipeline import Pipeline
|
|
||||||
from algorithm import NEAT
|
|
||||||
from algorithm.neat.gene import RecurrentGene, RecurrentGeneConfig
|
|
||||||
|
|
||||||
|
|
||||||
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
|
|
||||||
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
|
|
||||||
|
|
||||||
|
|
||||||
def evaluate(forward_func):
|
|
||||||
"""
|
|
||||||
:param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size)
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
outs = forward_func(xor_inputs)
|
|
||||||
outs = jax.device_get(outs)
|
|
||||||
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
|
|
||||||
return fitnesses
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
config = Config(
|
|
||||||
basic=BasicConfig(
|
|
||||||
fitness_target=3.99999,
|
|
||||||
pop_size=10000
|
|
||||||
),
|
|
||||||
neat=NeatConfig(
|
|
||||||
network_type="recurrent",
|
|
||||||
maximum_nodes=50,
|
|
||||||
maximum_conns=100
|
|
||||||
),
|
|
||||||
gene=RecurrentGeneConfig(
|
|
||||||
activate_times=3
|
|
||||||
)
|
|
||||||
)
|
|
||||||
algorithm = NEAT(config, RecurrentGene)
|
|
||||||
pipeline = Pipeline(config, algorithm)
|
|
||||||
pipeline.auto_run(evaluate)
|
|
||||||
116
pipeline.py
116
pipeline.py
@@ -1,83 +1,115 @@
|
|||||||
import time
|
from functools import partial
|
||||||
from typing import Union, Callable
|
from typing import Type
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
from jax import vmap, jit
|
import time
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from algorithm import NEAT, HyperNEAT
|
||||||
from config import Config
|
from config import Config
|
||||||
from core import Algorithm, Genome
|
from core import State, Algorithm, Problem
|
||||||
|
|
||||||
|
|
||||||
class Pipeline:
|
class Pipeline:
|
||||||
"""
|
|
||||||
Simple pipeline.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: Config, algorithm: Algorithm):
|
def __init__(self, config: Config, algorithm: Algorithm, problem_type: Type[Problem]):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.algorithm = algorithm
|
self.algorithm = algorithm
|
||||||
|
self.problem = problem_type(config.problem)
|
||||||
|
|
||||||
randkey = jax.random.PRNGKey(config.basic.seed)
|
if isinstance(algorithm, NEAT):
|
||||||
self.state = algorithm.setup(randkey)
|
assert config.neat.inputs == self.problem.input_shape[-1]
|
||||||
|
|
||||||
|
elif isinstance(algorithm, HyperNEAT):
|
||||||
|
assert config.hyperneat.inputs == self.problem.input_shape[-1]
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
self.act_func = self.algorithm.act
|
||||||
|
|
||||||
|
for _ in range(len(self.problem.input_shape) - 1):
|
||||||
|
self.act_func = jax.vmap(self.act_func, in_axes=(None, 0, None))
|
||||||
|
|
||||||
self.best_genome = None
|
self.best_genome = None
|
||||||
self.best_fitness = float('-inf')
|
self.best_fitness = float('-inf')
|
||||||
self.generation_timestamp = time.time()
|
self.generation_timestamp = None
|
||||||
|
|
||||||
self.evaluate_time = 0
|
def setup(self):
|
||||||
|
key = jax.random.PRNGKey(self.config.basic.seed)
|
||||||
|
algorithm_key, evaluate_key = jax.random.split(key, 2)
|
||||||
|
state = State()
|
||||||
|
state = self.algorithm.setup(algorithm_key, state)
|
||||||
|
return state.update(
|
||||||
|
evaluate_key=evaluate_key
|
||||||
|
)
|
||||||
|
|
||||||
self.act_func = jit(self.algorithm.act)
|
@partial(jax.jit, static_argnums=(0,))
|
||||||
self.batch_act_func = jit(vmap(self.act_func, in_axes=(None, 0, None)))
|
def step(self, state):
|
||||||
self.pop_batch_act_func = jit(vmap(self.batch_act_func, in_axes=(None, None, 0)))
|
|
||||||
self.forward_transform_func = jit(vmap(self.algorithm.forward_transform, in_axes=(None, 0)))
|
|
||||||
self.tell_func = jit(self.algorithm.tell)
|
|
||||||
|
|
||||||
def ask(self):
|
key, sub_key = jax.random.split(state.evaluate_key)
|
||||||
pop_transforms = self.forward_transform_func(self.state, self.state.pop_genomes)
|
keys = jax.random.split(key, self.config.basic.pop_size)
|
||||||
return lambda inputs: self.pop_batch_act_func(self.state, inputs, pop_transforms)
|
|
||||||
|
|
||||||
def tell(self, fitness):
|
pop = self.algorithm.ask(state)
|
||||||
# self.state = self.tell_func(self.state, fitness)
|
|
||||||
new_state = self.tell_func(self.state, fitness)
|
|
||||||
self.state = new_state
|
|
||||||
|
|
||||||
def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"):
|
pop_transformed = jax.vmap(self.algorithm.transform, in_axes=(None, 0))(state, pop)
|
||||||
|
|
||||||
|
fitnesses = jax.vmap(self.problem.evaluate, in_axes=(0, None, None, 0))(keys, state, self.act_func,
|
||||||
|
pop_transformed)
|
||||||
|
|
||||||
|
state = self.algorithm.tell(state, fitnesses)
|
||||||
|
|
||||||
|
return state.update(evaluate_key=sub_key), fitnesses
|
||||||
|
|
||||||
|
def auto_run(self, ini_state):
|
||||||
|
state = ini_state
|
||||||
for _ in range(self.config.basic.generation_limit):
|
for _ in range(self.config.basic.generation_limit):
|
||||||
forward_func = self.ask()
|
|
||||||
|
|
||||||
fitnesses = fitness_func(forward_func)
|
self.generation_timestamp = time.time()
|
||||||
|
|
||||||
if analysis is not None:
|
previous_pop = self.algorithm.ask(state)
|
||||||
if analysis == "default":
|
|
||||||
self.default_analysis(fitnesses)
|
state, fitnesses = self.step(state)
|
||||||
else:
|
|
||||||
assert callable(analysis), f"What the fuck you passed in? A {analysis}?"
|
fitnesses = jax.device_get(fitnesses)
|
||||||
analysis(fitnesses)
|
|
||||||
|
self.analysis(state, previous_pop, fitnesses)
|
||||||
|
|
||||||
if max(fitnesses) >= self.config.basic.fitness_target:
|
if max(fitnesses) >= self.config.basic.fitness_target:
|
||||||
print("Fitness limit reached!")
|
print("Fitness limit reached!")
|
||||||
return self.best_genome
|
return state, self.best_genome
|
||||||
|
|
||||||
self.tell(fitnesses)
|
|
||||||
print("Generation limit reached!")
|
print("Generation limit reached!")
|
||||||
return self.best_genome
|
return state, self.best_genome
|
||||||
|
|
||||||
|
def analysis(self, state, pop, fitnesses):
|
||||||
|
|
||||||
def default_analysis(self, fitnesses):
|
|
||||||
max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses)
|
max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses)
|
||||||
|
|
||||||
new_timestamp = time.time()
|
new_timestamp = time.time()
|
||||||
|
|
||||||
cost_time = new_timestamp - self.generation_timestamp
|
cost_time = new_timestamp - self.generation_timestamp
|
||||||
self.generation_timestamp = new_timestamp
|
|
||||||
|
|
||||||
max_idx = np.argmax(fitnesses)
|
max_idx = np.argmax(fitnesses)
|
||||||
if fitnesses[max_idx] > self.best_fitness:
|
if fitnesses[max_idx] > self.best_fitness:
|
||||||
self.best_fitness = fitnesses[max_idx]
|
self.best_fitness = fitnesses[max_idx]
|
||||||
self.best_genome = Genome(self.state.pop_genomes.nodes[max_idx], self.state.pop_genomes.conns[max_idx])
|
self.best_genome = pop[max_idx]
|
||||||
|
|
||||||
member_count = jax.device_get(self.state.species_info.member_count)
|
member_count = jax.device_get(state.species_info.member_count)
|
||||||
species_sizes = [int(i) for i in member_count if i > 0]
|
species_sizes = [int(i) for i in member_count if i > 0]
|
||||||
|
|
||||||
print(f"Generation: {self.state.generation}",
|
print(f"Generation: {state.generation}",
|
||||||
f"species: {len(species_sizes)}, {species_sizes}",
|
f"species: {len(species_sizes)}, {species_sizes}",
|
||||||
f"fitness: {max_f:.6f}, {min_f:.6f}, {mean_f:.6f}, {std_f:.6f}, Cost time: {cost_time * 1000:.6f}ms")
|
f"fitness: {max_f:.6f}, {min_f:.6f}, {mean_f:.6f}, {std_f:.6f}, Cost time: {cost_time * 1000:.6f}ms")
|
||||||
|
|
||||||
|
def show(self, state, genome):
|
||||||
|
transformed = self.algorithm.transform(state, genome)
|
||||||
|
self.problem.show(state.evaluate_key, state, self.act_func, transformed)
|
||||||
|
|
||||||
|
def pre_compile(self, state):
|
||||||
|
tic = time.time()
|
||||||
|
print("start compile")
|
||||||
|
self.step.lower(self, state).compile()
|
||||||
|
# compiled_step = jax.jit(self.step, static_argnums=(0,)).lower(state).compile()
|
||||||
|
# self.__dict__['step'] = compiled_step
|
||||||
|
print(f"compile finished, cost time: {time.time() - tic}s")
|
||||||
|
|||||||
@@ -0,0 +1,3 @@
|
|||||||
|
from .func_fit import FuncFit, FuncFitConfig
|
||||||
|
from .xor import XOR
|
||||||
|
from .xor3d import XOR3d
|
||||||
|
|||||||
69
problem/func_fit/func_fit.py
Normal file
69
problem/func_fit/func_fit.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
from typing import Callable
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from config import ProblemConfig
|
||||||
|
from core import Problem, State
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class FuncFitConfig(ProblemConfig):
|
||||||
|
error_method: str = 'mse'
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
assert self.error_method in {'mse', 'rmse', 'mae', 'mape'}
|
||||||
|
|
||||||
|
|
||||||
|
class FuncFit(Problem):
|
||||||
|
|
||||||
|
def __init__(self, config: FuncFitConfig = FuncFitConfig()):
|
||||||
|
self.config = config
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
def evaluate(self, randkey, state: State, act_func: Callable, params):
|
||||||
|
|
||||||
|
predict = act_func(state, self.inputs, params)
|
||||||
|
|
||||||
|
if self.config.error_method == 'mse':
|
||||||
|
loss = jnp.mean((predict - self.targets) ** 2)
|
||||||
|
|
||||||
|
elif self.config.error_method == 'rmse':
|
||||||
|
loss = jnp.sqrt(jnp.mean((predict - self.targets) ** 2))
|
||||||
|
|
||||||
|
elif self.config.error_method == 'mae':
|
||||||
|
loss = jnp.mean(jnp.abs(predict - self.targets))
|
||||||
|
|
||||||
|
elif self.config.error_method == 'mape':
|
||||||
|
loss = jnp.mean(jnp.abs((predict - self.targets) / self.targets))
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
return -loss
|
||||||
|
|
||||||
|
def show(self, randkey, state: State, act_func: Callable, params):
|
||||||
|
predict = act_func(state, self.inputs, params)
|
||||||
|
inputs, target, predict = jax.device_get([self.inputs, self.targets, predict])
|
||||||
|
loss = -self.evaluate(randkey, state, act_func, params)
|
||||||
|
msg = ""
|
||||||
|
for i in range(inputs.shape[0]):
|
||||||
|
msg += f"input: {inputs[i]}, target: {target[i]}, predict: {predict[i]}\n"
|
||||||
|
msg += f"loss: {loss}\n"
|
||||||
|
print(msg)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def targets(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_shape(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_shape(self):
|
||||||
|
raise NotImplementedError
|
||||||
@@ -1,21 +0,0 @@
|
|||||||
from dataclasses import dataclass
|
|
||||||
from typing import Callable
|
|
||||||
|
|
||||||
from config import ProblemConfig
|
|
||||||
from core import Problem, State
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class FuncFitConfig:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class FuncFit(Problem):
|
|
||||||
def __init__(self, config: ProblemConfig):
|
|
||||||
self.config = ProblemConfig
|
|
||||||
|
|
||||||
def setup(self, state=State()):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def evaluate(self, state: State, act_func: Callable, params):
|
|
||||||
pass
|
|
||||||
36
problem/func_fit/xor.py
Normal file
36
problem/func_fit/xor.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .func_fit import FuncFit, FuncFitConfig
|
||||||
|
|
||||||
|
|
||||||
|
class XOR(FuncFit):
|
||||||
|
|
||||||
|
def __init__(self, config: FuncFitConfig = FuncFitConfig()):
|
||||||
|
self.config = config
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self):
|
||||||
|
return np.array([
|
||||||
|
[0, 0],
|
||||||
|
[0, 1],
|
||||||
|
[1, 0],
|
||||||
|
[1, 1]
|
||||||
|
])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def targets(self):
|
||||||
|
return np.array([
|
||||||
|
[0],
|
||||||
|
[1],
|
||||||
|
[1],
|
||||||
|
[0]
|
||||||
|
])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_shape(self):
|
||||||
|
return (4, 2)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_shape(self):
|
||||||
|
return (4, 1)
|
||||||
44
problem/func_fit/xor3d.py
Normal file
44
problem/func_fit/xor3d.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .func_fit import FuncFit, FuncFitConfig
|
||||||
|
|
||||||
|
|
||||||
|
class XOR3d(FuncFit):
|
||||||
|
|
||||||
|
def __init__(self, config: FuncFitConfig = FuncFitConfig()):
|
||||||
|
self.config = config
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self):
|
||||||
|
return np.array([
|
||||||
|
[0, 0, 0],
|
||||||
|
[0, 0, 1],
|
||||||
|
[0, 1, 0],
|
||||||
|
[0, 1, 1],
|
||||||
|
[1, 0, 0],
|
||||||
|
[1, 0, 1],
|
||||||
|
[1, 1, 0],
|
||||||
|
[1, 1, 1],
|
||||||
|
])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def targets(self):
|
||||||
|
return np.array([
|
||||||
|
[0],
|
||||||
|
[1],
|
||||||
|
[1],
|
||||||
|
[0],
|
||||||
|
[1],
|
||||||
|
[0],
|
||||||
|
[0],
|
||||||
|
[1]
|
||||||
|
])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_shape(self):
|
||||||
|
return (8, 3)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_shape(self):
|
||||||
|
return (8, 1)
|
||||||
1
problem/rl_env/__init__.py
Normal file
1
problem/rl_env/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from .gymnax_env import GymNaxEnv, GymNaxConfig
|
||||||
42
problem/rl_env/gymnax_env.py
Normal file
42
problem/rl_env/gymnax_env.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import gymnax
|
||||||
|
|
||||||
|
from core import State
|
||||||
|
from .rl_env import RLEnv, RLEnvConfig
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class GymNaxConfig(RLEnvConfig):
|
||||||
|
env_name: str = "CartPole-v1"
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
assert self.env_name in gymnax.registered_envs, f"Env {self.env_name} not registered"
|
||||||
|
|
||||||
|
|
||||||
|
class GymNaxEnv(RLEnv):
|
||||||
|
|
||||||
|
def __init__(self, config: GymNaxConfig = GymNaxConfig()):
|
||||||
|
super().__init__(config)
|
||||||
|
self.config = config
|
||||||
|
self.env, self.env_params = gymnax.make(config.env_name)
|
||||||
|
|
||||||
|
def env_step(self, randkey, env_state, action):
|
||||||
|
return self.env.step(randkey, env_state, action, self.env_params)
|
||||||
|
|
||||||
|
def env_reset(self, randkey):
|
||||||
|
return self.env.reset(randkey, self.env_params)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_shape(self):
|
||||||
|
return self.env.observation_space(self.env_params).shape
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_shape(self):
|
||||||
|
return self.env.action_space(self.env_params).shape
|
||||||
|
|
||||||
|
def show(self, randkey, state: State, act_func: Callable, params):
|
||||||
|
raise NotImplementedError("GymNax render must rely on gym 0.19.0(old version).")
|
||||||
70
problem/rl_env/rl_env.py
Normal file
70
problem/rl_env/rl_env.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Callable
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import jax
|
||||||
|
|
||||||
|
from config import ProblemConfig
|
||||||
|
|
||||||
|
from core import Problem, State
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class RLEnvConfig(ProblemConfig):
|
||||||
|
output_transform: Callable = lambda x: x
|
||||||
|
|
||||||
|
|
||||||
|
class RLEnv(Problem):
|
||||||
|
|
||||||
|
def __init__(self, config: RLEnvConfig = RLEnvConfig()):
|
||||||
|
super().__init__(config)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
def evaluate(self, randkey, state: State, act_func: Callable, params):
|
||||||
|
rng_reset, rng_episode = jax.random.split(randkey)
|
||||||
|
init_obs, init_env_state = self.reset(rng_reset)
|
||||||
|
|
||||||
|
def cond_func(carry):
|
||||||
|
_, _, _, done, _ = carry
|
||||||
|
return ~done
|
||||||
|
|
||||||
|
def body_func(carry):
|
||||||
|
obs, env_state, rng, _, tr = carry # total reward
|
||||||
|
net_out = act_func(state, obs, params)
|
||||||
|
action = self.config.output_transform(net_out)
|
||||||
|
next_obs, next_env_state, reward, done, _ = self.step(rng, env_state, action)
|
||||||
|
next_rng, _ = jax.random.split(rng)
|
||||||
|
return next_obs, next_env_state, next_rng, done, tr + reward
|
||||||
|
|
||||||
|
_, _, _, _, total_reward = jax.lax.while_loop(
|
||||||
|
cond_func,
|
||||||
|
body_func,
|
||||||
|
(init_obs, init_env_state, rng_episode, False, 0.0)
|
||||||
|
)
|
||||||
|
|
||||||
|
return total_reward
|
||||||
|
|
||||||
|
@partial(jax.jit, static_argnums=(0,))
|
||||||
|
def step(self, randkey, env_state, action):
|
||||||
|
return self.env_step(randkey, env_state, action)
|
||||||
|
|
||||||
|
@partial(jax.jit, static_argnums=(0,))
|
||||||
|
def reset(self, randkey):
|
||||||
|
return self.env_reset(randkey)
|
||||||
|
|
||||||
|
def env_step(self, randkey, env_state, action):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def env_reset(self, randkey):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_shape(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_shape(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def show(self, randkey, state: State, act_func: Callable, params):
|
||||||
|
raise NotImplementedError
|
||||||
Reference in New Issue
Block a user