disable activation in the output node of network;
we recommend to use output_transform; change hyperparameters (strong) in XOR example;
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
from typing import Callable
|
||||
|
||||
import jax, jax.numpy as jnp
|
||||
|
||||
from utils import State, Act, Agg
|
||||
@@ -18,6 +20,7 @@ class HyperNEAT(BaseAlgorithm):
|
||||
activation=Act.sigmoid,
|
||||
aggregation=Agg.sum,
|
||||
activate_time: int = 10,
|
||||
output_transform: Callable = Act.sigmoid,
|
||||
):
|
||||
assert substrate.query_coors.shape[1] == neat.num_inputs, \
|
||||
"Substrate input size should be equal to NEAT input size"
|
||||
@@ -34,6 +37,7 @@ class HyperNEAT(BaseAlgorithm):
|
||||
node_gene=HyperNodeGene(activation, aggregation),
|
||||
conn_gene=HyperNEATConnGene(),
|
||||
activate_time=activate_time,
|
||||
output_transform=output_transform
|
||||
)
|
||||
|
||||
def setup(self, randkey):
|
||||
@@ -102,11 +106,13 @@ class HyperNodeGene(BaseNodeGene):
|
||||
self.activation = activation
|
||||
self.aggregation = aggregation
|
||||
|
||||
def forward(self, attrs, inputs):
|
||||
return self.activation(
|
||||
self.aggregation(inputs)
|
||||
)
|
||||
def forward(self, attrs, inputs, is_output_node=False):
|
||||
return jax.lax.cond(
|
||||
is_output_node,
|
||||
lambda: self.aggregation(inputs), # output node does not need activation
|
||||
lambda: self.activation(self.aggregation(inputs))
|
||||
|
||||
)
|
||||
|
||||
class HyperNEATConnGene(BaseConnGene):
|
||||
custom_attrs = ['weight']
|
||||
|
||||
@@ -92,7 +92,10 @@ class DefaultNodeGene(BaseNodeGene):
|
||||
z = bias + res * z
|
||||
|
||||
# the last output node should not be activated
|
||||
if not is_output_node:
|
||||
z = act(act_idx, z, self.activation_options)
|
||||
z = jax.lax.cond(
|
||||
is_output_node,
|
||||
lambda: z,
|
||||
lambda: act(act_idx, z, self.activation_options)
|
||||
)
|
||||
|
||||
return z
|
||||
|
||||
@@ -58,7 +58,7 @@ class DefaultGenome(BaseGenome):
|
||||
|
||||
def hit():
|
||||
ins = jax.vmap(self.conn_gene.forward, in_axes=(1, 0))(conns[:, :, i], values)
|
||||
z = self.node_gene.forward(nodes_attrs[i], ins)
|
||||
z = self.node_gene.forward(nodes_attrs[i], ins, is_output_node=jnp.isin(i, self.output_idx))
|
||||
new_values = values.at[i].set(z)
|
||||
return new_values
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Callable
|
||||
|
||||
import jax, jax.numpy as jnp
|
||||
from utils import unflatten_conns
|
||||
|
||||
@@ -18,10 +20,18 @@ class RecurrentGenome(BaseGenome):
|
||||
node_gene: BaseNodeGene = DefaultNodeGene(),
|
||||
conn_gene: BaseConnGene = DefaultConnGene(),
|
||||
activate_time: int = 10,
|
||||
output_transform: Callable = None
|
||||
):
|
||||
super().__init__(num_inputs, num_outputs, max_nodes, max_conns, node_gene, conn_gene)
|
||||
self.activate_time = activate_time
|
||||
|
||||
if output_transform is not None:
|
||||
try:
|
||||
_ = output_transform(jnp.zeros(num_outputs))
|
||||
except Exception as e:
|
||||
raise ValueError(f"Output transform function failed: {e}")
|
||||
self.output_transform = output_transform
|
||||
|
||||
def transform(self, nodes, conns):
|
||||
u_conns = unflatten_conns(nodes, conns)
|
||||
|
||||
@@ -52,7 +62,11 @@ class RecurrentGenome(BaseGenome):
|
||||
)(conns, values)
|
||||
|
||||
# calculate nodes
|
||||
values = jax.vmap(self.node_gene.forward)(nodes_attrs, node_ins.T)
|
||||
is_output_nodes = jnp.isin(
|
||||
jnp.arange(N),
|
||||
self.output_idx
|
||||
)
|
||||
values = jax.vmap(self.node_gene.forward)(nodes_attrs, node_ins.T, is_output_nodes)
|
||||
return values
|
||||
|
||||
vals = jax.lax.fori_loop(0, self.activate_time, body_func, vals)
|
||||
|
||||
@@ -2,6 +2,7 @@ from pipeline import Pipeline
|
||||
from algorithm.neat import *
|
||||
|
||||
from problem.func_fit import XOR3d
|
||||
from utils import Act
|
||||
|
||||
if __name__ == '__main__':
|
||||
pipeline = Pipeline(
|
||||
@@ -10,13 +11,25 @@ if __name__ == '__main__':
|
||||
genome=DefaultGenome(
|
||||
num_inputs=3,
|
||||
num_outputs=1,
|
||||
max_nodes=50,
|
||||
max_conns=100,
|
||||
max_nodes=100,
|
||||
max_conns=200,
|
||||
node_gene=DefaultNodeGene(
|
||||
activation_default=Act.tanh,
|
||||
activation_options=(Act.tanh,),
|
||||
),
|
||||
output_transform=Act.sigmoid, # the activation function for output node
|
||||
),
|
||||
pop_size=10000,
|
||||
species_size=10,
|
||||
compatibility_threshold=3.5,
|
||||
survival_threshold=0.01, # magic
|
||||
),
|
||||
mutation=DefaultMutation(
|
||||
node_add=0.05,
|
||||
conn_add=0.2,
|
||||
node_delete=0,
|
||||
conn_delete=0,
|
||||
)
|
||||
),
|
||||
problem=XOR3d(),
|
||||
generation_limit=10000,
|
||||
|
||||
@@ -28,14 +28,17 @@ if __name__ == '__main__':
|
||||
activation_default=Act.tanh,
|
||||
activation_options=(Act.tanh,),
|
||||
),
|
||||
output_transform=Act.tanh, # the activation function for output node in NEAT
|
||||
),
|
||||
pop_size=10000,
|
||||
species_size=10,
|
||||
compatibility_threshold=3.5,
|
||||
survival_threshold=0.03,
|
||||
),
|
||||
),
|
||||
activation=Act.sigmoid,
|
||||
activation=Act.tanh,
|
||||
activate_time=10,
|
||||
output_transform=Act.sigmoid, # the activation function for output node in HyperNEAT
|
||||
),
|
||||
problem=XOR3d(),
|
||||
generation_limit=300,
|
||||
|
||||
@@ -2,8 +2,7 @@ from pipeline import Pipeline
|
||||
from algorithm.neat import *
|
||||
|
||||
from problem.func_fit import XOR3d
|
||||
from utils.activation import ACT_ALL
|
||||
from utils.aggregation import AGG_ALL
|
||||
from utils.activation import ACT_ALL, Act
|
||||
|
||||
if __name__ == '__main__':
|
||||
pipeline = Pipeline(
|
||||
@@ -18,14 +17,21 @@ if __name__ == '__main__':
|
||||
activate_time=5,
|
||||
node_gene=DefaultNodeGene(
|
||||
activation_options=ACT_ALL,
|
||||
# aggregation_options=AGG_ALL,
|
||||
activation_replace_rate=0.2
|
||||
),
|
||||
output_transform=Act.sigmoid
|
||||
),
|
||||
pop_size=10000,
|
||||
species_size=10,
|
||||
compatibility_threshold=3.5,
|
||||
survival_threshold=0.03,
|
||||
),
|
||||
mutation=DefaultMutation(
|
||||
node_add=0.05,
|
||||
conn_add=0.2,
|
||||
node_delete=0,
|
||||
conn_delete=0,
|
||||
)
|
||||
),
|
||||
problem=XOR3d(),
|
||||
generation_limit=10000,
|
||||
|
||||
Reference in New Issue
Block a user