complete fully stateful!

use black to format all files!
This commit is contained in:
wls2002
2024-05-26 18:08:43 +08:00
parent cf69b916af
commit 18c3d44c79
41 changed files with 620 additions and 495 deletions

View File

@@ -1,6 +1,7 @@
import jax.numpy as jnp
from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene
from utils import fetch_first, State
import jax, jax.numpy as jnp
from ..gene import BaseNodeGene, BaseConnGene
from ..ga import BaseMutation, BaseCrossover
from utils import State
class BaseGenome:
@@ -12,8 +13,10 @@ class BaseGenome:
num_outputs: int,
max_nodes: int,
max_conns: int,
node_gene: BaseNodeGene = DefaultNodeGene(),
conn_gene: BaseConnGene = DefaultConnGene(),
node_gene: BaseNodeGene,
conn_gene: BaseConnGene,
mutation: BaseMutation,
crossover: BaseCrossover,
):
self.num_inputs = num_inputs
self.num_outputs = num_outputs
@@ -23,10 +26,14 @@ class BaseGenome:
self.max_conns = max_conns
self.node_gene = node_gene
self.conn_gene = conn_gene
self.mutation = mutation
self.crossover = crossover
def setup(self, state=State()):
state = self.node_gene.setup(state)
state = self.conn_gene.setup(state)
state = self.mutation.setup(state)
state = self.crossover.setup(state)
return state
def transform(self, state, nodes, conns):
@@ -35,36 +42,81 @@ class BaseGenome:
def forward(self, state, inputs, transformed):
raise NotImplementedError
def add_node(self, nodes, new_key: int, attrs):
"""
Add a new node to the genome.
The new node will place at the first NaN row.
"""
exist_keys = nodes[:, 0]
pos = fetch_first(jnp.isnan(exist_keys))
new_nodes = nodes.at[pos, 0].set(new_key)
return new_nodes.at[pos, 1:].set(attrs)
def execute_mutation(self, state, randkey, nodes, conns, new_node_key):
return self.mutation(state, randkey, self, nodes, conns, new_node_key)
def delete_node_by_pos(self, nodes, pos):
"""
Delete a node from the genome.
Delete the node by its pos in nodes.
"""
return nodes.at[pos].set(jnp.nan)
def execute_crossover(self, state, randkey, nodes1, conns1, nodes2, conns2):
return self.crossover(state, randkey, self, nodes1, conns1, nodes2, conns2)
def add_conn(self, conns, i_key, o_key, enable: bool, attrs):
def initialize(self, state, randkey):
"""
Add a new connection to the genome.
The new connection will place at the first NaN row.
"""
con_keys = conns[:, 0]
pos = fetch_first(jnp.isnan(con_keys))
new_conns = conns.at[pos, 0:3].set(jnp.array([i_key, o_key, enable]))
return new_conns.at[pos, 3:].set(attrs)
Default initialization method for the genome.
Add an extra hidden node.
Make all input nodes and output nodes connected to the hidden node.
All attributes will be initialized randomly using gene.new_random_attrs method.
def delete_conn_by_pos(self, conns, pos):
For example, a network with 2 inputs and 1 output, the structure will be:
nodes:
[
[0, attrs0], # input node 0
[1, attrs1], # input node 1
[2, attrs2], # output node 0
[3, attrs3], # hidden node
[NaN, NaN], # empty node
]
conns:
[
[0, 3, attrs0], # input node 0 -> hidden node
[1, 3, attrs1], # input node 1 -> hidden node
[3, 2, attrs2], # hidden node -> output node 0
[NaN, NaN],
[NaN, NaN],
]
"""
Delete a connection from the genome.
Delete the connection by its idx.
"""
return conns.at[pos].set(jnp.nan)
k1, k2 = jax.random.split(randkey) # k1 for nodes, k2 for conns
# initialize nodes
new_node_key = (
max([*self.input_idx, *self.output_idx]) + 1
) # the key for the hidden node
node_keys = jnp.concatenate(
[self.input_idx, self.output_idx, jnp.array([new_node_key])]
) # the list of all node keys
# initialize nodes and connections with NaN
nodes = jnp.full((self.max_nodes, self.node_gene.length), jnp.nan)
conns = jnp.full((self.max_conns, self.conn_gene.length), jnp.nan)
# set keys for input nodes, output nodes and hidden node
nodes = nodes.at[node_keys, 0].set(node_keys)
# generate random attributes for nodes
node_keys = jax.random.split(k1, len(node_keys))
random_node_attrs = jax.vmap(
self.node_gene.new_random_attrs, in_axes=(None, 0)
)(state, node_keys)
nodes = nodes.at[: len(node_keys), 1:].set(random_node_attrs)
# initialize conns
# input-hidden connections
input_conns = jnp.c_[
self.input_idx, jnp.full_like(self.input_idx, new_node_key)
]
conns = conns.at[self.input_idx, :2].set(input_conns) # in-keys, out-keys
conns = conns.at[self.input_idx, 2].set(True) # enable
# output-hidden connections
output_conns = jnp.c_[
jnp.full_like(self.output_idx, new_node_key), self.output_idx
]
conns = conns.at[self.output_idx, :2].set(output_conns) # in-keys, out-keys
conns = conns.at[self.output_idx, 2].set(True) # enable
conn_keys = jax.random.split(k2, num=len(self.input_idx) + len(self.output_idx))
# generate random attributes for conns
random_conn_attrs = jax.vmap(
self.conn_gene.new_random_attrs, in_axes=(None, 0)
)(state, conn_keys)
conns = conns.at[: len(conn_keys), 3:].set(random_conn_attrs)
return nodes, conns