complete fully stateful!
use black to format all files!
This commit is contained in:
@@ -5,5 +5,5 @@ class BaseCrossover:
|
||||
def setup(self, state=State()):
|
||||
return state
|
||||
|
||||
def __call__(self, state, genome, nodes1, nodes2, conns1, conns2):
|
||||
def __call__(self, state, randkey, genome, nodes1, nodes2, conns1, conns2):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -4,12 +4,12 @@ from .base import BaseCrossover
|
||||
|
||||
|
||||
class DefaultCrossover(BaseCrossover):
|
||||
def __call__(self, state, genome, nodes1, conns1, nodes2, conns2):
|
||||
def __call__(self, state, randkey, genome, nodes1, conns1, nodes2, conns2):
|
||||
"""
|
||||
use genome1 and genome2 to generate a new genome
|
||||
notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
|
||||
"""
|
||||
randkey1, randkey2, randkey = jax.random.split(state.randkey, 3)
|
||||
randkey1, randkey2 = jax.random.split(randkey, 2)
|
||||
|
||||
# crossover nodes
|
||||
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
|
||||
@@ -34,11 +34,12 @@ class DefaultCrossover(BaseCrossover):
|
||||
self.crossover_gene(randkey2, conns1, conns2, is_conn=True),
|
||||
)
|
||||
|
||||
return state.update(randkey=randkey), new_nodes, new_conns
|
||||
return new_nodes, new_conns
|
||||
|
||||
def align_array(self, seq1, seq2, ar2, is_conn: bool):
|
||||
"""
|
||||
After I review this code, I found that it is the most difficult part of the code. Please never change it!
|
||||
After I review this code, I found that it is the most difficult part of the code.
|
||||
Please consider carefully before change it!
|
||||
make ar2 align with ar1.
|
||||
:param seq1:
|
||||
:param seq2:
|
||||
@@ -64,8 +65,8 @@ class DefaultCrossover(BaseCrossover):
|
||||
|
||||
return refactor_ar2
|
||||
|
||||
def crossover_gene(self, rand_key, g1, g2, is_conn):
|
||||
r = jax.random.uniform(rand_key, shape=g1.shape)
|
||||
def crossover_gene(self, randkey, g1, g2, is_conn):
|
||||
r = jax.random.uniform(randkey, shape=g1.shape)
|
||||
new_gene = jnp.where(r > 0.5, g1, g2)
|
||||
if is_conn: # fix enabled
|
||||
enabled = jnp.where(g1[:, 2] + g2[:, 2] > 0, 1, 0) # any of them is enabled
|
||||
|
||||
@@ -5,5 +5,5 @@ class BaseMutation:
|
||||
def setup(self, state=State()):
|
||||
return state
|
||||
|
||||
def __call__(self, state, genome, nodes, conns, new_node_key):
|
||||
def __call__(self, state, randkey, genome, nodes, conns, new_node_key):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1,6 +1,16 @@
|
||||
import jax, jax.numpy as jnp
|
||||
from . import BaseMutation
|
||||
from utils import fetch_first, fetch_random, I_INF, unflatten_conns, check_cycles
|
||||
from utils import (
|
||||
fetch_first,
|
||||
fetch_random,
|
||||
I_INF,
|
||||
unflatten_conns,
|
||||
check_cycles,
|
||||
add_node,
|
||||
add_conn,
|
||||
delete_node_by_pos,
|
||||
delete_conn_by_pos,
|
||||
)
|
||||
|
||||
|
||||
class DefaultMutation(BaseMutation):
|
||||
@@ -16,15 +26,17 @@ class DefaultMutation(BaseMutation):
|
||||
self.node_add = node_add
|
||||
self.node_delete = node_delete
|
||||
|
||||
def __call__(self, state, genome, nodes, conns, new_node_key):
|
||||
k1, k2, randkey = jax.random.split(state.randkey)
|
||||
def __call__(self, state, randkey, genome, nodes, conns, new_node_key):
|
||||
k1, k2 = jax.random.split(randkey)
|
||||
|
||||
nodes, conns = self.mutate_structure(k1, genome, nodes, conns, new_node_key)
|
||||
nodes, conns = self.mutate_values(k2, genome, nodes, conns)
|
||||
nodes, conns = self.mutate_structure(
|
||||
state, k1, genome, nodes, conns, new_node_key
|
||||
)
|
||||
nodes, conns = self.mutate_values(state, k2, genome, nodes, conns)
|
||||
|
||||
return state.update(randkey=randkey), nodes, conns
|
||||
return nodes, conns
|
||||
|
||||
def mutate_structure(self, key, genome, nodes, conns, new_node_key):
|
||||
def mutate_structure(self, state, randkey, genome, nodes, conns, new_node_key):
|
||||
def mutate_add_node(key_, nodes_, conns_):
|
||||
i_key, o_key, idx = self.choice_connection_key(key_, conns_)
|
||||
|
||||
@@ -33,24 +45,24 @@ class DefaultMutation(BaseMutation):
|
||||
new_conns = conns_.at[idx, 2].set(False)
|
||||
|
||||
# add a new node
|
||||
new_nodes = genome.add_node(
|
||||
nodes_, new_node_key, genome.node_gene.new_custom_attrs()
|
||||
new_nodes = add_node(
|
||||
nodes_, new_node_key, genome.node_gene.new_custom_attrs(state)
|
||||
)
|
||||
|
||||
# add two new connections
|
||||
new_conns = genome.add_conn(
|
||||
new_conns = add_conn(
|
||||
new_conns,
|
||||
i_key,
|
||||
new_node_key,
|
||||
True,
|
||||
genome.conn_gene.new_custom_attrs(),
|
||||
genome.conn_gene.new_custom_attrs(state),
|
||||
)
|
||||
new_conns = genome.add_conn(
|
||||
new_conns = add_conn(
|
||||
new_conns,
|
||||
new_node_key,
|
||||
o_key,
|
||||
True,
|
||||
genome.conn_gene.new_custom_attrs(),
|
||||
genome.conn_gene.new_custom_attrs(state),
|
||||
)
|
||||
|
||||
return new_nodes, new_conns
|
||||
@@ -75,7 +87,7 @@ class DefaultMutation(BaseMutation):
|
||||
|
||||
def successful_delete_node():
|
||||
# delete the node
|
||||
new_nodes = genome.delete_node_by_pos(nodes_, idx)
|
||||
new_nodes = delete_node_by_pos(nodes_, idx)
|
||||
|
||||
# delete all connections
|
||||
new_conns = jnp.where(
|
||||
@@ -123,8 +135,8 @@ class DefaultMutation(BaseMutation):
|
||||
return nodes_, conns_
|
||||
|
||||
def successful():
|
||||
return nodes_, genome.add_conn(
|
||||
conns_, i_key, o_key, True, genome.conn_gene.new_custom_attrs()
|
||||
return nodes_, add_conn(
|
||||
conns_, i_key, o_key, True, genome.conn_gene.new_custom_attrs(state)
|
||||
)
|
||||
|
||||
def already_exist():
|
||||
@@ -152,7 +164,7 @@ class DefaultMutation(BaseMutation):
|
||||
i_key, o_key, idx = self.choice_connection_key(key_, conns_)
|
||||
|
||||
def successfully_delete_connection():
|
||||
return nodes_, genome.delete_conn_by_pos(conns_, idx)
|
||||
return nodes_, delete_conn_by_pos(conns_, idx)
|
||||
|
||||
return jax.lax.cond(
|
||||
idx == I_INF,
|
||||
@@ -160,7 +172,7 @@ class DefaultMutation(BaseMutation):
|
||||
successfully_delete_connection,
|
||||
)
|
||||
|
||||
k1, k2, k3, k4 = jax.random.split(key, num=4)
|
||||
k1, k2, k3, k4 = jax.random.split(randkey, num=4)
|
||||
r1, r2, r3, r4 = jax.random.uniform(k1, shape=(4,))
|
||||
|
||||
def no(key_, nodes_, conns_):
|
||||
@@ -181,13 +193,17 @@ class DefaultMutation(BaseMutation):
|
||||
|
||||
return nodes, conns
|
||||
|
||||
def mutate_values(self, key, genome, nodes, conns):
|
||||
k1, k2 = jax.random.split(key, num=2)
|
||||
def mutate_values(self, state, randkey, genome, nodes, conns):
|
||||
k1, k2 = jax.random.split(randkey, num=2)
|
||||
nodes_keys = jax.random.split(k1, num=nodes.shape[0])
|
||||
conns_keys = jax.random.split(k2, num=conns.shape[0])
|
||||
|
||||
new_nodes = jax.vmap(genome.node_gene.mutate)(nodes_keys, nodes)
|
||||
new_conns = jax.vmap(genome.conn_gene.mutate)(conns_keys, conns)
|
||||
new_nodes = jax.vmap(genome.node_gene.mutate, in_axes=(None, 0, 0))(
|
||||
state, nodes_keys, nodes
|
||||
)
|
||||
new_conns = jax.vmap(genome.conn_gene.mutate, in_axes=(None, 0, 0))(
|
||||
state, conns_keys, conns
|
||||
)
|
||||
|
||||
# nan nodes not changed
|
||||
new_nodes = jnp.where(jnp.isnan(nodes), jnp.nan, new_nodes)
|
||||
|
||||
Reference in New Issue
Block a user