initialize methods

This commit is contained in:
Priokin
2024-05-21 14:34:01 +08:00
parent 0e89ed1d7c
commit 40b7d8360c
46 changed files with 222 additions and 40 deletions

Binary file not shown.

Binary file not shown.

View File

@@ -1,5 +1,5 @@
from .ga import *
from .gene import * from .gene import *
from .genome import * from .genome import *
from .species import * from .species import *
from .neat import NEAT from .neat import NEAT

View File

@@ -27,6 +27,17 @@ class DefaultConnGene(BaseConnGene):
def new_custom_attrs(self): def new_custom_attrs(self):
return jnp.array([self.weight_init_mean]) return jnp.array([self.weight_init_mean])
def new_random_attrs(self, key):
return jnp.array([mutate_float(key,
self.weight_init_mean,
self.weight_init_mean,
1.0,
0,
0,
1.0,
)
])
def mutate(self, key, conn): def mutate(self, key, conn):
input_index = conn[0] input_index = conn[0]
output_index = conn[1] output_index = conn[1]

View File

@@ -61,6 +61,16 @@ class DefaultNodeGene(BaseNodeGene):
[self.bias_init_mean, self.response_init_mean, self.activation_default, self.aggregation_default] [self.bias_init_mean, self.response_init_mean, self.activation_default, self.aggregation_default]
) )
def new_random_attrs(self, key):
return jnp.array([
mutate_float(key, self.bias_init_mean, self.bias_init_mean, self.bias_init_std,
self.bias_mutate_power, self.bias_mutate_rate, self.bias_replace_rate),
mutate_float(key, self.response_init_mean, self.response_init_mean, self.response_init_std,
self.response_mutate_power, self.response_mutate_rate, self.response_replace_rate),
self.activation_default,
self.aggregation_default,
])
def mutate(self, key, node): def mutate(self, key, node):
k1, k2, k3, k4 = jax.random.split(key, num=4) k1, k2, k3, k4 = jax.random.split(key, num=4)
index = node[0] index = node[0]

View File

@@ -19,7 +19,8 @@ class DefaultSpecies(BaseSpecies):
genome_elitism: int = 2, genome_elitism: int = 2,
survival_threshold: float = 0.2, survival_threshold: float = 0.2,
min_species_size: int = 1, min_species_size: int = 1,
compatibility_threshold: float = 3. compatibility_threshold: float = 3.,
initialize_method: str = 'one_hidden_node', # {'one_hidden_node', 'dense_hideen_layer', 'no_hidden_random'}
): ):
self.genome = genome self.genome = genome
self.pop_size = pop_size self.pop_size = pop_size
@@ -34,11 +35,13 @@ class DefaultSpecies(BaseSpecies):
self.survival_threshold = survival_threshold self.survival_threshold = survival_threshold
self.min_species_size = min_species_size self.min_species_size = min_species_size
self.compatibility_threshold = compatibility_threshold self.compatibility_threshold = compatibility_threshold
self.initialize_method = initialize_method
self.species_arange = jnp.arange(self.species_size) self.species_arange = jnp.arange(self.species_size)
def setup(self, randkey): def setup(self, randkey):
pop_nodes, pop_conns = initialize_population(self.pop_size, self.genome) k1, k2 = jax.random.split(randkey, 2)
pop_nodes, pop_conns = initialize_population(self.pop_size, self.genome,k1, self.initialize_method)
species_keys = jnp.full((self.species_size,), jnp.nan) # the unique index (primary key) for each species species_keys = jnp.full((self.species_size,), jnp.nan) # the unique index (primary key) for each species
best_fitness = jnp.full((self.species_size,), jnp.nan) # the best fitness of each species best_fitness = jnp.full((self.species_size,), jnp.nan) # the best fitness of each species
@@ -62,7 +65,7 @@ class DefaultSpecies(BaseSpecies):
pop_nodes, pop_conns = jax.device_put((pop_nodes, pop_conns)) pop_nodes, pop_conns = jax.device_put((pop_nodes, pop_conns))
return State( return State(
randkey=randkey, randkey=k2,
pop_nodes=pop_nodes, pop_nodes=pop_nodes,
pop_conns=pop_conns, pop_conns=pop_conns,
species_keys=species_keys, species_keys=species_keys,
@@ -489,31 +492,159 @@ class DefaultSpecies(BaseSpecies):
return jnp.where(max_cnt == 0, 0, val / max_cnt) return jnp.where(max_cnt == 0, 0, val / max_cnt)
def initialize_population(pop_size, genome):
o_nodes = np.full((genome.max_nodes, genome.node_gene.length), np.nan) # original nodes
o_conns = np.full((genome.max_conns, genome.conn_gene.length), np.nan) # original connections
def initialize_population(pop_size, genome, randkey, init_method='default'):
rand_keys = jax.random.split(randkey, pop_size)
if init_method == 'one_hidden_node':
init_func = init_one_hidden_node
elif init_method == 'dense_hideen_layer':
init_func = init_dense_hideen_layer
elif init_method == 'no_hidden_random':
init_func = init_no_hidden_random
else:
raise ValueError("Unknown initialization method: {}".format(init_method))
pop_nodes, pop_conns = jax.vmap(init_func, in_axes=(None, 0))(genome, rand_keys)
return pop_nodes, pop_conns
# one hidden node
def init_one_hidden_node(genome, randkey):
input_idx, output_idx = genome.input_idx, genome.output_idx input_idx, output_idx = genome.input_idx, genome.output_idx
new_node_key = max([*input_idx, *output_idx]) + 1 new_node_key = max([*input_idx, *output_idx]) + 1
o_nodes[input_idx, 0] = genome.input_idx nodes = jnp.full((genome.max_nodes, genome.node_gene.length), jnp.nan)
o_nodes[output_idx, 0] = genome.output_idx conns = jnp.full((genome.max_conns, genome.conn_gene.length), jnp.nan)
o_nodes[new_node_key, 0] = new_node_key # one hidden node
o_nodes[np.concatenate([input_idx, output_idx]), 1:] = genome.node_gene.new_custom_attrs()
o_nodes[new_node_key, 1:] = genome.node_gene.new_custom_attrs() # one hidden node
input_conns = np.c_[input_idx, np.full_like(input_idx, new_node_key)] # input nodes to hidden nodes = nodes.at[input_idx, 0].set(input_idx)
o_conns[input_idx, 0:2] = input_conns # in key, out key nodes = nodes.at[output_idx, 0].set(output_idx)
o_conns[input_idx, 2] = True # enabled nodes = nodes.at[new_node_key, 0].set(new_node_key)
o_conns[input_idx, 3:] = genome.conn_gene.new_custom_attrs()
output_conns = np.c_[np.full_like(output_idx, new_node_key), output_idx] # hidden to output nodes rand_keys_nodes = jax.random.split(randkey, num=len(input_idx) + len(output_idx) + 1)
o_conns[output_idx, 0:2] = output_conns # in key, out key input_keys, output_keys, hidden_key = rand_keys_nodes[:len(input_idx)], rand_keys_nodes[len(input_idx):len(input_idx) + len(output_idx)], rand_keys_nodes[-1]
o_conns[output_idx, 2] = True # enabled
o_conns[output_idx, 3:] = genome.conn_gene.new_custom_attrs()
# repeat origin genome for P times to create population node_attr_func = jax.vmap(genome.node_gene.new_random_attrs, in_axes=(None,0))
pop_nodes = np.tile(o_nodes, (pop_size, 1, 1)) input_attrs = node_attr_func(input_keys)
pop_conns = np.tile(o_conns, (pop_size, 1, 1)) output_attrs = node_attr_func(output_keys)
hidden_attrs = genome.node_gene.new_custom_attrs(hidden_key)
return pop_nodes, pop_conns nodes = nodes.at[input_idx, 1:].set(input_attrs)
nodes = nodes.at[output_idx, 1:].set(output_attrs)
nodes = nodes.at[new_node_key, 1:].set(hidden_attrs)
input_conns = jnp.c_[input_idx, jnp.full_like(input_idx, new_node_key)]
conns = conns.at[input_idx, 0:2].set(input_conns)
conns = conns.at[input_idx, 2].set(True)
output_conns = jnp.c_[jnp.full_like(output_idx, new_node_key), output_idx]
conns = conns.at[output_idx, 0:2].set(output_conns)
conns = conns.at[output_idx, 2].set(True)
rand_keys_conns = jax.random.split(randkey, num=len(input_idx) + len(output_idx))
input_conn_keys, output_conn_keys = rand_keys_conns[:len(input_idx)], rand_keys_conns[len(input_idx):]
conn_attr_func = jax.vmap(genome.conn_gene.new_random_attrs, in_axes=(None,0))
input_conn_attrs = conn_attr_func(input_conn_keys)
output_conn_attrs = conn_attr_func(output_conn_keys)
conns = conns.at[input_idx, 3:].set(input_conn_attrs)
conns = conns.at[output_idx, 3:].set(output_conn_attrs)
return nodes, conns
#random dense connections with 1 hidden layer
def init_dense_hideen_layer( genome, randkey,hiddens=20):
k1, k2, k3 = jax.random.split(randkey, num=3)
input_idx, output_idx = genome.input_idx, genome.output_idx
input_size = len(input_idx)
output_size = len(output_idx)
hidden_idx = jnp.arange(input_size + output_size, input_size + output_size + hiddens)
nodes = jnp.full((genome.max_nodes, genome.node_gene.length), jnp.nan, dtype=jnp.float32)
nodes = nodes.at[input_idx, 0].set(input_idx)
nodes = nodes.at[output_idx, 0].set(output_idx)
nodes = nodes.at[hidden_idx, 0].set(hidden_idx)
total_idx = input_size + output_size + hiddens
rand_keys_n = jax.random.split(k1, num=total_idx)
input_keys = rand_keys_n[:input_size]
output_keys = rand_keys_n[input_size:input_size + output_size]
hidden_keys = rand_keys_n[input_size + output_size:]
node_attr_func = jax.vmap(genome.conn_gene.new_random_attrs, in_axes=(0))
input_attrs = node_attr_func(input_keys)
output_attrs = node_attr_func(output_keys)
hidden_attrs = node_attr_func(hidden_keys)
nodes = nodes.at[input_idx, 1:].set(input_attrs)
nodes = nodes.at[output_idx, 1:].set(output_attrs)
nodes = nodes.at[hidden_idx, 1:].set(hidden_attrs)
total_connections = input_size * hiddens + hiddens * output_size
conns = jnp.full((genome.max_conns, genome.conn_gene.length), jnp.nan, dtype=jnp.float32)
rand_keys_c = jax.random.split(k2, num=total_connections)
conns_attr_func = jax.vmap(genome.node_gene.new_random_attrs, in_axes=(0))
conns_attrs = conns_attr_func(rand_keys_c)
input_to_hidden_ids, hidden_ids = jnp.meshgrid(input_idx, hidden_idx, indexing='ij')
hidden_to_output_ids, output_ids = jnp.meshgrid(hidden_idx, output_idx, indexing='ij')
conns = conns.at[:input_size * hiddens, 0].set(input_to_hidden_ids.flatten())
conns = conns.at[:input_size * hiddens, 1].set(hidden_ids.flatten())
conns = conns.at[input_size * hiddens: total_connections, 0].set(hidden_to_output_ids.flatten())
conns = conns.at[input_size * hiddens: total_connections, 1].set(output_ids.flatten())
conns = conns.at[:input_size * hiddens + hiddens * output_size, 2].set(True)
conns = conns.at[:input_size * hiddens + hiddens * output_size, 3:].set(conns_attrs)
return nodes, conns
# random sparse connections with no hidden nodes
def init_no_hidden_random(genome, randkey):
k1, k2, k3 = jax.random.split(randkey, num=3)
input_idx, output_idx = genome.input_idx, genome.output_idx
nodes = jnp.full((genome.max_nodes, genome.node_gene.length), jnp.nan)
nodes = nodes.at[input_idx, 0].set(input_idx)
nodes = nodes.at[output_idx, 0].set(output_idx)
total_idx = len(input_idx) + len(output_idx)
rand_keys_n = jax.random.split(k1, num=total_idx)
input_keys = rand_keys_n[:len(input_idx)]
output_keys = rand_keys_n[len(input_idx):]
node_attr_func = jax.vmap(genome.node_gene.new_random_attrs, in_axes=(0))
input_attrs = node_attr_func(input_keys)
output_attrs = node_attr_func(output_keys)
nodes = nodes.at[input_idx, 1:].set(input_attrs)
nodes = nodes.at[output_idx, 1:].set(output_attrs)
conns = jnp.full((genome.max_conns, genome.conn_gene.length), jnp.nan)
num_connections_per_output = 4
total_connections = len(output_idx) * num_connections_per_output
def create_connections_for_output(key):
permuted_inputs = jax.random.permutation(key, input_idx)
selected_inputs = permuted_inputs[:num_connections_per_output]
return selected_inputs
conn_keys = jax.random.split(k2, num=len(output_idx))
connections = jax.vmap(create_connections_for_output)(conn_keys)
connections = connections.flatten()
output_repeats = jnp.repeat(output_idx, num_connections_per_output)
rand_keys_c = jax.random.split(k3, num=total_connections)
conns_attr_func = jax.vmap(genome.conn_gene.new_random_attrs, in_axes=(0))
conns_attrs = conns_attr_func(rand_keys_c)
conns = conns.at[:total_connections, 0].set(connections)
conns = conns.at[:total_connections, 1].set(output_repeats)
conns = conns.at[:total_connections, 2].set(True) # enabled
conns = conns.at[:total_connections, 3:].set(conns_attrs)
return nodes, conns

View File

@@ -82,12 +82,13 @@ class Pipeline:
state = ini_state state = ini_state
compiled_step = jax.jit(self.step).lower(ini_state).compile() compiled_step = jax.jit(self.step).lower(ini_state).compile()
for _ in range(self.generation_limit): for w in range(self.generation_limit):
self.generation_timestamp = time.time() self.generation_timestamp = time.time()
previous_pop = self.algorithm.ask(state.alg) previous_pop = self.algorithm.ask(state.alg)
state, fitnesses = compiled_step(state) state, fitnesses = compiled_step(state)
fitnesses = jax.device_get(fitnesses) fitnesses = jax.device_get(fitnesses)
@@ -102,7 +103,13 @@ class Pipeline:
if max(fitnesses) >= self.fitness_target: if max(fitnesses) >= self.fitness_target:
print("Fitness limit reached!") print("Fitness limit reached!")
return state, self.best_genome return state, self.best_genome
node= previous_pop[0][0][:,0]
node_count = jnp.sum(~jnp.isnan(node))
conn= previous_pop[1][0][:,0]
conn_count = jnp.sum(~jnp.isnan(conn))
if(w%5==0):
print("node_count",node_count)
print("conn_count",conn_count)
print("Generation limit reached!") print("Generation limit reached!")
return state, self.best_genome return state, self.best_genome

Binary file not shown.

View File

@@ -9,32 +9,55 @@ class RLEnv(BaseProblem):
jitable = True jitable = True
# TODO: move output transform to algorithm # TODO: move output transform to algorithm
def __init__(self): def __init__(self, max_step=1000):
super().__init__() super().__init__()
self.max_step = max_step
# def evaluate(self, randkey, state, act_func, params):
# rng_reset, rng_episode = jax.random.split(randkey)
# init_obs, init_env_state = self.reset(rng_reset)
# def cond_func(carry):
# _, _, _, done, _ = carry
# return ~done
# def body_func(carry):
# obs, env_state, rng, _, tr = carry # total reward
# action = act_func(obs, params)
# next_obs, next_env_state, reward, done, _ = self.step(rng, env_state, action)
# next_rng, _ = jax.random.split(rng)
# return next_obs, next_env_state, next_rng, done, tr + reward
# _, _, _, _, total_reward = jax.lax.while_loop(
# cond_func,
# body_func,
# (init_obs, init_env_state, rng_episode, False, 0.0)
# )
# return total_reward
def evaluate(self, randkey, state, act_func, params): def evaluate(self, randkey, state, act_func, params):
rng_reset, rng_episode = jax.random.split(randkey) rng_reset, rng_episode = jax.random.split(randkey)
init_obs, init_env_state = self.reset(rng_reset) init_obs, init_env_state = self.reset(rng_reset)
def cond_func(carry): def cond_func(carry):
_, _, _, done, _ = carry _, _, _, done, _, count = carry
return ~done return ~done & (count < self.max_step)
def body_func(carry): def body_func(carry):
obs, env_state, rng, _, tr = carry # total reward obs, env_state, rng, done, tr, count = carry # tr -> total reward
action = act_func(obs, params) action = act_func(obs, params)
next_obs, next_env_state, reward, done, _ = self.step(rng, env_state, action) next_obs, next_env_state, reward, done, _ = self.step(rng, env_state, action)
next_rng, _ = jax.random.split(rng) next_rng, _ = jax.random.split(rng)
return next_obs, next_env_state, next_rng, done, tr + reward return next_obs, next_env_state, next_rng, done, tr + reward, count + 1
_, _, _, _, total_reward = jax.lax.while_loop( _, _, _, _, total_reward, _ = jax.lax.while_loop(
cond_func, cond_func,
body_func, body_func,
(init_obs, init_env_state, rng_episode, False, 0.0) (init_obs, init_env_state, rng_episode, False, 0.0, 0)
) )
return total_reward return total_reward
@partial(jax.jit, static_argnums=(0,)) @partial(jax.jit, static_argnums=(0,))
def step(self, randkey, env_state, action): def step(self, randkey, env_state, action):
return self.env_step(randkey, env_state, action) return self.env_step(randkey, env_state, action)

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -11,7 +11,7 @@ class Act:
@staticmethod @staticmethod
def tanh(z): def tanh(z):
return jnp.tanh(z) return jnp.tanh(0.6 * z)
@staticmethod @staticmethod
def sin(z): def sin(z):