change str in config (act, agg) from str to callable

This commit is contained in:
wls2002
2023-08-05 03:03:02 +08:00
parent 0e44b13291
commit af54db3b12
10 changed files with 55 additions and 101 deletions

View File

@@ -48,9 +48,9 @@ class RecurrentGene(NormalGene):
def body_func(i, values):
values = values.at[input_idx].set(inputs)
nodes_ins = values * weights.T
values = batch_agg(nodes[:, 4], nodes_ins, self.agg_funcs) # z = agg(ins)
values = batch_agg(nodes[:, 4], nodes_ins, self.config.aggregation_options) # z = agg(ins)
values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias
values = batch_act(nodes[:, 3], values, self.act_funcs) # z = act(z)
values = batch_act(nodes[:, 3], values, self.config.activation_options) # z = act(z)
return values
vals = jax.lax.fori_loop(0, self.config.activate_times, body_func, vals)