The whole NEAT algorithm is written into functional programming.
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
from .mutate import mutate
|
||||
from .distance import distance
|
||||
from .crossover import crossover
|
||||
from .forward import create_forward
|
||||
from .graph import topological_sort, check_cycles
|
||||
from .utils import unflatten_connections
|
||||
from .genome import initialize_genomes, expand, expand_single
|
||||
from .forward import create_forward_function
|
||||
|
||||
@@ -5,7 +5,7 @@ from jax import jit, vmap
|
||||
from .utils import I_INT
|
||||
|
||||
|
||||
def create_forward(config):
|
||||
def create_forward_function(config):
|
||||
"""
|
||||
meta method to create forward function
|
||||
"""
|
||||
@@ -83,4 +83,22 @@ def create_forward(config):
|
||||
|
||||
return vals[output_idx]
|
||||
|
||||
# (batch_size, inputs_nums) -> (batch_size, outputs_nums)
|
||||
batch_forward = vmap(forward, in_axes=(0, None, None, None))
|
||||
|
||||
# (pop_size, batch_size, inputs_nums) -> (pop_size, batch_size, outputs_nums)
|
||||
pop_batch_forward = vmap(batch_forward, in_axes=(0, 0, 0, 0))
|
||||
|
||||
# (batch_size, inputs_nums) -> (pop_size, batch_size, outputs_nums)
|
||||
common_forward = vmap(batch_forward, in_axes=(None, 0, 0, 0))
|
||||
|
||||
if config['forward_way'] == 'single':
|
||||
return jit(batch_forward)
|
||||
|
||||
elif config['forward_way'] == 'pop':
|
||||
return jit(pop_batch_forward)
|
||||
|
||||
elif config['forward_way'] == 'common':
|
||||
return jit(common_forward)
|
||||
|
||||
return forward
|
||||
|
||||
@@ -65,55 +65,6 @@ def initialize_genomes(N: int, C: int, config: Dict) -> Tuple[NDArray, NDArray]:
|
||||
return pop_nodes, pop_cons
|
||||
|
||||
|
||||
def expand_single(nodes: NDArray, cons: NDArray, new_N: int, new_C: int) -> Tuple[NDArray, NDArray]:
|
||||
"""
|
||||
Expand a single genome to accommodate more nodes or connections.
|
||||
:param nodes: (N, 5)
|
||||
:param cons: (C, 4)
|
||||
:param new_N:
|
||||
:param new_C:
|
||||
:return: (new_N, 5), (new_C, 4)
|
||||
"""
|
||||
old_N, old_C = nodes.shape[0], cons.shape[0]
|
||||
new_nodes = np.full((new_N, 5), np.nan)
|
||||
new_nodes[:old_N, :] = nodes
|
||||
|
||||
new_cons = np.full((new_C, 4), np.nan)
|
||||
new_cons[:old_C, :] = cons
|
||||
|
||||
return new_nodes, new_cons
|
||||
|
||||
|
||||
def expand(pop_nodes: NDArray, pop_cons: NDArray, new_N: int, new_C: int) -> Tuple[NDArray, NDArray]:
|
||||
"""
|
||||
Expand the population to accommodate more nodes or connections.
|
||||
:param pop_nodes: (pop_size, N, 5)
|
||||
:param pop_cons: (pop_size, C, 4)
|
||||
:param new_N:
|
||||
:param new_C:
|
||||
:return: (pop_size, new_N, 5), (pop_size, new_C, 4)
|
||||
"""
|
||||
pop_size, old_N, old_C = pop_nodes.shape[0], pop_nodes.shape[1], pop_cons.shape[1]
|
||||
|
||||
new_pop_nodes = np.full((pop_size, new_N, 5), np.nan)
|
||||
new_pop_nodes[:, :old_N, :] = pop_nodes
|
||||
|
||||
new_pop_cons = np.full((pop_size, new_C, 4), np.nan)
|
||||
new_pop_cons[:, :old_C, :] = pop_cons
|
||||
|
||||
return new_pop_nodes, new_pop_cons
|
||||
|
||||
|
||||
@jit
|
||||
def count(nodes: NDArray, cons: NDArray) -> Tuple[NDArray, NDArray]:
|
||||
"""
|
||||
Count how many nodes and connections are in the genome.
|
||||
"""
|
||||
node_cnt = jnp.sum(~jnp.isnan(nodes[:, 0]))
|
||||
cons_cnt = jnp.sum(~jnp.isnan(cons[:, 0]))
|
||||
return node_cnt, cons_cnt
|
||||
|
||||
|
||||
@jit
|
||||
def add_node(nodes: NDArray, cons: NDArray, new_key: int,
|
||||
bias: float = 0.0, response: float = 1.0, act: int = 0, agg: int = 0) -> Tuple[NDArray, NDArray]:
|
||||
|
||||
@@ -59,12 +59,13 @@ def fetch_random(rand_key, mask, default=I_INT) -> Array:
|
||||
target = jax.random.randint(rand_key, shape=(), minval=1, maxval=true_cnt + 1)
|
||||
mask = jnp.where(true_cnt == 0, False, cumsum >= target)
|
||||
return fetch_first(mask, default)
|
||||
|
||||
@partial(jit, static_argnames=['reverse'])
|
||||
def rank_elements(array, reverse=False):
|
||||
"""
|
||||
rank the element in the array.
|
||||
if reverse is True, the rank is from large to small.
|
||||
if reverse is True, the rank is from small to large. default large to small
|
||||
"""
|
||||
if reverse:
|
||||
if not reverse:
|
||||
array = -array
|
||||
return jnp.argsort(jnp.argsort(array))
|
||||
Reference in New Issue
Block a user