modify the behavior for mutate_add_node and mutate_add_conn. Currently, this two mutation will just change the structure of the network, but not influence the output for the network.

This commit is contained in:
wls2002
2024-06-01 20:42:42 +08:00
parent 4ad9f0a85a
commit e65200a94e
14 changed files with 281 additions and 204 deletions

View File

@@ -1,5 +1,5 @@
from .activation import Act, act, ACT_ALL
from .aggregation import Agg, agg, AGG_ALL
from .activation import Act, act_func, ACT_ALL
from .aggregation import Agg, agg_func, AGG_ALL
from .tools import *
from .graph import *
from .state import State

View File

@@ -68,11 +68,18 @@ ACT_ALL = (
)
def act(idx, z, act_funcs):
def act_func(idx, z, act_funcs):
"""
calculate activation function for each node
"""
idx = jnp.asarray(idx, dtype=jnp.int32)
# change idx from float to int
res = jax.lax.switch(idx, act_funcs, z)
# -1 means identity activation
res = jax.lax.cond(
idx == -1,
lambda: z,
lambda: jax.lax.switch(idx, act_funcs, z),
)
return res

View File

@@ -53,7 +53,7 @@ class Agg:
AGG_ALL = (Agg.sum, Agg.product, Agg.max, Agg.min, Agg.maxabs, Agg.median, Agg.mean)
def agg(idx, z, agg_funcs):
def agg_func(idx, z, agg_funcs):
"""
calculate activation function for inputs of node
"""

View File

@@ -47,7 +47,9 @@ def flatten_conns(nodes, unflatten, C):
return jnp.where(
jnp.isnan(unflatten[0, i, j]),
jnp.nan,
jnp.concatenate([jnp.array([node_keys[i], node_keys[j]]), unflatten[:, i, j]]),
jnp.concatenate(
[jnp.array([node_keys[i], node_keys[j]]), unflatten[:, i, j]]
),
)
x, y = jnp.meshgrid(jnp.arange(N), jnp.arange(N), indexing="ij")
@@ -64,6 +66,40 @@ def flatten_conns(nodes, unflatten, C):
return conns
def extract_node_attrs(node):
"""
node: Array(NL, )
extract the attributes of a node
"""
return node[1:] # 0 is for idx
def set_node_attrs(node, attrs):
"""
node: Array(NL, )
attrs: Array(NL-1, )
set the attributes of a node
"""
return node.at[1:].set(attrs) # 0 is for idx
def extract_conn_attrs(conn):
"""
conn: Array(CL, )
extract the attributes of a connection
"""
return conn[2:] # 0, 1 is for in-idx and out-idx
def set_conn_attrs(conn, attrs):
"""
conn: Array(CL, )
attrs: Array(CL-2, )
set the attributes of a connection
"""
return conn.at[2:].set(attrs) # 0, 1 is for in-idx and out-idx
@jit
def fetch_first(mask, default=I_INF) -> Array:
"""