91 lines
2.9 KiB
Python
91 lines
2.9 KiB
Python
from __future__ import annotations
|
|
|
|
from jax.tree_util import register_pytree_node_class
|
|
from jax import numpy as jnp
|
|
|
|
from utils.tools import fetch_first
|
|
|
|
|
|
@register_pytree_node_class
|
|
class Genome:
|
|
|
|
def __init__(self, nodes, conns):
|
|
self.nodes = nodes
|
|
self.conns = conns
|
|
|
|
def __repr__(self):
|
|
return f"Genome(nodes={self.nodes}, conns={self.conns})"
|
|
|
|
def __getitem__(self, idx):
|
|
return self.__class__(self.nodes[idx], self.conns[idx])
|
|
|
|
def __eq__(self, other):
|
|
nodes_eq = jnp.alltrue((self.nodes == other.nodes) | (jnp.isnan(self.nodes) & jnp.isnan(other.nodes)))
|
|
conns_eq = jnp.alltrue((self.conns == other.conns) | (jnp.isnan(self.conns) & jnp.isnan(other.conns)))
|
|
return nodes_eq & conns_eq
|
|
|
|
def set(self, idx, value: Genome):
|
|
return self.__class__(self.nodes.at[idx].set(value.nodes), self.conns.at[idx].set(value.conns))
|
|
|
|
def update(self, nodes, conns):
|
|
return self.__class__(nodes, conns)
|
|
|
|
def update_nodes(self, nodes):
|
|
return self.update(nodes, self.conns)
|
|
|
|
def update_conns(self, conns):
|
|
return self.update(self.nodes, conns)
|
|
|
|
def count(self):
|
|
"""Count how many nodes and connections are in the genome."""
|
|
nodes_cnt = jnp.sum(~jnp.isnan(self.nodes[:, 0]))
|
|
conns_cnt = jnp.sum(~jnp.isnan(self.conns[:, 0]))
|
|
return nodes_cnt, conns_cnt
|
|
|
|
def add_node(self, new_key: int, attrs):
|
|
"""
|
|
Add a new node to the genome.
|
|
The new node will place at the first NaN row.
|
|
"""
|
|
exist_keys = self.nodes[:, 0]
|
|
pos = fetch_first(jnp.isnan(exist_keys))
|
|
new_nodes = self.nodes.at[pos, 0].set(new_key)
|
|
new_nodes = new_nodes.at[pos, 1:].set(attrs)
|
|
return self.update_nodes(new_nodes)
|
|
|
|
def delete_node_by_pos(self, pos):
|
|
"""
|
|
Delete a node from the genome.
|
|
Delete the node by its pos in nodes.
|
|
"""
|
|
nodes = self.nodes.at[pos].set(jnp.nan)
|
|
return self.update_nodes(nodes)
|
|
|
|
def add_conn(self, i_key, o_key, enable: bool, attrs):
|
|
"""
|
|
Add a new connection to the genome.
|
|
The new connection will place at the first NaN row.
|
|
"""
|
|
con_keys = self.conns[:, 0]
|
|
pos = fetch_first(jnp.isnan(con_keys))
|
|
new_conns = self.conns.at[pos, 0:3].set(jnp.array([i_key, o_key, enable]))
|
|
new_conns = new_conns.at[pos, 3:].set(attrs)
|
|
return self.update_conns(new_conns)
|
|
|
|
def delete_conn_by_pos(self, pos):
|
|
"""
|
|
Delete a connection from the genome.
|
|
Delete the connection by its idx.
|
|
"""
|
|
conns = self.conns.at[pos].set(jnp.nan)
|
|
return self.update_conns(conns)
|
|
|
|
def tree_flatten(self):
|
|
children = self.nodes, self.conns
|
|
aux_data = None
|
|
return children, aux_data
|
|
|
|
@classmethod
|
|
def tree_unflatten(cls, aux_data, children):
|
|
return cls(*children)
|