modifying

This commit is contained in:
wls2002
2023-06-27 18:47:47 +08:00
parent ba369db0b2
commit 114ff2b0cc
28 changed files with 451 additions and 123 deletions

View File

@@ -1,27 +1,18 @@
import numpy as np
from jax import jit
from functools import partial
from configs import Configer
from neat.pipeline import Pipeline
import jax
from jax import numpy as jnp, jit
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
def main():
config = Configer.load_config("xor.ini")
print(config)
pipeline = Pipeline(config)
forward_func = pipeline.ask()
# inputs = np.tile(xor_inputs, (150, 1, 1))
outputs = forward_func(xor_inputs)
print(outputs)
@partial(jit, static_argnames=['reverse'])
def rank_element(array, reverse=False):
"""
rank the element in the array.
if reverse is True, the rank is from large to small.
"""
if reverse:
array = -array
return jnp.argsort(jnp.argsort(array))
@jit
def f(x, jit_config):
return x + jit_config["bias_mutate_rate"]
if __name__ == '__main__':
main()
a = jnp.array([1 ,5, 3, 5, 2, 1, 0])
print(rank_element(a, reverse=True))