Files
tensorneat-mend/tensorneat/algorithm/neat/genome/hidden.py

71 lines
2.5 KiB
Python

import jax, jax.numpy as jnp
from .default import DefaultGenome
class HiddenInitialize(DefaultGenome):
def __init__(self, hidden_cnt=8, *args, **kwargs):
super().__init__(*args, **kwargs)
self.hidden_cnt = hidden_cnt
def initialize(self, state, randkey):
k1, k2 = jax.random.split(randkey, num=2)
input_idx, output_idx = self.input_idx, self.output_idx
input_size = len(input_idx)
output_size = len(output_idx)
hidden_idx = jnp.arange(
input_size + output_size, input_size + output_size + self.hidden_cnt
)
nodes = jnp.full(
(self.max_nodes, self.node_gene.length), jnp.nan, dtype=jnp.float32
)
nodes = nodes.at[input_idx, 0].set(input_idx)
nodes = nodes.at[output_idx, 0].set(output_idx)
nodes = nodes.at[hidden_idx, 0].set(hidden_idx)
total_idx = input_size + output_size + self.hidden_cnt
rand_keys_n = jax.random.split(k1, num=total_idx)
node_attr_func = jax.vmap(self.node_gene.new_random_attrs, in_axes=(None, 0))
node_attrs = node_attr_func(state, rand_keys_n)
nodes = nodes.at[:total_idx, 1:].set(node_attrs)
conns = jnp.full(
(self.max_conns, self.conn_gene.length), jnp.nan, dtype=jnp.float32
)
input_to_hidden_ids, hidden_ids = jnp.meshgrid(
input_idx, hidden_idx, indexing="ij"
)
total_input_to_hidden_conns = input_size * self.hidden_cnt
conns = conns.at[:total_input_to_hidden_conns, :2].set(
jnp.column_stack([input_to_hidden_ids.flatten(), hidden_ids.flatten()])
)
hidden_to_output_ids, output_ids = jnp.meshgrid(
hidden_idx, output_idx, indexing="ij"
)
total_hidden_to_output_conns = self.hidden_cnt * output_size
conns = conns.at[
total_input_to_hidden_conns : total_input_to_hidden_conns
+ total_hidden_to_output_conns,
:2,
].set(jnp.column_stack([hidden_to_output_ids.flatten(), output_ids.flatten()]))
total_conns = total_input_to_hidden_conns + total_hidden_to_output_conns
rand_keys_c = jax.random.split(k2, num=total_conns)
conns_attr_func = jax.vmap(
self.conn_gene.new_random_attrs,
in_axes=(
None,
0,
),
)
conns_attrs = conns_attr_func(state, rand_keys_c)
conns = conns.at[:total_conns, 2:].set(conns_attrs)
return nodes, conns