Merge branch 'main' into advance

This commit is contained in:
WLS2002
2024-05-24 19:42:03 +08:00
committed by GitHub
17 changed files with 156 additions and 82 deletions

View File

@@ -1,3 +1,5 @@
from typing import Callable
import jax, jax.numpy as jnp
from utils import State, Act, Agg
@@ -18,6 +20,7 @@ class HyperNEAT(BaseAlgorithm):
activation=Act.sigmoid,
aggregation=Agg.sum,
activate_time: int = 10,
output_transform: Callable = Act.sigmoid,
):
assert substrate.query_coors.shape[1] == neat.num_inputs, \
"Substrate input size should be equal to NEAT input size"
@@ -34,6 +37,7 @@ class HyperNEAT(BaseAlgorithm):
node_gene=HyperNodeGene(activation, aggregation),
conn_gene=HyperNEATConnGene(),
activate_time=activate_time,
output_transform=output_transform
)
def setup(self, randkey):
@@ -102,11 +106,13 @@ class HyperNodeGene(BaseNodeGene):
self.activation = activation
self.aggregation = aggregation
def forward(self, attrs, inputs):
return self.activation(
self.aggregation(inputs)
)
def forward(self, attrs, inputs, is_output_node=False):
return jax.lax.cond(
is_output_node,
lambda: self.aggregation(inputs), # output node does not need activation
lambda: self.activation(self.aggregation(inputs))
)
class HyperNEATConnGene(BaseConnGene):
custom_attrs = ['weight']

View File

@@ -1,3 +1,3 @@
class BaseCrossover:
def __call__(self, randkey, genome, nodes1, nodes2, conns1, conns2):
raise NotImplementedError
raise NotImplementedError

View File

@@ -2,6 +2,7 @@ import jax, jax.numpy as jnp
from .base import BaseCrossover
class DefaultCrossover(BaseCrossover):
def __call__(self, randkey, genome, nodes1, conns1, nodes2, conns2):
@@ -14,17 +15,19 @@ class DefaultCrossover(BaseCrossover):
# crossover nodes
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
# make homologous genes align in nodes2 align with nodes1
nodes2 = self.align_array(keys1, keys2, nodes2, False)
nodes2 = self.align_array(keys1, keys2, nodes2, is_conn=False)
# For not homologous genes, use the value of nodes1(winner)
# For homologous genes, use the crossover result between nodes1 and nodes2
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, self.crossover_gene(randkey_1, nodes1, nodes2))
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1,
self.crossover_gene(randkey_1, nodes1, nodes2, is_conn=False))
# crossover connections
con_keys1, con_keys2 = conns1[:, :2], conns2[:, :2]
conns2 = self.align_array(con_keys1, con_keys2, conns2, True)
conns2 = self.align_array(con_keys1, con_keys2, conns2, is_conn=True)
new_conns = jnp.where(jnp.isnan(conns1) | jnp.isnan(conns2), conns1, self.crossover_gene(randkey_2, conns1, conns2))
new_conns = jnp.where(jnp.isnan(conns1) | jnp.isnan(conns2), conns1,
self.crossover_gene(randkey_2, conns1, conns2, is_conn=True))
return new_nodes, new_conns
@@ -54,14 +57,14 @@ class DefaultCrossover(BaseCrossover):
return refactor_ar2
def crossover_gene(self, rand_key, g1, g2):
"""
crossover two genes
:param rand_key:
:param g1:
:param g2:
:return:
only gene with the same key will be crossover, thus don't need to consider change key
"""
def crossover_gene(self, rand_key, g1, g2, is_conn):
r = jax.random.uniform(rand_key, shape=g1.shape)
return jnp.where(r > 0.5, g1, g2)
new_gene = jnp.where(r > 0.5, g1, g2)
if is_conn: # fix enabled
enabled = jnp.where(
g1[:, 2] + g2[:, 2] > 0, # any of them is enabled
1,
0
)
new_gene = new_gene.at[:, 2].set(enabled)
return new_gene

View File

@@ -154,8 +154,8 @@ class DefaultMutation(BaseMutation):
nodes_keys = jax.random.split(k1, num=nodes.shape[0])
conns_keys = jax.random.split(k2, num=conns.shape[0])
new_nodes = jax.vmap(genome.node_gene.mutate, in_axes=(0, 0))(nodes_keys, nodes)
new_conns = jax.vmap(genome.conn_gene.mutate, in_axes=(0, 0))(conns_keys, conns)
new_nodes = jax.vmap(genome.node_gene.mutate)(nodes_keys, nodes)
new_conns = jax.vmap(genome.conn_gene.mutate)(conns_keys, conns)
# nan nodes not changed
new_nodes = jnp.where(jnp.isnan(nodes), jnp.nan, new_nodes)

View File

@@ -8,5 +8,5 @@ class BaseNodeGene(BaseGene):
def __init__(self):
super().__init__()
def forward(self, attrs, inputs):
def forward(self, attrs, inputs, is_output_node=False):
raise NotImplementedError

View File

@@ -95,11 +95,17 @@ class DefaultNodeGene(BaseNodeGene):
(node1[4] != node2[4])
)
def forward(self, attrs, inputs):
def forward(self, attrs, inputs, is_output_node=False):
bias, res, act_idx, agg_idx = attrs
z = agg(agg_idx, inputs, self.aggregation_options)
z = bias + res * z
z = act(act_idx, z, self.activation_options)
# the last output node should not be activated
z = jax.lax.cond(
is_output_node,
lambda: z,
lambda: act(act_idx, z, self.activation_options)
)
return z

View File

@@ -25,19 +25,13 @@ class DefaultGenome(BaseGenome):
if output_transform is not None:
try:
aux = output_transform(jnp.zeros(num_outputs))
_ = output_transform(jnp.zeros(num_outputs))
except Exception as e:
raise ValueError(f"Output transform function failed: {e}")
self.output_transform = output_transform
def transform(self, nodes, conns):
u_conns = unflatten_conns(nodes, conns)
# DONE: Seems like there is a bug in this line
# conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False)
# modified: exist conn and enable is true
# conn_enable = jnp.where( (~jnp.isnan(u_conns[0])) & (u_conns[0] == 1), True, False)
# advanced modified: when and only when enabled is True
conn_enable = u_conns[0] == 1
# remove enable attr
@@ -64,13 +58,7 @@ class DefaultGenome(BaseGenome):
def hit():
ins = jax.vmap(self.conn_gene.forward, in_axes=(1, 0))(conns[:, :, i], values)
# ins = values * weights[:, i]
z = self.node_gene.forward(nodes_attrs[i], ins)
# z = agg(nodes[i, 4], ins, self.config.aggregation_options) # z = agg(ins)
# z = z * nodes[i, 2] + nodes[i, 1] # z = z * response + bias
# z = act(nodes[i, 3], z, self.config.activation_options) # z = act(z)
z = self.node_gene.forward(nodes_attrs[i], ins, is_output_node=jnp.isin(i, self.output_idx))
new_values = values.at[i].set(z)
return new_values
@@ -78,7 +66,11 @@ class DefaultGenome(BaseGenome):
return values
# the val of input nodes is obtained by the task, not by calculation
values = jax.lax.cond(jnp.isin(i, self.input_idx), miss, hit)
values = jax.lax.cond(
jnp.isin(i, self.input_idx),
miss,
hit
)
return values, idx + 1

View File

@@ -1,3 +1,5 @@
from typing import Callable
import jax, jax.numpy as jnp
from utils import unflatten_conns
@@ -18,10 +20,18 @@ class RecurrentGenome(BaseGenome):
node_gene: BaseNodeGene = DefaultNodeGene(),
conn_gene: BaseConnGene = DefaultConnGene(),
activate_time: int = 10,
output_transform: Callable = None
):
super().__init__(num_inputs, num_outputs, max_nodes, max_conns, node_gene, conn_gene)
self.activate_time = activate_time
if output_transform is not None:
try:
_ = output_transform(jnp.zeros(num_outputs))
except Exception as e:
raise ValueError(f"Output transform function failed: {e}")
self.output_transform = output_transform
def transform(self, nodes, conns):
u_conns = unflatten_conns(nodes, conns)
@@ -52,7 +62,11 @@ class RecurrentGenome(BaseGenome):
)(conns, values)
# calculate nodes
values = jax.vmap(self.node_gene.forward)(nodes_attrs, node_ins.T)
is_output_nodes = jnp.isin(
jnp.arange(N),
self.output_idx
)
values = jax.vmap(self.node_gene.forward)(nodes_attrs, node_ins.T, is_output_nodes)
return values
vals = jax.lax.fori_loop(0, self.activate_time, body_func, vals)