complete fully stateful!
use black to format all files!
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import jax.numpy as jnp
|
||||
from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene
|
||||
from utils import fetch_first, State
|
||||
import jax, jax.numpy as jnp
|
||||
from ..gene import BaseNodeGene, BaseConnGene
|
||||
from ..ga import BaseMutation, BaseCrossover
|
||||
from utils import State
|
||||
|
||||
|
||||
class BaseGenome:
|
||||
@@ -12,8 +13,10 @@ class BaseGenome:
|
||||
num_outputs: int,
|
||||
max_nodes: int,
|
||||
max_conns: int,
|
||||
node_gene: BaseNodeGene = DefaultNodeGene(),
|
||||
conn_gene: BaseConnGene = DefaultConnGene(),
|
||||
node_gene: BaseNodeGene,
|
||||
conn_gene: BaseConnGene,
|
||||
mutation: BaseMutation,
|
||||
crossover: BaseCrossover,
|
||||
):
|
||||
self.num_inputs = num_inputs
|
||||
self.num_outputs = num_outputs
|
||||
@@ -23,10 +26,14 @@ class BaseGenome:
|
||||
self.max_conns = max_conns
|
||||
self.node_gene = node_gene
|
||||
self.conn_gene = conn_gene
|
||||
self.mutation = mutation
|
||||
self.crossover = crossover
|
||||
|
||||
def setup(self, state=State()):
|
||||
state = self.node_gene.setup(state)
|
||||
state = self.conn_gene.setup(state)
|
||||
state = self.mutation.setup(state)
|
||||
state = self.crossover.setup(state)
|
||||
return state
|
||||
|
||||
def transform(self, state, nodes, conns):
|
||||
@@ -35,36 +42,81 @@ class BaseGenome:
|
||||
def forward(self, state, inputs, transformed):
|
||||
raise NotImplementedError
|
||||
|
||||
def add_node(self, nodes, new_key: int, attrs):
|
||||
"""
|
||||
Add a new node to the genome.
|
||||
The new node will place at the first NaN row.
|
||||
"""
|
||||
exist_keys = nodes[:, 0]
|
||||
pos = fetch_first(jnp.isnan(exist_keys))
|
||||
new_nodes = nodes.at[pos, 0].set(new_key)
|
||||
return new_nodes.at[pos, 1:].set(attrs)
|
||||
def execute_mutation(self, state, randkey, nodes, conns, new_node_key):
|
||||
return self.mutation(state, randkey, self, nodes, conns, new_node_key)
|
||||
|
||||
def delete_node_by_pos(self, nodes, pos):
|
||||
"""
|
||||
Delete a node from the genome.
|
||||
Delete the node by its pos in nodes.
|
||||
"""
|
||||
return nodes.at[pos].set(jnp.nan)
|
||||
def execute_crossover(self, state, randkey, nodes1, conns1, nodes2, conns2):
|
||||
return self.crossover(state, randkey, self, nodes1, conns1, nodes2, conns2)
|
||||
|
||||
def add_conn(self, conns, i_key, o_key, enable: bool, attrs):
|
||||
def initialize(self, state, randkey):
|
||||
"""
|
||||
Add a new connection to the genome.
|
||||
The new connection will place at the first NaN row.
|
||||
"""
|
||||
con_keys = conns[:, 0]
|
||||
pos = fetch_first(jnp.isnan(con_keys))
|
||||
new_conns = conns.at[pos, 0:3].set(jnp.array([i_key, o_key, enable]))
|
||||
return new_conns.at[pos, 3:].set(attrs)
|
||||
Default initialization method for the genome.
|
||||
Add an extra hidden node.
|
||||
Make all input nodes and output nodes connected to the hidden node.
|
||||
All attributes will be initialized randomly using gene.new_random_attrs method.
|
||||
|
||||
def delete_conn_by_pos(self, conns, pos):
|
||||
For example, a network with 2 inputs and 1 output, the structure will be:
|
||||
nodes:
|
||||
[
|
||||
[0, attrs0], # input node 0
|
||||
[1, attrs1], # input node 1
|
||||
[2, attrs2], # output node 0
|
||||
[3, attrs3], # hidden node
|
||||
[NaN, NaN], # empty node
|
||||
]
|
||||
conns:
|
||||
[
|
||||
[0, 3, attrs0], # input node 0 -> hidden node
|
||||
[1, 3, attrs1], # input node 1 -> hidden node
|
||||
[3, 2, attrs2], # hidden node -> output node 0
|
||||
[NaN, NaN],
|
||||
[NaN, NaN],
|
||||
]
|
||||
"""
|
||||
Delete a connection from the genome.
|
||||
Delete the connection by its idx.
|
||||
"""
|
||||
return conns.at[pos].set(jnp.nan)
|
||||
|
||||
k1, k2 = jax.random.split(randkey) # k1 for nodes, k2 for conns
|
||||
# initialize nodes
|
||||
new_node_key = (
|
||||
max([*self.input_idx, *self.output_idx]) + 1
|
||||
) # the key for the hidden node
|
||||
node_keys = jnp.concatenate(
|
||||
[self.input_idx, self.output_idx, jnp.array([new_node_key])]
|
||||
) # the list of all node keys
|
||||
|
||||
# initialize nodes and connections with NaN
|
||||
nodes = jnp.full((self.max_nodes, self.node_gene.length), jnp.nan)
|
||||
conns = jnp.full((self.max_conns, self.conn_gene.length), jnp.nan)
|
||||
|
||||
# set keys for input nodes, output nodes and hidden node
|
||||
nodes = nodes.at[node_keys, 0].set(node_keys)
|
||||
|
||||
# generate random attributes for nodes
|
||||
node_keys = jax.random.split(k1, len(node_keys))
|
||||
random_node_attrs = jax.vmap(
|
||||
self.node_gene.new_random_attrs, in_axes=(None, 0)
|
||||
)(state, node_keys)
|
||||
nodes = nodes.at[: len(node_keys), 1:].set(random_node_attrs)
|
||||
|
||||
# initialize conns
|
||||
# input-hidden connections
|
||||
input_conns = jnp.c_[
|
||||
self.input_idx, jnp.full_like(self.input_idx, new_node_key)
|
||||
]
|
||||
conns = conns.at[self.input_idx, :2].set(input_conns) # in-keys, out-keys
|
||||
conns = conns.at[self.input_idx, 2].set(True) # enable
|
||||
|
||||
# output-hidden connections
|
||||
output_conns = jnp.c_[
|
||||
jnp.full_like(self.output_idx, new_node_key), self.output_idx
|
||||
]
|
||||
conns = conns.at[self.output_idx, :2].set(output_conns) # in-keys, out-keys
|
||||
conns = conns.at[self.output_idx, 2].set(True) # enable
|
||||
|
||||
conn_keys = jax.random.split(k2, num=len(self.input_idx) + len(self.output_idx))
|
||||
# generate random attributes for conns
|
||||
random_conn_attrs = jax.vmap(
|
||||
self.conn_gene.new_random_attrs, in_axes=(None, 0)
|
||||
)(state, conn_keys)
|
||||
conns = conns.at[: len(conn_keys), 3:].set(random_conn_attrs)
|
||||
|
||||
return nodes, conns
|
||||
|
||||
Reference in New Issue
Block a user