add function to put **all** compilation at the beginning of the execution.

This commit is contained in:
wls2002
2023-05-09 02:55:47 +08:00
parent 1f2327bbd6
commit 0fdc856f2d
6 changed files with 183 additions and 75 deletions

View File

@@ -7,6 +7,7 @@ import numpy as np
from jax import jit, vmap
from .genome import act_name2key, agg_name2key, initialize_genomes, mutate, distance, crossover
from .genome import topological_sort, forward_single
class FunctionFactory:
@@ -24,6 +25,8 @@ class FunctionFactory:
pass
def load_config_vals(self, config):
self.problem_batch = config.basic.problem_batch
self.pop_size = config.neat.population.pop_size
self.init_N = config.basic.init_maximum_nodes
@@ -98,12 +101,17 @@ class FunctionFactory:
self.create_mutate_with_args()
self.create_distance_with_args()
self.create_crossover_with_args()
self.create_topological_sort_with_args()
self.create_single_forward_with_args()
n = self.init_N
print("start precompile")
for _ in range(self.precompile_times):
self.compile_mutate(n)
self.compile_distance(n)
self.compile_crossover(n)
self.compile_topological_sort(n)
self.compile_pop_batch_forward(n)
n = int(self.expand_coe * n)
print("end precompile")
@@ -209,3 +217,99 @@ class FunctionFactory:
if key not in self.compiled_function:
self.compile_crossover(n)
return self.compiled_function[key]
def create_topological_sort_with_args(self):
self.topological_sort_with_args = topological_sort
def compile_topological_sort(self, n):
func = self.topological_sort_with_args
func = vmap(func)
nodes_lower = np.zeros((self.pop_size, n, 5))
connections_lower = np.zeros((self.pop_size, 2, n, n))
func = jit(func).lower(nodes_lower, connections_lower).compile()
self.compiled_function[('topological_sort', n)] = func
def create_topological_sort(self, n):
key = ('topological_sort', n)
if key not in self.compiled_function:
self.compile_topological_sort(n)
return self.compiled_function[key]
def create_single_forward_with_args(self):
func = partial(
forward_single,
input_idx=self.input_idx,
output_idx=self.output_idx
)
self.single_forward_with_args = func
def compile_single_forward(self, n):
"""
single input for a genome
:param n:
:return:
"""
func = self.single_forward_with_args
inputs_lower = np.zeros((self.num_inputs,))
cal_seqs_lower = np.zeros((n,), dtype=np.int32)
nodes_lower = np.zeros((n, 5))
connections_lower = np.zeros((2, n, n))
func = jit(func).lower(inputs_lower, cal_seqs_lower, nodes_lower, connections_lower).compile()
self.compiled_function[('single_forward', n)] = func
def compile_pop_forward(self, n):
func = self.single_forward_with_args
func = vmap(func, in_axes=(None, 0, 0, 0))
inputs_lower = np.zeros((self.num_inputs,))
cal_seqs_lower = np.zeros((self.pop_size, n), dtype=np.int32)
nodes_lower = np.zeros((self.pop_size, n, 5))
connections_lower = np.zeros((self.pop_size, 2, n, n))
func = jit(func).lower(inputs_lower, cal_seqs_lower, nodes_lower, connections_lower).compile()
self.compiled_function[('pop_forward', n)] = func
def compile_batch_forward(self, n):
func = self.single_forward_with_args
func = vmap(func, in_axes=(0, None, None, None))
inputs_lower = np.zeros((self.problem_batch, self.num_inputs))
cal_seqs_lower = np.zeros((n,), dtype=np.int32)
nodes_lower = np.zeros((n, 5))
connections_lower = np.zeros((2, n, n))
func = jit(func).lower(inputs_lower, cal_seqs_lower, nodes_lower, connections_lower).compile()
self.compiled_function[('batch_forward', n)] = func
def compile_pop_batch_forward(self, n):
func = self.single_forward_with_args
func = vmap(func, in_axes=(0, None, None, None)) # batch_forward
func = vmap(func, in_axes=(None, 0, 0, 0)) # pop_batch_forward
inputs_lower = np.zeros((self.problem_batch, self.num_inputs))
cal_seqs_lower = np.zeros((self.pop_size, n), dtype=np.int32)
nodes_lower = np.zeros((self.pop_size, n, 5))
connections_lower = np.zeros((self.pop_size, 2, n, n))
func = jit(func).lower(inputs_lower, cal_seqs_lower, nodes_lower, connections_lower).compile()
self.compiled_function[('pop_batch_forward', n)] = func
def create_pop_batch_forward(self, n):
key = ('pop_batch_forward', n)
if key not in self.compiled_function:
self.compile_pop_batch_forward(n)
return self.compiled_function[key]
def ask(self, pop_nodes, pop_connections):
n = pop_nodes.shape[1]
ts = self.create_topological_sort(n)
pop_cal_seqs = ts(pop_nodes, pop_connections)
forward_func = self.create_pop_batch_forward(n)
return lambda inputs: forward_func(inputs, pop_cal_seqs, pop_nodes, pop_connections)
# return partial(
# forward_func,
# cal_seqs=pop_cal_seqs,
# nodes=pop_nodes,
# connections=pop_connections
# )

View File

@@ -1,7 +1,8 @@
from .genome import expand, expand_single, pop_analysis, initialize_genomes
from .forward import create_forward_function
from .forward import create_forward_function, forward_single
from .activations import act_name2key
from .aggregations import agg_name2key
from .crossover import crossover
from .mutate import mutate
from .distance import distance
from .graph import topological_sort

View File

@@ -46,15 +46,14 @@ def create_forward_function(nodes: NDArray, connections: NDArray,
raise ValueError(f"nodes.ndim should be 2 or 3, but got {nodes.ndim}")
@partial(jit, static_argnames=['N'])
def forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array,
cal_seqs: Array, nodes: Array, connections: Array) -> Array:
@jit
def forward_single(inputs: Array, cal_seqs: Array, nodes: Array, connections: Array,
input_idx: Array, output_idx: Array) -> Array:
"""
jax forward for single input shaped (input_num, )
nodes, connections are single genome
:argument inputs: (input_num, )
:argument N: int
:argument input_idx: (input_num, )
:argument output_idx: (output_num, )
:argument cal_seqs: (N, )
@@ -63,6 +62,7 @@ def forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array,
:return (output_num, )
"""
N = nodes.shape[0]
ini_vals = jnp.full((N,), jnp.nan)
ini_vals = ini_vals.at[input_idx].set(inputs)
@@ -86,64 +86,64 @@ def forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array,
return vals[output_idx]
@partial(jit, static_argnames=['N'])
@partial(vmap, in_axes=(0, None, None, None, None, None, None))
def forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Array,
cal_seqs: Array, nodes: Array, connections: Array) -> Array:
"""
jax forward for batch_inputs shaped (batch_size, input_num)
nodes, connections are single genome
:argument batch_inputs: (batch_size, input_num)
:argument N: int
:argument input_idx: (input_num, )
:argument output_idx: (output_num, )
:argument cal_seqs: (N, )
:argument nodes: (N, 5)
:argument connections: (2, N, N)
:return (batch_size, output_num)
"""
return forward_single(batch_inputs, N, input_idx, output_idx, cal_seqs, nodes, connections)
@partial(jit, static_argnames=['N'])
@partial(vmap, in_axes=(None, None, None, None, 0, 0, 0))
def pop_forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array,
pop_cal_seqs: Array, pop_nodes: Array, pop_connections: Array) -> Array:
"""
jax forward for single input shaped (input_num, )
pop_nodes, pop_connections are population of genomes
:argument inputs: (input_num, )
:argument N: int
:argument input_idx: (input_num, )
:argument output_idx: (output_num, )
:argument pop_cal_seqs: (pop_size, N)
:argument pop_nodes: (pop_size, N, 5)
:argument pop_connections: (pop_size, 2, N, N)
:return (pop_size, output_num)
"""
return forward_single(inputs, N, input_idx, output_idx, pop_cal_seqs, pop_nodes, pop_connections)
@partial(jit, static_argnames=['N'])
@partial(vmap, in_axes=(None, None, None, None, 0, 0, 0))
def pop_forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Array,
pop_cal_seqs: Array, pop_nodes: Array, pop_connections: Array) -> Array:
"""
jax forward for batch input shaped (batch, input_num)
pop_nodes, pop_connections are population of genomes
:argument batch_inputs: (batch_size, input_num)
:argument N: int
:argument input_idx: (input_num, )
:argument output_idx: (output_num, )
:argument pop_cal_seqs: (pop_size, N)
:argument pop_nodes: (pop_size, N, 5)
:argument pop_connections: (pop_size, 2, N, N)
:return (pop_size, batch_size, output_num)
"""
return forward_batch(batch_inputs, N, input_idx, output_idx, pop_cal_seqs, pop_nodes, pop_connections)
# @partial(jit, static_argnames=['N'])
# @partial(vmap, in_axes=(0, None, None, None, None, None, None))
# def forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Array,
# cal_seqs: Array, nodes: Array, connections: Array) -> Array:
# """
# jax forward for batch_inputs shaped (batch_size, input_num)
# nodes, connections are single genome
#
# :argument batch_inputs: (batch_size, input_num)
# :argument N: int
# :argument input_idx: (input_num, )
# :argument output_idx: (output_num, )
# :argument cal_seqs: (N, )
# :argument nodes: (N, 5)
# :argument connections: (2, N, N)
#
# :return (batch_size, output_num)
# """
# return forward_single(batch_inputs, N, input_idx, output_idx, cal_seqs, nodes, connections)
#
#
# @partial(jit, static_argnames=['N'])
# @partial(vmap, in_axes=(None, None, None, None, 0, 0, 0))
# def pop_forward_single(inputs: Array, N: int, input_idx: Array, output_idx: Array,
# pop_cal_seqs: Array, pop_nodes: Array, pop_connections: Array) -> Array:
# """
# jax forward for single input shaped (input_num, )
# pop_nodes, pop_connections are population of genomes
#
# :argument inputs: (input_num, )
# :argument N: int
# :argument input_idx: (input_num, )
# :argument output_idx: (output_num, )
# :argument pop_cal_seqs: (pop_size, N)
# :argument pop_nodes: (pop_size, N, 5)
# :argument pop_connections: (pop_size, 2, N, N)
#
# :return (pop_size, output_num)
# """
# return forward_single(inputs, N, input_idx, output_idx, pop_cal_seqs, pop_nodes, pop_connections)
#
#
# @partial(jit, static_argnames=['N'])
# @partial(vmap, in_axes=(None, None, None, None, 0, 0, 0))
# def pop_forward_batch(batch_inputs: Array, N: int, input_idx: Array, output_idx: Array,
# pop_cal_seqs: Array, pop_nodes: Array, pop_connections: Array) -> Array:
# """
# jax forward for batch input shaped (batch, input_num)
# pop_nodes, pop_connections are population of genomes
#
# :argument batch_inputs: (batch_size, input_num)
# :argument N: int
# :argument input_idx: (input_num, )
# :argument output_idx: (output_num, )
# :argument pop_cal_seqs: (pop_size, N)
# :argument pop_nodes: (pop_size, N, 5)
# :argument pop_connections: (pop_size, 2, N, N)
#
# :return (pop_size, batch_size, output_num)
# """
# return forward_batch(batch_inputs, N, input_idx, output_idx, pop_cal_seqs, pop_nodes, pop_connections)

View File

@@ -37,16 +37,18 @@ class Pipeline:
self.best_fitness = float('-inf')
self.generation_timestamp = time.time()
def ask(self, batch: bool):
def ask(self):
"""
Create a forward function for the population.
:param batch:
:return:
Algorithm gives the population a forward function, then environment gives back the fitnesses.
"""
func = create_forward_function(self.pop_nodes, self.pop_connections, self.N, self.input_idx, self.output_idx,
batch=batch)
return func
return self.function_factory.ask(self.pop_nodes, self.pop_connections)
#
# func = create_forward_function(self.pop_nodes, self.pop_connections, self.N, self.input_idx, self.output_idx,
# batch=batch)
# return func
def tell(self, fitnesses):
@@ -65,7 +67,7 @@ class Pipeline:
def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"):
for _ in range(self.config.neat.population.generation_limit):
forward_func = self.ask(batch=True)
forward_func = self.ask()
fitnesses = fitness_func(forward_func)
if analysis is not None:

View File

@@ -8,7 +8,7 @@ from utils import Configer
from algorithms.neat import Pipeline
from time_utils import using_cprofile
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
xor_outputs = np.array([[0], [1], [1], [0]])

View File

@@ -2,9 +2,10 @@
"basic": {
"num_inputs": 2,
"num_outputs": 1,
"init_maximum_nodes": 10,
"problem_batch": 4,
"init_maximum_nodes": 20,
"expands_coe": 2,
"pre_compile_times": 3
"pre_compile_times": 0
},
"neat": {
"population": {