finish ask part of the algorithm;
use jax.lax.while_loop in graph algorithms and forward function; fix "enabled not care" bug in forward
This commit is contained in:
@@ -3,6 +3,9 @@ import numpy as np
|
||||
import jax.numpy as jnp
|
||||
import jax
|
||||
|
||||
a = {1:2, 2:3, 4:5}
|
||||
print(a.values())
|
||||
|
||||
a = jnp.array([1, 0, 1, 0, np.nan])
|
||||
b = jnp.array([1, 1, 1, 1, 1])
|
||||
c = jnp.array([1, 1, 1, 1, 1])
|
||||
@@ -44,5 +47,9 @@ def func(x):
|
||||
else:
|
||||
return 2
|
||||
|
||||
a = jnp.zeros((3, 3))
|
||||
print(a.dtype)
|
||||
|
||||
print(main())
|
||||
c = None
|
||||
b = 1 or c
|
||||
print(b)
|
||||
@@ -1,16 +1,25 @@
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import jax
|
||||
from jax import jit
|
||||
|
||||
from configs import Configer
|
||||
from neat.pipeline_ import Pipeline
|
||||
from neat.pipeline import Pipeline
|
||||
from neat.function_factory import FunctionFactory
|
||||
|
||||
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
|
||||
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
|
||||
|
||||
def main():
|
||||
config = Configer.load_config("xor.ini")
|
||||
print(config)
|
||||
pipeline = Pipeline(config)
|
||||
function_factory = FunctionFactory(config)
|
||||
pipeline = Pipeline(config, function_factory)
|
||||
forward_func = pipeline.ask()
|
||||
# inputs = np.tile(xor_inputs, (150, 1, 1))
|
||||
outputs = forward_func(xor_inputs)
|
||||
print(outputs)
|
||||
|
||||
|
||||
|
||||
@jit
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
import cProfile
|
||||
from io import StringIO
|
||||
import pstats
|
||||
|
||||
|
||||
def using_cprofile(func, root_abs_path=None, replace_pattern=None, save_path=None):
|
||||
def inner(*args, **kwargs):
|
||||
pr = cProfile.Profile()
|
||||
pr.enable()
|
||||
ret = func(*args, **kwargs)
|
||||
pr.disable()
|
||||
profile_stats = StringIO()
|
||||
stats = pstats.Stats(pr, stream=profile_stats)
|
||||
if root_abs_path is not None:
|
||||
stats.sort_stats('cumulative').print_stats(root_abs_path)
|
||||
else:
|
||||
stats.sort_stats('cumulative').print_stats()
|
||||
output = profile_stats.getvalue()
|
||||
if replace_pattern is not None:
|
||||
output = output.replace(replace_pattern, "")
|
||||
if save_path is None:
|
||||
print(output)
|
||||
else:
|
||||
with open(save_path, "w") as f:
|
||||
f.write(output)
|
||||
return ret
|
||||
|
||||
return inner
|
||||
@@ -1,2 +1,5 @@
|
||||
[basic]
|
||||
forward_way = "common"
|
||||
|
||||
[population]
|
||||
fitness_threshold = -1e-2
|
||||
Reference in New Issue
Block a user