initialize methods
This commit is contained in:
@@ -19,7 +19,8 @@ class DefaultSpecies(BaseSpecies):
|
||||
genome_elitism: int = 2,
|
||||
survival_threshold: float = 0.2,
|
||||
min_species_size: int = 1,
|
||||
compatibility_threshold: float = 3.
|
||||
compatibility_threshold: float = 3.,
|
||||
initialize_method: str = 'one_hidden_node', # {'one_hidden_node', 'dense_hideen_layer', 'no_hidden_random'}
|
||||
):
|
||||
self.genome = genome
|
||||
self.pop_size = pop_size
|
||||
@@ -34,11 +35,13 @@ class DefaultSpecies(BaseSpecies):
|
||||
self.survival_threshold = survival_threshold
|
||||
self.min_species_size = min_species_size
|
||||
self.compatibility_threshold = compatibility_threshold
|
||||
self.initialize_method = initialize_method
|
||||
|
||||
self.species_arange = jnp.arange(self.species_size)
|
||||
|
||||
def setup(self, randkey):
|
||||
pop_nodes, pop_conns = initialize_population(self.pop_size, self.genome)
|
||||
k1, k2 = jax.random.split(randkey, 2)
|
||||
pop_nodes, pop_conns = initialize_population(self.pop_size, self.genome,k1, self.initialize_method)
|
||||
|
||||
species_keys = jnp.full((self.species_size,), jnp.nan) # the unique index (primary key) for each species
|
||||
best_fitness = jnp.full((self.species_size,), jnp.nan) # the best fitness of each species
|
||||
@@ -62,7 +65,7 @@ class DefaultSpecies(BaseSpecies):
|
||||
pop_nodes, pop_conns = jax.device_put((pop_nodes, pop_conns))
|
||||
|
||||
return State(
|
||||
randkey=randkey,
|
||||
randkey=k2,
|
||||
pop_nodes=pop_nodes,
|
||||
pop_conns=pop_conns,
|
||||
species_keys=species_keys,
|
||||
@@ -489,31 +492,159 @@ class DefaultSpecies(BaseSpecies):
|
||||
return jnp.where(max_cnt == 0, 0, val / max_cnt)
|
||||
|
||||
|
||||
def initialize_population(pop_size, genome):
|
||||
o_nodes = np.full((genome.max_nodes, genome.node_gene.length), np.nan) # original nodes
|
||||
o_conns = np.full((genome.max_conns, genome.conn_gene.length), np.nan) # original connections
|
||||
|
||||
|
||||
def initialize_population(pop_size, genome, randkey, init_method='default'):
|
||||
rand_keys = jax.random.split(randkey, pop_size)
|
||||
|
||||
if init_method == 'one_hidden_node':
|
||||
init_func = init_one_hidden_node
|
||||
elif init_method == 'dense_hideen_layer':
|
||||
init_func = init_dense_hideen_layer
|
||||
elif init_method == 'no_hidden_random':
|
||||
init_func = init_no_hidden_random
|
||||
else:
|
||||
raise ValueError("Unknown initialization method: {}".format(init_method))
|
||||
|
||||
pop_nodes, pop_conns = jax.vmap(init_func, in_axes=(None, 0))(genome, rand_keys)
|
||||
|
||||
return pop_nodes, pop_conns
|
||||
|
||||
# one hidden node
|
||||
def init_one_hidden_node(genome, randkey):
|
||||
input_idx, output_idx = genome.input_idx, genome.output_idx
|
||||
new_node_key = max([*input_idx, *output_idx]) + 1
|
||||
|
||||
o_nodes[input_idx, 0] = genome.input_idx
|
||||
o_nodes[output_idx, 0] = genome.output_idx
|
||||
o_nodes[new_node_key, 0] = new_node_key # one hidden node
|
||||
o_nodes[np.concatenate([input_idx, output_idx]), 1:] = genome.node_gene.new_custom_attrs()
|
||||
o_nodes[new_node_key, 1:] = genome.node_gene.new_custom_attrs() # one hidden node
|
||||
nodes = jnp.full((genome.max_nodes, genome.node_gene.length), jnp.nan)
|
||||
conns = jnp.full((genome.max_conns, genome.conn_gene.length), jnp.nan)
|
||||
|
||||
input_conns = np.c_[input_idx, np.full_like(input_idx, new_node_key)] # input nodes to hidden
|
||||
o_conns[input_idx, 0:2] = input_conns # in key, out key
|
||||
o_conns[input_idx, 2] = True # enabled
|
||||
o_conns[input_idx, 3:] = genome.conn_gene.new_custom_attrs()
|
||||
nodes = nodes.at[input_idx, 0].set(input_idx)
|
||||
nodes = nodes.at[output_idx, 0].set(output_idx)
|
||||
nodes = nodes.at[new_node_key, 0].set(new_node_key)
|
||||
|
||||
output_conns = np.c_[np.full_like(output_idx, new_node_key), output_idx] # hidden to output nodes
|
||||
o_conns[output_idx, 0:2] = output_conns # in key, out key
|
||||
o_conns[output_idx, 2] = True # enabled
|
||||
o_conns[output_idx, 3:] = genome.conn_gene.new_custom_attrs()
|
||||
rand_keys_nodes = jax.random.split(randkey, num=len(input_idx) + len(output_idx) + 1)
|
||||
input_keys, output_keys, hidden_key = rand_keys_nodes[:len(input_idx)], rand_keys_nodes[len(input_idx):len(input_idx) + len(output_idx)], rand_keys_nodes[-1]
|
||||
|
||||
# repeat origin genome for P times to create population
|
||||
pop_nodes = np.tile(o_nodes, (pop_size, 1, 1))
|
||||
pop_conns = np.tile(o_conns, (pop_size, 1, 1))
|
||||
node_attr_func = jax.vmap(genome.node_gene.new_random_attrs, in_axes=(None,0))
|
||||
input_attrs = node_attr_func(input_keys)
|
||||
output_attrs = node_attr_func(output_keys)
|
||||
hidden_attrs = genome.node_gene.new_custom_attrs(hidden_key)
|
||||
|
||||
return pop_nodes, pop_conns
|
||||
nodes = nodes.at[input_idx, 1:].set(input_attrs)
|
||||
nodes = nodes.at[output_idx, 1:].set(output_attrs)
|
||||
nodes = nodes.at[new_node_key, 1:].set(hidden_attrs)
|
||||
|
||||
input_conns = jnp.c_[input_idx, jnp.full_like(input_idx, new_node_key)]
|
||||
conns = conns.at[input_idx, 0:2].set(input_conns)
|
||||
conns = conns.at[input_idx, 2].set(True)
|
||||
|
||||
output_conns = jnp.c_[jnp.full_like(output_idx, new_node_key), output_idx]
|
||||
conns = conns.at[output_idx, 0:2].set(output_conns)
|
||||
conns = conns.at[output_idx, 2].set(True)
|
||||
|
||||
rand_keys_conns = jax.random.split(randkey, num=len(input_idx) + len(output_idx))
|
||||
input_conn_keys, output_conn_keys = rand_keys_conns[:len(input_idx)], rand_keys_conns[len(input_idx):]
|
||||
|
||||
conn_attr_func = jax.vmap(genome.conn_gene.new_random_attrs, in_axes=(None,0))
|
||||
input_conn_attrs = conn_attr_func(input_conn_keys)
|
||||
output_conn_attrs = conn_attr_func(output_conn_keys)
|
||||
|
||||
conns = conns.at[input_idx, 3:].set(input_conn_attrs)
|
||||
conns = conns.at[output_idx, 3:].set(output_conn_attrs)
|
||||
|
||||
return nodes, conns
|
||||
|
||||
|
||||
#random dense connections with 1 hidden layer
|
||||
def init_dense_hideen_layer( genome, randkey,hiddens=20):
|
||||
k1, k2, k3 = jax.random.split(randkey, num=3)
|
||||
input_idx, output_idx = genome.input_idx, genome.output_idx
|
||||
input_size = len(input_idx)
|
||||
output_size = len(output_idx)
|
||||
|
||||
hidden_idx = jnp.arange(input_size + output_size, input_size + output_size + hiddens)
|
||||
nodes = jnp.full((genome.max_nodes, genome.node_gene.length), jnp.nan, dtype=jnp.float32)
|
||||
nodes = nodes.at[input_idx, 0].set(input_idx)
|
||||
nodes = nodes.at[output_idx, 0].set(output_idx)
|
||||
nodes = nodes.at[hidden_idx, 0].set(hidden_idx)
|
||||
|
||||
total_idx = input_size + output_size + hiddens
|
||||
rand_keys_n = jax.random.split(k1, num=total_idx)
|
||||
input_keys = rand_keys_n[:input_size]
|
||||
output_keys = rand_keys_n[input_size:input_size + output_size]
|
||||
hidden_keys = rand_keys_n[input_size + output_size:]
|
||||
|
||||
node_attr_func = jax.vmap(genome.conn_gene.new_random_attrs, in_axes=(0))
|
||||
input_attrs = node_attr_func(input_keys)
|
||||
output_attrs = node_attr_func(output_keys)
|
||||
hidden_attrs = node_attr_func(hidden_keys)
|
||||
|
||||
nodes = nodes.at[input_idx, 1:].set(input_attrs)
|
||||
nodes = nodes.at[output_idx, 1:].set(output_attrs)
|
||||
nodes = nodes.at[hidden_idx, 1:].set(hidden_attrs)
|
||||
|
||||
total_connections = input_size * hiddens + hiddens * output_size
|
||||
conns = jnp.full((genome.max_conns, genome.conn_gene.length), jnp.nan, dtype=jnp.float32)
|
||||
|
||||
rand_keys_c = jax.random.split(k2, num=total_connections)
|
||||
conns_attr_func = jax.vmap(genome.node_gene.new_random_attrs, in_axes=(0))
|
||||
conns_attrs = conns_attr_func(rand_keys_c)
|
||||
|
||||
input_to_hidden_ids, hidden_ids = jnp.meshgrid(input_idx, hidden_idx, indexing='ij')
|
||||
hidden_to_output_ids, output_ids = jnp.meshgrid(hidden_idx, output_idx, indexing='ij')
|
||||
|
||||
conns = conns.at[:input_size * hiddens, 0].set(input_to_hidden_ids.flatten())
|
||||
conns = conns.at[:input_size * hiddens, 1].set(hidden_ids.flatten())
|
||||
conns = conns.at[input_size * hiddens: total_connections, 0].set(hidden_to_output_ids.flatten())
|
||||
conns = conns.at[input_size * hiddens: total_connections, 1].set(output_ids.flatten())
|
||||
conns = conns.at[:input_size * hiddens + hiddens * output_size, 2].set(True)
|
||||
conns = conns.at[:input_size * hiddens + hiddens * output_size, 3:].set(conns_attrs)
|
||||
|
||||
return nodes, conns
|
||||
|
||||
# random sparse connections with no hidden nodes
|
||||
def init_no_hidden_random(genome, randkey):
|
||||
k1, k2, k3 = jax.random.split(randkey, num=3)
|
||||
input_idx, output_idx = genome.input_idx, genome.output_idx
|
||||
|
||||
nodes = jnp.full((genome.max_nodes, genome.node_gene.length), jnp.nan)
|
||||
nodes = nodes.at[input_idx, 0].set(input_idx)
|
||||
nodes = nodes.at[output_idx, 0].set(output_idx)
|
||||
|
||||
total_idx = len(input_idx) + len(output_idx)
|
||||
rand_keys_n = jax.random.split(k1, num=total_idx)
|
||||
input_keys = rand_keys_n[:len(input_idx)]
|
||||
output_keys = rand_keys_n[len(input_idx):]
|
||||
|
||||
node_attr_func = jax.vmap(genome.node_gene.new_random_attrs, in_axes=(0))
|
||||
input_attrs = node_attr_func(input_keys)
|
||||
output_attrs = node_attr_func(output_keys)
|
||||
nodes = nodes.at[input_idx, 1:].set(input_attrs)
|
||||
nodes = nodes.at[output_idx, 1:].set(output_attrs)
|
||||
|
||||
conns = jnp.full((genome.max_conns, genome.conn_gene.length), jnp.nan)
|
||||
|
||||
num_connections_per_output = 4
|
||||
total_connections = len(output_idx) * num_connections_per_output
|
||||
|
||||
def create_connections_for_output(key):
|
||||
permuted_inputs = jax.random.permutation(key, input_idx)
|
||||
selected_inputs = permuted_inputs[:num_connections_per_output]
|
||||
return selected_inputs
|
||||
|
||||
conn_keys = jax.random.split(k2, num=len(output_idx))
|
||||
connections = jax.vmap(create_connections_for_output)(conn_keys)
|
||||
connections = connections.flatten()
|
||||
|
||||
output_repeats = jnp.repeat(output_idx, num_connections_per_output)
|
||||
|
||||
rand_keys_c = jax.random.split(k3, num=total_connections)
|
||||
conns_attr_func = jax.vmap(genome.conn_gene.new_random_attrs, in_axes=(0))
|
||||
conns_attrs = conns_attr_func(rand_keys_c)
|
||||
|
||||
conns = conns.at[:total_connections, 0].set(connections)
|
||||
conns = conns.at[:total_connections, 1].set(output_repeats)
|
||||
conns = conns.at[:total_connections, 2].set(True) # enabled
|
||||
conns = conns.at[:total_connections, 3:].set(conns_attrs)
|
||||
|
||||
return nodes, conns
|
||||
|
||||
Reference in New Issue
Block a user