change a lot a lot a lot!!!!!!!
This commit is contained in:
5
core/__init__.py
Normal file
5
core/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .algorithm import Algorithm
|
||||
from .state import State
|
||||
from .genome import Genome
|
||||
from .gene import Gene
|
||||
|
||||
28
core/algorithm.py
Normal file
28
core/algorithm.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from jax import Array
|
||||
from .state import State
|
||||
from .genome import Genome
|
||||
|
||||
EMPTY = lambda *args: args
|
||||
|
||||
|
||||
class Algorithm:
|
||||
|
||||
def setup(self, randkey, state: State = State()):
|
||||
"""initialize the state of the algorithm"""
|
||||
pass
|
||||
|
||||
def ask(self, state: State):
|
||||
"""require the population to be evaluated"""
|
||||
pass
|
||||
|
||||
def tell(self, state: State, fitness):
|
||||
"""update the state of the algorithm"""
|
||||
pass
|
||||
|
||||
def forward(self, inputs: Array, transformed: Array):
|
||||
"""the forward function of a single forward transformation"""
|
||||
pass
|
||||
|
||||
def forward_transform(self, state: State, genome: Genome):
|
||||
"""create the forward transformation of a genome"""
|
||||
pass
|
||||
46
core/gene.py
Normal file
46
core/gene.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from jax import Array, numpy as jnp
|
||||
|
||||
from config import GeneConfig
|
||||
from .state import State
|
||||
from .genome import Genome
|
||||
|
||||
|
||||
class Gene:
|
||||
node_attrs = []
|
||||
conn_attrs = []
|
||||
|
||||
@staticmethod
|
||||
def setup(config: GeneConfig, state: State):
|
||||
return state
|
||||
|
||||
@staticmethod
|
||||
def new_node_attrs(state: State):
|
||||
return jnp.zeros(0)
|
||||
|
||||
@staticmethod
|
||||
def new_conn_attrs(state: State):
|
||||
return jnp.zeros(0)
|
||||
|
||||
@staticmethod
|
||||
def mutate_node(state: State, attrs: Array, randkey: Array):
|
||||
return attrs
|
||||
|
||||
@staticmethod
|
||||
def mutate_conn(state: State, attrs: Array, randkey: Array):
|
||||
return attrs
|
||||
|
||||
@staticmethod
|
||||
def distance_node(state: State, node1: Array, node2: Array):
|
||||
return node1
|
||||
|
||||
@staticmethod
|
||||
def distance_conn(state: State, conn1: Array, conn2: Array):
|
||||
return conn1
|
||||
|
||||
@staticmethod
|
||||
def forward_transform(state: State, genome: Genome):
|
||||
return jnp.zeros(0) # transformed
|
||||
@staticmethod
|
||||
def create_forward(state: State, config: GeneConfig):
|
||||
return lambda *args: args # forward function
|
||||
|
||||
77
core/genome.py
Normal file
77
core/genome.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from jax.tree_util import register_pytree_node_class
|
||||
from jax import numpy as jnp
|
||||
|
||||
from utils.tools import fetch_first
|
||||
|
||||
|
||||
@register_pytree_node_class
|
||||
class Genome:
|
||||
|
||||
def __init__(self, nodes, conns):
|
||||
self.nodes = nodes
|
||||
self.conns = conns
|
||||
|
||||
def update(self, nodes, conns):
|
||||
return self.__class__(nodes, conns)
|
||||
|
||||
def update_nodes(self, nodes):
|
||||
return self.update(nodes, self.conns)
|
||||
|
||||
def update_conns(self, conns):
|
||||
return self.update(self.nodes, conns)
|
||||
|
||||
def count(self):
|
||||
"""Count how many nodes and connections are in the genome."""
|
||||
nodes_cnt = jnp.sum(~jnp.isnan(self.nodes[:, 0]))
|
||||
conns_cnt = jnp.sum(~jnp.isnan(self.conns[:, 0]))
|
||||
return nodes_cnt, conns_cnt
|
||||
|
||||
def add_node(self, new_key: int, attrs):
|
||||
"""
|
||||
Add a new node to the genome.
|
||||
The new node will place at the first NaN row.
|
||||
"""
|
||||
exist_keys = self.nodes[:, 0]
|
||||
pos = fetch_first(jnp.isnan(exist_keys))
|
||||
new_nodes = self.nodes.at[pos, 0].set(new_key)
|
||||
new_nodes = new_nodes.at[pos, 1:].set(attrs)
|
||||
return self.update_nodes(new_nodes)
|
||||
|
||||
def delete_node_by_pos(self, pos):
|
||||
"""
|
||||
Delete a node from the genome.
|
||||
Delete the node by its pos in nodes.
|
||||
"""
|
||||
nodes = self.nodes.at[pos].set(jnp.nan)
|
||||
return self.update_nodes(nodes)
|
||||
|
||||
def add_conn(self, i_key, o_key, enable: bool, attrs):
|
||||
"""
|
||||
Add a new connection to the genome.
|
||||
The new connection will place at the first NaN row.
|
||||
"""
|
||||
con_keys = self.conns[:, 0]
|
||||
pos = fetch_first(jnp.isnan(con_keys))
|
||||
new_conns = self.conns.at[pos, 0:3].set(jnp.array([i_key, o_key, enable]))
|
||||
new_conns = new_conns.at[pos, 3:].set(attrs)
|
||||
return self.update_conns(new_conns)
|
||||
|
||||
def delete_conn_by_pos(self, pos):
|
||||
"""
|
||||
Delete a connection from the genome.
|
||||
Delete the connection by its idx.
|
||||
"""
|
||||
conns = self.conns.at[pos].set(jnp.nan)
|
||||
return self.update_conns(conns)
|
||||
|
||||
def tree_flatten(self):
|
||||
children = self.nodes, self.conns
|
||||
aux_data = None
|
||||
return children, aux_data
|
||||
|
||||
@classmethod
|
||||
def tree_unflatten(cls, aux_data, children):
|
||||
return cls(*children)
|
||||
|
||||
def __repr__(self):
|
||||
return f"Genome(nodes={self.nodes}, conns={self.conns})"
|
||||
29
core/state.py
Normal file
29
core/state.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from jax.tree_util import register_pytree_node_class
|
||||
|
||||
|
||||
@register_pytree_node_class
|
||||
class State:
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.__dict__['state_dict'] = kwargs
|
||||
|
||||
def update(self, **kwargs):
|
||||
return State(**{**self.state_dict, **kwargs})
|
||||
|
||||
def __getattr__(self, name):
|
||||
return self.state_dict[name]
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
raise AttributeError("State is immutable")
|
||||
|
||||
def __repr__(self):
|
||||
return f"State ({self.state_dict})"
|
||||
|
||||
def tree_flatten(self):
|
||||
children = list(self.state_dict.values())
|
||||
aux_data = list(self.state_dict.keys())
|
||||
return children, aux_data
|
||||
|
||||
@classmethod
|
||||
def tree_unflatten(cls, aux_data, children):
|
||||
return cls(**dict(zip(aux_data, children)))
|
||||
Reference in New Issue
Block a user