Merge pull request #4 from EMI-Group/advance
add step_limit to rl envs; add more initialize methods;
This commit is contained in:
BIN
tensorneat/__pycache__/pipeline.cpython-311.pyc
Normal file
BIN
tensorneat/__pycache__/pipeline.cpython-311.pyc
Normal file
Binary file not shown.
BIN
tensorneat/algorithm/__pycache__/__init__.cpython-311.pyc
Normal file
BIN
tensorneat/algorithm/__pycache__/__init__.cpython-311.pyc
Normal file
Binary file not shown.
BIN
tensorneat/algorithm/__pycache__/base.cpython-311.pyc
Normal file
BIN
tensorneat/algorithm/__pycache__/base.cpython-311.pyc
Normal file
Binary file not shown.
@@ -2,5 +2,4 @@ from .ga import *
|
||||
from .gene import *
|
||||
from .genome import *
|
||||
from .species import *
|
||||
from .neat import NEAT
|
||||
|
||||
from .neat import NEAT
|
||||
BIN
tensorneat/algorithm/neat/__pycache__/__init__.cpython-311.pyc
Normal file
BIN
tensorneat/algorithm/neat/__pycache__/__init__.cpython-311.pyc
Normal file
Binary file not shown.
BIN
tensorneat/algorithm/neat/__pycache__/neat.cpython-311.pyc
Normal file
BIN
tensorneat/algorithm/neat/__pycache__/neat.cpython-311.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
tensorneat/algorithm/neat/gene/__pycache__/base.cpython-311.pyc
Normal file
BIN
tensorneat/algorithm/neat/gene/__pycache__/base.cpython-311.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -26,6 +26,17 @@ class DefaultConnGene(BaseConnGene):
|
||||
|
||||
def new_custom_attrs(self):
|
||||
return jnp.array([self.weight_init_mean])
|
||||
|
||||
def new_random_attrs(self, key):
|
||||
return jnp.array([mutate_float(key,
|
||||
self.weight_init_mean,
|
||||
self.weight_init_mean,
|
||||
1.0,
|
||||
0,
|
||||
0,
|
||||
1.0,
|
||||
)
|
||||
])
|
||||
|
||||
def mutate(self, key, conn):
|
||||
input_index = conn[0]
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -60,6 +60,16 @@ class DefaultNodeGene(BaseNodeGene):
|
||||
return jnp.array(
|
||||
[self.bias_init_mean, self.response_init_mean, self.activation_default, self.aggregation_default]
|
||||
)
|
||||
|
||||
def new_random_attrs(self, key):
|
||||
return jnp.array([
|
||||
mutate_float(key, self.bias_init_mean, self.bias_init_mean, self.bias_init_std,
|
||||
self.bias_mutate_power, self.bias_mutate_rate, self.bias_replace_rate),
|
||||
mutate_float(key, self.response_init_mean, self.response_init_mean, self.response_init_std,
|
||||
self.response_mutate_power, self.response_mutate_rate, self.response_replace_rate),
|
||||
self.activation_default,
|
||||
self.aggregation_default,
|
||||
])
|
||||
|
||||
def mutate(self, key, node):
|
||||
k1, k2, k3, k4 = jax.random.split(key, num=4)
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -19,7 +19,8 @@ class DefaultSpecies(BaseSpecies):
|
||||
genome_elitism: int = 2,
|
||||
survival_threshold: float = 0.2,
|
||||
min_species_size: int = 1,
|
||||
compatibility_threshold: float = 3.
|
||||
compatibility_threshold: float = 3.,
|
||||
initialize_method: str = 'one_hidden_node', # {'one_hidden_node', 'dense_hideen_layer', 'no_hidden_random'}
|
||||
):
|
||||
self.genome = genome
|
||||
self.pop_size = pop_size
|
||||
@@ -34,11 +35,13 @@ class DefaultSpecies(BaseSpecies):
|
||||
self.survival_threshold = survival_threshold
|
||||
self.min_species_size = min_species_size
|
||||
self.compatibility_threshold = compatibility_threshold
|
||||
self.initialize_method = initialize_method
|
||||
|
||||
self.species_arange = jnp.arange(self.species_size)
|
||||
|
||||
def setup(self, randkey):
|
||||
pop_nodes, pop_conns = initialize_population(self.pop_size, self.genome)
|
||||
k1, k2 = jax.random.split(randkey, 2)
|
||||
pop_nodes, pop_conns = initialize_population(self.pop_size, self.genome,k1, self.initialize_method)
|
||||
|
||||
species_keys = jnp.full((self.species_size,), jnp.nan) # the unique index (primary key) for each species
|
||||
best_fitness = jnp.full((self.species_size,), jnp.nan) # the best fitness of each species
|
||||
@@ -62,7 +65,7 @@ class DefaultSpecies(BaseSpecies):
|
||||
pop_nodes, pop_conns = jax.device_put((pop_nodes, pop_conns))
|
||||
|
||||
return State(
|
||||
randkey=randkey,
|
||||
randkey=k2,
|
||||
pop_nodes=pop_nodes,
|
||||
pop_conns=pop_conns,
|
||||
species_keys=species_keys,
|
||||
@@ -489,31 +492,159 @@ class DefaultSpecies(BaseSpecies):
|
||||
return jnp.where(max_cnt == 0, 0, val / max_cnt)
|
||||
|
||||
|
||||
def initialize_population(pop_size, genome):
|
||||
o_nodes = np.full((genome.max_nodes, genome.node_gene.length), np.nan) # original nodes
|
||||
o_conns = np.full((genome.max_conns, genome.conn_gene.length), np.nan) # original connections
|
||||
|
||||
|
||||
def initialize_population(pop_size, genome, randkey, init_method='default'):
|
||||
rand_keys = jax.random.split(randkey, pop_size)
|
||||
|
||||
if init_method == 'one_hidden_node':
|
||||
init_func = init_one_hidden_node
|
||||
elif init_method == 'dense_hideen_layer':
|
||||
init_func = init_dense_hideen_layer
|
||||
elif init_method == 'no_hidden_random':
|
||||
init_func = init_no_hidden_random
|
||||
else:
|
||||
raise ValueError("Unknown initialization method: {}".format(init_method))
|
||||
|
||||
pop_nodes, pop_conns = jax.vmap(init_func, in_axes=(None, 0))(genome, rand_keys)
|
||||
|
||||
return pop_nodes, pop_conns
|
||||
|
||||
# one hidden node
|
||||
def init_one_hidden_node(genome, randkey):
|
||||
input_idx, output_idx = genome.input_idx, genome.output_idx
|
||||
new_node_key = max([*input_idx, *output_idx]) + 1
|
||||
|
||||
o_nodes[input_idx, 0] = genome.input_idx
|
||||
o_nodes[output_idx, 0] = genome.output_idx
|
||||
o_nodes[new_node_key, 0] = new_node_key # one hidden node
|
||||
o_nodes[np.concatenate([input_idx, output_idx]), 1:] = genome.node_gene.new_custom_attrs()
|
||||
o_nodes[new_node_key, 1:] = genome.node_gene.new_custom_attrs() # one hidden node
|
||||
nodes = jnp.full((genome.max_nodes, genome.node_gene.length), jnp.nan)
|
||||
conns = jnp.full((genome.max_conns, genome.conn_gene.length), jnp.nan)
|
||||
|
||||
input_conns = np.c_[input_idx, np.full_like(input_idx, new_node_key)] # input nodes to hidden
|
||||
o_conns[input_idx, 0:2] = input_conns # in key, out key
|
||||
o_conns[input_idx, 2] = True # enabled
|
||||
o_conns[input_idx, 3:] = genome.conn_gene.new_custom_attrs()
|
||||
nodes = nodes.at[input_idx, 0].set(input_idx)
|
||||
nodes = nodes.at[output_idx, 0].set(output_idx)
|
||||
nodes = nodes.at[new_node_key, 0].set(new_node_key)
|
||||
|
||||
output_conns = np.c_[np.full_like(output_idx, new_node_key), output_idx] # hidden to output nodes
|
||||
o_conns[output_idx, 0:2] = output_conns # in key, out key
|
||||
o_conns[output_idx, 2] = True # enabled
|
||||
o_conns[output_idx, 3:] = genome.conn_gene.new_custom_attrs()
|
||||
rand_keys_nodes = jax.random.split(randkey, num=len(input_idx) + len(output_idx) + 1)
|
||||
input_keys, output_keys, hidden_key = rand_keys_nodes[:len(input_idx)], rand_keys_nodes[len(input_idx):len(input_idx) + len(output_idx)], rand_keys_nodes[-1]
|
||||
|
||||
# repeat origin genome for P times to create population
|
||||
pop_nodes = np.tile(o_nodes, (pop_size, 1, 1))
|
||||
pop_conns = np.tile(o_conns, (pop_size, 1, 1))
|
||||
node_attr_func = jax.vmap(genome.node_gene.new_random_attrs, in_axes=(None,0))
|
||||
input_attrs = node_attr_func(input_keys)
|
||||
output_attrs = node_attr_func(output_keys)
|
||||
hidden_attrs = genome.node_gene.new_custom_attrs(hidden_key)
|
||||
|
||||
return pop_nodes, pop_conns
|
||||
nodes = nodes.at[input_idx, 1:].set(input_attrs)
|
||||
nodes = nodes.at[output_idx, 1:].set(output_attrs)
|
||||
nodes = nodes.at[new_node_key, 1:].set(hidden_attrs)
|
||||
|
||||
input_conns = jnp.c_[input_idx, jnp.full_like(input_idx, new_node_key)]
|
||||
conns = conns.at[input_idx, 0:2].set(input_conns)
|
||||
conns = conns.at[input_idx, 2].set(True)
|
||||
|
||||
output_conns = jnp.c_[jnp.full_like(output_idx, new_node_key), output_idx]
|
||||
conns = conns.at[output_idx, 0:2].set(output_conns)
|
||||
conns = conns.at[output_idx, 2].set(True)
|
||||
|
||||
rand_keys_conns = jax.random.split(randkey, num=len(input_idx) + len(output_idx))
|
||||
input_conn_keys, output_conn_keys = rand_keys_conns[:len(input_idx)], rand_keys_conns[len(input_idx):]
|
||||
|
||||
conn_attr_func = jax.vmap(genome.conn_gene.new_random_attrs, in_axes=(None,0))
|
||||
input_conn_attrs = conn_attr_func(input_conn_keys)
|
||||
output_conn_attrs = conn_attr_func(output_conn_keys)
|
||||
|
||||
conns = conns.at[input_idx, 3:].set(input_conn_attrs)
|
||||
conns = conns.at[output_idx, 3:].set(output_conn_attrs)
|
||||
|
||||
return nodes, conns
|
||||
|
||||
|
||||
#random dense connections with 1 hidden layer
|
||||
def init_dense_hideen_layer( genome, randkey,hiddens=20):
|
||||
k1, k2, k3 = jax.random.split(randkey, num=3)
|
||||
input_idx, output_idx = genome.input_idx, genome.output_idx
|
||||
input_size = len(input_idx)
|
||||
output_size = len(output_idx)
|
||||
|
||||
hidden_idx = jnp.arange(input_size + output_size, input_size + output_size + hiddens)
|
||||
nodes = jnp.full((genome.max_nodes, genome.node_gene.length), jnp.nan, dtype=jnp.float32)
|
||||
nodes = nodes.at[input_idx, 0].set(input_idx)
|
||||
nodes = nodes.at[output_idx, 0].set(output_idx)
|
||||
nodes = nodes.at[hidden_idx, 0].set(hidden_idx)
|
||||
|
||||
total_idx = input_size + output_size + hiddens
|
||||
rand_keys_n = jax.random.split(k1, num=total_idx)
|
||||
input_keys = rand_keys_n[:input_size]
|
||||
output_keys = rand_keys_n[input_size:input_size + output_size]
|
||||
hidden_keys = rand_keys_n[input_size + output_size:]
|
||||
|
||||
node_attr_func = jax.vmap(genome.conn_gene.new_random_attrs, in_axes=(0))
|
||||
input_attrs = node_attr_func(input_keys)
|
||||
output_attrs = node_attr_func(output_keys)
|
||||
hidden_attrs = node_attr_func(hidden_keys)
|
||||
|
||||
nodes = nodes.at[input_idx, 1:].set(input_attrs)
|
||||
nodes = nodes.at[output_idx, 1:].set(output_attrs)
|
||||
nodes = nodes.at[hidden_idx, 1:].set(hidden_attrs)
|
||||
|
||||
total_connections = input_size * hiddens + hiddens * output_size
|
||||
conns = jnp.full((genome.max_conns, genome.conn_gene.length), jnp.nan, dtype=jnp.float32)
|
||||
|
||||
rand_keys_c = jax.random.split(k2, num=total_connections)
|
||||
conns_attr_func = jax.vmap(genome.node_gene.new_random_attrs, in_axes=(0))
|
||||
conns_attrs = conns_attr_func(rand_keys_c)
|
||||
|
||||
input_to_hidden_ids, hidden_ids = jnp.meshgrid(input_idx, hidden_idx, indexing='ij')
|
||||
hidden_to_output_ids, output_ids = jnp.meshgrid(hidden_idx, output_idx, indexing='ij')
|
||||
|
||||
conns = conns.at[:input_size * hiddens, 0].set(input_to_hidden_ids.flatten())
|
||||
conns = conns.at[:input_size * hiddens, 1].set(hidden_ids.flatten())
|
||||
conns = conns.at[input_size * hiddens: total_connections, 0].set(hidden_to_output_ids.flatten())
|
||||
conns = conns.at[input_size * hiddens: total_connections, 1].set(output_ids.flatten())
|
||||
conns = conns.at[:input_size * hiddens + hiddens * output_size, 2].set(True)
|
||||
conns = conns.at[:input_size * hiddens + hiddens * output_size, 3:].set(conns_attrs)
|
||||
|
||||
return nodes, conns
|
||||
|
||||
# random sparse connections with no hidden nodes
|
||||
def init_no_hidden_random(genome, randkey):
|
||||
k1, k2, k3 = jax.random.split(randkey, num=3)
|
||||
input_idx, output_idx = genome.input_idx, genome.output_idx
|
||||
|
||||
nodes = jnp.full((genome.max_nodes, genome.node_gene.length), jnp.nan)
|
||||
nodes = nodes.at[input_idx, 0].set(input_idx)
|
||||
nodes = nodes.at[output_idx, 0].set(output_idx)
|
||||
|
||||
total_idx = len(input_idx) + len(output_idx)
|
||||
rand_keys_n = jax.random.split(k1, num=total_idx)
|
||||
input_keys = rand_keys_n[:len(input_idx)]
|
||||
output_keys = rand_keys_n[len(input_idx):]
|
||||
|
||||
node_attr_func = jax.vmap(genome.node_gene.new_random_attrs, in_axes=(0))
|
||||
input_attrs = node_attr_func(input_keys)
|
||||
output_attrs = node_attr_func(output_keys)
|
||||
nodes = nodes.at[input_idx, 1:].set(input_attrs)
|
||||
nodes = nodes.at[output_idx, 1:].set(output_attrs)
|
||||
|
||||
conns = jnp.full((genome.max_conns, genome.conn_gene.length), jnp.nan)
|
||||
|
||||
num_connections_per_output = 4
|
||||
total_connections = len(output_idx) * num_connections_per_output
|
||||
|
||||
def create_connections_for_output(key):
|
||||
permuted_inputs = jax.random.permutation(key, input_idx)
|
||||
selected_inputs = permuted_inputs[:num_connections_per_output]
|
||||
return selected_inputs
|
||||
|
||||
conn_keys = jax.random.split(k2, num=len(output_idx))
|
||||
connections = jax.vmap(create_connections_for_output)(conn_keys)
|
||||
connections = connections.flatten()
|
||||
|
||||
output_repeats = jnp.repeat(output_idx, num_connections_per_output)
|
||||
|
||||
rand_keys_c = jax.random.split(k3, num=total_connections)
|
||||
conns_attr_func = jax.vmap(genome.conn_gene.new_random_attrs, in_axes=(0))
|
||||
conns_attrs = conns_attr_func(rand_keys_c)
|
||||
|
||||
conns = conns.at[:total_connections, 0].set(connections)
|
||||
conns = conns.at[:total_connections, 1].set(output_repeats)
|
||||
conns = conns.at[:total_connections, 2].set(True) # enabled
|
||||
conns = conns.at[:total_connections, 3:].set(conns_attrs)
|
||||
|
||||
return nodes, conns
|
||||
|
||||
@@ -83,6 +83,7 @@ class Pipeline:
|
||||
print("start compile")
|
||||
tic = time.time()
|
||||
compiled_step = jax.jit(self.step).lower(ini_state).compile()
|
||||
|
||||
print(f"compile finished, cost time: {time.time() - tic:.6f}s", )
|
||||
for _ in range(self.generation_limit):
|
||||
|
||||
@@ -90,6 +91,7 @@ class Pipeline:
|
||||
|
||||
previous_pop = self.algorithm.ask(state.alg)
|
||||
|
||||
|
||||
state, fitnesses = compiled_step(state)
|
||||
|
||||
fitnesses = jax.device_get(fitnesses)
|
||||
@@ -99,7 +101,13 @@ class Pipeline:
|
||||
if max(fitnesses) >= self.fitness_target:
|
||||
print("Fitness limit reached!")
|
||||
return state, self.best_genome
|
||||
|
||||
node= previous_pop[0][0][:,0]
|
||||
node_count = jnp.sum(~jnp.isnan(node))
|
||||
conn= previous_pop[1][0][:,0]
|
||||
conn_count = jnp.sum(~jnp.isnan(conn))
|
||||
if(w%5==0):
|
||||
print("node_count",node_count)
|
||||
print("conn_count",conn_count)
|
||||
print("Generation limit reached!")
|
||||
return state, self.best_genome
|
||||
|
||||
|
||||
BIN
tensorneat/problem/__pycache__/__init__.cpython-311.pyc
Normal file
BIN
tensorneat/problem/__pycache__/__init__.cpython-311.pyc
Normal file
Binary file not shown.
BIN
tensorneat/problem/__pycache__/base.cpython-311.pyc
Normal file
BIN
tensorneat/problem/__pycache__/base.cpython-311.pyc
Normal file
Binary file not shown.
BIN
tensorneat/problem/rl_env/__pycache__/__init__.cpython-311.pyc
Normal file
BIN
tensorneat/problem/rl_env/__pycache__/__init__.cpython-311.pyc
Normal file
Binary file not shown.
BIN
tensorneat/problem/rl_env/__pycache__/brax_env.cpython-311.pyc
Normal file
BIN
tensorneat/problem/rl_env/__pycache__/brax_env.cpython-311.pyc
Normal file
Binary file not shown.
BIN
tensorneat/problem/rl_env/__pycache__/gymnax_env.cpython-311.pyc
Normal file
BIN
tensorneat/problem/rl_env/__pycache__/gymnax_env.cpython-311.pyc
Normal file
Binary file not shown.
BIN
tensorneat/problem/rl_env/__pycache__/rl_jit.cpython-311.pyc
Normal file
BIN
tensorneat/problem/rl_env/__pycache__/rl_jit.cpython-311.pyc
Normal file
Binary file not shown.
@@ -8,32 +8,33 @@ from .. import BaseProblem
|
||||
class RLEnv(BaseProblem):
|
||||
jitable = True
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, max_step=1000):
|
||||
super().__init__()
|
||||
self.max_step = max_step
|
||||
|
||||
def evaluate(self, randkey, state, act_func, params):
|
||||
rng_reset, rng_episode = jax.random.split(randkey)
|
||||
init_obs, init_env_state = self.reset(rng_reset)
|
||||
|
||||
def cond_func(carry):
|
||||
_, _, _, done, _ = carry
|
||||
return ~done
|
||||
_, _, _, done, _, count = carry
|
||||
return ~done & (count < self.max_step)
|
||||
|
||||
def body_func(carry):
|
||||
obs, env_state, rng, _, tr = carry # total reward
|
||||
action = act_func(obs, params)
|
||||
next_obs, next_env_state, reward, done, _ = self.step(rng, env_state, action)
|
||||
next_rng, _ = jax.random.split(rng)
|
||||
return next_obs, next_env_state, next_rng, done, tr + reward
|
||||
obs, env_state, rng, done, tr, count = carry # tr -> total reward
|
||||
action = act_func(obs, params)
|
||||
next_obs, next_env_state, reward, done, _ = self.step(rng, env_state, action)
|
||||
next_rng, _ = jax.random.split(rng)
|
||||
return next_obs, next_env_state, next_rng, done, tr + reward, count + 1
|
||||
|
||||
_, _, _, _, total_reward = jax.lax.while_loop(
|
||||
_, _, _, _, total_reward, _ = jax.lax.while_loop(
|
||||
cond_func,
|
||||
body_func,
|
||||
(init_obs, init_env_state, rng_episode, False, 0.0)
|
||||
(init_obs, init_env_state, rng_episode, False, 0.0, 0)
|
||||
)
|
||||
|
||||
return total_reward
|
||||
|
||||
return total_reward
|
||||
|
||||
@partial(jax.jit, static_argnums=(0,))
|
||||
def step(self, randkey, env_state, action):
|
||||
return self.env_step(randkey, env_state, action)
|
||||
|
||||
BIN
tensorneat/utils/__pycache__/__init__.cpython-311.pyc
Normal file
BIN
tensorneat/utils/__pycache__/__init__.cpython-311.pyc
Normal file
Binary file not shown.
BIN
tensorneat/utils/__pycache__/activation.cpython-311.pyc
Normal file
BIN
tensorneat/utils/__pycache__/activation.cpython-311.pyc
Normal file
Binary file not shown.
BIN
tensorneat/utils/__pycache__/aggregation.cpython-311.pyc
Normal file
BIN
tensorneat/utils/__pycache__/aggregation.cpython-311.pyc
Normal file
Binary file not shown.
BIN
tensorneat/utils/__pycache__/graph.cpython-311.pyc
Normal file
BIN
tensorneat/utils/__pycache__/graph.cpython-311.pyc
Normal file
Binary file not shown.
BIN
tensorneat/utils/__pycache__/state.cpython-311.pyc
Normal file
BIN
tensorneat/utils/__pycache__/state.cpython-311.pyc
Normal file
Binary file not shown.
BIN
tensorneat/utils/__pycache__/tools.cpython-311.pyc
Normal file
BIN
tensorneat/utils/__pycache__/tools.cpython-311.pyc
Normal file
Binary file not shown.
@@ -11,7 +11,7 @@ class Act:
|
||||
|
||||
@staticmethod
|
||||
def tanh(z):
|
||||
return jnp.tanh(z)
|
||||
return jnp.tanh(0.6 * z)
|
||||
|
||||
@staticmethod
|
||||
def sin(z):
|
||||
|
||||
Reference in New Issue
Block a user