The whole NEAT algorithm is written into functional programming.

This commit is contained in:
wls2002
2023-06-29 09:28:49 +08:00
parent 114ff2b0cc
commit d28cef1a87
16 changed files with 371 additions and 1102 deletions

View File

@@ -1,7 +1,6 @@
from .mutate import mutate
from .distance import distance
from .crossover import crossover
from .forward import create_forward
from .graph import topological_sort, check_cycles
from .utils import unflatten_connections
from .genome import initialize_genomes, expand, expand_single
from .forward import create_forward_function

View File

@@ -5,7 +5,7 @@ from jax import jit, vmap
from .utils import I_INT
def create_forward(config):
def create_forward_function(config):
"""
meta method to create forward function
"""
@@ -83,4 +83,22 @@ def create_forward(config):
return vals[output_idx]
# (batch_size, inputs_nums) -> (batch_size, outputs_nums)
batch_forward = vmap(forward, in_axes=(0, None, None, None))
# (pop_size, batch_size, inputs_nums) -> (pop_size, batch_size, outputs_nums)
pop_batch_forward = vmap(batch_forward, in_axes=(0, 0, 0, 0))
# (batch_size, inputs_nums) -> (pop_size, batch_size, outputs_nums)
common_forward = vmap(batch_forward, in_axes=(None, 0, 0, 0))
if config['forward_way'] == 'single':
return jit(batch_forward)
elif config['forward_way'] == 'pop':
return jit(pop_batch_forward)
elif config['forward_way'] == 'common':
return jit(common_forward)
return forward

View File

@@ -65,55 +65,6 @@ def initialize_genomes(N: int, C: int, config: Dict) -> Tuple[NDArray, NDArray]:
return pop_nodes, pop_cons
def expand_single(nodes: NDArray, cons: NDArray, new_N: int, new_C: int) -> Tuple[NDArray, NDArray]:
"""
Expand a single genome to accommodate more nodes or connections.
:param nodes: (N, 5)
:param cons: (C, 4)
:param new_N:
:param new_C:
:return: (new_N, 5), (new_C, 4)
"""
old_N, old_C = nodes.shape[0], cons.shape[0]
new_nodes = np.full((new_N, 5), np.nan)
new_nodes[:old_N, :] = nodes
new_cons = np.full((new_C, 4), np.nan)
new_cons[:old_C, :] = cons
return new_nodes, new_cons
def expand(pop_nodes: NDArray, pop_cons: NDArray, new_N: int, new_C: int) -> Tuple[NDArray, NDArray]:
"""
Expand the population to accommodate more nodes or connections.
:param pop_nodes: (pop_size, N, 5)
:param pop_cons: (pop_size, C, 4)
:param new_N:
:param new_C:
:return: (pop_size, new_N, 5), (pop_size, new_C, 4)
"""
pop_size, old_N, old_C = pop_nodes.shape[0], pop_nodes.shape[1], pop_cons.shape[1]
new_pop_nodes = np.full((pop_size, new_N, 5), np.nan)
new_pop_nodes[:, :old_N, :] = pop_nodes
new_pop_cons = np.full((pop_size, new_C, 4), np.nan)
new_pop_cons[:, :old_C, :] = pop_cons
return new_pop_nodes, new_pop_cons
@jit
def count(nodes: NDArray, cons: NDArray) -> Tuple[NDArray, NDArray]:
"""
Count how many nodes and connections are in the genome.
"""
node_cnt = jnp.sum(~jnp.isnan(nodes[:, 0]))
cons_cnt = jnp.sum(~jnp.isnan(cons[:, 0]))
return node_cnt, cons_cnt
@jit
def add_node(nodes: NDArray, cons: NDArray, new_key: int,
bias: float = 0.0, response: float = 1.0, act: int = 0, agg: int = 0) -> Tuple[NDArray, NDArray]:

View File

@@ -59,12 +59,13 @@ def fetch_random(rand_key, mask, default=I_INT) -> Array:
target = jax.random.randint(rand_key, shape=(), minval=1, maxval=true_cnt + 1)
mask = jnp.where(true_cnt == 0, False, cumsum >= target)
return fetch_first(mask, default)
@partial(jit, static_argnames=['reverse'])
def rank_elements(array, reverse=False):
"""
rank the element in the array.
if reverse is True, the rank is from large to small.
if reverse is True, the rank is from small to large. default large to small
"""
if reverse:
if not reverse:
array = -array
return jnp.argsort(jnp.argsort(array))