prepare for experiment

This commit is contained in:
wls2002
2023-05-14 15:27:17 +08:00
parent 72c9d4167a
commit 2b79f2c903
11 changed files with 252 additions and 62 deletions

View File

@@ -19,7 +19,7 @@ class FunctionFactory:
self.expand_coe = config.basic.expands_coe
self.precompile_times = config.basic.pre_compile_times
self.compiled_function = {}
self.time_cost = {}
self.compile_time = 0
self.load_config_vals(config)
@@ -150,6 +150,8 @@ class FunctionFactory:
return self.compiled_function[key]
def compile_update_speciate(self, N, C, S):
s = time.time()
func = self.update_speciate_with_args
randkey_lower = np.zeros((2,), dtype=np.uint32)
pop_nodes_lower = np.zeros((self.pop_size, N, 5))
@@ -177,16 +179,22 @@ class FunctionFactory:
).compile()
self.compiled_function[("update_speciate", N, C, S)] = compiled_func
self.compile_time += time.time() - s
def create_topological_sort_with_args(self):
self.topological_sort_with_args = topological_sort
def compile_topological_sort(self, n):
s = time.time()
func = self.topological_sort_with_args
nodes_lower = np.zeros((n, 5))
connections_lower = np.zeros((2, n, n))
func = jit(func).lower(nodes_lower, connections_lower).compile()
self.compiled_function[('topological_sort', n)] = func
self.compile_time += time.time() - s
def create_topological_sort(self, n):
key = ('topological_sort', n)
if key not in self.compiled_function:
@@ -194,6 +202,8 @@ class FunctionFactory:
return self.compiled_function[key]
def compile_topological_sort_batch(self, n):
s = time.time()
func = self.topological_sort_with_args
func = vmap(func)
nodes_lower = np.zeros((self.pop_size, n, 5))
@@ -201,6 +211,8 @@ class FunctionFactory:
func = jit(func).lower(nodes_lower, connections_lower).compile()
self.compiled_function[('topological_sort_batch', n)] = func
self.compile_time += time.time() - s
def create_topological_sort_batch(self, n):
key = ('topological_sort_batch', n)
if key not in self.compiled_function:
@@ -215,32 +227,10 @@ class FunctionFactory:
)
self.single_forward_with_args = func
def compile_single_forward(self, n):
"""
single input for a genome
:param n:
:return:
"""
func = self.single_forward_with_args
inputs_lower = np.zeros((self.num_inputs,))
cal_seqs_lower = np.zeros((n,), dtype=np.int32)
nodes_lower = np.zeros((n, 5))
connections_lower = np.zeros((2, n, n))
func = jit(func).lower(inputs_lower, cal_seqs_lower, nodes_lower, connections_lower).compile()
self.compiled_function[('single_forward', n)] = func
def compile_pop_forward(self, n):
func = self.single_forward_with_args
func = vmap(func, in_axes=(None, 0, 0, 0))
inputs_lower = np.zeros((self.num_inputs,))
cal_seqs_lower = np.zeros((self.pop_size, n), dtype=np.int32)
nodes_lower = np.zeros((self.pop_size, n, 5))
connections_lower = np.zeros((self.pop_size, 2, n, n))
func = jit(func).lower(inputs_lower, cal_seqs_lower, nodes_lower, connections_lower).compile()
self.compiled_function[('pop_forward', n)] = func
def compile_batch_forward(self, n):
s = time.time()
func = self.single_forward_with_args
func = vmap(func, in_axes=(0, None, None, None))
@@ -251,19 +241,19 @@ class FunctionFactory:
func = jit(func).lower(inputs_lower, cal_seqs_lower, nodes_lower, connections_lower).compile()
self.compiled_function[('batch_forward', n)] = func
self.compile_time += time.time() - s
def create_batch_forward(self, n):
key = ('batch_forward', n)
if key not in self.compiled_function:
self.compile_batch_forward(n)
if self.debug:
def debug_batch_forward(*args):
return self.compiled_function[key](*args).block_until_ready()
return debug_batch_forward
else:
return self.compiled_function[key]
return self.compiled_function[key]
def compile_pop_batch_forward(self, n):
s = time.time()
func = self.single_forward_with_args
func = vmap(func, in_axes=(0, None, None, None)) # batch_forward
func = vmap(func, in_axes=(None, 0, 0, 0)) # pop_batch_forward
@@ -276,25 +266,24 @@ class FunctionFactory:
func = jit(func).lower(inputs_lower, cal_seqs_lower, nodes_lower, connections_lower).compile()
self.compiled_function[('pop_batch_forward', n)] = func
self.compile_time += time.time() - s
def create_pop_batch_forward(self, n):
key = ('pop_batch_forward', n)
if key not in self.compiled_function:
self.compile_pop_batch_forward(n)
if self.debug:
def debug_pop_batch_forward(*args):
return self.compiled_function[key](*args).block_until_ready()
return debug_pop_batch_forward
else:
return self.compiled_function[key]
return self.compiled_function[key]
def ask_pop_batch_forward(self, pop_nodes, pop_cons):
n, c = pop_nodes.shape[1], pop_cons.shape[1]
batch_unflatten_func = self.create_batch_unflatten_connections(n, c)
pop_cons = batch_unflatten_func(pop_nodes, pop_cons)
ts = self.create_topological_sort_batch(n)
pop_cal_seqs = ts(pop_nodes, pop_cons)
# for connections with enabled is false, set weight to 0)
pop_cal_seqs = ts(pop_nodes, pop_cons)
# print(pop_cal_seqs)
forward_func = self.create_pop_batch_forward(n)
def debug_forward(inputs):
@@ -314,6 +303,9 @@ class FunctionFactory:
return debug_forward
def compile_batch_unflatten_connections(self, n, c):
s = time.time()
func = unflatten_connections
func = vmap(func)
pop_nodes_lower = np.zeros((self.pop_size, n, 5))
@@ -321,14 +313,11 @@ class FunctionFactory:
func = jit(func).lower(pop_nodes_lower, pop_connections_lower).compile()
self.compiled_function[('batch_unflatten_connections', n, c)] = func
self.compile_time += time.time() - s
def create_batch_unflatten_connections(self, n, c):
key = ('batch_unflatten_connections', n, c)
if key not in self.compiled_function:
self.compile_batch_unflatten_connections(n, c)
if self.debug:
def debug_batch_unflatten_connections(*args):
return self.compiled_function[key](*args).block_until_ready()
return debug_batch_unflatten_connections
else:
return self.compiled_function[key]
return self.compiled_function[key]

View File

@@ -133,5 +133,8 @@ act_name2key = {
def act(idx, z):
idx = jnp.asarray(idx, dtype=jnp.int32)
# change idx from float to int
return jax.lax.switch(idx, ACT_TOTAL_LIST, z)
res = jax.lax.switch(idx, ACT_TOTAL_LIST, z)
return jnp.where(jnp.isnan(res), jnp.nan, res)
# return jax.lax.switch(idx, ACT_TOTAL_LIST, z)

View File

@@ -88,6 +88,12 @@ def mutate(rand_key: Array,
def m_add_connection(rk, n, c):
return mutate_add_connection(rk, n, c, input_idx, output_idx)
def m_delete_node(rk, n, c):
return mutate_delete_node(rk, n, c, input_idx, output_idx)
def m_delete_connection(rk, n, c):
return mutate_delete_connection(rk, n, c)
r1, r2, r3, r4, rand_key = jax.random.split(rand_key, 5)
# mutate add node
@@ -100,6 +106,16 @@ def mutate(rand_key: Array,
nodes = jnp.where(rand(r3) < add_connection_rate, aux_nodes, nodes)
connections = jnp.where(rand(r3) < add_connection_rate, aux_connections, connections)
# mutate delete node
aux_nodes, aux_connections = m_delete_node(r2, nodes, connections)
nodes = jnp.where(rand(r2) < delete_node_rate, aux_nodes, nodes)
connections = jnp.where(rand(r2) < delete_node_rate, aux_connections, connections)
# mutate delete connection
aux_nodes, aux_connections = m_delete_connection(r4, nodes, connections)
nodes = jnp.where(rand(r4) < delete_connection_rate, aux_nodes, nodes)
connections = jnp.where(rand(r4) < delete_connection_rate, aux_connections, connections)
nodes, connections = mutate_values(rand_key, nodes, connections, bias_mean, bias_std, bias_mutate_strength,
bias_mutate_rate, bias_replace_rate, response_mean, response_std,
response_mutate_strength, response_mutate_rate, response_replace_rate,

View File

@@ -14,6 +14,8 @@ EMPTY_CON = jnp.full((1, 4), jnp.nan)
def unflatten_connections(nodes, cons):
"""
transform the (C, 4) connections to (2, N, N)
this function is only used for transform a genome to the forward function, so here we set the weight of un=enabled
connections to nan, that means we dont consider such connection when forward;
:param cons:
:param nodes:
:return:
@@ -29,6 +31,10 @@ def unflatten_connections(nodes, cons):
# however, it will do nothing set values in an array
res = res.at[0, i_idxs, o_idxs].set(cons[:, 2])
res = res.at[1, i_idxs, o_idxs].set(cons[:, 3])
# (2, N, N), (2, N, N), (2, N, N)
# res = jnp.where(res[1, :, :] == 0, jnp.nan, res)
return res

View File

@@ -16,9 +16,9 @@ class Pipeline:
Neat algorithm pipeline.
"""
def __init__(self, config, seed=42):
def __init__(self, config, function_factory, seed=42):
self.time_dict = {}
self.function_factory = FunctionFactory(config)
self.function_factory = function_factory
self.randkey = jax.random.PRNGKey(seed)
np.random.seed(seed)
@@ -31,18 +31,21 @@ class Pipeline:
self.pop_size = config.neat.population.pop_size
self.species_controller = SpeciesController(config)
self.initialize_func = self.function_factory.create_initialize()
self.initialize_func = self.function_factory.create_initialize(self.N, self.C)
self.pop_nodes, self.pop_cons, self.input_idx, self.output_idx = self.initialize_func()
self.create_and_speciate = self.function_factory.create_update_speciate(self.N, self.C, self.S)
self.generation = 0
self.generation_time_list = []
self.species_controller.init_speciate(self.pop_nodes, self.pop_cons)
self.best_fitness = float('-inf')
self.best_genome = None
self.generation_timestamp = time.time()
self.evaluate_time = 0
def ask(self):
"""
Create a forward function for the population.
@@ -66,7 +69,9 @@ class Pipeline:
new_node_keys,
pre_spe_center_nodes, pre_spe_center_cons, pre_species_keys, new_species_key_start)
idx2specie, new_center_nodes, new_center_cons, new_species_keys = jax.device_get([idx2specie, new_center_nodes, new_center_cons, new_species_keys])
self.pop_nodes, self.pop_cons, idx2specie, new_center_nodes, new_center_cons, new_species_keys = \
jax.device_get([self.pop_nodes, self.pop_cons, idx2specie, new_center_nodes, new_center_cons, new_species_keys])
self.species_controller.tell(idx2specie, new_center_nodes, new_center_cons, new_species_keys, self.generation)
@@ -75,7 +80,12 @@ class Pipeline:
def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"):
for _ in range(self.config.neat.population.generation_limit):
forward_func = self.ask()
tic = time.time()
fitnesses = fitness_func(forward_func)
self.evaluate_time += time.time() - tic
assert np.all(~np.isnan(fitnesses)), "fitnesses should not be nan!"
if analysis is not None:
if analysis == "default":
@@ -104,6 +114,7 @@ class Pipeline:
max_node_size = np.max(pop_node_sizes)
if max_node_size >= self.N:
self.N = int(self.N * self.expand_coe)
# self.C = int(self.C * self.expand_coe)
print(f"node expand to {self.N}!")
self.pop_nodes, self.pop_cons = expand(self.pop_nodes, self.pop_cons, self.N, self.C)
@@ -116,6 +127,7 @@ class Pipeline:
pop_node_sizes = np.sum(~np.isnan(pop_con_keys), axis=1)
max_con_size = np.max(pop_node_sizes)
if max_con_size >= self.C:
# self.N = int(self.N * self.expand_coe)
self.C = int(self.C * self.expand_coe)
print(f"connections expand to {self.C}!")
self.pop_nodes, self.pop_cons = expand(self.pop_nodes, self.pop_cons, self.N, self.C)
@@ -134,6 +146,7 @@ class Pipeline:
new_timestamp = time.time()
cost_time = new_timestamp - self.generation_timestamp
self.generation_time_list.append(cost_time)
self.generation_timestamp = new_timestamp
max_idx = np.argmax(fitnesses)