initialize methods
This commit is contained in:
BIN
tensorneat/__pycache__/pipeline.cpython-311.pyc
Normal file
BIN
tensorneat/__pycache__/pipeline.cpython-311.pyc
Normal file
Binary file not shown.
BIN
tensorneat/algorithm/__pycache__/__init__.cpython-311.pyc
Normal file
BIN
tensorneat/algorithm/__pycache__/__init__.cpython-311.pyc
Normal file
Binary file not shown.
BIN
tensorneat/algorithm/__pycache__/base.cpython-311.pyc
Normal file
BIN
tensorneat/algorithm/__pycache__/base.cpython-311.pyc
Normal file
Binary file not shown.
@@ -1,5 +1,5 @@
|
||||
from .ga import *
|
||||
from .gene import *
|
||||
from .genome import *
|
||||
from .species import *
|
||||
from .neat import NEAT
|
||||
|
||||
|
||||
BIN
tensorneat/algorithm/neat/__pycache__/__init__.cpython-311.pyc
Normal file
BIN
tensorneat/algorithm/neat/__pycache__/__init__.cpython-311.pyc
Normal file
Binary file not shown.
BIN
tensorneat/algorithm/neat/__pycache__/neat.cpython-311.pyc
Normal file
BIN
tensorneat/algorithm/neat/__pycache__/neat.cpython-311.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
tensorneat/algorithm/neat/gene/__pycache__/base.cpython-311.pyc
Normal file
BIN
tensorneat/algorithm/neat/gene/__pycache__/base.cpython-311.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -27,6 +27,17 @@ class DefaultConnGene(BaseConnGene):
|
||||
def new_custom_attrs(self):
|
||||
return jnp.array([self.weight_init_mean])
|
||||
|
||||
def new_random_attrs(self, key):
|
||||
return jnp.array([mutate_float(key,
|
||||
self.weight_init_mean,
|
||||
self.weight_init_mean,
|
||||
1.0,
|
||||
0,
|
||||
0,
|
||||
1.0,
|
||||
)
|
||||
])
|
||||
|
||||
def mutate(self, key, conn):
|
||||
input_index = conn[0]
|
||||
output_index = conn[1]
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -61,6 +61,16 @@ class DefaultNodeGene(BaseNodeGene):
|
||||
[self.bias_init_mean, self.response_init_mean, self.activation_default, self.aggregation_default]
|
||||
)
|
||||
|
||||
def new_random_attrs(self, key):
|
||||
return jnp.array([
|
||||
mutate_float(key, self.bias_init_mean, self.bias_init_mean, self.bias_init_std,
|
||||
self.bias_mutate_power, self.bias_mutate_rate, self.bias_replace_rate),
|
||||
mutate_float(key, self.response_init_mean, self.response_init_mean, self.response_init_std,
|
||||
self.response_mutate_power, self.response_mutate_rate, self.response_replace_rate),
|
||||
self.activation_default,
|
||||
self.aggregation_default,
|
||||
])
|
||||
|
||||
def mutate(self, key, node):
|
||||
k1, k2, k3, k4 = jax.random.split(key, num=4)
|
||||
index = node[0]
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -19,7 +19,8 @@ class DefaultSpecies(BaseSpecies):
|
||||
genome_elitism: int = 2,
|
||||
survival_threshold: float = 0.2,
|
||||
min_species_size: int = 1,
|
||||
compatibility_threshold: float = 3.
|
||||
compatibility_threshold: float = 3.,
|
||||
initialize_method: str = 'one_hidden_node', # {'one_hidden_node', 'dense_hideen_layer', 'no_hidden_random'}
|
||||
):
|
||||
self.genome = genome
|
||||
self.pop_size = pop_size
|
||||
@@ -34,11 +35,13 @@ class DefaultSpecies(BaseSpecies):
|
||||
self.survival_threshold = survival_threshold
|
||||
self.min_species_size = min_species_size
|
||||
self.compatibility_threshold = compatibility_threshold
|
||||
self.initialize_method = initialize_method
|
||||
|
||||
self.species_arange = jnp.arange(self.species_size)
|
||||
|
||||
def setup(self, randkey):
|
||||
pop_nodes, pop_conns = initialize_population(self.pop_size, self.genome)
|
||||
k1, k2 = jax.random.split(randkey, 2)
|
||||
pop_nodes, pop_conns = initialize_population(self.pop_size, self.genome,k1, self.initialize_method)
|
||||
|
||||
species_keys = jnp.full((self.species_size,), jnp.nan) # the unique index (primary key) for each species
|
||||
best_fitness = jnp.full((self.species_size,), jnp.nan) # the best fitness of each species
|
||||
@@ -62,7 +65,7 @@ class DefaultSpecies(BaseSpecies):
|
||||
pop_nodes, pop_conns = jax.device_put((pop_nodes, pop_conns))
|
||||
|
||||
return State(
|
||||
randkey=randkey,
|
||||
randkey=k2,
|
||||
pop_nodes=pop_nodes,
|
||||
pop_conns=pop_conns,
|
||||
species_keys=species_keys,
|
||||
@@ -489,31 +492,159 @@ class DefaultSpecies(BaseSpecies):
|
||||
return jnp.where(max_cnt == 0, 0, val / max_cnt)
|
||||
|
||||
|
||||
def initialize_population(pop_size, genome):
|
||||
o_nodes = np.full((genome.max_nodes, genome.node_gene.length), np.nan) # original nodes
|
||||
o_conns = np.full((genome.max_conns, genome.conn_gene.length), np.nan) # original connections
|
||||
|
||||
|
||||
def initialize_population(pop_size, genome, randkey, init_method='default'):
|
||||
rand_keys = jax.random.split(randkey, pop_size)
|
||||
|
||||
if init_method == 'one_hidden_node':
|
||||
init_func = init_one_hidden_node
|
||||
elif init_method == 'dense_hideen_layer':
|
||||
init_func = init_dense_hideen_layer
|
||||
elif init_method == 'no_hidden_random':
|
||||
init_func = init_no_hidden_random
|
||||
else:
|
||||
raise ValueError("Unknown initialization method: {}".format(init_method))
|
||||
|
||||
pop_nodes, pop_conns = jax.vmap(init_func, in_axes=(None, 0))(genome, rand_keys)
|
||||
|
||||
return pop_nodes, pop_conns
|
||||
|
||||
# one hidden node
|
||||
def init_one_hidden_node(genome, randkey):
|
||||
input_idx, output_idx = genome.input_idx, genome.output_idx
|
||||
new_node_key = max([*input_idx, *output_idx]) + 1
|
||||
|
||||
o_nodes[input_idx, 0] = genome.input_idx
|
||||
o_nodes[output_idx, 0] = genome.output_idx
|
||||
o_nodes[new_node_key, 0] = new_node_key # one hidden node
|
||||
o_nodes[np.concatenate([input_idx, output_idx]), 1:] = genome.node_gene.new_custom_attrs()
|
||||
o_nodes[new_node_key, 1:] = genome.node_gene.new_custom_attrs() # one hidden node
|
||||
nodes = jnp.full((genome.max_nodes, genome.node_gene.length), jnp.nan)
|
||||
conns = jnp.full((genome.max_conns, genome.conn_gene.length), jnp.nan)
|
||||
|
||||
input_conns = np.c_[input_idx, np.full_like(input_idx, new_node_key)] # input nodes to hidden
|
||||
o_conns[input_idx, 0:2] = input_conns # in key, out key
|
||||
o_conns[input_idx, 2] = True # enabled
|
||||
o_conns[input_idx, 3:] = genome.conn_gene.new_custom_attrs()
|
||||
nodes = nodes.at[input_idx, 0].set(input_idx)
|
||||
nodes = nodes.at[output_idx, 0].set(output_idx)
|
||||
nodes = nodes.at[new_node_key, 0].set(new_node_key)
|
||||
|
||||
output_conns = np.c_[np.full_like(output_idx, new_node_key), output_idx] # hidden to output nodes
|
||||
o_conns[output_idx, 0:2] = output_conns # in key, out key
|
||||
o_conns[output_idx, 2] = True # enabled
|
||||
o_conns[output_idx, 3:] = genome.conn_gene.new_custom_attrs()
|
||||
rand_keys_nodes = jax.random.split(randkey, num=len(input_idx) + len(output_idx) + 1)
|
||||
input_keys, output_keys, hidden_key = rand_keys_nodes[:len(input_idx)], rand_keys_nodes[len(input_idx):len(input_idx) + len(output_idx)], rand_keys_nodes[-1]
|
||||
|
||||
# repeat origin genome for P times to create population
|
||||
pop_nodes = np.tile(o_nodes, (pop_size, 1, 1))
|
||||
pop_conns = np.tile(o_conns, (pop_size, 1, 1))
|
||||
node_attr_func = jax.vmap(genome.node_gene.new_random_attrs, in_axes=(None,0))
|
||||
input_attrs = node_attr_func(input_keys)
|
||||
output_attrs = node_attr_func(output_keys)
|
||||
hidden_attrs = genome.node_gene.new_custom_attrs(hidden_key)
|
||||
|
||||
return pop_nodes, pop_conns
|
||||
nodes = nodes.at[input_idx, 1:].set(input_attrs)
|
||||
nodes = nodes.at[output_idx, 1:].set(output_attrs)
|
||||
nodes = nodes.at[new_node_key, 1:].set(hidden_attrs)
|
||||
|
||||
input_conns = jnp.c_[input_idx, jnp.full_like(input_idx, new_node_key)]
|
||||
conns = conns.at[input_idx, 0:2].set(input_conns)
|
||||
conns = conns.at[input_idx, 2].set(True)
|
||||
|
||||
output_conns = jnp.c_[jnp.full_like(output_idx, new_node_key), output_idx]
|
||||
conns = conns.at[output_idx, 0:2].set(output_conns)
|
||||
conns = conns.at[output_idx, 2].set(True)
|
||||
|
||||
rand_keys_conns = jax.random.split(randkey, num=len(input_idx) + len(output_idx))
|
||||
input_conn_keys, output_conn_keys = rand_keys_conns[:len(input_idx)], rand_keys_conns[len(input_idx):]
|
||||
|
||||
conn_attr_func = jax.vmap(genome.conn_gene.new_random_attrs, in_axes=(None,0))
|
||||
input_conn_attrs = conn_attr_func(input_conn_keys)
|
||||
output_conn_attrs = conn_attr_func(output_conn_keys)
|
||||
|
||||
conns = conns.at[input_idx, 3:].set(input_conn_attrs)
|
||||
conns = conns.at[output_idx, 3:].set(output_conn_attrs)
|
||||
|
||||
return nodes, conns
|
||||
|
||||
|
||||
#random dense connections with 1 hidden layer
|
||||
def init_dense_hideen_layer( genome, randkey,hiddens=20):
|
||||
k1, k2, k3 = jax.random.split(randkey, num=3)
|
||||
input_idx, output_idx = genome.input_idx, genome.output_idx
|
||||
input_size = len(input_idx)
|
||||
output_size = len(output_idx)
|
||||
|
||||
hidden_idx = jnp.arange(input_size + output_size, input_size + output_size + hiddens)
|
||||
nodes = jnp.full((genome.max_nodes, genome.node_gene.length), jnp.nan, dtype=jnp.float32)
|
||||
nodes = nodes.at[input_idx, 0].set(input_idx)
|
||||
nodes = nodes.at[output_idx, 0].set(output_idx)
|
||||
nodes = nodes.at[hidden_idx, 0].set(hidden_idx)
|
||||
|
||||
total_idx = input_size + output_size + hiddens
|
||||
rand_keys_n = jax.random.split(k1, num=total_idx)
|
||||
input_keys = rand_keys_n[:input_size]
|
||||
output_keys = rand_keys_n[input_size:input_size + output_size]
|
||||
hidden_keys = rand_keys_n[input_size + output_size:]
|
||||
|
||||
node_attr_func = jax.vmap(genome.conn_gene.new_random_attrs, in_axes=(0))
|
||||
input_attrs = node_attr_func(input_keys)
|
||||
output_attrs = node_attr_func(output_keys)
|
||||
hidden_attrs = node_attr_func(hidden_keys)
|
||||
|
||||
nodes = nodes.at[input_idx, 1:].set(input_attrs)
|
||||
nodes = nodes.at[output_idx, 1:].set(output_attrs)
|
||||
nodes = nodes.at[hidden_idx, 1:].set(hidden_attrs)
|
||||
|
||||
total_connections = input_size * hiddens + hiddens * output_size
|
||||
conns = jnp.full((genome.max_conns, genome.conn_gene.length), jnp.nan, dtype=jnp.float32)
|
||||
|
||||
rand_keys_c = jax.random.split(k2, num=total_connections)
|
||||
conns_attr_func = jax.vmap(genome.node_gene.new_random_attrs, in_axes=(0))
|
||||
conns_attrs = conns_attr_func(rand_keys_c)
|
||||
|
||||
input_to_hidden_ids, hidden_ids = jnp.meshgrid(input_idx, hidden_idx, indexing='ij')
|
||||
hidden_to_output_ids, output_ids = jnp.meshgrid(hidden_idx, output_idx, indexing='ij')
|
||||
|
||||
conns = conns.at[:input_size * hiddens, 0].set(input_to_hidden_ids.flatten())
|
||||
conns = conns.at[:input_size * hiddens, 1].set(hidden_ids.flatten())
|
||||
conns = conns.at[input_size * hiddens: total_connections, 0].set(hidden_to_output_ids.flatten())
|
||||
conns = conns.at[input_size * hiddens: total_connections, 1].set(output_ids.flatten())
|
||||
conns = conns.at[:input_size * hiddens + hiddens * output_size, 2].set(True)
|
||||
conns = conns.at[:input_size * hiddens + hiddens * output_size, 3:].set(conns_attrs)
|
||||
|
||||
return nodes, conns
|
||||
|
||||
# random sparse connections with no hidden nodes
|
||||
def init_no_hidden_random(genome, randkey):
|
||||
k1, k2, k3 = jax.random.split(randkey, num=3)
|
||||
input_idx, output_idx = genome.input_idx, genome.output_idx
|
||||
|
||||
nodes = jnp.full((genome.max_nodes, genome.node_gene.length), jnp.nan)
|
||||
nodes = nodes.at[input_idx, 0].set(input_idx)
|
||||
nodes = nodes.at[output_idx, 0].set(output_idx)
|
||||
|
||||
total_idx = len(input_idx) + len(output_idx)
|
||||
rand_keys_n = jax.random.split(k1, num=total_idx)
|
||||
input_keys = rand_keys_n[:len(input_idx)]
|
||||
output_keys = rand_keys_n[len(input_idx):]
|
||||
|
||||
node_attr_func = jax.vmap(genome.node_gene.new_random_attrs, in_axes=(0))
|
||||
input_attrs = node_attr_func(input_keys)
|
||||
output_attrs = node_attr_func(output_keys)
|
||||
nodes = nodes.at[input_idx, 1:].set(input_attrs)
|
||||
nodes = nodes.at[output_idx, 1:].set(output_attrs)
|
||||
|
||||
conns = jnp.full((genome.max_conns, genome.conn_gene.length), jnp.nan)
|
||||
|
||||
num_connections_per_output = 4
|
||||
total_connections = len(output_idx) * num_connections_per_output
|
||||
|
||||
def create_connections_for_output(key):
|
||||
permuted_inputs = jax.random.permutation(key, input_idx)
|
||||
selected_inputs = permuted_inputs[:num_connections_per_output]
|
||||
return selected_inputs
|
||||
|
||||
conn_keys = jax.random.split(k2, num=len(output_idx))
|
||||
connections = jax.vmap(create_connections_for_output)(conn_keys)
|
||||
connections = connections.flatten()
|
||||
|
||||
output_repeats = jnp.repeat(output_idx, num_connections_per_output)
|
||||
|
||||
rand_keys_c = jax.random.split(k3, num=total_connections)
|
||||
conns_attr_func = jax.vmap(genome.conn_gene.new_random_attrs, in_axes=(0))
|
||||
conns_attrs = conns_attr_func(rand_keys_c)
|
||||
|
||||
conns = conns.at[:total_connections, 0].set(connections)
|
||||
conns = conns.at[:total_connections, 1].set(output_repeats)
|
||||
conns = conns.at[:total_connections, 2].set(True) # enabled
|
||||
conns = conns.at[:total_connections, 3:].set(conns_attrs)
|
||||
|
||||
return nodes, conns
|
||||
|
||||
@@ -82,12 +82,13 @@ class Pipeline:
|
||||
state = ini_state
|
||||
compiled_step = jax.jit(self.step).lower(ini_state).compile()
|
||||
|
||||
for _ in range(self.generation_limit):
|
||||
for w in range(self.generation_limit):
|
||||
|
||||
self.generation_timestamp = time.time()
|
||||
|
||||
previous_pop = self.algorithm.ask(state.alg)
|
||||
|
||||
|
||||
state, fitnesses = compiled_step(state)
|
||||
|
||||
fitnesses = jax.device_get(fitnesses)
|
||||
@@ -102,7 +103,13 @@ class Pipeline:
|
||||
if max(fitnesses) >= self.fitness_target:
|
||||
print("Fitness limit reached!")
|
||||
return state, self.best_genome
|
||||
|
||||
node= previous_pop[0][0][:,0]
|
||||
node_count = jnp.sum(~jnp.isnan(node))
|
||||
conn= previous_pop[1][0][:,0]
|
||||
conn_count = jnp.sum(~jnp.isnan(conn))
|
||||
if(w%5==0):
|
||||
print("node_count",node_count)
|
||||
print("conn_count",conn_count)
|
||||
print("Generation limit reached!")
|
||||
return state, self.best_genome
|
||||
|
||||
|
||||
BIN
tensorneat/problem/__pycache__/__init__.cpython-311.pyc
Normal file
BIN
tensorneat/problem/__pycache__/__init__.cpython-311.pyc
Normal file
Binary file not shown.
BIN
tensorneat/problem/__pycache__/base.cpython-311.pyc
Normal file
BIN
tensorneat/problem/__pycache__/base.cpython-311.pyc
Normal file
Binary file not shown.
BIN
tensorneat/problem/rl_env/__pycache__/__init__.cpython-311.pyc
Normal file
BIN
tensorneat/problem/rl_env/__pycache__/__init__.cpython-311.pyc
Normal file
Binary file not shown.
BIN
tensorneat/problem/rl_env/__pycache__/brax_env.cpython-311.pyc
Normal file
BIN
tensorneat/problem/rl_env/__pycache__/brax_env.cpython-311.pyc
Normal file
Binary file not shown.
BIN
tensorneat/problem/rl_env/__pycache__/gymnax_env.cpython-311.pyc
Normal file
BIN
tensorneat/problem/rl_env/__pycache__/gymnax_env.cpython-311.pyc
Normal file
Binary file not shown.
BIN
tensorneat/problem/rl_env/__pycache__/rl_jit.cpython-311.pyc
Normal file
BIN
tensorneat/problem/rl_env/__pycache__/rl_jit.cpython-311.pyc
Normal file
Binary file not shown.
@@ -9,32 +9,55 @@ class RLEnv(BaseProblem):
|
||||
jitable = True
|
||||
|
||||
# TODO: move output transform to algorithm
|
||||
def __init__(self):
|
||||
def __init__(self, max_step=1000):
|
||||
super().__init__()
|
||||
self.max_step = max_step
|
||||
|
||||
# def evaluate(self, randkey, state, act_func, params):
|
||||
# rng_reset, rng_episode = jax.random.split(randkey)
|
||||
# init_obs, init_env_state = self.reset(rng_reset)
|
||||
|
||||
# def cond_func(carry):
|
||||
# _, _, _, done, _ = carry
|
||||
# return ~done
|
||||
|
||||
# def body_func(carry):
|
||||
# obs, env_state, rng, _, tr = carry # total reward
|
||||
# action = act_func(obs, params)
|
||||
# next_obs, next_env_state, reward, done, _ = self.step(rng, env_state, action)
|
||||
# next_rng, _ = jax.random.split(rng)
|
||||
# return next_obs, next_env_state, next_rng, done, tr + reward
|
||||
|
||||
# _, _, _, _, total_reward = jax.lax.while_loop(
|
||||
# cond_func,
|
||||
# body_func,
|
||||
# (init_obs, init_env_state, rng_episode, False, 0.0)
|
||||
# )
|
||||
|
||||
# return total_reward
|
||||
|
||||
def evaluate(self, randkey, state, act_func, params):
|
||||
rng_reset, rng_episode = jax.random.split(randkey)
|
||||
init_obs, init_env_state = self.reset(rng_reset)
|
||||
|
||||
def cond_func(carry):
|
||||
_, _, _, done, _ = carry
|
||||
return ~done
|
||||
_, _, _, done, _, count = carry
|
||||
return ~done & (count < self.max_step)
|
||||
|
||||
def body_func(carry):
|
||||
obs, env_state, rng, _, tr = carry # total reward
|
||||
obs, env_state, rng, done, tr, count = carry # tr -> total reward
|
||||
action = act_func(obs, params)
|
||||
next_obs, next_env_state, reward, done, _ = self.step(rng, env_state, action)
|
||||
next_rng, _ = jax.random.split(rng)
|
||||
return next_obs, next_env_state, next_rng, done, tr + reward
|
||||
return next_obs, next_env_state, next_rng, done, tr + reward, count + 1
|
||||
|
||||
_, _, _, _, total_reward = jax.lax.while_loop(
|
||||
_, _, _, _, total_reward, _ = jax.lax.while_loop(
|
||||
cond_func,
|
||||
body_func,
|
||||
(init_obs, init_env_state, rng_episode, False, 0.0)
|
||||
(init_obs, init_env_state, rng_episode, False, 0.0, 0)
|
||||
)
|
||||
|
||||
return total_reward
|
||||
|
||||
@partial(jax.jit, static_argnums=(0,))
|
||||
def step(self, randkey, env_state, action):
|
||||
return self.env_step(randkey, env_state, action)
|
||||
|
||||
BIN
tensorneat/utils/__pycache__/__init__.cpython-311.pyc
Normal file
BIN
tensorneat/utils/__pycache__/__init__.cpython-311.pyc
Normal file
Binary file not shown.
BIN
tensorneat/utils/__pycache__/activation.cpython-311.pyc
Normal file
BIN
tensorneat/utils/__pycache__/activation.cpython-311.pyc
Normal file
Binary file not shown.
BIN
tensorneat/utils/__pycache__/aggregation.cpython-311.pyc
Normal file
BIN
tensorneat/utils/__pycache__/aggregation.cpython-311.pyc
Normal file
Binary file not shown.
BIN
tensorneat/utils/__pycache__/graph.cpython-311.pyc
Normal file
BIN
tensorneat/utils/__pycache__/graph.cpython-311.pyc
Normal file
Binary file not shown.
BIN
tensorneat/utils/__pycache__/state.cpython-311.pyc
Normal file
BIN
tensorneat/utils/__pycache__/state.cpython-311.pyc
Normal file
Binary file not shown.
BIN
tensorneat/utils/__pycache__/tools.cpython-311.pyc
Normal file
BIN
tensorneat/utils/__pycache__/tools.cpython-311.pyc
Normal file
Binary file not shown.
@@ -11,7 +11,7 @@ class Act:
|
||||
|
||||
@staticmethod
|
||||
def tanh(z):
|
||||
return jnp.tanh(z)
|
||||
return jnp.tanh(0.6 * z)
|
||||
|
||||
@staticmethod
|
||||
def sin(z):
|
||||
|
||||
Reference in New Issue
Block a user