add function to put **all** compilation at the beginning of the execution.

This commit is contained in:
wls2002
2023-05-09 02:55:47 +08:00
parent 1f2327bbd6
commit 0fdc856f2d
6 changed files with 183 additions and 75 deletions

View File

@@ -7,6 +7,7 @@ import numpy as np
from jax import jit, vmap from jax import jit, vmap
from .genome import act_name2key, agg_name2key, initialize_genomes, mutate, distance, crossover from .genome import act_name2key, agg_name2key, initialize_genomes, mutate, distance, crossover
from .genome import topological_sort, forward_single
class FunctionFactory: class FunctionFactory:
@@ -24,6 +25,8 @@ class FunctionFactory:
pass pass
def load_config_vals(self, config): def load_config_vals(self, config):
self.problem_batch = config.basic.problem_batch
self.pop_size = config.neat.population.pop_size self.pop_size = config.neat.population.pop_size
self.init_N = config.basic.init_maximum_nodes self.init_N = config.basic.init_maximum_nodes
@@ -98,12 +101,17 @@ class FunctionFactory:
self.create_mutate_with_args() self.create_mutate_with_args()
self.create_distance_with_args() self.create_distance_with_args()
self.create_crossover_with_args() self.create_crossover_with_args()
self.create_topological_sort_with_args()
self.create_single_forward_with_args()
n = self.init_N n = self.init_N
print("start precompile") print("start precompile")
for _ in range(self.precompile_times): for _ in range(self.precompile_times):
self.compile_mutate(n) self.compile_mutate(n)
self.compile_distance(n) self.compile_distance(n)
self.compile_crossover(n) self.compile_crossover(n)
self.compile_topological_sort(n)
self.compile_pop_batch_forward(n)
n = int(self.expand_coe * n) n = int(self.expand_coe * n)
print("end precompile") print("end precompile")
@@ -209,3 +217,99 @@ class FunctionFactory:
if key not in self.compiled_function: if key not in self.compiled_function:
self.compile_crossover(n) self.compile_crossover(n)
return self.compiled_function[key] return self.compiled_function[key]
def create_topological_sort_with_args(self):
self.topological_sort_with_args = topological_sort
def compile_topological_sort(self, n):
func = self.topological_sort_with_args
func = vmap(func)
nodes_lower = np.zeros((self.pop_size, n, 5))
connections_lower = np.zeros((self.pop_size, 2, n, n))
func = jit(func).lower(nodes_lower, connections_lower).compile()
self.compiled_function[('topological_sort', n)] = func
def create_topological_sort(self, n):
key = ('topological_sort', n)
if key not in self.compiled_function:
self.compile_topological_sort(n)
return self.compiled_function[key]
def create_single_forward_with_args(self):
func = partial(
forward_single,
input_idx=self.input_idx,
output_idx=self.output_idx
)
self.single_forward_with_args = func
def compile_single_forward(self, n):
"""
single input for a genome
:param n:
:return:
"""
func = self.single_forward_with_args
inputs_lower = np.zeros((self.num_inputs,))
cal_seqs_lower = np.zeros((n,), dtype=np.int32)
nodes_lower = np.zeros((n, 5))
connections_lower = np.zeros((2, n, n))
func = jit(func).lower(inputs_lower, cal_seqs_lower, nodes_lower, connections_lower).compile()
self.compiled_function[('single_forward', n)] = func
def compile_pop_forward(self, n):
func = self.single_forward_with_args
func = vmap(func, in_axes=(None, 0, 0, 0))
inputs_lower = np.zeros((self.num_inputs,))
cal_seqs_lower = np.zeros((self.pop_size, n), dtype=np.int32)
nodes_lower = np.zeros((self.pop_size, n, 5))
connections_lower = np.zeros((self.pop_size, 2, n, n))
func = jit(func).lower(inputs_lower, cal_seqs_lower, nodes_lower, connections_lower).compile()
self.compiled_function[('pop_forward', n)] = func
def compile_batch_forward(self, n):
func = self.single_forward_with_args
func = vmap(func, in_axes=(0, None, None, None))
inputs_lower = np.zeros((self.problem_batch, self.num_inputs))
cal_seqs_lower = np.zeros((n,), dtype=np.int32)
nodes_lower = np.zeros((n, 5))
connections_lower = np.zeros((2, n, n))
func = jit(func).lower(inputs_lower, cal_seqs_lower, nodes_lower, connections_lower).compile()
self.compiled_function[('batch_forward', n)] = func
def compile_pop_batch_forward(self, n):
func = self.single_forward_with_args
func = vmap(func, in_axes=(0, None, None, None)) # batch_forward
func = vmap(func, in_axes=(None, 0, 0, 0)) # pop_batch_forward
inputs_lower = np.zeros((self.problem_batch, self.num_inputs))
cal_seqs_lower = np.zeros((self.pop_size, n), dtype=np.int32)
nodes_lower = np.zeros((self.pop_size, n, 5))
connections_lower = np.zeros((self.pop_size, 2, n, n))
func = jit(func).lower(inputs_lower, cal_seqs_lower, nodes_lower, connections_lower).compile()
self.compiled_function[('pop_batch_forward', n)] = func
def create_pop_batch_forward(self, n):
key = ('pop_batch_forward', n)
if key not in self.compiled_function:
self.compile_pop_batch_forward(n)
return self.compiled_function[key]
def ask(self, pop_nodes, pop_connections):
n = pop_nodes.shape[1]
ts = self.create_topological_sort(n)
pop_cal_seqs = ts(pop_nodes, pop_connections)
forward_func = self.create_pop_batch_forward(n)
return lambda inputs: forward_func(inputs, pop_cal_seqs, pop_nodes, pop_connections)
# return partial(
# forward_func,
# cal_seqs=pop_cal_seqs,
# nodes=pop_nodes,
# connections=pop_connections
# )

View File

@@ -1,7 +1,8 @@
from .genome import expand, expand_single, pop_analysis, initialize_genomes from .genome import expand, expand_single, pop_analysis, initialize_genomes
from .forward import create_forward_function from .forward import create_forward_function, forward_single
from .activations import act_name2key from .activations import act_name2key
from .aggregations import agg_name2key from .aggregations import agg_name2key
from .crossover import crossover from .crossover import crossover
from .mutate import mutate from .mutate import mutate
from .distance import distance from .distance import distance
from .graph import topological_sort

View File

@@ -46,15 +46,14 @@ def create_forward_function(nodes: NDArray, connections: NDArray,
raise ValueError(f"nodes.ndim should be 2 or 3, but got {nodes.ndim}") raise ValueError(f"nodes.ndim should be 2 or 3, but got {nodes.ndim}")
@partial(jit, static_argnames=['N']) @jit
def forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array, def forward_single(inputs: Array, cal_seqs: Array, nodes: Array, connections: Array,
cal_seqs: Array, nodes: Array, connections: Array) -> Array: input_idx: Array, output_idx: Array) -> Array:
""" """
jax forward for single input shaped (input_num, ) jax forward for single input shaped (input_num, )
nodes, connections are single genome nodes, connections are single genome
:argument inputs: (input_num, ) :argument inputs: (input_num, )
:argument N: int
:argument input_idx: (input_num, ) :argument input_idx: (input_num, )
:argument output_idx: (output_num, ) :argument output_idx: (output_num, )
:argument cal_seqs: (N, ) :argument cal_seqs: (N, )
@@ -63,6 +62,7 @@ def forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array,
:return (output_num, ) :return (output_num, )
""" """
N = nodes.shape[0]
ini_vals = jnp.full((N,), jnp.nan) ini_vals = jnp.full((N,), jnp.nan)
ini_vals = ini_vals.at[input_idx].set(inputs) ini_vals = ini_vals.at[input_idx].set(inputs)
@@ -86,64 +86,64 @@ def forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array,
return vals[output_idx] return vals[output_idx]
@partial(jit, static_argnames=['N']) # @partial(jit, static_argnames=['N'])
@partial(vmap, in_axes=(0, None, None, None, None, None, None)) # @partial(vmap, in_axes=(0, None, None, None, None, None, None))
def forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Array, # def forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Array,
cal_seqs: Array, nodes: Array, connections: Array) -> Array: # cal_seqs: Array, nodes: Array, connections: Array) -> Array:
""" # """
jax forward for batch_inputs shaped (batch_size, input_num) # jax forward for batch_inputs shaped (batch_size, input_num)
nodes, connections are single genome # nodes, connections are single genome
#
:argument batch_inputs: (batch_size, input_num) # :argument batch_inputs: (batch_size, input_num)
:argument N: int # :argument N: int
:argument input_idx: (input_num, ) # :argument input_idx: (input_num, )
:argument output_idx: (output_num, ) # :argument output_idx: (output_num, )
:argument cal_seqs: (N, ) # :argument cal_seqs: (N, )
:argument nodes: (N, 5) # :argument nodes: (N, 5)
:argument connections: (2, N, N) # :argument connections: (2, N, N)
#
:return (batch_size, output_num) # :return (batch_size, output_num)
""" # """
return forward_single(batch_inputs, N, input_idx, output_idx, cal_seqs, nodes, connections) # return forward_single(batch_inputs, N, input_idx, output_idx, cal_seqs, nodes, connections)
#
#
@partial(jit, static_argnames=['N']) # @partial(jit, static_argnames=['N'])
@partial(vmap, in_axes=(None, None, None, None, 0, 0, 0)) # @partial(vmap, in_axes=(None, None, None, None, 0, 0, 0))
def pop_forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array, # def pop_forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array,
pop_cal_seqs: Array, pop_nodes: Array, pop_connections: Array) -> Array: # pop_cal_seqs: Array, pop_nodes: Array, pop_connections: Array) -> Array:
""" # """
jax forward for single input shaped (input_num, ) # jax forward for single input shaped (input_num, )
pop_nodes, pop_connections are population of genomes # pop_nodes, pop_connections are population of genomes
#
:argument inputs: (input_num, ) # :argument inputs: (input_num, )
:argument N: int # :argument N: int
:argument input_idx: (input_num, ) # :argument input_idx: (input_num, )
:argument output_idx: (output_num, ) # :argument output_idx: (output_num, )
:argument pop_cal_seqs: (pop_size, N) # :argument pop_cal_seqs: (pop_size, N)
:argument pop_nodes: (pop_size, N, 5) # :argument pop_nodes: (pop_size, N, 5)
:argument pop_connections: (pop_size, 2, N, N) # :argument pop_connections: (pop_size, 2, N, N)
#
:return (pop_size, output_num) # :return (pop_size, output_num)
""" # """
return forward_single(inputs, N, input_idx, output_idx, pop_cal_seqs, pop_nodes, pop_connections) # return forward_single(inputs, N, input_idx, output_idx, pop_cal_seqs, pop_nodes, pop_connections)
#
#
@partial(jit, static_argnames=['N']) # @partial(jit, static_argnames=['N'])
@partial(vmap, in_axes=(None, None, None, None, 0, 0, 0)) # @partial(vmap, in_axes=(None, None, None, None, 0, 0, 0))
def pop_forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Array, # def pop_forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Array,
pop_cal_seqs: Array, pop_nodes: Array, pop_connections: Array) -> Array: # pop_cal_seqs: Array, pop_nodes: Array, pop_connections: Array) -> Array:
""" # """
jax forward for batch input shaped (batch, input_num) # jax forward for batch input shaped (batch, input_num)
pop_nodes, pop_connections are population of genomes # pop_nodes, pop_connections are population of genomes
#
:argument batch_inputs: (batch_size, input_num) # :argument batch_inputs: (batch_size, input_num)
:argument N: int # :argument N: int
:argument input_idx: (input_num, ) # :argument input_idx: (input_num, )
:argument output_idx: (output_num, ) # :argument output_idx: (output_num, )
:argument pop_cal_seqs: (pop_size, N) # :argument pop_cal_seqs: (pop_size, N)
:argument pop_nodes: (pop_size, N, 5) # :argument pop_nodes: (pop_size, N, 5)
:argument pop_connections: (pop_size, 2, N, N) # :argument pop_connections: (pop_size, 2, N, N)
#
:return (pop_size, batch_size, output_num) # :return (pop_size, batch_size, output_num)
""" # """
return forward_batch(batch_inputs, N, input_idx, output_idx, pop_cal_seqs, pop_nodes, pop_connections) # return forward_batch(batch_inputs, N, input_idx, output_idx, pop_cal_seqs, pop_nodes, pop_connections)

View File

@@ -37,16 +37,18 @@ class Pipeline:
self.best_fitness = float('-inf') self.best_fitness = float('-inf')
self.generation_timestamp = time.time() self.generation_timestamp = time.time()
def ask(self, batch: bool): def ask(self):
""" """
Create a forward function for the population. Create a forward function for the population.
:param batch:
:return: :return:
Algorithm gives the population a forward function, then environment gives back the fitnesses. Algorithm gives the population a forward function, then environment gives back the fitnesses.
""" """
func = create_forward_function(self.pop_nodes, self.pop_connections, self.N, self.input_idx, self.output_idx, return self.function_factory.ask(self.pop_nodes, self.pop_connections)
batch=batch)
return func #
# func = create_forward_function(self.pop_nodes, self.pop_connections, self.N, self.input_idx, self.output_idx,
# batch=batch)
# return func
def tell(self, fitnesses): def tell(self, fitnesses):
@@ -65,7 +67,7 @@ class Pipeline:
def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"): def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"):
for _ in range(self.config.neat.population.generation_limit): for _ in range(self.config.neat.population.generation_limit):
forward_func = self.ask(batch=True) forward_func = self.ask()
fitnesses = fitness_func(forward_func) fitnesses = fitness_func(forward_func)
if analysis is not None: if analysis is not None:

View File

@@ -8,7 +8,7 @@ from utils import Configer
from algorithms.neat import Pipeline from algorithms.neat import Pipeline
from time_utils import using_cprofile from time_utils import using_cprofile
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]]) xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
xor_outputs = np.array([[0], [1], [1], [0]]) xor_outputs = np.array([[0], [1], [1], [0]])

View File

@@ -2,9 +2,10 @@
"basic": { "basic": {
"num_inputs": 2, "num_inputs": 2,
"num_outputs": 1, "num_outputs": 1,
"init_maximum_nodes": 10, "problem_batch": 4,
"init_maximum_nodes": 20,
"expands_coe": 2, "expands_coe": 2,
"pre_compile_times": 3 "pre_compile_times": 0
}, },
"neat": { "neat": {
"population": { "population": {