add input_transform and update_input_transform;
change the args for genome.forward. Origin: (state, inputs, transformed) New: (state, transformed, inputs)
This commit is contained in:
@@ -22,7 +22,7 @@ class BaseAlgorithm:
|
||||
def restore(self, state, transformed):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, state, inputs, transformed):
|
||||
def forward(self, state, transformed, inputs):
|
||||
raise NotImplementedError
|
||||
|
||||
def update_by_batch(self, state, batch_input, transformed):
|
||||
|
||||
@@ -54,8 +54,8 @@ class HyperNEAT(BaseAlgorithm):
|
||||
|
||||
def transform(self, state, individual):
|
||||
transformed = self.neat.transform(state, individual)
|
||||
query_res = jax.vmap(self.neat.forward, in_axes=(None, 0, None))(
|
||||
state, self.substrate.query_coors, transformed
|
||||
query_res = jax.vmap(self.neat.forward, in_axes=(None, None, 0))(
|
||||
state, transformed, self.substrate.query_coors
|
||||
)
|
||||
# mute the connection with weight below threshold
|
||||
query_res = jnp.where(
|
||||
|
||||
@@ -163,8 +163,8 @@ class DefaultMutation(BaseMutation):
|
||||
)
|
||||
|
||||
if genome.network_type == "feedforward":
|
||||
u_cons = unflatten_conns(nodes_, conns_)
|
||||
conns_exist = ~jnp.isnan(u_cons[0, :, :])
|
||||
u_conns = unflatten_conns(nodes_, conns_)
|
||||
conns_exist = (u_conns != I_INF)
|
||||
is_cycle = check_cycles(nodes_, conns_exist, from_idx, to_idx)
|
||||
|
||||
return jax.lax.cond(
|
||||
|
||||
@@ -12,6 +12,13 @@ class BaseNodeGene(BaseGene):
|
||||
def forward(self, state, attrs, inputs, is_output_node=False):
|
||||
raise NotImplementedError
|
||||
|
||||
def input_transform(self, state, attrs, inputs):
|
||||
"""
|
||||
make transformation in the input node.
|
||||
default: do nothing
|
||||
"""
|
||||
return inputs
|
||||
|
||||
def update_by_batch(self, state, attrs, batch_inputs, is_output_node=False):
|
||||
# default: do not update attrs, but to calculate batch_res
|
||||
return (
|
||||
@@ -20,3 +27,15 @@ class BaseNodeGene(BaseGene):
|
||||
),
|
||||
attrs,
|
||||
)
|
||||
|
||||
def update_input_transform(self, state, attrs, batch_inputs):
|
||||
"""
|
||||
update the attrs for transformation in the input node.
|
||||
default: do nothing
|
||||
"""
|
||||
return (
|
||||
jax.vmap(self.input_transform, in_axes=(None, None, 0))(
|
||||
state, attrs, batch_inputs
|
||||
),
|
||||
attrs,
|
||||
)
|
||||
|
||||
@@ -157,6 +157,15 @@ class NormalizedNode(BaseNodeGene):
|
||||
|
||||
return z
|
||||
|
||||
def input_transform(self, state, attrs, inputs):
|
||||
"""
|
||||
make transform in the input node.
|
||||
the normalization also need be done in the first node.
|
||||
"""
|
||||
bias, agg, act, mean, std, alpha, beta = attrs
|
||||
inputs = (inputs - mean) / (std + self.eps) * alpha + beta # normalization
|
||||
return inputs
|
||||
|
||||
def update_by_batch(self, state, attrs, batch_inputs, is_output_node=False):
|
||||
|
||||
bias, agg, act, mean, std, alpha, beta = attrs
|
||||
@@ -192,3 +201,31 @@ class NormalizedNode(BaseNodeGene):
|
||||
attrs = attrs.at[4].set(std)
|
||||
|
||||
return batch_z, attrs
|
||||
|
||||
def update_input_transform(self, state, attrs, batch_inputs):
|
||||
"""
|
||||
update the attrs for transformation in the input node.
|
||||
default: do nothing
|
||||
"""
|
||||
bias, agg, act, mean, std, alpha, beta = attrs
|
||||
|
||||
# calculate mean
|
||||
valid_values_count = jnp.sum(~jnp.isnan(batch_inputs))
|
||||
valid_values_sum = jnp.sum(jnp.where(jnp.isnan(batch_inputs), 0, batch_inputs))
|
||||
mean = valid_values_sum / valid_values_count
|
||||
|
||||
# calculate std
|
||||
std = jnp.sqrt(
|
||||
jnp.sum(jnp.where(jnp.isnan(batch_inputs), 0, (batch_inputs - mean) ** 2))
|
||||
/ valid_values_count
|
||||
)
|
||||
|
||||
batch_inputs = (batch_inputs - mean) / (
|
||||
std + self.eps
|
||||
) * alpha + beta # normalization
|
||||
|
||||
# update mean and std to the attrs
|
||||
attrs = attrs.at[3].set(mean)
|
||||
attrs = attrs.at[4].set(std)
|
||||
|
||||
return batch_inputs, attrs
|
||||
|
||||
@@ -42,7 +42,7 @@ class BaseGenome:
|
||||
def restore(self, state, transformed):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, state, inputs, transformed):
|
||||
def forward(self, state, transformed, inputs):
|
||||
raise NotImplementedError
|
||||
|
||||
def execute_mutation(self, state, randkey, nodes, conns, new_node_key):
|
||||
|
||||
@@ -1,7 +1,16 @@
|
||||
from typing import Callable
|
||||
|
||||
import jax, jax.numpy as jnp
|
||||
from utils import unflatten_conns, flatten_conns, topological_sort, I_INF
|
||||
from utils import (
|
||||
unflatten_conns,
|
||||
topological_sort,
|
||||
I_INF,
|
||||
extract_node_attrs,
|
||||
extract_conn_attrs,
|
||||
set_node_attrs,
|
||||
set_conn_attrs,
|
||||
attach_with_inf,
|
||||
)
|
||||
|
||||
from . import BaseGenome
|
||||
from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene
|
||||
@@ -45,23 +54,23 @@ class DefaultGenome(BaseGenome):
|
||||
|
||||
def transform(self, state, nodes, conns):
|
||||
u_conns = unflatten_conns(nodes, conns)
|
||||
conn_exist = ~jnp.isnan(u_conns[0])
|
||||
conn_exist = u_conns != I_INF
|
||||
|
||||
seqs = topological_sort(nodes, conn_exist)
|
||||
|
||||
return seqs, nodes, u_conns
|
||||
return seqs, nodes, conns, u_conns
|
||||
|
||||
def restore(self, state, transformed):
|
||||
seqs, nodes, u_conns = transformed
|
||||
conns = flatten_conns(nodes, u_conns, C=self.max_conns)
|
||||
seqs, nodes, conns, u_conns = transformed
|
||||
return nodes, conns
|
||||
|
||||
def forward(self, state, inputs, transformed):
|
||||
cal_seqs, nodes, u_conns = transformed
|
||||
def forward(self, state, transformed, inputs):
|
||||
cal_seqs, nodes, conns, u_conns = transformed
|
||||
|
||||
ini_vals = jnp.full((self.max_nodes,), jnp.nan)
|
||||
ini_vals = ini_vals.at[self.input_idx].set(inputs)
|
||||
nodes_attrs = nodes[:, 1:]
|
||||
nodes_attrs = jax.vmap(extract_node_attrs)(nodes)
|
||||
conns_attrs = jax.vmap(extract_conn_attrs)(conns)
|
||||
|
||||
def cond_fun(carry):
|
||||
values, idx = carry
|
||||
@@ -71,9 +80,16 @@ class DefaultGenome(BaseGenome):
|
||||
values, idx = carry
|
||||
i = cal_seqs[idx]
|
||||
|
||||
def hit():
|
||||
ins = jax.vmap(self.conn_gene.forward, in_axes=(None, 1, 0))(
|
||||
state, u_conns[:, :, i], values
|
||||
def input_node():
|
||||
z = self.node_gene.input_transform(state, nodes_attrs[i], values[i])
|
||||
new_values = values.at[i].set(z)
|
||||
return new_values
|
||||
|
||||
def otherwise():
|
||||
conn_indices = u_conns[:, i]
|
||||
hit_attrs = attach_with_inf(conns_attrs, conn_indices)
|
||||
ins = jax.vmap(self.conn_gene.forward, in_axes=(None, 0, 0))(
|
||||
state, hit_attrs, values
|
||||
)
|
||||
|
||||
z = self.node_gene.forward(
|
||||
@@ -86,8 +102,7 @@ class DefaultGenome(BaseGenome):
|
||||
new_values = values.at[i].set(z)
|
||||
return new_values
|
||||
|
||||
# the val of input nodes is obtained by the task, not by calculation
|
||||
values = jax.lax.cond(jnp.isin(i, self.input_idx), lambda: values, hit)
|
||||
values = jax.lax.cond(jnp.isin(i, self.input_idx), input_node, otherwise)
|
||||
|
||||
return values, idx + 1
|
||||
|
||||
@@ -99,55 +114,72 @@ class DefaultGenome(BaseGenome):
|
||||
return self.output_transform(vals[self.output_idx])
|
||||
|
||||
def update_by_batch(self, state, batch_input, transformed):
|
||||
cal_seqs, nodes, u_conns = transformed
|
||||
cal_seqs, nodes, conns, u_conns = transformed
|
||||
|
||||
batch_size = batch_input.shape[0]
|
||||
batch_ini_vals = jnp.full((batch_size, self.max_nodes), jnp.nan)
|
||||
batch_ini_vals = batch_ini_vals.at[:, self.input_idx].set(batch_input)
|
||||
nodes_attrs = nodes[:, 1:]
|
||||
nodes_attrs = jax.vmap(extract_node_attrs)(nodes)
|
||||
conns_attrs = jax.vmap(extract_conn_attrs)(conns)
|
||||
|
||||
def cond_fun(carry):
|
||||
batch_values, nodes_attrs_, u_conns_, idx = carry
|
||||
batch_values, nodes_attrs_, conns_attrs_, idx = carry
|
||||
return (idx < self.max_nodes) & (cal_seqs[idx] != I_INF)
|
||||
|
||||
def body_func(carry):
|
||||
batch_values, nodes_attrs_, u_conns_, idx = carry
|
||||
batch_values, nodes_attrs_, conns_attrs_, idx = carry
|
||||
i = cal_seqs[idx]
|
||||
|
||||
def hit():
|
||||
def input_node():
|
||||
batch, new_attrs = self.node_gene.update_input_transform(
|
||||
state, nodes_attrs_[i], batch_values[:, i]
|
||||
)
|
||||
return (
|
||||
batch_values.at[:, i].set(batch),
|
||||
nodes_attrs_.at[i].set(new_attrs),
|
||||
conns_attrs_,
|
||||
)
|
||||
|
||||
def otherwise():
|
||||
|
||||
conn_indices = u_conns[:, i]
|
||||
hit_attrs = attach_with_inf(conns_attrs, conn_indices)
|
||||
batch_ins, new_conn_attrs = jax.vmap(
|
||||
self.conn_gene.update_by_batch,
|
||||
in_axes=(None, 1, 1),
|
||||
out_axes=(1, 1),
|
||||
)(state, u_conns_[:, :, i], batch_values)
|
||||
in_axes=(None, 0, 1),
|
||||
out_axes=(1, 0),
|
||||
)(state, hit_attrs, batch_values)
|
||||
|
||||
batch_z, new_node_attrs = self.node_gene.update_by_batch(
|
||||
state,
|
||||
nodes_attrs[i],
|
||||
nodes_attrs_[i],
|
||||
batch_ins,
|
||||
is_output_node=jnp.isin(i, self.output_idx),
|
||||
)
|
||||
new_batch_values = batch_values.at[:, i].set(batch_z)
|
||||
|
||||
return (
|
||||
new_batch_values,
|
||||
batch_values.at[:, i].set(batch_z),
|
||||
nodes_attrs_.at[i].set(new_node_attrs),
|
||||
u_conns_.at[:, :, i].set(new_conn_attrs),
|
||||
conns_attrs_.at[conn_indices].set(new_conn_attrs),
|
||||
)
|
||||
|
||||
# the val of input nodes is obtained by the task, not by calculation
|
||||
(batch_values, nodes_attrs_, u_conns_) = jax.lax.cond(
|
||||
(batch_values, nodes_attrs_, conns_attrs_) = jax.lax.cond(
|
||||
jnp.isin(i, self.input_idx),
|
||||
lambda: (batch_values, nodes_attrs_, u_conns_),
|
||||
hit,
|
||||
input_node,
|
||||
otherwise,
|
||||
)
|
||||
|
||||
return batch_values, nodes_attrs_, u_conns_, idx + 1
|
||||
return batch_values, nodes_attrs_, conns_attrs_, idx + 1
|
||||
|
||||
batch_vals, nodes_attrs, u_conns, _ = jax.lax.while_loop(
|
||||
cond_fun, body_func, (batch_ini_vals, nodes_attrs, u_conns, 0)
|
||||
batch_vals, nodes_attrs, conns_attrs, _ = jax.lax.while_loop(
|
||||
cond_fun, body_func, (batch_ini_vals, nodes_attrs, conns_attrs, 0)
|
||||
)
|
||||
|
||||
nodes = nodes.at[:, 1:].set(nodes_attrs)
|
||||
new_transformed = (cal_seqs, nodes, u_conns)
|
||||
nodes = jax.vmap(set_node_attrs)(nodes, nodes_attrs)
|
||||
conns = jax.vmap(set_conn_attrs)(conns, conns_attrs)
|
||||
|
||||
new_transformed = (cal_seqs, nodes, conns, u_conns)
|
||||
|
||||
if self.output_transform is None:
|
||||
return batch_vals[:, self.output_idx], new_transformed
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Callable
|
||||
|
||||
import jax, jax.numpy as jnp
|
||||
from utils import unflatten_conns, flatten_conns
|
||||
from utils import unflatten_conns
|
||||
|
||||
from . import BaseGenome
|
||||
from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene
|
||||
@@ -47,11 +47,10 @@ class RecurrentGenome(BaseGenome):
|
||||
|
||||
def transform(self, state, nodes, conns):
|
||||
u_conns = unflatten_conns(nodes, conns)
|
||||
return nodes, u_conns
|
||||
return nodes, conns, u_conns
|
||||
|
||||
def restore(self, state, transformed):
|
||||
nodes, u_conns = transformed
|
||||
conns = flatten_conns(nodes, u_conns, C=self.max_conns)
|
||||
nodes, conns, u_conns = transformed
|
||||
return nodes, conns
|
||||
|
||||
def forward(self, state, inputs, transformed):
|
||||
|
||||
@@ -47,8 +47,8 @@ class NEAT(BaseAlgorithm):
|
||||
def restore(self, state, transformed):
|
||||
return self.genome.restore(state, transformed)
|
||||
|
||||
def forward(self, state, inputs, transformed):
|
||||
return self.genome.forward(state, inputs, transformed)
|
||||
def forward(self, state, transformed, inputs):
|
||||
return self.genome.forward(state, transformed, inputs)
|
||||
|
||||
def update_by_batch(self, state, batch_input, transformed):
|
||||
return self.genome.update_by_batch(state, batch_input, transformed)
|
||||
|
||||
Reference in New Issue
Block a user