update class State. Add method register and update method update.
This commit is contained in:
@@ -20,7 +20,8 @@ class DefaultSpecies(BaseSpecies):
|
||||
survival_threshold: float = 0.2,
|
||||
min_species_size: int = 1,
|
||||
compatibility_threshold: float = 3.,
|
||||
initialize_method: str = 'one_hidden_node', # {'one_hidden_node', 'dense_hideen_layer', 'no_hidden_random'}
|
||||
initialize_method: str = 'one_hidden_node',
|
||||
# {'one_hidden_node', 'dense_hideen_layer', 'no_hidden_random'}
|
||||
):
|
||||
self.genome = genome
|
||||
self.pop_size = pop_size
|
||||
@@ -41,7 +42,7 @@ class DefaultSpecies(BaseSpecies):
|
||||
|
||||
def setup(self, randkey):
|
||||
k1, k2 = jax.random.split(randkey, 2)
|
||||
pop_nodes, pop_conns = initialize_population(self.pop_size, self.genome,k1, self.initialize_method)
|
||||
pop_nodes, pop_conns = initialize_population(self.pop_size, self.genome, k1, self.initialize_method)
|
||||
|
||||
species_keys = jnp.full((self.species_size,), jnp.nan) # the unique index (primary key) for each species
|
||||
best_fitness = jnp.full((self.species_size,), jnp.nan) # the best fitness of each species
|
||||
@@ -492,11 +493,9 @@ class DefaultSpecies(BaseSpecies):
|
||||
return jnp.where(max_cnt == 0, 0, val / max_cnt)
|
||||
|
||||
|
||||
|
||||
|
||||
def initialize_population(pop_size, genome, randkey, init_method='default'):
|
||||
rand_keys = jax.random.split(randkey, pop_size)
|
||||
|
||||
|
||||
if init_method == 'one_hidden_node':
|
||||
init_func = init_one_hidden_node
|
||||
elif init_method == 'dense_hideen_layer':
|
||||
@@ -510,6 +509,7 @@ def initialize_population(pop_size, genome, randkey, init_method='default'):
|
||||
|
||||
return pop_nodes, pop_conns
|
||||
|
||||
|
||||
# one hidden node
|
||||
def init_one_hidden_node(genome, randkey):
|
||||
input_idx, output_idx = genome.input_idx, genome.output_idx
|
||||
@@ -520,12 +520,14 @@ def init_one_hidden_node(genome, randkey):
|
||||
|
||||
nodes = nodes.at[input_idx, 0].set(input_idx)
|
||||
nodes = nodes.at[output_idx, 0].set(output_idx)
|
||||
nodes = nodes.at[new_node_key, 0].set(new_node_key)
|
||||
nodes = nodes.at[new_node_key, 0].set(new_node_key)
|
||||
|
||||
rand_keys_nodes = jax.random.split(randkey, num=len(input_idx) + len(output_idx) + 1)
|
||||
input_keys, output_keys, hidden_key = rand_keys_nodes[:len(input_idx)], rand_keys_nodes[len(input_idx):len(input_idx) + len(output_idx)], rand_keys_nodes[-1]
|
||||
input_keys, output_keys, hidden_key = rand_keys_nodes[:len(input_idx)], rand_keys_nodes[
|
||||
len(input_idx):len(input_idx) + len(
|
||||
output_idx)], rand_keys_nodes[-1]
|
||||
|
||||
node_attr_func = jax.vmap(genome.node_gene.new_random_attrs, in_axes=(None,0))
|
||||
node_attr_func = jax.vmap(genome.node_gene.new_random_attrs, in_axes=(None, 0))
|
||||
input_attrs = node_attr_func(input_keys)
|
||||
output_attrs = node_attr_func(output_keys)
|
||||
hidden_attrs = genome.node_gene.new_custom_attrs(hidden_key)
|
||||
@@ -540,12 +542,12 @@ def init_one_hidden_node(genome, randkey):
|
||||
|
||||
output_conns = jnp.c_[jnp.full_like(output_idx, new_node_key), output_idx]
|
||||
conns = conns.at[output_idx, 0:2].set(output_conns)
|
||||
conns = conns.at[output_idx, 2].set(True)
|
||||
conns = conns.at[output_idx, 2].set(True)
|
||||
|
||||
rand_keys_conns = jax.random.split(randkey, num=len(input_idx) + len(output_idx))
|
||||
input_conn_keys, output_conn_keys = rand_keys_conns[:len(input_idx)], rand_keys_conns[len(input_idx):]
|
||||
|
||||
conn_attr_func = jax.vmap(genome.conn_gene.new_random_attrs, in_axes=(None,0))
|
||||
conn_attr_func = jax.vmap(genome.conn_gene.new_random_attrs, in_axes=(None, 0))
|
||||
input_conn_attrs = conn_attr_func(input_conn_keys)
|
||||
output_conn_attrs = conn_attr_func(output_conn_keys)
|
||||
|
||||
@@ -553,10 +555,10 @@ def init_one_hidden_node(genome, randkey):
|
||||
conns = conns.at[output_idx, 3:].set(output_conn_attrs)
|
||||
|
||||
return nodes, conns
|
||||
|
||||
|
||||
#random dense connections with 1 hidden layer
|
||||
def init_dense_hideen_layer( genome, randkey,hiddens=20):
|
||||
|
||||
# random dense connections with 1 hidden layer
|
||||
def init_dense_hideen_layer(genome, randkey, hiddens=20):
|
||||
k1, k2, k3 = jax.random.split(randkey, num=3)
|
||||
input_idx, output_idx = genome.input_idx, genome.output_idx
|
||||
input_size = len(input_idx)
|
||||
@@ -602,6 +604,7 @@ def init_dense_hideen_layer( genome, randkey,hiddens=20):
|
||||
|
||||
return nodes, conns
|
||||
|
||||
|
||||
# random sparse connections with no hidden nodes
|
||||
def init_no_hidden_random(genome, randkey):
|
||||
k1, k2, k3 = jax.random.split(randkey, num=3)
|
||||
|
||||
@@ -7,7 +7,19 @@ class State:
|
||||
def __init__(self, **kwargs):
|
||||
self.__dict__['state_dict'] = kwargs
|
||||
|
||||
def registered_keys(self):
|
||||
return self.state_dict.keys()
|
||||
|
||||
def register(self, **kwargs):
|
||||
for key in kwargs:
|
||||
if key in self.registered_keys():
|
||||
raise ValueError(f"Key {key} already exists in state")
|
||||
return State(**{**self.state_dict, **kwargs})
|
||||
|
||||
def update(self, **kwargs):
|
||||
for key in kwargs:
|
||||
if key not in self.registered_keys():
|
||||
raise ValueError(f"Key {key} does not exist in state")
|
||||
return State(**{**self.state_dict, **kwargs})
|
||||
|
||||
def __getattr__(self, name):
|
||||
@@ -26,4 +38,4 @@ class State:
|
||||
|
||||
@classmethod
|
||||
def tree_unflatten(cls, aux_data, children):
|
||||
return cls(**dict(zip(aux_data, children)))
|
||||
return cls(**dict(zip(aux_data, children)))
|
||||
|
||||
Reference in New Issue
Block a user