remove create_func....

This commit is contained in:
wls2002
2023-08-02 15:02:08 +08:00
parent 1499e062fe
commit c7fb1ddabe
22 changed files with 425 additions and 21 deletions

View File

@@ -1,24 +1,31 @@
from functools import partial
import jax
from utils import unflatten_conns, act, agg, Activation, Aggregation
from algorithm.neat.gene import RecurrentGeneConfig
config = RecurrentGeneConfig(
activation_options=("tanh", "sigmoid"),
activation_default="tanh",
)
class A:
def __init__(self):
self.a = 1
self.b = 2
self.act_funcs = [Activation.name2func[name] for name in config.activation_options]
self.agg_funcs = [Aggregation.name2func[name] for name in config.aggregation_options]
self.isTrue = False
@partial(jax.jit, static_argnums=(0,))
def step(self):
if self.isTrue:
return self.a + 1
else:
return self.b + 1
i = jax.numpy.array([0, 1])
z = jax.numpy.array([
[1, 1],
[2, 2]
])
print(self.act_funcs)
return jax.vmap(act, in_axes=(0, 0, None))(i, z, self.act_funcs)
AA = A()
print(AA.step(), hash(AA))
print(AA.step(), hash(AA))
print(AA.step(), hash(AA))
AA.a = (2, 3, 4)
print(AA.step(), hash(AA))
print(AA.step())