Perfect!
Next is to connect with Evox!
This commit is contained in:
@@ -0,0 +1,7 @@
|
||||
from .mutate import mutate
|
||||
from .distance import distance
|
||||
from .crossover import crossover
|
||||
from .forward import create_forward
|
||||
from .graph import topological_sort, check_cycles
|
||||
from .utils import unflatten_connections
|
||||
from .genome import initialize_genomes, expand, expand_single
|
||||
@@ -1,34 +1,27 @@
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from jax import jit
|
||||
|
||||
|
||||
@jit
|
||||
|
||||
def sum_agg(z):
|
||||
z = jnp.where(jnp.isnan(z), 0, z)
|
||||
return jnp.sum(z, axis=0)
|
||||
|
||||
|
||||
@jit
|
||||
def product_agg(z):
|
||||
z = jnp.where(jnp.isnan(z), 1, z)
|
||||
return jnp.prod(z, axis=0)
|
||||
|
||||
|
||||
@jit
|
||||
def max_agg(z):
|
||||
z = jnp.where(jnp.isnan(z), -jnp.inf, z)
|
||||
return jnp.max(z, axis=0)
|
||||
|
||||
|
||||
@jit
|
||||
def min_agg(z):
|
||||
z = jnp.where(jnp.isnan(z), jnp.inf, z)
|
||||
return jnp.min(z, axis=0)
|
||||
|
||||
|
||||
@jit
|
||||
def maxabs_agg(z):
|
||||
z = jnp.where(jnp.isnan(z), 0, z)
|
||||
abs_z = jnp.abs(z)
|
||||
@@ -36,7 +29,6 @@ def maxabs_agg(z):
|
||||
return z[max_abs_index]
|
||||
|
||||
|
||||
@jit
|
||||
def median_agg(z):
|
||||
non_nan_mask = ~jnp.isnan(z)
|
||||
n = jnp.sum(non_nan_mask, axis=0)
|
||||
@@ -49,7 +41,6 @@ def median_agg(z):
|
||||
return median
|
||||
|
||||
|
||||
@jit
|
||||
def mean_agg(z):
|
||||
non_zero_mask = ~jnp.isnan(z)
|
||||
valid_values_sum = sum_agg(z)
|
||||
|
||||
@@ -10,7 +10,7 @@ import jax
|
||||
from jax import numpy as jnp
|
||||
from jax import jit, Array
|
||||
|
||||
from .utils import fetch_random, fetch_first, I_INT
|
||||
from .utils import fetch_random, fetch_first, I_INT, unflatten_connections
|
||||
from .genome import add_node, delete_node_by_idx, delete_connection_by_idx, add_connection
|
||||
from .graph import check_cycles
|
||||
|
||||
@@ -273,7 +273,8 @@ def mutate_add_connection(rand_key: Array, nodes: Array, cons: Array, jit_config
|
||||
|
||||
is_already_exist = con_idx != I_INT
|
||||
|
||||
is_cycle = check_cycles(nodes, cons, from_idx, to_idx)
|
||||
u_cons = unflatten_connections(nodes, cons)
|
||||
is_cycle = check_cycles(nodes, u_cons, from_idx, to_idx)
|
||||
|
||||
choice = jnp.where(is_already_exist, 0, jnp.where(is_cycle, 1, 2))
|
||||
nodes, cons = jax.lax.switch(choice, [already_exist, cycle, successful])
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import jax
|
||||
from jax import numpy as jnp, Array
|
||||
from jax import jit, vmap
|
||||
|
||||
I_INT = jnp.iinfo(jnp.int32).max # infinite int
|
||||
EMPTY_NODE = jnp.full((1, 5), jnp.nan)
|
||||
EMPTY_CON = jnp.full((1, 4), jnp.nan)
|
||||
I_INT = np.iinfo(jnp.int32).max # infinite int
|
||||
EMPTY_NODE = np.full((1, 5), jnp.nan)
|
||||
EMPTY_CON = np.full((1, 4), jnp.nan)
|
||||
|
||||
|
||||
@jit
|
||||
@@ -58,8 +57,3 @@ def fetch_random(rand_key, mask, default=I_INT) -> Array:
|
||||
mask = jnp.where(true_cnt == 0, False, cumsum >= target)
|
||||
return fetch_first(mask, default)
|
||||
|
||||
@jit
|
||||
def argmin_with_mask(arr: Array, mask: Array) -> Array:
|
||||
masked_arr = jnp.where(mask, arr, jnp.inf)
|
||||
min_idx = jnp.argmin(masked_arr)
|
||||
return min_idx
|
||||
Reference in New Issue
Block a user