Files
tensorneat-mend/algorithm/hyperneat/hyperneat_gene.py
2023-07-21 15:03:12 +08:00

55 lines
1.7 KiB
Python

import jax
from jax import numpy as jnp, vmap
from algorithm.neat import BaseGene
from algorithm.neat.gene import Activation
from algorithm.neat.gene import Aggregation
class HyperNEATGene(BaseGene):
node_attrs = [] # no node attributes
conn_attrs = ['weight']
@staticmethod
def forward_transform(state, nodes, conns):
N = nodes.shape[0]
u_conns = jnp.zeros((N, N), dtype=jnp.float32)
in_keys = jnp.asarray(conns[:, 0], jnp.int32)
out_keys = jnp.asarray(conns[:, 1], jnp.int32)
weights = conns[:, 2]
u_conns = u_conns.at[in_keys, out_keys].set(weights)
return nodes, u_conns
@staticmethod
def create_forward(config):
act = Activation.name2func[config['h_activation']]
agg = Aggregation.name2func[config['h_aggregation']]
batch_act, batch_agg = vmap(act), vmap(agg)
def forward(inputs, transform):
inputs_with_bias = jnp.concatenate((inputs, jnp.ones((1,))), axis=0)
nodes, weights = transform
input_idx = config['h_input_idx']
output_idx = config['h_output_idx']
N = nodes.shape[0]
vals = jnp.full((N,), 0.)
def body_func(i, values):
values = values.at[input_idx].set(inputs_with_bias)
nodes_ins = values * weights.T
values = batch_agg(nodes_ins) # z = agg(ins)
values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias
values = batch_act(values) # z = act(z)
return values
vals = jax.lax.fori_loop(0, config['h_activate_times'], body_func, vals)
return vals[output_idx]
return forward