disable activation in the output node of network;

we recommend to use output_transform;
change hyperparameters (strong) in XOR example;
This commit is contained in:
wls2002
2024-05-22 11:09:25 +08:00
parent bb80f12640
commit 1fe5d5fca2
7 changed files with 59 additions and 14 deletions

View File

@@ -1,3 +1,5 @@
from typing import Callable
import jax, jax.numpy as jnp
from utils import State, Act, Agg
@@ -18,6 +20,7 @@ class HyperNEAT(BaseAlgorithm):
activation=Act.sigmoid,
aggregation=Agg.sum,
activate_time: int = 10,
output_transform: Callable = Act.sigmoid,
):
assert substrate.query_coors.shape[1] == neat.num_inputs, \
"Substrate input size should be equal to NEAT input size"
@@ -34,6 +37,7 @@ class HyperNEAT(BaseAlgorithm):
node_gene=HyperNodeGene(activation, aggregation),
conn_gene=HyperNEATConnGene(),
activate_time=activate_time,
output_transform=output_transform
)
def setup(self, randkey):
@@ -102,11 +106,13 @@ class HyperNodeGene(BaseNodeGene):
self.activation = activation
self.aggregation = aggregation
def forward(self, attrs, inputs):
return self.activation(
self.aggregation(inputs)
)
def forward(self, attrs, inputs, is_output_node=False):
return jax.lax.cond(
is_output_node,
lambda: self.aggregation(inputs), # output node does not need activation
lambda: self.activation(self.aggregation(inputs))
)
class HyperNEATConnGene(BaseConnGene):
custom_attrs = ['weight']

View File

@@ -92,7 +92,10 @@ class DefaultNodeGene(BaseNodeGene):
z = bias + res * z
# the last output node should not be activated
if not is_output_node:
z = act(act_idx, z, self.activation_options)
z = jax.lax.cond(
is_output_node,
lambda: z,
lambda: act(act_idx, z, self.activation_options)
)
return z

View File

@@ -58,7 +58,7 @@ class DefaultGenome(BaseGenome):
def hit():
ins = jax.vmap(self.conn_gene.forward, in_axes=(1, 0))(conns[:, :, i], values)
z = self.node_gene.forward(nodes_attrs[i], ins)
z = self.node_gene.forward(nodes_attrs[i], ins, is_output_node=jnp.isin(i, self.output_idx))
new_values = values.at[i].set(z)
return new_values

View File

@@ -1,3 +1,5 @@
from typing import Callable
import jax, jax.numpy as jnp
from utils import unflatten_conns
@@ -18,10 +20,18 @@ class RecurrentGenome(BaseGenome):
node_gene: BaseNodeGene = DefaultNodeGene(),
conn_gene: BaseConnGene = DefaultConnGene(),
activate_time: int = 10,
output_transform: Callable = None
):
super().__init__(num_inputs, num_outputs, max_nodes, max_conns, node_gene, conn_gene)
self.activate_time = activate_time
if output_transform is not None:
try:
_ = output_transform(jnp.zeros(num_outputs))
except Exception as e:
raise ValueError(f"Output transform function failed: {e}")
self.output_transform = output_transform
def transform(self, nodes, conns):
u_conns = unflatten_conns(nodes, conns)
@@ -52,7 +62,11 @@ class RecurrentGenome(BaseGenome):
)(conns, values)
# calculate nodes
values = jax.vmap(self.node_gene.forward)(nodes_attrs, node_ins.T)
is_output_nodes = jnp.isin(
jnp.arange(N),
self.output_idx
)
values = jax.vmap(self.node_gene.forward)(nodes_attrs, node_ins.T, is_output_nodes)
return values
vals = jax.lax.fori_loop(0, self.activate_time, body_func, vals)

View File

@@ -2,6 +2,7 @@ from pipeline import Pipeline
from algorithm.neat import *
from problem.func_fit import XOR3d
from utils import Act
if __name__ == '__main__':
pipeline = Pipeline(
@@ -10,13 +11,25 @@ if __name__ == '__main__':
genome=DefaultGenome(
num_inputs=3,
num_outputs=1,
max_nodes=50,
max_conns=100,
max_nodes=100,
max_conns=200,
node_gene=DefaultNodeGene(
activation_default=Act.tanh,
activation_options=(Act.tanh,),
),
output_transform=Act.sigmoid, # the activation function for output node
),
pop_size=10000,
species_size=10,
compatibility_threshold=3.5,
survival_threshold=0.01, # magic
),
mutation=DefaultMutation(
node_add=0.05,
conn_add=0.2,
node_delete=0,
conn_delete=0,
)
),
problem=XOR3d(),
generation_limit=10000,

View File

@@ -28,14 +28,17 @@ if __name__ == '__main__':
activation_default=Act.tanh,
activation_options=(Act.tanh,),
),
output_transform=Act.tanh, # the activation function for output node in NEAT
),
pop_size=10000,
species_size=10,
compatibility_threshold=3.5,
survival_threshold=0.03,
),
),
activation=Act.sigmoid,
activation=Act.tanh,
activate_time=10,
output_transform=Act.sigmoid, # the activation function for output node in HyperNEAT
),
problem=XOR3d(),
generation_limit=300,

View File

@@ -2,8 +2,7 @@ from pipeline import Pipeline
from algorithm.neat import *
from problem.func_fit import XOR3d
from utils.activation import ACT_ALL
from utils.aggregation import AGG_ALL
from utils.activation import ACT_ALL, Act
if __name__ == '__main__':
pipeline = Pipeline(
@@ -18,14 +17,21 @@ if __name__ == '__main__':
activate_time=5,
node_gene=DefaultNodeGene(
activation_options=ACT_ALL,
# aggregation_options=AGG_ALL,
activation_replace_rate=0.2
),
output_transform=Act.sigmoid
),
pop_size=10000,
species_size=10,
compatibility_threshold=3.5,
survival_threshold=0.03,
),
mutation=DefaultMutation(
node_add=0.05,
conn_add=0.2,
node_delete=0,
conn_delete=0,
)
),
problem=XOR3d(),
generation_limit=10000,