complete HyperNEAT!
This commit is contained in:
54
algorithm/hyperneat/hyperneat_gene.py
Normal file
54
algorithm/hyperneat/hyperneat_gene.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import jax
|
||||
from jax import numpy as jnp, vmap
|
||||
|
||||
from algorithm.neat import BaseGene
|
||||
from algorithm.neat.gene import Activation
|
||||
from algorithm.neat.gene import Aggregation
|
||||
|
||||
|
||||
class HyperNEATGene(BaseGene):
|
||||
node_attrs = [] # no node attributes
|
||||
conn_attrs = ['weight']
|
||||
|
||||
@staticmethod
|
||||
def forward_transform(state, nodes, conns):
|
||||
N = nodes.shape[0]
|
||||
u_conns = jnp.zeros((N, N), dtype=jnp.float32)
|
||||
|
||||
in_keys = jnp.asarray(conns[:, 0], jnp.int32)
|
||||
out_keys = jnp.asarray(conns[:, 1], jnp.int32)
|
||||
weights = conns[:, 2]
|
||||
|
||||
u_conns = u_conns.at[in_keys, out_keys].set(weights)
|
||||
return nodes, u_conns
|
||||
|
||||
@staticmethod
|
||||
def create_forward(config):
|
||||
act = Activation.name2func[config['h_activation']]
|
||||
agg = Aggregation.name2func[config['h_aggregation']]
|
||||
|
||||
batch_act, batch_agg = vmap(act), vmap(agg)
|
||||
|
||||
def forward(inputs, transform):
|
||||
|
||||
inputs_with_bias = jnp.concatenate((inputs, jnp.ones((1,))), axis=0)
|
||||
nodes, weights = transform
|
||||
|
||||
input_idx = config['h_input_idx']
|
||||
output_idx = config['h_output_idx']
|
||||
|
||||
N = nodes.shape[0]
|
||||
vals = jnp.full((N,), 0.)
|
||||
|
||||
def body_func(i, values):
|
||||
values = values.at[input_idx].set(inputs_with_bias)
|
||||
nodes_ins = values * weights.T
|
||||
values = batch_agg(nodes_ins) # z = agg(ins)
|
||||
values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias
|
||||
values = batch_act(values) # z = act(z)
|
||||
return values
|
||||
|
||||
vals = jax.lax.fori_loop(0, config['h_activate_times'], body_func, vals)
|
||||
return vals[output_idx]
|
||||
|
||||
return forward
|
||||
Reference in New Issue
Block a user