change repo structure; modify readme
This commit is contained in:
3
tensorneat/algorithm/neat/genome/__init__.py
Normal file
3
tensorneat/algorithm/neat/genome/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .base import BaseGenome
|
||||
from .default import DefaultGenome
|
||||
from .recurrent import RecurrentGenome
|
||||
65
tensorneat/algorithm/neat/genome/base.py
Normal file
65
tensorneat/algorithm/neat/genome/base.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import jax.numpy as jnp
|
||||
from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene
|
||||
from utils import fetch_first
|
||||
|
||||
|
||||
class BaseGenome:
|
||||
network_type = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_inputs: int,
|
||||
num_outputs: int,
|
||||
max_nodes: int,
|
||||
max_conns: int,
|
||||
node_gene: BaseNodeGene = DefaultNodeGene(),
|
||||
conn_gene: BaseConnGene = DefaultConnGene(),
|
||||
):
|
||||
self.num_inputs = num_inputs
|
||||
self.num_outputs = num_outputs
|
||||
self.input_idx = jnp.arange(num_inputs)
|
||||
self.output_idx = jnp.arange(num_inputs, num_inputs + num_outputs)
|
||||
self.max_nodes = max_nodes
|
||||
self.max_conns = max_conns
|
||||
self.node_gene = node_gene
|
||||
self.conn_gene = conn_gene
|
||||
|
||||
def transform(self, nodes, conns):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, inputs, transformed):
|
||||
raise NotImplementedError
|
||||
|
||||
def add_node(self, nodes, new_key: int, attrs):
|
||||
"""
|
||||
Add a new node to the genome.
|
||||
The new node will place at the first NaN row.
|
||||
"""
|
||||
exist_keys = nodes[:, 0]
|
||||
pos = fetch_first(jnp.isnan(exist_keys))
|
||||
new_nodes = nodes.at[pos, 0].set(new_key)
|
||||
return new_nodes.at[pos, 1:].set(attrs)
|
||||
|
||||
def delete_node_by_pos(self, nodes, pos):
|
||||
"""
|
||||
Delete a node from the genome.
|
||||
Delete the node by its pos in nodes.
|
||||
"""
|
||||
return nodes.at[pos].set(jnp.nan)
|
||||
|
||||
def add_conn(self, conns, i_key, o_key, enable: bool, attrs):
|
||||
"""
|
||||
Add a new connection to the genome.
|
||||
The new connection will place at the first NaN row.
|
||||
"""
|
||||
con_keys = conns[:, 0]
|
||||
pos = fetch_first(jnp.isnan(con_keys))
|
||||
new_conns = conns.at[pos, 0:3].set(jnp.array([i_key, o_key, enable]))
|
||||
return new_conns.at[pos, 3:].set(attrs)
|
||||
|
||||
def delete_conn_by_pos(self, conns, pos):
|
||||
"""
|
||||
Delete a connection from the genome.
|
||||
Delete the connection by its idx.
|
||||
"""
|
||||
return conns.at[pos].set(jnp.nan)
|
||||
90
tensorneat/algorithm/neat/genome/default.py
Normal file
90
tensorneat/algorithm/neat/genome/default.py
Normal file
@@ -0,0 +1,90 @@
|
||||
from typing import Callable
|
||||
|
||||
import jax, jax.numpy as jnp
|
||||
from utils import unflatten_conns, topological_sort, I_INT
|
||||
|
||||
from . import BaseGenome
|
||||
from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene
|
||||
|
||||
|
||||
class DefaultGenome(BaseGenome):
|
||||
"""Default genome class, with the same behavior as the NEAT-Python"""
|
||||
|
||||
network_type = 'feedforward'
|
||||
|
||||
def __init__(self,
|
||||
num_inputs: int,
|
||||
num_outputs: int,
|
||||
max_nodes=5,
|
||||
max_conns=4,
|
||||
node_gene: BaseNodeGene = DefaultNodeGene(),
|
||||
conn_gene: BaseConnGene = DefaultConnGene(),
|
||||
output_transform: Callable = None
|
||||
):
|
||||
super().__init__(num_inputs, num_outputs, max_nodes, max_conns, node_gene, conn_gene)
|
||||
|
||||
if output_transform is not None:
|
||||
try:
|
||||
aux = output_transform(jnp.zeros(num_outputs))
|
||||
except Exception as e:
|
||||
raise ValueError(f"Output transform function failed: {e}")
|
||||
self.output_transform = output_transform
|
||||
|
||||
def transform(self, nodes, conns):
|
||||
u_conns = unflatten_conns(nodes, conns)
|
||||
|
||||
# DONE: Seems like there is a bug in this line
|
||||
# conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False)
|
||||
# modified: exist conn and enable is true
|
||||
# conn_enable = jnp.where( (~jnp.isnan(u_conns[0])) & (u_conns[0] == 1), True, False)
|
||||
# advanced modified: when and only when enabled is True
|
||||
conn_enable = u_conns[0] == 1
|
||||
|
||||
# remove enable attr
|
||||
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
|
||||
seqs = topological_sort(nodes, conn_enable)
|
||||
|
||||
return seqs, nodes, u_conns
|
||||
|
||||
def forward(self, inputs, transformed):
|
||||
cal_seqs, nodes, conns = transformed
|
||||
|
||||
N = nodes.shape[0]
|
||||
ini_vals = jnp.full((N,), jnp.nan)
|
||||
ini_vals = ini_vals.at[self.input_idx].set(inputs)
|
||||
nodes_attrs = nodes[:, 1:]
|
||||
|
||||
def cond_fun(carry):
|
||||
values, idx = carry
|
||||
return (idx < N) & (cal_seqs[idx] != I_INT)
|
||||
|
||||
def body_func(carry):
|
||||
values, idx = carry
|
||||
i = cal_seqs[idx]
|
||||
|
||||
def hit():
|
||||
ins = jax.vmap(self.conn_gene.forward, in_axes=(1, 0))(conns[:, :, i], values)
|
||||
# ins = values * weights[:, i]
|
||||
|
||||
z = self.node_gene.forward(nodes_attrs[i], ins)
|
||||
# z = agg(nodes[i, 4], ins, self.config.aggregation_options) # z = agg(ins)
|
||||
# z = z * nodes[i, 2] + nodes[i, 1] # z = z * response + bias
|
||||
# z = act(nodes[i, 3], z, self.config.activation_options) # z = act(z)
|
||||
|
||||
new_values = values.at[i].set(z)
|
||||
return new_values
|
||||
|
||||
def miss():
|
||||
return values
|
||||
|
||||
# the val of input nodes is obtained by the task, not by calculation
|
||||
values = jax.lax.cond(jnp.isin(i, self.input_idx), miss, hit)
|
||||
|
||||
return values, idx + 1
|
||||
|
||||
vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0))
|
||||
|
||||
if self.output_transform is None:
|
||||
return vals[self.output_idx]
|
||||
else:
|
||||
return self.output_transform(vals[self.output_idx])
|
||||
60
tensorneat/algorithm/neat/genome/recurrent.py
Normal file
60
tensorneat/algorithm/neat/genome/recurrent.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import jax, jax.numpy as jnp
|
||||
from utils import unflatten_conns
|
||||
|
||||
from . import BaseGenome
|
||||
from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene
|
||||
|
||||
|
||||
class RecurrentGenome(BaseGenome):
|
||||
"""Default genome class, with the same behavior as the NEAT-Python"""
|
||||
|
||||
network_type = 'recurrent'
|
||||
|
||||
def __init__(self,
|
||||
num_inputs: int,
|
||||
num_outputs: int,
|
||||
max_nodes: int,
|
||||
max_conns: int,
|
||||
node_gene: BaseNodeGene = DefaultNodeGene(),
|
||||
conn_gene: BaseConnGene = DefaultConnGene(),
|
||||
activate_time: int = 10,
|
||||
):
|
||||
super().__init__(num_inputs, num_outputs, max_nodes, max_conns, node_gene, conn_gene)
|
||||
self.activate_time = activate_time
|
||||
|
||||
def transform(self, nodes, conns):
|
||||
u_conns = unflatten_conns(nodes, conns)
|
||||
|
||||
# remove un-enable connections and remove enable attr
|
||||
conn_enable = u_conns[0] == 1
|
||||
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
|
||||
|
||||
return nodes, u_conns
|
||||
|
||||
def forward(self, inputs, transformed):
|
||||
nodes, conns = transformed
|
||||
|
||||
N = nodes.shape[0]
|
||||
vals = jnp.full((N,), jnp.nan)
|
||||
nodes_attrs = nodes[:, 1:]
|
||||
|
||||
def body_func(_, values):
|
||||
# set input values
|
||||
values = values.at[self.input_idx].set(inputs)
|
||||
|
||||
# calculate connections
|
||||
node_ins = jax.vmap(
|
||||
jax.vmap(
|
||||
self.conn_gene.forward,
|
||||
in_axes=(1, None)
|
||||
),
|
||||
in_axes=(1, 0)
|
||||
)(conns, values)
|
||||
|
||||
# calculate nodes
|
||||
values = jax.vmap(self.node_gene.forward)(nodes_attrs, node_ins.T)
|
||||
return values
|
||||
|
||||
vals = jax.lax.fori_loop(0, self.activate_time, body_func, vals)
|
||||
|
||||
return vals[self.output_idx]
|
||||
Reference in New Issue
Block a user