finish all refactoring

This commit is contained in:
wls2002
2024-02-21 15:41:08 +08:00
parent aac41a089d
commit 6970e6a6d5
44 changed files with 856 additions and 825 deletions

View File

@@ -4,7 +4,6 @@ from utils import fetch_first
class BaseGenome:
network_type = None
def __init__(

View File

@@ -1,3 +1,5 @@
from typing import Callable
import jax, jax.numpy as jnp
from utils import unflatten_conns, topological_sort, I_INT
@@ -13,10 +15,20 @@ class DefaultGenome(BaseGenome):
def __init__(self,
num_inputs: int,
num_outputs: int,
max_nodes=5,
max_conns=4,
node_gene: BaseNodeGene = DefaultNodeGene(),
conn_gene: BaseConnGene = DefaultConnGene(),
output_transform: Callable = None
):
super().__init__(num_inputs, num_outputs, node_gene, conn_gene)
super().__init__(num_inputs, num_outputs, max_nodes, max_conns, node_gene, conn_gene)
if output_transform is not None:
try:
aux = output_transform(jnp.zeros(num_outputs))
except Exception as e:
raise ValueError(f"Output transform function failed: {e}")
self.output_transform = output_transform
def transform(self, nodes, conns):
u_conns = unflatten_conns(nodes, conns)
@@ -72,4 +84,7 @@ class DefaultGenome(BaseGenome):
vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0))
return vals[self.output_idx]
if self.output_transform is None:
return vals[self.output_idx]
else:
return self.output_transform(vals[self.output_idx])

View File

@@ -13,11 +13,13 @@ class RecurrentGenome(BaseGenome):
def __init__(self,
num_inputs: int,
num_outputs: int,
max_nodes: int,
max_conns: int,
node_gene: BaseNodeGene = DefaultNodeGene(),
conn_gene: BaseConnGene = DefaultConnGene(),
activate_time: int = 10,
):
super().__init__(num_inputs, num_outputs, node_gene, conn_gene)
super().__init__(num_inputs, num_outputs, max_nodes, max_conns, node_gene, conn_gene)
self.activate_time = activate_time
def transform(self, nodes, conns):