remove create_func....
This commit is contained in:
@@ -1,24 +1,31 @@
|
||||
from functools import partial
|
||||
import jax
|
||||
|
||||
from utils import unflatten_conns, act, agg, Activation, Aggregation
|
||||
from algorithm.neat.gene import RecurrentGeneConfig
|
||||
|
||||
config = RecurrentGeneConfig(
|
||||
activation_options=("tanh", "sigmoid"),
|
||||
activation_default="tanh",
|
||||
)
|
||||
|
||||
|
||||
class A:
|
||||
def __init__(self):
|
||||
self.a = 1
|
||||
self.b = 2
|
||||
self.act_funcs = [Activation.name2func[name] for name in config.activation_options]
|
||||
self.agg_funcs = [Aggregation.name2func[name] for name in config.aggregation_options]
|
||||
self.isTrue = False
|
||||
|
||||
@partial(jax.jit, static_argnums=(0,))
|
||||
def step(self):
|
||||
if self.isTrue:
|
||||
return self.a + 1
|
||||
else:
|
||||
return self.b + 1
|
||||
i = jax.numpy.array([0, 1])
|
||||
z = jax.numpy.array([
|
||||
[1, 1],
|
||||
[2, 2]
|
||||
])
|
||||
print(self.act_funcs)
|
||||
return jax.vmap(act, in_axes=(0, 0, None))(i, z, self.act_funcs)
|
||||
|
||||
|
||||
AA = A()
|
||||
print(AA.step(), hash(AA))
|
||||
print(AA.step(), hash(AA))
|
||||
print(AA.step(), hash(AA))
|
||||
AA.a = (2, 3, 4)
|
||||
print(AA.step(), hash(AA))
|
||||
print(AA.step())
|
||||
|
||||
Reference in New Issue
Block a user