complete fully stateful!

use black to format all files!
This commit is contained in:
wls2002
2024-05-26 18:08:43 +08:00
parent cf69b916af
commit 18c3d44c79
41 changed files with 620 additions and 495 deletions

View File

@@ -1,6 +1,16 @@
import jax, jax.numpy as jnp
from . import BaseMutation
from utils import fetch_first, fetch_random, I_INF, unflatten_conns, check_cycles
from utils import (
fetch_first,
fetch_random,
I_INF,
unflatten_conns,
check_cycles,
add_node,
add_conn,
delete_node_by_pos,
delete_conn_by_pos,
)
class DefaultMutation(BaseMutation):
@@ -16,15 +26,17 @@ class DefaultMutation(BaseMutation):
self.node_add = node_add
self.node_delete = node_delete
def __call__(self, state, genome, nodes, conns, new_node_key):
k1, k2, randkey = jax.random.split(state.randkey)
def __call__(self, state, randkey, genome, nodes, conns, new_node_key):
k1, k2 = jax.random.split(randkey)
nodes, conns = self.mutate_structure(k1, genome, nodes, conns, new_node_key)
nodes, conns = self.mutate_values(k2, genome, nodes, conns)
nodes, conns = self.mutate_structure(
state, k1, genome, nodes, conns, new_node_key
)
nodes, conns = self.mutate_values(state, k2, genome, nodes, conns)
return state.update(randkey=randkey), nodes, conns
return nodes, conns
def mutate_structure(self, key, genome, nodes, conns, new_node_key):
def mutate_structure(self, state, randkey, genome, nodes, conns, new_node_key):
def mutate_add_node(key_, nodes_, conns_):
i_key, o_key, idx = self.choice_connection_key(key_, conns_)
@@ -33,24 +45,24 @@ class DefaultMutation(BaseMutation):
new_conns = conns_.at[idx, 2].set(False)
# add a new node
new_nodes = genome.add_node(
nodes_, new_node_key, genome.node_gene.new_custom_attrs()
new_nodes = add_node(
nodes_, new_node_key, genome.node_gene.new_custom_attrs(state)
)
# add two new connections
new_conns = genome.add_conn(
new_conns = add_conn(
new_conns,
i_key,
new_node_key,
True,
genome.conn_gene.new_custom_attrs(),
genome.conn_gene.new_custom_attrs(state),
)
new_conns = genome.add_conn(
new_conns = add_conn(
new_conns,
new_node_key,
o_key,
True,
genome.conn_gene.new_custom_attrs(),
genome.conn_gene.new_custom_attrs(state),
)
return new_nodes, new_conns
@@ -75,7 +87,7 @@ class DefaultMutation(BaseMutation):
def successful_delete_node():
# delete the node
new_nodes = genome.delete_node_by_pos(nodes_, idx)
new_nodes = delete_node_by_pos(nodes_, idx)
# delete all connections
new_conns = jnp.where(
@@ -123,8 +135,8 @@ class DefaultMutation(BaseMutation):
return nodes_, conns_
def successful():
return nodes_, genome.add_conn(
conns_, i_key, o_key, True, genome.conn_gene.new_custom_attrs()
return nodes_, add_conn(
conns_, i_key, o_key, True, genome.conn_gene.new_custom_attrs(state)
)
def already_exist():
@@ -152,7 +164,7 @@ class DefaultMutation(BaseMutation):
i_key, o_key, idx = self.choice_connection_key(key_, conns_)
def successfully_delete_connection():
return nodes_, genome.delete_conn_by_pos(conns_, idx)
return nodes_, delete_conn_by_pos(conns_, idx)
return jax.lax.cond(
idx == I_INF,
@@ -160,7 +172,7 @@ class DefaultMutation(BaseMutation):
successfully_delete_connection,
)
k1, k2, k3, k4 = jax.random.split(key, num=4)
k1, k2, k3, k4 = jax.random.split(randkey, num=4)
r1, r2, r3, r4 = jax.random.uniform(k1, shape=(4,))
def no(key_, nodes_, conns_):
@@ -181,13 +193,17 @@ class DefaultMutation(BaseMutation):
return nodes, conns
def mutate_values(self, key, genome, nodes, conns):
k1, k2 = jax.random.split(key, num=2)
def mutate_values(self, state, randkey, genome, nodes, conns):
k1, k2 = jax.random.split(randkey, num=2)
nodes_keys = jax.random.split(k1, num=nodes.shape[0])
conns_keys = jax.random.split(k2, num=conns.shape[0])
new_nodes = jax.vmap(genome.node_gene.mutate)(nodes_keys, nodes)
new_conns = jax.vmap(genome.conn_gene.mutate)(conns_keys, conns)
new_nodes = jax.vmap(genome.node_gene.mutate, in_axes=(None, 0, 0))(
state, nodes_keys, nodes
)
new_conns = jax.vmap(genome.conn_gene.mutate, in_axes=(None, 0, 0))(
state, conns_keys, conns
)
# nan nodes not changed
new_nodes = jnp.where(jnp.isnan(nodes), jnp.nan, new_nodes)