From 20320105e656aad9afe6cb6c4126e3c4efe2afee Mon Sep 17 00:00:00 2001 From: wls2002 Date: Thu, 30 May 2024 15:06:08 +0800 Subject: [PATCH] move crossover_gene from ga.crossover to gene.basegene. --- .../algorithm/neat/ga/crossover/default.py | 14 ++++---------- tensorneat/algorithm/neat/gene/base.py | 3 +++ tensorneat/algorithm/neat/gene/conn/base.py | 19 +++++++++++++++++++ tensorneat/algorithm/neat/gene/node/base.py | 8 ++++++++ 4 files changed, 34 insertions(+), 10 deletions(-) diff --git a/tensorneat/algorithm/neat/ga/crossover/default.py b/tensorneat/algorithm/neat/ga/crossover/default.py index 5c3014f..073f8ef 100644 --- a/tensorneat/algorithm/neat/ga/crossover/default.py +++ b/tensorneat/algorithm/neat/ga/crossover/default.py @@ -10,6 +10,8 @@ class DefaultCrossover(BaseCrossover): notice that genome1 should have higher fitness than genome2 (genome1 is winner!) """ randkey1, randkey2 = jax.random.split(randkey, 2) + randkeys1 = jax.random.split(randkey1, genome.max_nodes) + randkeys2 = jax.random.split(randkey2, genome.max_conns) # crossover nodes keys1, keys2 = nodes1[:, 0], nodes2[:, 0] @@ -21,7 +23,7 @@ class DefaultCrossover(BaseCrossover): new_nodes = jnp.where( jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1, - self.crossover_gene(randkey1, nodes1, nodes2, is_conn=False), + jax.vmap(genome.node_gene.crossover, in_axes=(None, 0, 0, 0))(state, randkeys1, nodes1, nodes2), ) # crossover connections @@ -31,7 +33,7 @@ class DefaultCrossover(BaseCrossover): new_conns = jnp.where( jnp.isnan(conns1) | jnp.isnan(conns2), conns1, - self.crossover_gene(randkey2, conns1, conns2, is_conn=True), + jax.vmap(genome.conn_gene.crossover, in_axes=(None, 0, 0, 0))(state, randkeys2, conns1, conns2), ) return new_nodes, new_conns @@ -64,11 +66,3 @@ class DefaultCrossover(BaseCrossover): ) return refactor_ar2 - - def crossover_gene(self, randkey, g1, g2, is_conn): - r = jax.random.uniform(randkey, shape=g1.shape) - new_gene = jnp.where(r > 0.5, g1, g2) - if is_conn: # fix enabled - enabled = jnp.where(g1[:, 2] + g2[:, 2] > 0, 1, 0) # any of them is enabled - new_gene = new_gene.at[:, 2].set(enabled) - return new_gene diff --git a/tensorneat/algorithm/neat/gene/base.py b/tensorneat/algorithm/neat/gene/base.py index afbd5f6..bef2182 100644 --- a/tensorneat/algorithm/neat/gene/base.py +++ b/tensorneat/algorithm/neat/gene/base.py @@ -23,6 +23,9 @@ class BaseGene: def mutate(self, state, randkey, gene): raise NotImplementedError + def crossover(self, state, randkey, gene1, gene2): + raise NotImplementedError + def distance(self, state, gene1, gene2): raise NotImplementedError diff --git a/tensorneat/algorithm/neat/gene/conn/base.py b/tensorneat/algorithm/neat/gene/conn/base.py index a4ab3dc..9819129 100644 --- a/tensorneat/algorithm/neat/gene/conn/base.py +++ b/tensorneat/algorithm/neat/gene/conn/base.py @@ -1,3 +1,4 @@ +import jax, jax.numpy as jnp from .. import BaseGene @@ -8,5 +9,23 @@ class BaseConnGene(BaseGene): def __init__(self): super().__init__() + def crossover(self, state, randkey, gene1, gene2): + def crossover_attr(): + return jnp.where( + jax.random.normal(randkey, gene1.shape) > 0, + gene1, + gene2, + ) + + return jax.lax.cond( + gene1[2] == gene2[2], # if both genes are enabled or disabled + crossover_attr, # then randomly pick attributes from gene1 or gene2 + lambda: jnp.where( # one gene is enabled and the other is disabled + gene1[2], # if gene1 is enabled + gene1, # then return gene1 + gene2, # else return gene2 + ), + ) + def forward(self, state, attrs, inputs): raise NotImplementedError diff --git a/tensorneat/algorithm/neat/gene/node/base.py b/tensorneat/algorithm/neat/gene/node/base.py index 8b81299..55361f4 100644 --- a/tensorneat/algorithm/neat/gene/node/base.py +++ b/tensorneat/algorithm/neat/gene/node/base.py @@ -1,3 +1,4 @@ +import jax, jax.numpy as jnp from .. import BaseGene @@ -8,5 +9,12 @@ class BaseNodeGene(BaseGene): def __init__(self): super().__init__() + def crossover(self, state, randkey, gene1, gene2): + return jnp.where( + jax.random.normal(randkey, gene1.shape) > 0, + gene1, + gene2, + ) + def forward(self, state, attrs, inputs, is_output_node=False): raise NotImplementedError