complete HyperNEAT!

This commit is contained in:
wls2002
2023-07-21 15:03:12 +08:00
parent 80ee5ea2ea
commit 48f90c7eef
32 changed files with 432 additions and 136 deletions

View File

@@ -1,13 +1,44 @@
import numpy as np
import jax
import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class
vals = np.array([1, 2])
weights = np.array([[0, 4], [5, 0]])
@register_pytree_node_class
class Genome:
def __init__(self, nodes, conns):
self.nodes = nodes
self.conns = conns
ins1 = vals * weights[:, 0]
ins2 = vals * weights[:, 1]
ins_all = vals * weights.T
def update_nodes(self, nodes):
return Genome(nodes, self.conns)
print(ins1)
print(ins2)
print(ins_all)
def update_conns(self, conns):
return Genome(self.nodes, conns)
def tree_flatten(self):
children = self.nodes, self.conns
aux_data = None
return children, aux_data
@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(*children)
def __repr__(self):
return f"Genome ({self.nodes}, \n\t{self.conns})"
@jax.jit
def add_node(self, a: int):
nodes = self.nodes.at[0, :].set(a)
return self.update_nodes(nodes)
nodes, conns = jnp.array([[1, 2, 3, 4, 5]]), jnp.array([[1, 2, 3, 4]])
g = Genome(nodes, conns)
print(g)
g = g.add_node(1)
print(g)
g = jax.jit(g.add_node)(2)
print(g)