The whole NEAT algorithm is written into functional programming.
This commit is contained in:
@@ -5,7 +5,7 @@ from jax import jit, vmap
|
||||
from .utils import I_INT
|
||||
|
||||
|
||||
def create_forward(config):
|
||||
def create_forward_function(config):
|
||||
"""
|
||||
meta method to create forward function
|
||||
"""
|
||||
@@ -83,4 +83,22 @@ def create_forward(config):
|
||||
|
||||
return vals[output_idx]
|
||||
|
||||
# (batch_size, inputs_nums) -> (batch_size, outputs_nums)
|
||||
batch_forward = vmap(forward, in_axes=(0, None, None, None))
|
||||
|
||||
# (pop_size, batch_size, inputs_nums) -> (pop_size, batch_size, outputs_nums)
|
||||
pop_batch_forward = vmap(batch_forward, in_axes=(0, 0, 0, 0))
|
||||
|
||||
# (batch_size, inputs_nums) -> (pop_size, batch_size, outputs_nums)
|
||||
common_forward = vmap(batch_forward, in_axes=(None, 0, 0, 0))
|
||||
|
||||
if config['forward_way'] == 'single':
|
||||
return jit(batch_forward)
|
||||
|
||||
elif config['forward_way'] == 'pop':
|
||||
return jit(pop_batch_forward)
|
||||
|
||||
elif config['forward_way'] == 'common':
|
||||
return jit(common_forward)
|
||||
|
||||
return forward
|
||||
|
||||
Reference in New Issue
Block a user