finish ask part of the algorithm;

use jax.lax.while_loop in graph algorithms and forward function;
fix "enabled not care" bug in forward
This commit is contained in:
wls2002
2023-06-25 00:26:52 +08:00
parent 86820db5a6
commit 0cb2f9473d
24 changed files with 485 additions and 1623 deletions

View File

@@ -1,16 +1,25 @@
from functools import partial
import numpy as np
import jax
from jax import jit
from configs import Configer
from neat.pipeline_ import Pipeline
from neat.pipeline import Pipeline
from neat.function_factory import FunctionFactory
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
def main():
config = Configer.load_config("xor.ini")
print(config)
pipeline = Pipeline(config)
function_factory = FunctionFactory(config)
pipeline = Pipeline(config, function_factory)
forward_func = pipeline.ask()
# inputs = np.tile(xor_inputs, (150, 1, 1))
outputs = forward_func(xor_inputs)
print(outputs)
@jit