prepare for experiment
This commit is contained in:
@@ -19,7 +19,7 @@ class FunctionFactory:
|
|||||||
self.expand_coe = config.basic.expands_coe
|
self.expand_coe = config.basic.expands_coe
|
||||||
self.precompile_times = config.basic.pre_compile_times
|
self.precompile_times = config.basic.pre_compile_times
|
||||||
self.compiled_function = {}
|
self.compiled_function = {}
|
||||||
self.time_cost = {}
|
self.compile_time = 0
|
||||||
|
|
||||||
self.load_config_vals(config)
|
self.load_config_vals(config)
|
||||||
|
|
||||||
@@ -150,6 +150,8 @@ class FunctionFactory:
|
|||||||
return self.compiled_function[key]
|
return self.compiled_function[key]
|
||||||
|
|
||||||
def compile_update_speciate(self, N, C, S):
|
def compile_update_speciate(self, N, C, S):
|
||||||
|
s = time.time()
|
||||||
|
|
||||||
func = self.update_speciate_with_args
|
func = self.update_speciate_with_args
|
||||||
randkey_lower = np.zeros((2,), dtype=np.uint32)
|
randkey_lower = np.zeros((2,), dtype=np.uint32)
|
||||||
pop_nodes_lower = np.zeros((self.pop_size, N, 5))
|
pop_nodes_lower = np.zeros((self.pop_size, N, 5))
|
||||||
@@ -177,16 +179,22 @@ class FunctionFactory:
|
|||||||
).compile()
|
).compile()
|
||||||
self.compiled_function[("update_speciate", N, C, S)] = compiled_func
|
self.compiled_function[("update_speciate", N, C, S)] = compiled_func
|
||||||
|
|
||||||
|
self.compile_time += time.time() - s
|
||||||
|
|
||||||
def create_topological_sort_with_args(self):
|
def create_topological_sort_with_args(self):
|
||||||
self.topological_sort_with_args = topological_sort
|
self.topological_sort_with_args = topological_sort
|
||||||
|
|
||||||
def compile_topological_sort(self, n):
|
def compile_topological_sort(self, n):
|
||||||
|
s = time.time()
|
||||||
|
|
||||||
func = self.topological_sort_with_args
|
func = self.topological_sort_with_args
|
||||||
nodes_lower = np.zeros((n, 5))
|
nodes_lower = np.zeros((n, 5))
|
||||||
connections_lower = np.zeros((2, n, n))
|
connections_lower = np.zeros((2, n, n))
|
||||||
func = jit(func).lower(nodes_lower, connections_lower).compile()
|
func = jit(func).lower(nodes_lower, connections_lower).compile()
|
||||||
self.compiled_function[('topological_sort', n)] = func
|
self.compiled_function[('topological_sort', n)] = func
|
||||||
|
|
||||||
|
self.compile_time += time.time() - s
|
||||||
|
|
||||||
def create_topological_sort(self, n):
|
def create_topological_sort(self, n):
|
||||||
key = ('topological_sort', n)
|
key = ('topological_sort', n)
|
||||||
if key not in self.compiled_function:
|
if key not in self.compiled_function:
|
||||||
@@ -194,6 +202,8 @@ class FunctionFactory:
|
|||||||
return self.compiled_function[key]
|
return self.compiled_function[key]
|
||||||
|
|
||||||
def compile_topological_sort_batch(self, n):
|
def compile_topological_sort_batch(self, n):
|
||||||
|
s = time.time()
|
||||||
|
|
||||||
func = self.topological_sort_with_args
|
func = self.topological_sort_with_args
|
||||||
func = vmap(func)
|
func = vmap(func)
|
||||||
nodes_lower = np.zeros((self.pop_size, n, 5))
|
nodes_lower = np.zeros((self.pop_size, n, 5))
|
||||||
@@ -201,6 +211,8 @@ class FunctionFactory:
|
|||||||
func = jit(func).lower(nodes_lower, connections_lower).compile()
|
func = jit(func).lower(nodes_lower, connections_lower).compile()
|
||||||
self.compiled_function[('topological_sort_batch', n)] = func
|
self.compiled_function[('topological_sort_batch', n)] = func
|
||||||
|
|
||||||
|
self.compile_time += time.time() - s
|
||||||
|
|
||||||
def create_topological_sort_batch(self, n):
|
def create_topological_sort_batch(self, n):
|
||||||
key = ('topological_sort_batch', n)
|
key = ('topological_sort_batch', n)
|
||||||
if key not in self.compiled_function:
|
if key not in self.compiled_function:
|
||||||
@@ -215,32 +227,10 @@ class FunctionFactory:
|
|||||||
)
|
)
|
||||||
self.single_forward_with_args = func
|
self.single_forward_with_args = func
|
||||||
|
|
||||||
def compile_single_forward(self, n):
|
|
||||||
"""
|
|
||||||
single input for a genome
|
|
||||||
:param n:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
func = self.single_forward_with_args
|
|
||||||
inputs_lower = np.zeros((self.num_inputs,))
|
|
||||||
cal_seqs_lower = np.zeros((n,), dtype=np.int32)
|
|
||||||
nodes_lower = np.zeros((n, 5))
|
|
||||||
connections_lower = np.zeros((2, n, n))
|
|
||||||
func = jit(func).lower(inputs_lower, cal_seqs_lower, nodes_lower, connections_lower).compile()
|
|
||||||
self.compiled_function[('single_forward', n)] = func
|
|
||||||
|
|
||||||
def compile_pop_forward(self, n):
|
|
||||||
func = self.single_forward_with_args
|
|
||||||
func = vmap(func, in_axes=(None, 0, 0, 0))
|
|
||||||
|
|
||||||
inputs_lower = np.zeros((self.num_inputs,))
|
|
||||||
cal_seqs_lower = np.zeros((self.pop_size, n), dtype=np.int32)
|
|
||||||
nodes_lower = np.zeros((self.pop_size, n, 5))
|
|
||||||
connections_lower = np.zeros((self.pop_size, 2, n, n))
|
|
||||||
func = jit(func).lower(inputs_lower, cal_seqs_lower, nodes_lower, connections_lower).compile()
|
|
||||||
self.compiled_function[('pop_forward', n)] = func
|
|
||||||
|
|
||||||
def compile_batch_forward(self, n):
|
def compile_batch_forward(self, n):
|
||||||
|
s = time.time()
|
||||||
|
|
||||||
func = self.single_forward_with_args
|
func = self.single_forward_with_args
|
||||||
func = vmap(func, in_axes=(0, None, None, None))
|
func = vmap(func, in_axes=(0, None, None, None))
|
||||||
|
|
||||||
@@ -251,19 +241,19 @@ class FunctionFactory:
|
|||||||
func = jit(func).lower(inputs_lower, cal_seqs_lower, nodes_lower, connections_lower).compile()
|
func = jit(func).lower(inputs_lower, cal_seqs_lower, nodes_lower, connections_lower).compile()
|
||||||
self.compiled_function[('batch_forward', n)] = func
|
self.compiled_function[('batch_forward', n)] = func
|
||||||
|
|
||||||
|
self.compile_time += time.time() - s
|
||||||
|
|
||||||
def create_batch_forward(self, n):
|
def create_batch_forward(self, n):
|
||||||
key = ('batch_forward', n)
|
key = ('batch_forward', n)
|
||||||
if key not in self.compiled_function:
|
if key not in self.compiled_function:
|
||||||
self.compile_batch_forward(n)
|
self.compile_batch_forward(n)
|
||||||
if self.debug:
|
|
||||||
def debug_batch_forward(*args):
|
|
||||||
return self.compiled_function[key](*args).block_until_ready()
|
|
||||||
|
|
||||||
return debug_batch_forward
|
return self.compiled_function[key]
|
||||||
else:
|
|
||||||
return self.compiled_function[key]
|
|
||||||
|
|
||||||
def compile_pop_batch_forward(self, n):
|
def compile_pop_batch_forward(self, n):
|
||||||
|
|
||||||
|
s = time.time()
|
||||||
|
|
||||||
func = self.single_forward_with_args
|
func = self.single_forward_with_args
|
||||||
func = vmap(func, in_axes=(0, None, None, None)) # batch_forward
|
func = vmap(func, in_axes=(0, None, None, None)) # batch_forward
|
||||||
func = vmap(func, in_axes=(None, 0, 0, 0)) # pop_batch_forward
|
func = vmap(func, in_axes=(None, 0, 0, 0)) # pop_batch_forward
|
||||||
@@ -276,25 +266,24 @@ class FunctionFactory:
|
|||||||
func = jit(func).lower(inputs_lower, cal_seqs_lower, nodes_lower, connections_lower).compile()
|
func = jit(func).lower(inputs_lower, cal_seqs_lower, nodes_lower, connections_lower).compile()
|
||||||
self.compiled_function[('pop_batch_forward', n)] = func
|
self.compiled_function[('pop_batch_forward', n)] = func
|
||||||
|
|
||||||
|
self.compile_time += time.time() - s
|
||||||
|
|
||||||
def create_pop_batch_forward(self, n):
|
def create_pop_batch_forward(self, n):
|
||||||
key = ('pop_batch_forward', n)
|
key = ('pop_batch_forward', n)
|
||||||
if key not in self.compiled_function:
|
if key not in self.compiled_function:
|
||||||
self.compile_pop_batch_forward(n)
|
self.compile_pop_batch_forward(n)
|
||||||
if self.debug:
|
|
||||||
def debug_pop_batch_forward(*args):
|
|
||||||
return self.compiled_function[key](*args).block_until_ready()
|
|
||||||
|
|
||||||
return debug_pop_batch_forward
|
return self.compiled_function[key]
|
||||||
else:
|
|
||||||
return self.compiled_function[key]
|
|
||||||
|
|
||||||
def ask_pop_batch_forward(self, pop_nodes, pop_cons):
|
def ask_pop_batch_forward(self, pop_nodes, pop_cons):
|
||||||
n, c = pop_nodes.shape[1], pop_cons.shape[1]
|
n, c = pop_nodes.shape[1], pop_cons.shape[1]
|
||||||
batch_unflatten_func = self.create_batch_unflatten_connections(n, c)
|
batch_unflatten_func = self.create_batch_unflatten_connections(n, c)
|
||||||
pop_cons = batch_unflatten_func(pop_nodes, pop_cons)
|
pop_cons = batch_unflatten_func(pop_nodes, pop_cons)
|
||||||
ts = self.create_topological_sort_batch(n)
|
ts = self.create_topological_sort_batch(n)
|
||||||
pop_cal_seqs = ts(pop_nodes, pop_cons)
|
|
||||||
|
|
||||||
|
# for connections with enabled is false, set weight to 0)
|
||||||
|
pop_cal_seqs = ts(pop_nodes, pop_cons)
|
||||||
|
# print(pop_cal_seqs)
|
||||||
forward_func = self.create_pop_batch_forward(n)
|
forward_func = self.create_pop_batch_forward(n)
|
||||||
|
|
||||||
def debug_forward(inputs):
|
def debug_forward(inputs):
|
||||||
@@ -314,6 +303,9 @@ class FunctionFactory:
|
|||||||
return debug_forward
|
return debug_forward
|
||||||
|
|
||||||
def compile_batch_unflatten_connections(self, n, c):
|
def compile_batch_unflatten_connections(self, n, c):
|
||||||
|
|
||||||
|
s = time.time()
|
||||||
|
|
||||||
func = unflatten_connections
|
func = unflatten_connections
|
||||||
func = vmap(func)
|
func = vmap(func)
|
||||||
pop_nodes_lower = np.zeros((self.pop_size, n, 5))
|
pop_nodes_lower = np.zeros((self.pop_size, n, 5))
|
||||||
@@ -321,14 +313,11 @@ class FunctionFactory:
|
|||||||
func = jit(func).lower(pop_nodes_lower, pop_connections_lower).compile()
|
func = jit(func).lower(pop_nodes_lower, pop_connections_lower).compile()
|
||||||
self.compiled_function[('batch_unflatten_connections', n, c)] = func
|
self.compiled_function[('batch_unflatten_connections', n, c)] = func
|
||||||
|
|
||||||
|
self.compile_time += time.time() - s
|
||||||
|
|
||||||
def create_batch_unflatten_connections(self, n, c):
|
def create_batch_unflatten_connections(self, n, c):
|
||||||
key = ('batch_unflatten_connections', n, c)
|
key = ('batch_unflatten_connections', n, c)
|
||||||
if key not in self.compiled_function:
|
if key not in self.compiled_function:
|
||||||
self.compile_batch_unflatten_connections(n, c)
|
self.compile_batch_unflatten_connections(n, c)
|
||||||
if self.debug:
|
|
||||||
def debug_batch_unflatten_connections(*args):
|
|
||||||
return self.compiled_function[key](*args).block_until_ready()
|
|
||||||
|
|
||||||
return debug_batch_unflatten_connections
|
return self.compiled_function[key]
|
||||||
else:
|
|
||||||
return self.compiled_function[key]
|
|
||||||
|
|||||||
@@ -133,5 +133,8 @@ act_name2key = {
|
|||||||
def act(idx, z):
|
def act(idx, z):
|
||||||
idx = jnp.asarray(idx, dtype=jnp.int32)
|
idx = jnp.asarray(idx, dtype=jnp.int32)
|
||||||
# change idx from float to int
|
# change idx from float to int
|
||||||
return jax.lax.switch(idx, ACT_TOTAL_LIST, z)
|
res = jax.lax.switch(idx, ACT_TOTAL_LIST, z)
|
||||||
|
return jnp.where(jnp.isnan(res), jnp.nan, res)
|
||||||
|
|
||||||
|
# return jax.lax.switch(idx, ACT_TOTAL_LIST, z)
|
||||||
|
|
||||||
|
|||||||
@@ -88,6 +88,12 @@ def mutate(rand_key: Array,
|
|||||||
def m_add_connection(rk, n, c):
|
def m_add_connection(rk, n, c):
|
||||||
return mutate_add_connection(rk, n, c, input_idx, output_idx)
|
return mutate_add_connection(rk, n, c, input_idx, output_idx)
|
||||||
|
|
||||||
|
def m_delete_node(rk, n, c):
|
||||||
|
return mutate_delete_node(rk, n, c, input_idx, output_idx)
|
||||||
|
|
||||||
|
def m_delete_connection(rk, n, c):
|
||||||
|
return mutate_delete_connection(rk, n, c)
|
||||||
|
|
||||||
r1, r2, r3, r4, rand_key = jax.random.split(rand_key, 5)
|
r1, r2, r3, r4, rand_key = jax.random.split(rand_key, 5)
|
||||||
|
|
||||||
# mutate add node
|
# mutate add node
|
||||||
@@ -100,6 +106,16 @@ def mutate(rand_key: Array,
|
|||||||
nodes = jnp.where(rand(r3) < add_connection_rate, aux_nodes, nodes)
|
nodes = jnp.where(rand(r3) < add_connection_rate, aux_nodes, nodes)
|
||||||
connections = jnp.where(rand(r3) < add_connection_rate, aux_connections, connections)
|
connections = jnp.where(rand(r3) < add_connection_rate, aux_connections, connections)
|
||||||
|
|
||||||
|
# mutate delete node
|
||||||
|
aux_nodes, aux_connections = m_delete_node(r2, nodes, connections)
|
||||||
|
nodes = jnp.where(rand(r2) < delete_node_rate, aux_nodes, nodes)
|
||||||
|
connections = jnp.where(rand(r2) < delete_node_rate, aux_connections, connections)
|
||||||
|
|
||||||
|
# mutate delete connection
|
||||||
|
aux_nodes, aux_connections = m_delete_connection(r4, nodes, connections)
|
||||||
|
nodes = jnp.where(rand(r4) < delete_connection_rate, aux_nodes, nodes)
|
||||||
|
connections = jnp.where(rand(r4) < delete_connection_rate, aux_connections, connections)
|
||||||
|
|
||||||
nodes, connections = mutate_values(rand_key, nodes, connections, bias_mean, bias_std, bias_mutate_strength,
|
nodes, connections = mutate_values(rand_key, nodes, connections, bias_mean, bias_std, bias_mutate_strength,
|
||||||
bias_mutate_rate, bias_replace_rate, response_mean, response_std,
|
bias_mutate_rate, bias_replace_rate, response_mean, response_std,
|
||||||
response_mutate_strength, response_mutate_rate, response_replace_rate,
|
response_mutate_strength, response_mutate_rate, response_replace_rate,
|
||||||
|
|||||||
@@ -14,6 +14,8 @@ EMPTY_CON = jnp.full((1, 4), jnp.nan)
|
|||||||
def unflatten_connections(nodes, cons):
|
def unflatten_connections(nodes, cons):
|
||||||
"""
|
"""
|
||||||
transform the (C, 4) connections to (2, N, N)
|
transform the (C, 4) connections to (2, N, N)
|
||||||
|
this function is only used for transform a genome to the forward function, so here we set the weight of un=enabled
|
||||||
|
connections to nan, that means we dont consider such connection when forward;
|
||||||
:param cons:
|
:param cons:
|
||||||
:param nodes:
|
:param nodes:
|
||||||
:return:
|
:return:
|
||||||
@@ -29,6 +31,10 @@ def unflatten_connections(nodes, cons):
|
|||||||
# however, it will do nothing set values in an array
|
# however, it will do nothing set values in an array
|
||||||
res = res.at[0, i_idxs, o_idxs].set(cons[:, 2])
|
res = res.at[0, i_idxs, o_idxs].set(cons[:, 2])
|
||||||
res = res.at[1, i_idxs, o_idxs].set(cons[:, 3])
|
res = res.at[1, i_idxs, o_idxs].set(cons[:, 3])
|
||||||
|
|
||||||
|
# (2, N, N), (2, N, N), (2, N, N)
|
||||||
|
# res = jnp.where(res[1, :, :] == 0, jnp.nan, res)
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -16,9 +16,9 @@ class Pipeline:
|
|||||||
Neat algorithm pipeline.
|
Neat algorithm pipeline.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config, seed=42):
|
def __init__(self, config, function_factory, seed=42):
|
||||||
self.time_dict = {}
|
self.time_dict = {}
|
||||||
self.function_factory = FunctionFactory(config)
|
self.function_factory = function_factory
|
||||||
|
|
||||||
self.randkey = jax.random.PRNGKey(seed)
|
self.randkey = jax.random.PRNGKey(seed)
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
@@ -31,18 +31,21 @@ class Pipeline:
|
|||||||
self.pop_size = config.neat.population.pop_size
|
self.pop_size = config.neat.population.pop_size
|
||||||
|
|
||||||
self.species_controller = SpeciesController(config)
|
self.species_controller = SpeciesController(config)
|
||||||
self.initialize_func = self.function_factory.create_initialize()
|
self.initialize_func = self.function_factory.create_initialize(self.N, self.C)
|
||||||
self.pop_nodes, self.pop_cons, self.input_idx, self.output_idx = self.initialize_func()
|
self.pop_nodes, self.pop_cons, self.input_idx, self.output_idx = self.initialize_func()
|
||||||
|
|
||||||
self.create_and_speciate = self.function_factory.create_update_speciate(self.N, self.C, self.S)
|
self.create_and_speciate = self.function_factory.create_update_speciate(self.N, self.C, self.S)
|
||||||
|
|
||||||
self.generation = 0
|
self.generation = 0
|
||||||
|
self.generation_time_list = []
|
||||||
self.species_controller.init_speciate(self.pop_nodes, self.pop_cons)
|
self.species_controller.init_speciate(self.pop_nodes, self.pop_cons)
|
||||||
|
|
||||||
self.best_fitness = float('-inf')
|
self.best_fitness = float('-inf')
|
||||||
self.best_genome = None
|
self.best_genome = None
|
||||||
self.generation_timestamp = time.time()
|
self.generation_timestamp = time.time()
|
||||||
|
|
||||||
|
self.evaluate_time = 0
|
||||||
|
|
||||||
def ask(self):
|
def ask(self):
|
||||||
"""
|
"""
|
||||||
Create a forward function for the population.
|
Create a forward function for the population.
|
||||||
@@ -66,7 +69,9 @@ class Pipeline:
|
|||||||
new_node_keys,
|
new_node_keys,
|
||||||
pre_spe_center_nodes, pre_spe_center_cons, pre_species_keys, new_species_key_start)
|
pre_spe_center_nodes, pre_spe_center_cons, pre_species_keys, new_species_key_start)
|
||||||
|
|
||||||
idx2specie, new_center_nodes, new_center_cons, new_species_keys = jax.device_get([idx2specie, new_center_nodes, new_center_cons, new_species_keys])
|
|
||||||
|
self.pop_nodes, self.pop_cons, idx2specie, new_center_nodes, new_center_cons, new_species_keys = \
|
||||||
|
jax.device_get([self.pop_nodes, self.pop_cons, idx2specie, new_center_nodes, new_center_cons, new_species_keys])
|
||||||
|
|
||||||
self.species_controller.tell(idx2specie, new_center_nodes, new_center_cons, new_species_keys, self.generation)
|
self.species_controller.tell(idx2specie, new_center_nodes, new_center_cons, new_species_keys, self.generation)
|
||||||
|
|
||||||
@@ -75,7 +80,12 @@ class Pipeline:
|
|||||||
def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"):
|
def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"):
|
||||||
for _ in range(self.config.neat.population.generation_limit):
|
for _ in range(self.config.neat.population.generation_limit):
|
||||||
forward_func = self.ask()
|
forward_func = self.ask()
|
||||||
|
|
||||||
|
tic = time.time()
|
||||||
fitnesses = fitness_func(forward_func)
|
fitnesses = fitness_func(forward_func)
|
||||||
|
self.evaluate_time += time.time() - tic
|
||||||
|
|
||||||
|
assert np.all(~np.isnan(fitnesses)), "fitnesses should not be nan!"
|
||||||
|
|
||||||
if analysis is not None:
|
if analysis is not None:
|
||||||
if analysis == "default":
|
if analysis == "default":
|
||||||
@@ -104,6 +114,7 @@ class Pipeline:
|
|||||||
max_node_size = np.max(pop_node_sizes)
|
max_node_size = np.max(pop_node_sizes)
|
||||||
if max_node_size >= self.N:
|
if max_node_size >= self.N:
|
||||||
self.N = int(self.N * self.expand_coe)
|
self.N = int(self.N * self.expand_coe)
|
||||||
|
# self.C = int(self.C * self.expand_coe)
|
||||||
print(f"node expand to {self.N}!")
|
print(f"node expand to {self.N}!")
|
||||||
self.pop_nodes, self.pop_cons = expand(self.pop_nodes, self.pop_cons, self.N, self.C)
|
self.pop_nodes, self.pop_cons = expand(self.pop_nodes, self.pop_cons, self.N, self.C)
|
||||||
|
|
||||||
@@ -116,6 +127,7 @@ class Pipeline:
|
|||||||
pop_node_sizes = np.sum(~np.isnan(pop_con_keys), axis=1)
|
pop_node_sizes = np.sum(~np.isnan(pop_con_keys), axis=1)
|
||||||
max_con_size = np.max(pop_node_sizes)
|
max_con_size = np.max(pop_node_sizes)
|
||||||
if max_con_size >= self.C:
|
if max_con_size >= self.C:
|
||||||
|
# self.N = int(self.N * self.expand_coe)
|
||||||
self.C = int(self.C * self.expand_coe)
|
self.C = int(self.C * self.expand_coe)
|
||||||
print(f"connections expand to {self.C}!")
|
print(f"connections expand to {self.C}!")
|
||||||
self.pop_nodes, self.pop_cons = expand(self.pop_nodes, self.pop_cons, self.N, self.C)
|
self.pop_nodes, self.pop_cons = expand(self.pop_nodes, self.pop_cons, self.N, self.C)
|
||||||
@@ -134,6 +146,7 @@ class Pipeline:
|
|||||||
|
|
||||||
new_timestamp = time.time()
|
new_timestamp = time.time()
|
||||||
cost_time = new_timestamp - self.generation_timestamp
|
cost_time = new_timestamp - self.generation_timestamp
|
||||||
|
self.generation_time_list.append(cost_time)
|
||||||
self.generation_timestamp = new_timestamp
|
self.generation_timestamp = new_timestamp
|
||||||
|
|
||||||
max_idx = np.argmax(fitnesses)
|
max_idx = np.argmax(fitnesses)
|
||||||
|
|||||||
44
examples/enhane_xor.py
Normal file
44
examples/enhane_xor.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
import numpy as np
|
||||||
|
import jax
|
||||||
|
from utils import Configer
|
||||||
|
from algorithms.neat import Pipeline
|
||||||
|
from time_utils import using_cprofile
|
||||||
|
from algorithms.neat.function_factory import FunctionFactory
|
||||||
|
from problems import EnhanceLogic
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(problem, func):
|
||||||
|
inputs = problem.ask_for_inputs()
|
||||||
|
pop_predict = jax.device_get(func(inputs))
|
||||||
|
# print(pop_predict)
|
||||||
|
fitnesses = []
|
||||||
|
for predict in pop_predict:
|
||||||
|
f = problem.evaluate_predict(predict)
|
||||||
|
fitnesses.append(f)
|
||||||
|
return np.array(fitnesses)
|
||||||
|
|
||||||
|
# @using_cprofile
|
||||||
|
# @partial(using_cprofile, root_abs_path='/mnt/e/neatax/', replace_pattern="/mnt/e/neat-jax/")
|
||||||
|
def main():
|
||||||
|
tic = time.time()
|
||||||
|
config = Configer.load_config()
|
||||||
|
problem = EnhanceLogic("xor", n=3)
|
||||||
|
problem.refactor_config(config)
|
||||||
|
function_factory = FunctionFactory(config)
|
||||||
|
evaluate_func = lambda func: evaluate(problem, func)
|
||||||
|
pipeline = Pipeline(config, function_factory, seed=33413)
|
||||||
|
print("start run")
|
||||||
|
pipeline.auto_run(evaluate_func)
|
||||||
|
|
||||||
|
total_time = time.time() - tic
|
||||||
|
compile_time = pipeline.function_factory.compile_time
|
||||||
|
total_it = pipeline.generation
|
||||||
|
mean_time_per_it = (total_time - compile_time) / total_it
|
||||||
|
evaluate_time = pipeline.evaluate_time
|
||||||
|
print(f"total time: {total_time:.2f}s, compile time: {compile_time:.2f}s, real_time: {total_time - compile_time:.2f}s, evaluate time: {evaluate_time:.2f}s")
|
||||||
|
print(f"total it: {total_it}, mean time per it: {mean_time_per_it:.2f}s")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
56
examples/final_design_experiement.py
Normal file
56
examples/final_design_experiement.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
import numpy as np
|
||||||
|
import jax
|
||||||
|
from utils import Configer
|
||||||
|
from algorithms.neat import Pipeline
|
||||||
|
from time_utils import using_cprofile
|
||||||
|
from algorithms.neat.function_factory import FunctionFactory
|
||||||
|
from problems import EnhanceLogic
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(problem, func):
|
||||||
|
outs = func(problem.inputs)
|
||||||
|
outs = jax.device_get(outs)
|
||||||
|
fitnesses = -np.mean((problem.outputs - outs) ** 2, axis=(1, 2))
|
||||||
|
return fitnesses
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
config = Configer.load_config()
|
||||||
|
problem = EnhanceLogic("xor", n=3)
|
||||||
|
problem.refactor_config(config)
|
||||||
|
function_factory = FunctionFactory(config)
|
||||||
|
evaluate_func = lambda func: evaluate(problem, func)
|
||||||
|
|
||||||
|
# precompile
|
||||||
|
pipeline = Pipeline(config, function_factory, seed=114514)
|
||||||
|
pipeline.auto_run(evaluate_func)
|
||||||
|
|
||||||
|
for r in range(10):
|
||||||
|
print(f"running: {r}/{10}")
|
||||||
|
tic = time.time()
|
||||||
|
|
||||||
|
pipeline = Pipeline(config, function_factory, seed=r)
|
||||||
|
pipeline.auto_run(evaluate_func)
|
||||||
|
|
||||||
|
total_time = time.time() - tic
|
||||||
|
evaluate_time = pipeline.evaluate_time
|
||||||
|
total_it = pipeline.generation
|
||||||
|
print(f"total time: {total_time:.2f}s, evaluate time: {evaluate_time:.2f}s, total_it: {total_it}")
|
||||||
|
|
||||||
|
if total_it >= 500:
|
||||||
|
res = "fail"
|
||||||
|
else:
|
||||||
|
res = "success"
|
||||||
|
|
||||||
|
with open("log", "wb") as f:
|
||||||
|
f.write(f"{res}, total time: {total_time:.2f}s, evaluate time: {evaluate_time:.2f}s, total_it: {total_it}\n".encode("utf-8"))
|
||||||
|
f.write(str(pipeline.generation_time_list).encode("utf-8"))
|
||||||
|
|
||||||
|
compile_time = function_factory.compile_time
|
||||||
|
|
||||||
|
print("total_compile_time:", compile_time)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
@@ -1,19 +1,27 @@
|
|||||||
from functools import partial
|
|
||||||
|
|
||||||
from utils import Configer
|
from utils import Configer
|
||||||
from algorithms.neat import Pipeline
|
from algorithms.neat import Pipeline
|
||||||
from time_utils import using_cprofile
|
from time_utils import using_cprofile
|
||||||
from problems import Sin, Xor, DIY
|
from problems import Sin, Xor, DIY
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
@using_cprofile
|
# @using_cprofile
|
||||||
# @partial(using_cprofile, root_abs_path='/mnt/e/neatax/', replace_pattern="/mnt/e/neat-jax/")
|
# @partial(using_cprofile, root_abs_path='/mnt/e/neatax/', replace_pattern="/mnt/e/neat-jax/")
|
||||||
def main():
|
def main():
|
||||||
|
tic = time.time()
|
||||||
config = Configer.load_config()
|
config = Configer.load_config()
|
||||||
problem = Xor()
|
problem = Xor()
|
||||||
problem.refactor_config(config)
|
problem.refactor_config(config)
|
||||||
pipeline = Pipeline(config, seed=1)
|
pipeline = Pipeline(config, seed=6)
|
||||||
pipeline.auto_run(problem.evaluate)
|
nodes, cons = pipeline.auto_run(problem.evaluate)
|
||||||
|
# print(nodes, cons)
|
||||||
|
total_time = time.time() - tic
|
||||||
|
compile_time = pipeline.function_factory.compile_time
|
||||||
|
total_it = pipeline.generation
|
||||||
|
mean_time_per_it = (total_time - compile_time) / total_it
|
||||||
|
evaluate_time = pipeline.evaluate_time
|
||||||
|
print(f"total time: {total_time:.2f}s, compile time: {compile_time:.2f}s, real_time: {total_time - compile_time:.2f}s, evaluate time: {evaluate_time:.2f}s")
|
||||||
|
print(f"total it: {total_it}, mean time per it: {mean_time_per_it:.2f}s")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from .function_fitting_problem import FunctionFittingProblem
|
from .function_fitting_problem import FunctionFittingProblem
|
||||||
from .xor import *
|
from .xor import *
|
||||||
from .sin import *
|
from .sin import *
|
||||||
from .diy import *
|
from .diy import *
|
||||||
|
from .enhance_logic import *
|
||||||
54
problems/function_fitting/enhance_logic.py
Normal file
54
problems/function_fitting/enhance_logic.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
"""
|
||||||
|
xor problem in multiple dimensions
|
||||||
|
"""
|
||||||
|
|
||||||
|
from itertools import product
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class EnhanceLogic:
|
||||||
|
def __init__(self, name="xor", n=2):
|
||||||
|
self.name = name
|
||||||
|
self.n = n
|
||||||
|
self.num_inputs = n
|
||||||
|
self.num_outputs = 1
|
||||||
|
self.batch = 2 ** n
|
||||||
|
self.forward_way = 'pop_batch'
|
||||||
|
|
||||||
|
self.inputs = np.array(generate_permutations(n), dtype=np.float32)
|
||||||
|
|
||||||
|
if self.name == "xor":
|
||||||
|
self.outputs = np.sum(self.inputs, axis=1) % 2
|
||||||
|
elif self.name == "and":
|
||||||
|
self.outputs = np.all(self.inputs==1, axis=1)
|
||||||
|
elif self.name == "or":
|
||||||
|
self.outputs = np.any(self.inputs==1, axis=1)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Only support xor, and, or")
|
||||||
|
self.outputs = self.outputs[:, np.newaxis]
|
||||||
|
|
||||||
|
|
||||||
|
def refactor_config(self, config):
|
||||||
|
config.basic.forward_way = self.forward_way
|
||||||
|
config.basic.num_inputs = self.num_inputs
|
||||||
|
config.basic.num_outputs = self.num_outputs
|
||||||
|
config.basic.problem_batch = self.batch
|
||||||
|
|
||||||
|
|
||||||
|
def ask_for_inputs(self):
|
||||||
|
return self.inputs
|
||||||
|
|
||||||
|
def evaluate_predict(self, predict):
|
||||||
|
# print((predict - self.outputs) ** 2)
|
||||||
|
return -np.mean((predict - self.outputs) ** 2)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def generate_permutations(n):
|
||||||
|
permutations = [list(i) for i in product([0, 1], repeat=n)]
|
||||||
|
|
||||||
|
return permutations
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
_ = EnhanceLogic(4)
|
||||||
@@ -13,9 +13,9 @@
|
|||||||
"neat": {
|
"neat": {
|
||||||
"population": {
|
"population": {
|
||||||
"fitness_criterion": "max",
|
"fitness_criterion": "max",
|
||||||
"fitness_threshold": -0.001,
|
"fitness_threshold": -1e-2,
|
||||||
"generation_limit": 1000,
|
"generation_limit": 500,
|
||||||
"pop_size": 1000,
|
"pop_size": 5000,
|
||||||
"reset_on_extinction": "False"
|
"reset_on_extinction": "False"
|
||||||
},
|
},
|
||||||
"gene": {
|
"gene": {
|
||||||
@@ -35,7 +35,7 @@
|
|||||||
},
|
},
|
||||||
"activation": {
|
"activation": {
|
||||||
"default": "sigmoid",
|
"default": "sigmoid",
|
||||||
"options": "sigmoid",
|
"options": ["sigmoid"],
|
||||||
"mutate_rate": 0.1
|
"mutate_rate": 0.1
|
||||||
},
|
},
|
||||||
"aggregation": {
|
"aggregation": {
|
||||||
@@ -58,13 +58,13 @@
|
|||||||
"compatibility_disjoint_coefficient": 1.0,
|
"compatibility_disjoint_coefficient": 1.0,
|
||||||
"compatibility_weight_coefficient": 0.5,
|
"compatibility_weight_coefficient": 0.5,
|
||||||
"single_structural_mutation": "False",
|
"single_structural_mutation": "False",
|
||||||
"conn_add_prob": 0.5,
|
"conn_add_prob": 0.6,
|
||||||
"conn_delete_prob": 0,
|
"conn_delete_prob": 0,
|
||||||
"node_add_prob": 0.2,
|
"node_add_prob": 0.3,
|
||||||
"node_delete_prob": 0
|
"node_delete_prob": 0
|
||||||
},
|
},
|
||||||
"species": {
|
"species": {
|
||||||
"compatibility_threshold": 3,
|
"compatibility_threshold": 2.5,
|
||||||
"species_fitness_func": "max",
|
"species_fitness_func": "max",
|
||||||
"max_stagnation": 20,
|
"max_stagnation": 20,
|
||||||
"species_elitism": 2,
|
"species_elitism": 2,
|
||||||
|
|||||||
Reference in New Issue
Block a user