The whole NEAT algorithm is written into functional programming.

This commit is contained in:
wls2002
2023-06-29 09:28:49 +08:00
parent 114ff2b0cc
commit d28cef1a87
16 changed files with 371 additions and 1102 deletions

View File

@@ -5,7 +5,7 @@ from jax import jit, vmap
from .utils import I_INT
def create_forward(config):
def create_forward_function(config):
"""
meta method to create forward function
"""
@@ -83,4 +83,22 @@ def create_forward(config):
return vals[output_idx]
# (batch_size, inputs_nums) -> (batch_size, outputs_nums)
batch_forward = vmap(forward, in_axes=(0, None, None, None))
# (pop_size, batch_size, inputs_nums) -> (pop_size, batch_size, outputs_nums)
pop_batch_forward = vmap(batch_forward, in_axes=(0, 0, 0, 0))
# (batch_size, inputs_nums) -> (pop_size, batch_size, outputs_nums)
common_forward = vmap(batch_forward, in_axes=(None, 0, 0, 0))
if config['forward_way'] == 'single':
return jit(batch_forward)
elif config['forward_way'] == 'pop':
return jit(pop_batch_forward)
elif config['forward_way'] == 'common':
return jit(common_forward)
return forward