change repo structure; modify readme

This commit is contained in:
wls2002
2024-03-26 21:58:27 +08:00
parent 6970e6a6d5
commit 47dbcbea80
69 changed files with 74 additions and 60 deletions

View File

@@ -0,0 +1,3 @@
from .base import BaseGenome
from .default import DefaultGenome
from .recurrent import RecurrentGenome

View File

@@ -0,0 +1,65 @@
import jax.numpy as jnp
from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene
from utils import fetch_first
class BaseGenome:
network_type = None
def __init__(
self,
num_inputs: int,
num_outputs: int,
max_nodes: int,
max_conns: int,
node_gene: BaseNodeGene = DefaultNodeGene(),
conn_gene: BaseConnGene = DefaultConnGene(),
):
self.num_inputs = num_inputs
self.num_outputs = num_outputs
self.input_idx = jnp.arange(num_inputs)
self.output_idx = jnp.arange(num_inputs, num_inputs + num_outputs)
self.max_nodes = max_nodes
self.max_conns = max_conns
self.node_gene = node_gene
self.conn_gene = conn_gene
def transform(self, nodes, conns):
raise NotImplementedError
def forward(self, inputs, transformed):
raise NotImplementedError
def add_node(self, nodes, new_key: int, attrs):
"""
Add a new node to the genome.
The new node will place at the first NaN row.
"""
exist_keys = nodes[:, 0]
pos = fetch_first(jnp.isnan(exist_keys))
new_nodes = nodes.at[pos, 0].set(new_key)
return new_nodes.at[pos, 1:].set(attrs)
def delete_node_by_pos(self, nodes, pos):
"""
Delete a node from the genome.
Delete the node by its pos in nodes.
"""
return nodes.at[pos].set(jnp.nan)
def add_conn(self, conns, i_key, o_key, enable: bool, attrs):
"""
Add a new connection to the genome.
The new connection will place at the first NaN row.
"""
con_keys = conns[:, 0]
pos = fetch_first(jnp.isnan(con_keys))
new_conns = conns.at[pos, 0:3].set(jnp.array([i_key, o_key, enable]))
return new_conns.at[pos, 3:].set(attrs)
def delete_conn_by_pos(self, conns, pos):
"""
Delete a connection from the genome.
Delete the connection by its idx.
"""
return conns.at[pos].set(jnp.nan)

View File

@@ -0,0 +1,90 @@
from typing import Callable
import jax, jax.numpy as jnp
from utils import unflatten_conns, topological_sort, I_INT
from . import BaseGenome
from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene
class DefaultGenome(BaseGenome):
"""Default genome class, with the same behavior as the NEAT-Python"""
network_type = 'feedforward'
def __init__(self,
num_inputs: int,
num_outputs: int,
max_nodes=5,
max_conns=4,
node_gene: BaseNodeGene = DefaultNodeGene(),
conn_gene: BaseConnGene = DefaultConnGene(),
output_transform: Callable = None
):
super().__init__(num_inputs, num_outputs, max_nodes, max_conns, node_gene, conn_gene)
if output_transform is not None:
try:
aux = output_transform(jnp.zeros(num_outputs))
except Exception as e:
raise ValueError(f"Output transform function failed: {e}")
self.output_transform = output_transform
def transform(self, nodes, conns):
u_conns = unflatten_conns(nodes, conns)
# DONE: Seems like there is a bug in this line
# conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False)
# modified: exist conn and enable is true
# conn_enable = jnp.where( (~jnp.isnan(u_conns[0])) & (u_conns[0] == 1), True, False)
# advanced modified: when and only when enabled is True
conn_enable = u_conns[0] == 1
# remove enable attr
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
seqs = topological_sort(nodes, conn_enable)
return seqs, nodes, u_conns
def forward(self, inputs, transformed):
cal_seqs, nodes, conns = transformed
N = nodes.shape[0]
ini_vals = jnp.full((N,), jnp.nan)
ini_vals = ini_vals.at[self.input_idx].set(inputs)
nodes_attrs = nodes[:, 1:]
def cond_fun(carry):
values, idx = carry
return (idx < N) & (cal_seqs[idx] != I_INT)
def body_func(carry):
values, idx = carry
i = cal_seqs[idx]
def hit():
ins = jax.vmap(self.conn_gene.forward, in_axes=(1, 0))(conns[:, :, i], values)
# ins = values * weights[:, i]
z = self.node_gene.forward(nodes_attrs[i], ins)
# z = agg(nodes[i, 4], ins, self.config.aggregation_options) # z = agg(ins)
# z = z * nodes[i, 2] + nodes[i, 1] # z = z * response + bias
# z = act(nodes[i, 3], z, self.config.activation_options) # z = act(z)
new_values = values.at[i].set(z)
return new_values
def miss():
return values
# the val of input nodes is obtained by the task, not by calculation
values = jax.lax.cond(jnp.isin(i, self.input_idx), miss, hit)
return values, idx + 1
vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0))
if self.output_transform is None:
return vals[self.output_idx]
else:
return self.output_transform(vals[self.output_idx])

View File

@@ -0,0 +1,60 @@
import jax, jax.numpy as jnp
from utils import unflatten_conns
from . import BaseGenome
from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene
class RecurrentGenome(BaseGenome):
"""Default genome class, with the same behavior as the NEAT-Python"""
network_type = 'recurrent'
def __init__(self,
num_inputs: int,
num_outputs: int,
max_nodes: int,
max_conns: int,
node_gene: BaseNodeGene = DefaultNodeGene(),
conn_gene: BaseConnGene = DefaultConnGene(),
activate_time: int = 10,
):
super().__init__(num_inputs, num_outputs, max_nodes, max_conns, node_gene, conn_gene)
self.activate_time = activate_time
def transform(self, nodes, conns):
u_conns = unflatten_conns(nodes, conns)
# remove un-enable connections and remove enable attr
conn_enable = u_conns[0] == 1
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
return nodes, u_conns
def forward(self, inputs, transformed):
nodes, conns = transformed
N = nodes.shape[0]
vals = jnp.full((N,), jnp.nan)
nodes_attrs = nodes[:, 1:]
def body_func(_, values):
# set input values
values = values.at[self.input_idx].set(inputs)
# calculate connections
node_ins = jax.vmap(
jax.vmap(
self.conn_gene.forward,
in_axes=(1, None)
),
in_axes=(1, 0)
)(conns, values)
# calculate nodes
values = jax.vmap(self.node_gene.forward)(nodes_attrs, node_ins.T)
return values
vals = jax.lax.fori_loop(0, self.activate_time, body_func, vals)
return vals[self.output_idx]