Files
tensorneat-mend/examples/error_forward_fix.py
2023-05-06 18:33:30 +08:00

24 lines
731 B
Python

import numpy as np
from jax import numpy as jnp
from algorithms.neat.genome.genome import analysis
from algorithms.neat.genome import create_forward_function
error_nodes = np.load('error_nodes.npy')
error_connections = np.load('error_connections.npy')
node_dict, connection_dict = analysis(error_nodes, error_connections, np.array([0, 1]), np.array([2, ]))
print(node_dict, connection_dict, sep='\n')
N = error_nodes.shape[0]
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
func = create_forward_function(error_nodes, error_connections, N, jnp.array([0, 1]), jnp.array([2, ]),
batch=True, debug=True)
out = func(np.array([1, 0]))
print(error_nodes)
print(error_connections)
print(out)