disable activation in the output node of network;

we recommend to use output_transform;
change hyperparameters (strong) in XOR example;
This commit is contained in:
wls2002
2024-05-22 11:09:25 +08:00
parent bb80f12640
commit 1fe5d5fca2
7 changed files with 59 additions and 14 deletions

View File

@@ -1,3 +1,5 @@
from typing import Callable
import jax, jax.numpy as jnp import jax, jax.numpy as jnp
from utils import State, Act, Agg from utils import State, Act, Agg
@@ -18,6 +20,7 @@ class HyperNEAT(BaseAlgorithm):
activation=Act.sigmoid, activation=Act.sigmoid,
aggregation=Agg.sum, aggregation=Agg.sum,
activate_time: int = 10, activate_time: int = 10,
output_transform: Callable = Act.sigmoid,
): ):
assert substrate.query_coors.shape[1] == neat.num_inputs, \ assert substrate.query_coors.shape[1] == neat.num_inputs, \
"Substrate input size should be equal to NEAT input size" "Substrate input size should be equal to NEAT input size"
@@ -34,6 +37,7 @@ class HyperNEAT(BaseAlgorithm):
node_gene=HyperNodeGene(activation, aggregation), node_gene=HyperNodeGene(activation, aggregation),
conn_gene=HyperNEATConnGene(), conn_gene=HyperNEATConnGene(),
activate_time=activate_time, activate_time=activate_time,
output_transform=output_transform
) )
def setup(self, randkey): def setup(self, randkey):
@@ -102,11 +106,13 @@ class HyperNodeGene(BaseNodeGene):
self.activation = activation self.activation = activation
self.aggregation = aggregation self.aggregation = aggregation
def forward(self, attrs, inputs): def forward(self, attrs, inputs, is_output_node=False):
return self.activation( return jax.lax.cond(
self.aggregation(inputs) is_output_node,
) lambda: self.aggregation(inputs), # output node does not need activation
lambda: self.activation(self.aggregation(inputs))
)
class HyperNEATConnGene(BaseConnGene): class HyperNEATConnGene(BaseConnGene):
custom_attrs = ['weight'] custom_attrs = ['weight']

View File

@@ -92,7 +92,10 @@ class DefaultNodeGene(BaseNodeGene):
z = bias + res * z z = bias + res * z
# the last output node should not be activated # the last output node should not be activated
if not is_output_node: z = jax.lax.cond(
z = act(act_idx, z, self.activation_options) is_output_node,
lambda: z,
lambda: act(act_idx, z, self.activation_options)
)
return z return z

View File

@@ -58,7 +58,7 @@ class DefaultGenome(BaseGenome):
def hit(): def hit():
ins = jax.vmap(self.conn_gene.forward, in_axes=(1, 0))(conns[:, :, i], values) ins = jax.vmap(self.conn_gene.forward, in_axes=(1, 0))(conns[:, :, i], values)
z = self.node_gene.forward(nodes_attrs[i], ins) z = self.node_gene.forward(nodes_attrs[i], ins, is_output_node=jnp.isin(i, self.output_idx))
new_values = values.at[i].set(z) new_values = values.at[i].set(z)
return new_values return new_values

View File

@@ -1,3 +1,5 @@
from typing import Callable
import jax, jax.numpy as jnp import jax, jax.numpy as jnp
from utils import unflatten_conns from utils import unflatten_conns
@@ -18,10 +20,18 @@ class RecurrentGenome(BaseGenome):
node_gene: BaseNodeGene = DefaultNodeGene(), node_gene: BaseNodeGene = DefaultNodeGene(),
conn_gene: BaseConnGene = DefaultConnGene(), conn_gene: BaseConnGene = DefaultConnGene(),
activate_time: int = 10, activate_time: int = 10,
output_transform: Callable = None
): ):
super().__init__(num_inputs, num_outputs, max_nodes, max_conns, node_gene, conn_gene) super().__init__(num_inputs, num_outputs, max_nodes, max_conns, node_gene, conn_gene)
self.activate_time = activate_time self.activate_time = activate_time
if output_transform is not None:
try:
_ = output_transform(jnp.zeros(num_outputs))
except Exception as e:
raise ValueError(f"Output transform function failed: {e}")
self.output_transform = output_transform
def transform(self, nodes, conns): def transform(self, nodes, conns):
u_conns = unflatten_conns(nodes, conns) u_conns = unflatten_conns(nodes, conns)
@@ -52,7 +62,11 @@ class RecurrentGenome(BaseGenome):
)(conns, values) )(conns, values)
# calculate nodes # calculate nodes
values = jax.vmap(self.node_gene.forward)(nodes_attrs, node_ins.T) is_output_nodes = jnp.isin(
jnp.arange(N),
self.output_idx
)
values = jax.vmap(self.node_gene.forward)(nodes_attrs, node_ins.T, is_output_nodes)
return values return values
vals = jax.lax.fori_loop(0, self.activate_time, body_func, vals) vals = jax.lax.fori_loop(0, self.activate_time, body_func, vals)

View File

@@ -2,6 +2,7 @@ from pipeline import Pipeline
from algorithm.neat import * from algorithm.neat import *
from problem.func_fit import XOR3d from problem.func_fit import XOR3d
from utils import Act
if __name__ == '__main__': if __name__ == '__main__':
pipeline = Pipeline( pipeline = Pipeline(
@@ -10,13 +11,25 @@ if __name__ == '__main__':
genome=DefaultGenome( genome=DefaultGenome(
num_inputs=3, num_inputs=3,
num_outputs=1, num_outputs=1,
max_nodes=50, max_nodes=100,
max_conns=100, max_conns=200,
node_gene=DefaultNodeGene(
activation_default=Act.tanh,
activation_options=(Act.tanh,),
),
output_transform=Act.sigmoid, # the activation function for output node
), ),
pop_size=10000, pop_size=10000,
species_size=10, species_size=10,
compatibility_threshold=3.5, compatibility_threshold=3.5,
survival_threshold=0.01, # magic
), ),
mutation=DefaultMutation(
node_add=0.05,
conn_add=0.2,
node_delete=0,
conn_delete=0,
)
), ),
problem=XOR3d(), problem=XOR3d(),
generation_limit=10000, generation_limit=10000,

View File

@@ -28,14 +28,17 @@ if __name__ == '__main__':
activation_default=Act.tanh, activation_default=Act.tanh,
activation_options=(Act.tanh,), activation_options=(Act.tanh,),
), ),
output_transform=Act.tanh, # the activation function for output node in NEAT
), ),
pop_size=10000, pop_size=10000,
species_size=10, species_size=10,
compatibility_threshold=3.5, compatibility_threshold=3.5,
survival_threshold=0.03,
), ),
), ),
activation=Act.sigmoid, activation=Act.tanh,
activate_time=10, activate_time=10,
output_transform=Act.sigmoid, # the activation function for output node in HyperNEAT
), ),
problem=XOR3d(), problem=XOR3d(),
generation_limit=300, generation_limit=300,

View File

@@ -2,8 +2,7 @@ from pipeline import Pipeline
from algorithm.neat import * from algorithm.neat import *
from problem.func_fit import XOR3d from problem.func_fit import XOR3d
from utils.activation import ACT_ALL from utils.activation import ACT_ALL, Act
from utils.aggregation import AGG_ALL
if __name__ == '__main__': if __name__ == '__main__':
pipeline = Pipeline( pipeline = Pipeline(
@@ -18,14 +17,21 @@ if __name__ == '__main__':
activate_time=5, activate_time=5,
node_gene=DefaultNodeGene( node_gene=DefaultNodeGene(
activation_options=ACT_ALL, activation_options=ACT_ALL,
# aggregation_options=AGG_ALL,
activation_replace_rate=0.2 activation_replace_rate=0.2
), ),
output_transform=Act.sigmoid
), ),
pop_size=10000, pop_size=10000,
species_size=10, species_size=10,
compatibility_threshold=3.5, compatibility_threshold=3.5,
survival_threshold=0.03,
), ),
mutation=DefaultMutation(
node_add=0.05,
conn_add=0.2,
node_delete=0,
conn_delete=0,
)
), ),
problem=XOR3d(), problem=XOR3d(),
generation_limit=10000, generation_limit=10000,