complete HyperNEAT!
This commit is contained in:
@@ -1,13 +1,44 @@
|
||||
import numpy as np
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax.tree_util import register_pytree_node_class
|
||||
|
||||
|
||||
vals = np.array([1, 2])
|
||||
weights = np.array([[0, 4], [5, 0]])
|
||||
@register_pytree_node_class
|
||||
class Genome:
|
||||
def __init__(self, nodes, conns):
|
||||
self.nodes = nodes
|
||||
self.conns = conns
|
||||
|
||||
ins1 = vals * weights[:, 0]
|
||||
ins2 = vals * weights[:, 1]
|
||||
ins_all = vals * weights.T
|
||||
def update_nodes(self, nodes):
|
||||
return Genome(nodes, self.conns)
|
||||
|
||||
print(ins1)
|
||||
print(ins2)
|
||||
print(ins_all)
|
||||
def update_conns(self, conns):
|
||||
return Genome(self.nodes, conns)
|
||||
|
||||
def tree_flatten(self):
|
||||
children = self.nodes, self.conns
|
||||
aux_data = None
|
||||
return children, aux_data
|
||||
|
||||
@classmethod
|
||||
def tree_unflatten(cls, aux_data, children):
|
||||
return cls(*children)
|
||||
|
||||
def __repr__(self):
|
||||
return f"Genome ({self.nodes}, \n\t{self.conns})"
|
||||
|
||||
@jax.jit
|
||||
def add_node(self, a: int):
|
||||
nodes = self.nodes.at[0, :].set(a)
|
||||
return self.update_nodes(nodes)
|
||||
|
||||
|
||||
nodes, conns = jnp.array([[1, 2, 3, 4, 5]]), jnp.array([[1, 2, 3, 4]])
|
||||
g = Genome(nodes, conns)
|
||||
print(g)
|
||||
|
||||
g = g.add_node(1)
|
||||
print(g)
|
||||
|
||||
g = jax.jit(g.add_node)(2)
|
||||
print(g)
|
||||
|
||||
Reference in New Issue
Block a user