modify the behavior for mutate_add_node and mutate_add_conn. Currently, this two mutation will just change the structure of the network, but not influence the output for the network.
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from .activation import Act, act, ACT_ALL
|
||||
from .aggregation import Agg, agg, AGG_ALL
|
||||
from .activation import Act, act_func, ACT_ALL
|
||||
from .aggregation import Agg, agg_func, AGG_ALL
|
||||
from .tools import *
|
||||
from .graph import *
|
||||
from .state import State
|
||||
|
||||
@@ -68,11 +68,18 @@ ACT_ALL = (
|
||||
)
|
||||
|
||||
|
||||
def act(idx, z, act_funcs):
|
||||
def act_func(idx, z, act_funcs):
|
||||
"""
|
||||
calculate activation function for each node
|
||||
"""
|
||||
idx = jnp.asarray(idx, dtype=jnp.int32)
|
||||
# change idx from float to int
|
||||
res = jax.lax.switch(idx, act_funcs, z)
|
||||
|
||||
# -1 means identity activation
|
||||
res = jax.lax.cond(
|
||||
idx == -1,
|
||||
lambda: z,
|
||||
lambda: jax.lax.switch(idx, act_funcs, z),
|
||||
)
|
||||
|
||||
return res
|
||||
|
||||
@@ -53,7 +53,7 @@ class Agg:
|
||||
AGG_ALL = (Agg.sum, Agg.product, Agg.max, Agg.min, Agg.maxabs, Agg.median, Agg.mean)
|
||||
|
||||
|
||||
def agg(idx, z, agg_funcs):
|
||||
def agg_func(idx, z, agg_funcs):
|
||||
"""
|
||||
calculate activation function for inputs of node
|
||||
"""
|
||||
|
||||
@@ -47,7 +47,9 @@ def flatten_conns(nodes, unflatten, C):
|
||||
return jnp.where(
|
||||
jnp.isnan(unflatten[0, i, j]),
|
||||
jnp.nan,
|
||||
jnp.concatenate([jnp.array([node_keys[i], node_keys[j]]), unflatten[:, i, j]]),
|
||||
jnp.concatenate(
|
||||
[jnp.array([node_keys[i], node_keys[j]]), unflatten[:, i, j]]
|
||||
),
|
||||
)
|
||||
|
||||
x, y = jnp.meshgrid(jnp.arange(N), jnp.arange(N), indexing="ij")
|
||||
@@ -64,6 +66,40 @@ def flatten_conns(nodes, unflatten, C):
|
||||
return conns
|
||||
|
||||
|
||||
def extract_node_attrs(node):
|
||||
"""
|
||||
node: Array(NL, )
|
||||
extract the attributes of a node
|
||||
"""
|
||||
return node[1:] # 0 is for idx
|
||||
|
||||
|
||||
def set_node_attrs(node, attrs):
|
||||
"""
|
||||
node: Array(NL, )
|
||||
attrs: Array(NL-1, )
|
||||
set the attributes of a node
|
||||
"""
|
||||
return node.at[1:].set(attrs) # 0 is for idx
|
||||
|
||||
|
||||
def extract_conn_attrs(conn):
|
||||
"""
|
||||
conn: Array(CL, )
|
||||
extract the attributes of a connection
|
||||
"""
|
||||
return conn[2:] # 0, 1 is for in-idx and out-idx
|
||||
|
||||
|
||||
def set_conn_attrs(conn, attrs):
|
||||
"""
|
||||
conn: Array(CL, )
|
||||
attrs: Array(CL-2, )
|
||||
set the attributes of a connection
|
||||
"""
|
||||
return conn.at[2:].set(attrs) # 0, 1 is for in-idx and out-idx
|
||||
|
||||
|
||||
@jit
|
||||
def fetch_first(mask, default=I_INF) -> Array:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user