diff --git a/src/tensorneat/algorithm/hyperneat/substrate/default.py b/src/tensorneat/algorithm/hyperneat/substrate/default.py index 96d797e..cad47de 100644 --- a/src/tensorneat/algorithm/hyperneat/substrate/default.py +++ b/src/tensorneat/algorithm/hyperneat/substrate/default.py @@ -1,4 +1,4 @@ -from jax import vmap +from jax import vmap, numpy as jnp import numpy as np from .base import BaseSubstrate @@ -12,9 +12,9 @@ class DefaultSubstrate(BaseSubstrate): def __init__(self, num_inputs, num_outputs, coors, nodes, conns): self.inputs = num_inputs self.outputs = num_outputs - self.coors = np.array(coors) - self.nodes = np.array(nodes) - self.conns = np.array(conns) + self.coors = jnp.array(coors) + self.nodes = jnp.array(nodes) + self.conns = jnp.array(conns) def make_nodes(self, query_res): return self.nodes @@ -22,7 +22,8 @@ class DefaultSubstrate(BaseSubstrate): def make_conns(self, query_res): # change weight of conns # the last column is the weight - return self.conns.at[:, -1].set(query_res) + # print(f"{self.conns.shape=}, {query_res.shape=}") + return self.conns.at[:, -1].set(query_res.flatten()) @property def query_coors(self): diff --git a/src/tensorneat/algorithm/hyperneat/substrate/full.py b/src/tensorneat/algorithm/hyperneat/substrate/full.py index 6e07b2b..af63df3 100644 --- a/src/tensorneat/algorithm/hyperneat/substrate/full.py +++ b/src/tensorneat/algorithm/hyperneat/substrate/full.py @@ -63,7 +63,7 @@ def analysis_substrate(input_coors, output_coors, hidden_coors): ) # input_idx, output_idx, weight conns[:, :2] = correspond_keys - print(query_coors, nodes, conns) + # print(query_coors, nodes, conns) return query_coors, nodes, conns