add input_transform and update_input_transform;
change the args for genome.forward. Origin: (state, inputs, transformed) New: (state, transformed, inputs)
This commit is contained in:
@@ -22,7 +22,7 @@ class BaseAlgorithm:
|
|||||||
def restore(self, state, transformed):
|
def restore(self, state, transformed):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def forward(self, state, inputs, transformed):
|
def forward(self, state, transformed, inputs):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def update_by_batch(self, state, batch_input, transformed):
|
def update_by_batch(self, state, batch_input, transformed):
|
||||||
|
|||||||
@@ -54,8 +54,8 @@ class HyperNEAT(BaseAlgorithm):
|
|||||||
|
|
||||||
def transform(self, state, individual):
|
def transform(self, state, individual):
|
||||||
transformed = self.neat.transform(state, individual)
|
transformed = self.neat.transform(state, individual)
|
||||||
query_res = jax.vmap(self.neat.forward, in_axes=(None, 0, None))(
|
query_res = jax.vmap(self.neat.forward, in_axes=(None, None, 0))(
|
||||||
state, self.substrate.query_coors, transformed
|
state, transformed, self.substrate.query_coors
|
||||||
)
|
)
|
||||||
# mute the connection with weight below threshold
|
# mute the connection with weight below threshold
|
||||||
query_res = jnp.where(
|
query_res = jnp.where(
|
||||||
|
|||||||
@@ -163,8 +163,8 @@ class DefaultMutation(BaseMutation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if genome.network_type == "feedforward":
|
if genome.network_type == "feedforward":
|
||||||
u_cons = unflatten_conns(nodes_, conns_)
|
u_conns = unflatten_conns(nodes_, conns_)
|
||||||
conns_exist = ~jnp.isnan(u_cons[0, :, :])
|
conns_exist = (u_conns != I_INF)
|
||||||
is_cycle = check_cycles(nodes_, conns_exist, from_idx, to_idx)
|
is_cycle = check_cycles(nodes_, conns_exist, from_idx, to_idx)
|
||||||
|
|
||||||
return jax.lax.cond(
|
return jax.lax.cond(
|
||||||
|
|||||||
@@ -12,6 +12,13 @@ class BaseNodeGene(BaseGene):
|
|||||||
def forward(self, state, attrs, inputs, is_output_node=False):
|
def forward(self, state, attrs, inputs, is_output_node=False):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def input_transform(self, state, attrs, inputs):
|
||||||
|
"""
|
||||||
|
make transformation in the input node.
|
||||||
|
default: do nothing
|
||||||
|
"""
|
||||||
|
return inputs
|
||||||
|
|
||||||
def update_by_batch(self, state, attrs, batch_inputs, is_output_node=False):
|
def update_by_batch(self, state, attrs, batch_inputs, is_output_node=False):
|
||||||
# default: do not update attrs, but to calculate batch_res
|
# default: do not update attrs, but to calculate batch_res
|
||||||
return (
|
return (
|
||||||
@@ -20,3 +27,15 @@ class BaseNodeGene(BaseGene):
|
|||||||
),
|
),
|
||||||
attrs,
|
attrs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def update_input_transform(self, state, attrs, batch_inputs):
|
||||||
|
"""
|
||||||
|
update the attrs for transformation in the input node.
|
||||||
|
default: do nothing
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
jax.vmap(self.input_transform, in_axes=(None, None, 0))(
|
||||||
|
state, attrs, batch_inputs
|
||||||
|
),
|
||||||
|
attrs,
|
||||||
|
)
|
||||||
|
|||||||
@@ -157,6 +157,15 @@ class NormalizedNode(BaseNodeGene):
|
|||||||
|
|
||||||
return z
|
return z
|
||||||
|
|
||||||
|
def input_transform(self, state, attrs, inputs):
|
||||||
|
"""
|
||||||
|
make transform in the input node.
|
||||||
|
the normalization also need be done in the first node.
|
||||||
|
"""
|
||||||
|
bias, agg, act, mean, std, alpha, beta = attrs
|
||||||
|
inputs = (inputs - mean) / (std + self.eps) * alpha + beta # normalization
|
||||||
|
return inputs
|
||||||
|
|
||||||
def update_by_batch(self, state, attrs, batch_inputs, is_output_node=False):
|
def update_by_batch(self, state, attrs, batch_inputs, is_output_node=False):
|
||||||
|
|
||||||
bias, agg, act, mean, std, alpha, beta = attrs
|
bias, agg, act, mean, std, alpha, beta = attrs
|
||||||
@@ -192,3 +201,31 @@ class NormalizedNode(BaseNodeGene):
|
|||||||
attrs = attrs.at[4].set(std)
|
attrs = attrs.at[4].set(std)
|
||||||
|
|
||||||
return batch_z, attrs
|
return batch_z, attrs
|
||||||
|
|
||||||
|
def update_input_transform(self, state, attrs, batch_inputs):
|
||||||
|
"""
|
||||||
|
update the attrs for transformation in the input node.
|
||||||
|
default: do nothing
|
||||||
|
"""
|
||||||
|
bias, agg, act, mean, std, alpha, beta = attrs
|
||||||
|
|
||||||
|
# calculate mean
|
||||||
|
valid_values_count = jnp.sum(~jnp.isnan(batch_inputs))
|
||||||
|
valid_values_sum = jnp.sum(jnp.where(jnp.isnan(batch_inputs), 0, batch_inputs))
|
||||||
|
mean = valid_values_sum / valid_values_count
|
||||||
|
|
||||||
|
# calculate std
|
||||||
|
std = jnp.sqrt(
|
||||||
|
jnp.sum(jnp.where(jnp.isnan(batch_inputs), 0, (batch_inputs - mean) ** 2))
|
||||||
|
/ valid_values_count
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_inputs = (batch_inputs - mean) / (
|
||||||
|
std + self.eps
|
||||||
|
) * alpha + beta # normalization
|
||||||
|
|
||||||
|
# update mean and std to the attrs
|
||||||
|
attrs = attrs.at[3].set(mean)
|
||||||
|
attrs = attrs.at[4].set(std)
|
||||||
|
|
||||||
|
return batch_inputs, attrs
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ class BaseGenome:
|
|||||||
def restore(self, state, transformed):
|
def restore(self, state, transformed):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def forward(self, state, inputs, transformed):
|
def forward(self, state, transformed, inputs):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def execute_mutation(self, state, randkey, nodes, conns, new_node_key):
|
def execute_mutation(self, state, randkey, nodes, conns, new_node_key):
|
||||||
|
|||||||
@@ -1,7 +1,16 @@
|
|||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
import jax, jax.numpy as jnp
|
import jax, jax.numpy as jnp
|
||||||
from utils import unflatten_conns, flatten_conns, topological_sort, I_INF
|
from utils import (
|
||||||
|
unflatten_conns,
|
||||||
|
topological_sort,
|
||||||
|
I_INF,
|
||||||
|
extract_node_attrs,
|
||||||
|
extract_conn_attrs,
|
||||||
|
set_node_attrs,
|
||||||
|
set_conn_attrs,
|
||||||
|
attach_with_inf,
|
||||||
|
)
|
||||||
|
|
||||||
from . import BaseGenome
|
from . import BaseGenome
|
||||||
from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene
|
from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene
|
||||||
@@ -45,23 +54,23 @@ class DefaultGenome(BaseGenome):
|
|||||||
|
|
||||||
def transform(self, state, nodes, conns):
|
def transform(self, state, nodes, conns):
|
||||||
u_conns = unflatten_conns(nodes, conns)
|
u_conns = unflatten_conns(nodes, conns)
|
||||||
conn_exist = ~jnp.isnan(u_conns[0])
|
conn_exist = u_conns != I_INF
|
||||||
|
|
||||||
seqs = topological_sort(nodes, conn_exist)
|
seqs = topological_sort(nodes, conn_exist)
|
||||||
|
|
||||||
return seqs, nodes, u_conns
|
return seqs, nodes, conns, u_conns
|
||||||
|
|
||||||
def restore(self, state, transformed):
|
def restore(self, state, transformed):
|
||||||
seqs, nodes, u_conns = transformed
|
seqs, nodes, conns, u_conns = transformed
|
||||||
conns = flatten_conns(nodes, u_conns, C=self.max_conns)
|
|
||||||
return nodes, conns
|
return nodes, conns
|
||||||
|
|
||||||
def forward(self, state, inputs, transformed):
|
def forward(self, state, transformed, inputs):
|
||||||
cal_seqs, nodes, u_conns = transformed
|
cal_seqs, nodes, conns, u_conns = transformed
|
||||||
|
|
||||||
ini_vals = jnp.full((self.max_nodes,), jnp.nan)
|
ini_vals = jnp.full((self.max_nodes,), jnp.nan)
|
||||||
ini_vals = ini_vals.at[self.input_idx].set(inputs)
|
ini_vals = ini_vals.at[self.input_idx].set(inputs)
|
||||||
nodes_attrs = nodes[:, 1:]
|
nodes_attrs = jax.vmap(extract_node_attrs)(nodes)
|
||||||
|
conns_attrs = jax.vmap(extract_conn_attrs)(conns)
|
||||||
|
|
||||||
def cond_fun(carry):
|
def cond_fun(carry):
|
||||||
values, idx = carry
|
values, idx = carry
|
||||||
@@ -71,9 +80,16 @@ class DefaultGenome(BaseGenome):
|
|||||||
values, idx = carry
|
values, idx = carry
|
||||||
i = cal_seqs[idx]
|
i = cal_seqs[idx]
|
||||||
|
|
||||||
def hit():
|
def input_node():
|
||||||
ins = jax.vmap(self.conn_gene.forward, in_axes=(None, 1, 0))(
|
z = self.node_gene.input_transform(state, nodes_attrs[i], values[i])
|
||||||
state, u_conns[:, :, i], values
|
new_values = values.at[i].set(z)
|
||||||
|
return new_values
|
||||||
|
|
||||||
|
def otherwise():
|
||||||
|
conn_indices = u_conns[:, i]
|
||||||
|
hit_attrs = attach_with_inf(conns_attrs, conn_indices)
|
||||||
|
ins = jax.vmap(self.conn_gene.forward, in_axes=(None, 0, 0))(
|
||||||
|
state, hit_attrs, values
|
||||||
)
|
)
|
||||||
|
|
||||||
z = self.node_gene.forward(
|
z = self.node_gene.forward(
|
||||||
@@ -86,8 +102,7 @@ class DefaultGenome(BaseGenome):
|
|||||||
new_values = values.at[i].set(z)
|
new_values = values.at[i].set(z)
|
||||||
return new_values
|
return new_values
|
||||||
|
|
||||||
# the val of input nodes is obtained by the task, not by calculation
|
values = jax.lax.cond(jnp.isin(i, self.input_idx), input_node, otherwise)
|
||||||
values = jax.lax.cond(jnp.isin(i, self.input_idx), lambda: values, hit)
|
|
||||||
|
|
||||||
return values, idx + 1
|
return values, idx + 1
|
||||||
|
|
||||||
@@ -99,55 +114,72 @@ class DefaultGenome(BaseGenome):
|
|||||||
return self.output_transform(vals[self.output_idx])
|
return self.output_transform(vals[self.output_idx])
|
||||||
|
|
||||||
def update_by_batch(self, state, batch_input, transformed):
|
def update_by_batch(self, state, batch_input, transformed):
|
||||||
cal_seqs, nodes, u_conns = transformed
|
cal_seqs, nodes, conns, u_conns = transformed
|
||||||
|
|
||||||
batch_size = batch_input.shape[0]
|
batch_size = batch_input.shape[0]
|
||||||
batch_ini_vals = jnp.full((batch_size, self.max_nodes), jnp.nan)
|
batch_ini_vals = jnp.full((batch_size, self.max_nodes), jnp.nan)
|
||||||
batch_ini_vals = batch_ini_vals.at[:, self.input_idx].set(batch_input)
|
batch_ini_vals = batch_ini_vals.at[:, self.input_idx].set(batch_input)
|
||||||
nodes_attrs = nodes[:, 1:]
|
nodes_attrs = jax.vmap(extract_node_attrs)(nodes)
|
||||||
|
conns_attrs = jax.vmap(extract_conn_attrs)(conns)
|
||||||
|
|
||||||
def cond_fun(carry):
|
def cond_fun(carry):
|
||||||
batch_values, nodes_attrs_, u_conns_, idx = carry
|
batch_values, nodes_attrs_, conns_attrs_, idx = carry
|
||||||
return (idx < self.max_nodes) & (cal_seqs[idx] != I_INF)
|
return (idx < self.max_nodes) & (cal_seqs[idx] != I_INF)
|
||||||
|
|
||||||
def body_func(carry):
|
def body_func(carry):
|
||||||
batch_values, nodes_attrs_, u_conns_, idx = carry
|
batch_values, nodes_attrs_, conns_attrs_, idx = carry
|
||||||
i = cal_seqs[idx]
|
i = cal_seqs[idx]
|
||||||
|
|
||||||
def hit():
|
def input_node():
|
||||||
|
batch, new_attrs = self.node_gene.update_input_transform(
|
||||||
|
state, nodes_attrs_[i], batch_values[:, i]
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
batch_values.at[:, i].set(batch),
|
||||||
|
nodes_attrs_.at[i].set(new_attrs),
|
||||||
|
conns_attrs_,
|
||||||
|
)
|
||||||
|
|
||||||
|
def otherwise():
|
||||||
|
|
||||||
|
conn_indices = u_conns[:, i]
|
||||||
|
hit_attrs = attach_with_inf(conns_attrs, conn_indices)
|
||||||
batch_ins, new_conn_attrs = jax.vmap(
|
batch_ins, new_conn_attrs = jax.vmap(
|
||||||
self.conn_gene.update_by_batch,
|
self.conn_gene.update_by_batch,
|
||||||
in_axes=(None, 1, 1),
|
in_axes=(None, 0, 1),
|
||||||
out_axes=(1, 1),
|
out_axes=(1, 0),
|
||||||
)(state, u_conns_[:, :, i], batch_values)
|
)(state, hit_attrs, batch_values)
|
||||||
|
|
||||||
batch_z, new_node_attrs = self.node_gene.update_by_batch(
|
batch_z, new_node_attrs = self.node_gene.update_by_batch(
|
||||||
state,
|
state,
|
||||||
nodes_attrs[i],
|
nodes_attrs_[i],
|
||||||
batch_ins,
|
batch_ins,
|
||||||
is_output_node=jnp.isin(i, self.output_idx),
|
is_output_node=jnp.isin(i, self.output_idx),
|
||||||
)
|
)
|
||||||
new_batch_values = batch_values.at[:, i].set(batch_z)
|
|
||||||
return (
|
return (
|
||||||
new_batch_values,
|
batch_values.at[:, i].set(batch_z),
|
||||||
nodes_attrs_.at[i].set(new_node_attrs),
|
nodes_attrs_.at[i].set(new_node_attrs),
|
||||||
u_conns_.at[:, :, i].set(new_conn_attrs),
|
conns_attrs_.at[conn_indices].set(new_conn_attrs),
|
||||||
)
|
)
|
||||||
|
|
||||||
# the val of input nodes is obtained by the task, not by calculation
|
# the val of input nodes is obtained by the task, not by calculation
|
||||||
(batch_values, nodes_attrs_, u_conns_) = jax.lax.cond(
|
(batch_values, nodes_attrs_, conns_attrs_) = jax.lax.cond(
|
||||||
jnp.isin(i, self.input_idx),
|
jnp.isin(i, self.input_idx),
|
||||||
lambda: (batch_values, nodes_attrs_, u_conns_),
|
input_node,
|
||||||
hit,
|
otherwise,
|
||||||
)
|
)
|
||||||
|
|
||||||
return batch_values, nodes_attrs_, u_conns_, idx + 1
|
return batch_values, nodes_attrs_, conns_attrs_, idx + 1
|
||||||
|
|
||||||
batch_vals, nodes_attrs, u_conns, _ = jax.lax.while_loop(
|
batch_vals, nodes_attrs, conns_attrs, _ = jax.lax.while_loop(
|
||||||
cond_fun, body_func, (batch_ini_vals, nodes_attrs, u_conns, 0)
|
cond_fun, body_func, (batch_ini_vals, nodes_attrs, conns_attrs, 0)
|
||||||
)
|
)
|
||||||
|
|
||||||
nodes = nodes.at[:, 1:].set(nodes_attrs)
|
nodes = jax.vmap(set_node_attrs)(nodes, nodes_attrs)
|
||||||
new_transformed = (cal_seqs, nodes, u_conns)
|
conns = jax.vmap(set_conn_attrs)(conns, conns_attrs)
|
||||||
|
|
||||||
|
new_transformed = (cal_seqs, nodes, conns, u_conns)
|
||||||
|
|
||||||
if self.output_transform is None:
|
if self.output_transform is None:
|
||||||
return batch_vals[:, self.output_idx], new_transformed
|
return batch_vals[:, self.output_idx], new_transformed
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
import jax, jax.numpy as jnp
|
import jax, jax.numpy as jnp
|
||||||
from utils import unflatten_conns, flatten_conns
|
from utils import unflatten_conns
|
||||||
|
|
||||||
from . import BaseGenome
|
from . import BaseGenome
|
||||||
from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene
|
from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene
|
||||||
@@ -47,11 +47,10 @@ class RecurrentGenome(BaseGenome):
|
|||||||
|
|
||||||
def transform(self, state, nodes, conns):
|
def transform(self, state, nodes, conns):
|
||||||
u_conns = unflatten_conns(nodes, conns)
|
u_conns = unflatten_conns(nodes, conns)
|
||||||
return nodes, u_conns
|
return nodes, conns, u_conns
|
||||||
|
|
||||||
def restore(self, state, transformed):
|
def restore(self, state, transformed):
|
||||||
nodes, u_conns = transformed
|
nodes, conns, u_conns = transformed
|
||||||
conns = flatten_conns(nodes, u_conns, C=self.max_conns)
|
|
||||||
return nodes, conns
|
return nodes, conns
|
||||||
|
|
||||||
def forward(self, state, inputs, transformed):
|
def forward(self, state, inputs, transformed):
|
||||||
|
|||||||
@@ -47,8 +47,8 @@ class NEAT(BaseAlgorithm):
|
|||||||
def restore(self, state, transformed):
|
def restore(self, state, transformed):
|
||||||
return self.genome.restore(state, transformed)
|
return self.genome.restore(state, transformed)
|
||||||
|
|
||||||
def forward(self, state, inputs, transformed):
|
def forward(self, state, transformed, inputs):
|
||||||
return self.genome.forward(state, inputs, transformed)
|
return self.genome.forward(state, transformed, inputs)
|
||||||
|
|
||||||
def update_by_batch(self, state, batch_input, transformed):
|
def update_by_batch(self, state, batch_input, transformed):
|
||||||
return self.genome.update_by_batch(state, batch_input, transformed)
|
return self.genome.update_by_batch(state, batch_input, transformed)
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from pipeline import Pipeline
|
|||||||
from algorithm.neat import *
|
from algorithm.neat import *
|
||||||
|
|
||||||
from problem.func_fit import XOR3d
|
from problem.func_fit import XOR3d
|
||||||
from utils import Act
|
from utils import ACT_ALL, AGG_ALL, Act, Agg
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pipeline = Pipeline(
|
pipeline = Pipeline(
|
||||||
@@ -15,17 +15,21 @@ if __name__ == "__main__":
|
|||||||
max_conns=100,
|
max_conns=100,
|
||||||
node_gene=DefaultNodeGene(
|
node_gene=DefaultNodeGene(
|
||||||
activation_default=Act.tanh,
|
activation_default=Act.tanh,
|
||||||
activation_options=(Act.tanh,),
|
# activation_options=(Act.tanh,),
|
||||||
|
activation_options=ACT_ALL,
|
||||||
|
aggregation_default=Agg.sum,
|
||||||
|
# aggregation_options=(Agg.sum,),
|
||||||
|
aggregation_options=AGG_ALL,
|
||||||
),
|
),
|
||||||
output_transform=Act.sigmoid, # the activation function for output node
|
output_transform=Act.sigmoid, # the activation function for output node
|
||||||
mutation=DefaultMutation(
|
mutation=DefaultMutation(
|
||||||
node_add=0.1,
|
node_add=0.1,
|
||||||
conn_add=0.1,
|
conn_add=0.1,
|
||||||
node_delete=0.05,
|
node_delete=0,
|
||||||
conn_delete=0.05,
|
conn_delete=0,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
pop_size=1000,
|
pop_size=100000,
|
||||||
species_size=20,
|
species_size=20,
|
||||||
compatibility_threshold=2,
|
compatibility_threshold=2,
|
||||||
survival_threshold=0.01, # magic
|
survival_threshold=0.01, # magic
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ if __name__ == "__main__":
|
|||||||
max_nodes=50,
|
max_nodes=50,
|
||||||
max_conns=100,
|
max_conns=100,
|
||||||
node_gene=KANNode(),
|
node_gene=KANNode(),
|
||||||
conn_gene=BSplineConn(),
|
conn_gene=BSplineConn(grid_cnt=10),
|
||||||
output_transform=Act.sigmoid, # the activation function for output node
|
output_transform=Act.sigmoid, # the activation function for output node
|
||||||
mutation=DefaultMutation(
|
mutation=DefaultMutation(
|
||||||
node_add=0.1,
|
node_add=0.1,
|
||||||
@@ -25,7 +25,7 @@ if __name__ == "__main__":
|
|||||||
conn_delete=0.05,
|
conn_delete=0.05,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
pop_size=1000,
|
pop_size=10000,
|
||||||
species_size=20,
|
species_size=20,
|
||||||
compatibility_threshold=1.5,
|
compatibility_threshold=1.5,
|
||||||
survival_threshold=0.01, # magic
|
survival_threshold=0.01, # magic
|
||||||
@@ -34,7 +34,7 @@ if __name__ == "__main__":
|
|||||||
# problem=XOR3d(return_data=True),
|
# problem=XOR3d(return_data=True),
|
||||||
problem=XOR3d(),
|
problem=XOR3d(),
|
||||||
generation_limit=10000,
|
generation_limit=10000,
|
||||||
fitness_target=-1e-8,
|
fitness_target=-1e-5,
|
||||||
# update_batch_size=8,
|
# update_batch_size=8,
|
||||||
# pre_update=True,
|
# pre_update=True,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -20,8 +20,8 @@ class FuncFit(BaseProblem):
|
|||||||
|
|
||||||
def evaluate(self, state, randkey, act_func, params):
|
def evaluate(self, state, randkey, act_func, params):
|
||||||
|
|
||||||
predict = jax.vmap(act_func, in_axes=(None, 0, None))(
|
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
|
||||||
state, self.inputs, params
|
state, params, self.inputs
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.error_method == "mse":
|
if self.error_method == "mse":
|
||||||
@@ -45,8 +45,8 @@ class FuncFit(BaseProblem):
|
|||||||
return -loss
|
return -loss
|
||||||
|
|
||||||
def show(self, state, randkey, act_func, params, *args, **kwargs):
|
def show(self, state, randkey, act_func, params, *args, **kwargs):
|
||||||
predict = jax.vmap(act_func, in_axes=(None, 0, None))(
|
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
|
||||||
state, self.inputs, params
|
state, params, self.inputs, params
|
||||||
)
|
)
|
||||||
inputs, target, predict = jax.device_get([self.inputs, self.targets, predict])
|
inputs, target, predict = jax.device_get([self.inputs, self.targets, predict])
|
||||||
if self.return_data:
|
if self.return_data:
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ class BraxEnv(RLEnv):
|
|||||||
|
|
||||||
def step(key, env_state, obs):
|
def step(key, env_state, obs):
|
||||||
key, _ = jax.random.split(key)
|
key, _ = jax.random.split(key)
|
||||||
action = act_func(obs, params)
|
action = act_func(params, obs)
|
||||||
obs, env_state, r, done, _ = self.step(randkey, env_state, action)
|
obs, env_state, r, done, _ = self.step(randkey, env_state, action)
|
||||||
return key, env_state, obs, r, done
|
return key, env_state, obs, r, done
|
||||||
|
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ class RLEnv(BaseProblem):
|
|||||||
|
|
||||||
def body_func(carry):
|
def body_func(carry):
|
||||||
obs, env_state, rng, done, tr, count, epis = carry # tr -> total reward
|
obs, env_state, rng, done, tr, count, epis = carry # tr -> total reward
|
||||||
action = act_func(state, obs, params)
|
action = act_func(state, params, obs)
|
||||||
next_obs, next_env_state, reward, done, _ = self.step(
|
next_obs, next_env_state, reward, done, _ = self.step(
|
||||||
rng, env_state, action
|
rng, env_state, action
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,132 +1,27 @@
|
|||||||
|
import jax
|
||||||
from algorithm.neat import *
|
from algorithm.neat import *
|
||||||
from utils import Act, Agg, State
|
|
||||||
|
|
||||||
import jax, jax.numpy as jnp
|
genome = DefaultGenome(
|
||||||
from algorithm.neat.gene.node.default_without_response import NodeGeneWithoutResponse
|
num_inputs=3,
|
||||||
|
num_outputs=1,
|
||||||
|
max_nodes=5,
|
||||||
|
max_conns=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_default():
|
def test_output_work():
|
||||||
|
randkey = jax.random.PRNGKey(0)
|
||||||
# index, bias, response, activation, aggregation
|
|
||||||
nodes = jnp.array(
|
|
||||||
[
|
|
||||||
[0, 0, 1, 0, 0], # in[0]
|
|
||||||
[1, 0, 1, 0, 0], # in[1]
|
|
||||||
[2, 0.5, 1, 0, 0], # out[0],
|
|
||||||
[3, 1, 1, 0, 0], # hidden[0],
|
|
||||||
[4, -1, 1, 0, 0], # hidden[1],
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# in_node, out_node, enable, weight
|
|
||||||
conns = jnp.array(
|
|
||||||
[
|
|
||||||
[0, 3, 0.5], # in[0] -> hidden[0]
|
|
||||||
[1, 4, 0.5], # in[1] -> hidden[1]
|
|
||||||
[3, 2, 0.5], # hidden[0] -> out[0]
|
|
||||||
[4, 2, 0.5], # hidden[1] -> out[0]
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
genome = DefaultGenome(
|
|
||||||
num_inputs=2,
|
|
||||||
num_outputs=1,
|
|
||||||
max_nodes=5,
|
|
||||||
max_conns=4,
|
|
||||||
node_gene=DefaultNodeGene(
|
|
||||||
activation_default=Act.identity,
|
|
||||||
activation_options=(Act.identity,),
|
|
||||||
aggregation_default=Agg.sum,
|
|
||||||
aggregation_options=(Agg.sum,),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
state = genome.setup(State(randkey=jax.random.key(0)))
|
|
||||||
|
|
||||||
transformed = genome.transform(state, nodes, conns)
|
|
||||||
print(*transformed, sep="\n")
|
|
||||||
|
|
||||||
inputs = jnp.array([[0, 0], [0, 1], [1, 0], [1, 1]])
|
|
||||||
outputs = jax.jit(jax.vmap(genome.forward, in_axes=(None, 0, None)))(
|
|
||||||
state, inputs, transformed
|
|
||||||
)
|
|
||||||
print(outputs)
|
|
||||||
assert jnp.allclose(outputs, jnp.array([[0.5], [0.75], [0.75], [1]]))
|
|
||||||
# expected: [[0.5], [0.75], [0.75], [1]]
|
|
||||||
|
|
||||||
|
|
||||||
def test_recurrent():
|
|
||||||
|
|
||||||
# index, bias, response, activation, aggregation
|
|
||||||
nodes = jnp.array(
|
|
||||||
[
|
|
||||||
[0, 0, 1, 0, 0], # in[0]
|
|
||||||
[1, 0, 1, 0, 0], # in[1]
|
|
||||||
[2, 0.5, 1, 0, 0], # out[0],
|
|
||||||
[3, 1, 1, 0, 0], # hidden[0],
|
|
||||||
[4, -1, 1, 0, 0], # hidden[1],
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# in_node, out_node, enable, weight
|
|
||||||
conns = jnp.array(
|
|
||||||
[
|
|
||||||
[0, 3, 0.5], # in[0] -> hidden[0]
|
|
||||||
[1, 4, 0.5], # in[1] -> hidden[1]
|
|
||||||
[3, 2, 0.5], # hidden[0] -> out[0]
|
|
||||||
[4, 2, 0.5], # hidden[1] -> out[0]
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
genome = RecurrentGenome(
|
|
||||||
num_inputs=2,
|
|
||||||
num_outputs=1,
|
|
||||||
max_nodes=5,
|
|
||||||
max_conns=4,
|
|
||||||
node_gene=DefaultNodeGene(
|
|
||||||
activation_default=Act.identity,
|
|
||||||
activation_options=(Act.identity,),
|
|
||||||
aggregation_default=Agg.sum,
|
|
||||||
aggregation_options=(Agg.sum,),
|
|
||||||
),
|
|
||||||
activate_time=3,
|
|
||||||
)
|
|
||||||
|
|
||||||
state = genome.setup(State(randkey=jax.random.key(0)))
|
|
||||||
|
|
||||||
transformed = genome.transform(state, nodes, conns)
|
|
||||||
print(*transformed, sep="\n")
|
|
||||||
|
|
||||||
inputs = jnp.array([[0, 0], [0, 1], [1, 0], [1, 1]])
|
|
||||||
outputs = jax.jit(jax.vmap(genome.forward, in_axes=(None, 0, None)))(
|
|
||||||
state, inputs, transformed
|
|
||||||
)
|
|
||||||
print(outputs)
|
|
||||||
assert jnp.allclose(outputs, jnp.array([[0.5], [0.75], [0.75], [1]]))
|
|
||||||
# expected: [[0.5], [0.75], [0.75], [1]]
|
|
||||||
|
|
||||||
|
|
||||||
def test_random_initialize():
|
|
||||||
genome = DefaultGenome(
|
|
||||||
num_inputs=2,
|
|
||||||
num_outputs=1,
|
|
||||||
max_nodes=5,
|
|
||||||
max_conns=4,
|
|
||||||
node_gene=NodeGeneWithoutResponse(
|
|
||||||
activation_default=Act.identity,
|
|
||||||
activation_options=(Act.identity,),
|
|
||||||
aggregation_default=Agg.sum,
|
|
||||||
aggregation_options=(Agg.sum,),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
state = genome.setup()
|
state = genome.setup()
|
||||||
key = jax.random.PRNGKey(0)
|
nodes, conns = genome.initialize(state, randkey)
|
||||||
nodes, conns = genome.initialize(state, key)
|
|
||||||
transformed = genome.transform(state, nodes, conns)
|
transformed = genome.transform(state, nodes, conns)
|
||||||
print(*transformed, sep="\n")
|
inputs = jax.random.normal(randkey, (3,))
|
||||||
|
output = genome.forward(state, transformed, inputs)
|
||||||
|
print(output)
|
||||||
|
|
||||||
inputs = jnp.array([[0, 0], [0, 1], [1, 0], [1, 1]])
|
batch_inputs = jax.random.normal(randkey, (10, 3))
|
||||||
outputs = jax.jit(jax.vmap(genome.forward, in_axes=(None, 0, None)))(
|
batch_output = jax.vmap(genome.forward, in_axes=(None, None, 0))(
|
||||||
state, inputs, transformed
|
state, transformed, batch_inputs
|
||||||
)
|
)
|
||||||
print(outputs)
|
print(batch_output)
|
||||||
|
|
||||||
|
assert True
|
||||||
|
|||||||
@@ -9,12 +9,12 @@ I_INF = np.iinfo(jnp.int32).max # infinite int
|
|||||||
|
|
||||||
def unflatten_conns(nodes, conns):
|
def unflatten_conns(nodes, conns):
|
||||||
"""
|
"""
|
||||||
transform the (C, CL) connections to (CL-2, N, N), 2 is for the input index and output index), which CL means
|
transform the (C, CL) connections to (N, N), which contains the idx of the connection in conns
|
||||||
connection length, N means the number of nodes, C means the number of connections
|
connection length, N means the number of nodes, C means the number of connections
|
||||||
returns the un_flattened connections with shape (CL-2, N, N)
|
returns the unflatten connection indices with shape (N, N)
|
||||||
"""
|
"""
|
||||||
N = nodes.shape[0] # max_nodes
|
N = nodes.shape[0] # max_nodes
|
||||||
CL = conns.shape[1] # connection length = (fix_attrs + custom_attrs)
|
C = conns.shape[0] # max_conns
|
||||||
node_keys = nodes[:, 0]
|
node_keys = nodes[:, 0]
|
||||||
i_keys, o_keys = conns[:, 0], conns[:, 1]
|
i_keys, o_keys = conns[:, 0], conns[:, 1]
|
||||||
|
|
||||||
@@ -23,47 +23,25 @@ def unflatten_conns(nodes, conns):
|
|||||||
|
|
||||||
i_idxs = vmap(key_to_indices, in_axes=(0, None))(i_keys, node_keys)
|
i_idxs = vmap(key_to_indices, in_axes=(0, None))(i_keys, node_keys)
|
||||||
o_idxs = vmap(key_to_indices, in_axes=(0, None))(o_keys, node_keys)
|
o_idxs = vmap(key_to_indices, in_axes=(0, None))(o_keys, node_keys)
|
||||||
unflatten = jnp.full((CL - 2, N, N), jnp.nan)
|
|
||||||
|
|
||||||
# Is interesting that jax use clip when attach data in array
|
# Is interesting that jax use clip when attach data in array
|
||||||
# however, it will do nothing set values in an array
|
# however, it will do nothing when setting values in an array
|
||||||
# put all attributes include enable in res
|
# put the index of connections in the unflatten array
|
||||||
unflatten = unflatten.at[:, i_idxs, o_idxs].set(conns[:, 2:].T)
|
unflatten = (
|
||||||
assert unflatten.shape == (CL - 2, N, N)
|
jnp.full((N, N), I_INF, dtype=jnp.int32)
|
||||||
|
.at[i_idxs, o_idxs]
|
||||||
|
.set(jnp.arange(C, dtype=jnp.int32))
|
||||||
|
)
|
||||||
|
|
||||||
return unflatten
|
return unflatten
|
||||||
|
|
||||||
|
|
||||||
def flatten_conns(nodes, unflatten, C):
|
def attach_with_inf(arr, idx):
|
||||||
"""
|
expand_size = arr.ndim - idx.ndim
|
||||||
the inverse function of unflatten_conns
|
expand_idx = jnp.expand_dims(
|
||||||
transform the unflatten conn (CL-2, N, N) to (C, CL)
|
idx, axis=tuple(range(idx.ndim, expand_size + idx.ndim))
|
||||||
"""
|
)
|
||||||
N = nodes.shape[0]
|
return jnp.where(expand_idx == I_INF, jnp.nan, arr[idx])
|
||||||
CL = unflatten.shape[0] + 2
|
|
||||||
node_keys = nodes[:, 0]
|
|
||||||
|
|
||||||
def extract_conn(i, j):
|
|
||||||
return jnp.where(
|
|
||||||
jnp.isnan(unflatten[0, i, j]),
|
|
||||||
jnp.nan,
|
|
||||||
jnp.concatenate(
|
|
||||||
[jnp.array([node_keys[i], node_keys[j]]), unflatten[:, i, j]]
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
x, y = jnp.meshgrid(jnp.arange(N), jnp.arange(N), indexing="ij")
|
|
||||||
conns = vmap(extract_conn)(x.flatten(), y.flatten())
|
|
||||||
assert conns.shape == (N * N, CL)
|
|
||||||
|
|
||||||
# put nan to the tail of the conns
|
|
||||||
sorted_idx = jnp.argsort(conns[:, 0])
|
|
||||||
sorted_conn = conns[sorted_idx]
|
|
||||||
|
|
||||||
# truncate the conns to the number of connections
|
|
||||||
conns = sorted_conn[:C]
|
|
||||||
assert conns.shape == (C, CL)
|
|
||||||
return conns
|
|
||||||
|
|
||||||
|
|
||||||
def extract_node_attrs(node):
|
def extract_node_attrs(node):
|
||||||
|
|||||||
Reference in New Issue
Block a user