complete HyperNEAT!
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
from .base import Algorithm
|
||||
from .state import State
|
||||
from .neat import NEAT
|
||||
from .config import Configer
|
||||
from .hyperneat import HyperNEAT
|
||||
|
||||
17
algorithm/base.py
Normal file
17
algorithm/base.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from typing import Callable
|
||||
|
||||
from .state import State
|
||||
|
||||
EMPTY = lambda *args: args
|
||||
|
||||
|
||||
class Algorithm:
|
||||
|
||||
def __init__(self):
|
||||
self.tell: Callable = EMPTY
|
||||
self.ask: Callable = EMPTY
|
||||
self.forward: Callable = EMPTY
|
||||
self.forward_transform: Callable = EMPTY
|
||||
|
||||
def setup(self, randkey, state=State()):
|
||||
pass
|
||||
@@ -1,69 +0,0 @@
|
||||
import os
|
||||
import warnings
|
||||
import configparser
|
||||
|
||||
import numpy as np
|
||||
|
||||
class Configer:
|
||||
|
||||
@classmethod
|
||||
def __load_default_config(cls):
|
||||
par_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
default_config_path = os.path.join(par_dir, "default_config.ini")
|
||||
return cls.__load_config(default_config_path)
|
||||
|
||||
@classmethod
|
||||
def __load_config(cls, config_path):
|
||||
c = configparser.ConfigParser()
|
||||
c.read(config_path)
|
||||
config = {}
|
||||
|
||||
for section in c.sections():
|
||||
for key, value in c.items(section):
|
||||
config[key] = eval(value)
|
||||
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def __check_redundant_config(cls, default_config, config):
|
||||
for key in config:
|
||||
if key not in default_config:
|
||||
warnings.warn(f"Redundant config: {key} in {config.name}")
|
||||
|
||||
@classmethod
|
||||
def __complete_config(cls, default_config, config):
|
||||
for key in default_config:
|
||||
if key not in config:
|
||||
config[key] = default_config[key]
|
||||
|
||||
@classmethod
|
||||
def load_config(cls, config_path=None):
|
||||
default_config = cls.__load_default_config()
|
||||
if config_path is None:
|
||||
config = {}
|
||||
elif not os.path.exists(config_path):
|
||||
warnings.warn(f"config file {config_path} not exist!")
|
||||
config = {}
|
||||
else:
|
||||
config = cls.__load_config(config_path)
|
||||
|
||||
cls.__check_redundant_config(default_config, config)
|
||||
cls.__complete_config(default_config, config)
|
||||
|
||||
cls.refactor_activation(config)
|
||||
cls.refactor_aggregation(config)
|
||||
|
||||
config['input_idx'] = np.arange(config['num_inputs'])
|
||||
config['output_idx'] = np.arange(config['num_inputs'], config['num_inputs'] + config['num_outputs'])
|
||||
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def refactor_activation(cls, config):
|
||||
config['activation_default'] = 0
|
||||
config['activation_options'] = np.arange(len(config['activation_option_names']))
|
||||
|
||||
@classmethod
|
||||
def refactor_aggregation(cls, config):
|
||||
config['aggregation_default'] = 0
|
||||
config['aggregation_options'] = np.arange(len(config['aggregation_option_names']))
|
||||
@@ -1,72 +0,0 @@
|
||||
[basic]
|
||||
num_inputs = 2
|
||||
num_outputs = 1
|
||||
maximum_nodes = 50
|
||||
maximum_conns = 100
|
||||
maximum_species = 10
|
||||
forward_way = "pop"
|
||||
batch_size = 4
|
||||
random_seed = 0
|
||||
network_type = "feedforward"
|
||||
activate_times = 10
|
||||
|
||||
[population]
|
||||
fitness_threshold = 3.9999
|
||||
generation_limit = 1000
|
||||
fitness_criterion = "max"
|
||||
pop_size = 50000
|
||||
|
||||
[genome]
|
||||
compatibility_disjoint = 1.0
|
||||
compatibility_weight = 0.5
|
||||
conn_add_prob = 0.4
|
||||
conn_add_trials = 1
|
||||
conn_delete_prob = 0
|
||||
node_add_prob = 0.2
|
||||
node_delete_prob = 0
|
||||
|
||||
[species]
|
||||
compatibility_threshold = 3.0
|
||||
species_elitism = 2
|
||||
max_stagnation = 15
|
||||
genome_elitism = 2
|
||||
survival_threshold = 0.2
|
||||
min_species_size = 1
|
||||
spawn_number_change_rate = 0.5
|
||||
|
||||
[gene-bias]
|
||||
bias_init_mean = 0.0
|
||||
bias_init_std = 1.0
|
||||
bias_mutate_power = 0.5
|
||||
bias_mutate_rate = 0.7
|
||||
bias_replace_rate = 0.1
|
||||
|
||||
[gene-response]
|
||||
response_init_mean = 1.0
|
||||
response_init_std = 0.0
|
||||
response_mutate_power = 0.0
|
||||
response_mutate_rate = 0.0
|
||||
response_replace_rate = 0.0
|
||||
|
||||
[gene-activation]
|
||||
activation_default = "sigmoid"
|
||||
activation_option_names = ["sigmoid"]
|
||||
activation_replace_rate = 0.0
|
||||
|
||||
[gene-aggregation]
|
||||
aggregation_default = "sum"
|
||||
aggregation_option_names = ["sum"]
|
||||
aggregation_replace_rate = 0.0
|
||||
|
||||
[gene-weight]
|
||||
weight_init_mean = 0.0
|
||||
weight_init_std = 1.0
|
||||
weight_mutate_power = 0.5
|
||||
weight_mutate_rate = 0.8
|
||||
weight_replace_rate = 0.1
|
||||
|
||||
[gene-enable]
|
||||
enable_mutate_rate = 0.01
|
||||
|
||||
[visualize]
|
||||
renumber_nodes = True
|
||||
@@ -0,0 +1,2 @@
|
||||
from .hyperneat import HyperNEAT
|
||||
from .substrate import BaseSubstrate
|
||||
|
||||
70
algorithm/hyperneat/hyperneat.py
Normal file
70
algorithm/hyperneat/hyperneat.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from typing import Type
|
||||
|
||||
import jax
|
||||
import numpy as np
|
||||
|
||||
from .substrate import BaseSubstrate, analysis_substrate
|
||||
from .hyperneat_gene import HyperNEATGene
|
||||
from algorithm import State, Algorithm, neat
|
||||
|
||||
|
||||
class HyperNEAT(Algorithm):
|
||||
|
||||
def __init__(self, config, gene_type: Type[neat.BaseGene], substrate: Type[BaseSubstrate]):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.gene_type = gene_type
|
||||
self.substrate = substrate
|
||||
self.neat = neat.NEAT(config, gene_type)
|
||||
|
||||
self.tell = create_tell(self.neat)
|
||||
self.forward_transform = create_forward_transform(config, self.neat)
|
||||
self.forward = HyperNEATGene.create_forward(config)
|
||||
|
||||
def setup(self, randkey, state=State()):
|
||||
state = state.update(
|
||||
below_threshold=self.config['below_threshold'],
|
||||
max_weight=self.config['max_weight']
|
||||
)
|
||||
|
||||
state = self.substrate.setup(state, self.config)
|
||||
h_input_idx, h_output_idx, h_hidden_idx, query_coors, correspond_keys = analysis_substrate(state)
|
||||
h_nodes = np.concatenate((h_input_idx, h_output_idx, h_hidden_idx))[..., np.newaxis]
|
||||
h_conns = np.zeros((correspond_keys.shape[0], 3), dtype=np.float32)
|
||||
h_conns[:, 0:2] = correspond_keys
|
||||
|
||||
state = state.update(
|
||||
# h is short for hyperneat
|
||||
h_input_idx=h_input_idx,
|
||||
h_output_idx=h_output_idx,
|
||||
h_hidden_idx=h_hidden_idx,
|
||||
query_coors=query_coors,
|
||||
correspond_keys=correspond_keys,
|
||||
h_nodes=h_nodes,
|
||||
h_conns=h_conns
|
||||
)
|
||||
state = self.neat.setup(randkey, state=state)
|
||||
|
||||
self.config['h_input_idx'] = h_input_idx
|
||||
self.config['h_output_idx'] = h_output_idx
|
||||
|
||||
return state
|
||||
|
||||
|
||||
def create_tell(neat_instance):
|
||||
def tell(state, fitness):
|
||||
return neat_instance.tell(state, fitness)
|
||||
|
||||
return tell
|
||||
|
||||
|
||||
def create_forward_transform(config, neat_instance):
|
||||
def forward_transform(state, nodes, conns):
|
||||
t = neat_instance.forward_transform(state, nodes, conns)
|
||||
batch_forward_func = jax.vmap(neat_instance.forward, in_axes=(0, None))
|
||||
query_res = batch_forward_func(state.query_coors, t) # hyperneat connections
|
||||
h_nodes = state.h_nodes
|
||||
h_conns = state.h_conns.at[:, 2:].set(query_res)
|
||||
return HyperNEATGene.forward_transform(state, h_nodes, h_conns)
|
||||
|
||||
return forward_transform
|
||||
54
algorithm/hyperneat/hyperneat_gene.py
Normal file
54
algorithm/hyperneat/hyperneat_gene.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import jax
|
||||
from jax import numpy as jnp, vmap
|
||||
|
||||
from algorithm.neat import BaseGene
|
||||
from algorithm.neat.gene import Activation
|
||||
from algorithm.neat.gene import Aggregation
|
||||
|
||||
|
||||
class HyperNEATGene(BaseGene):
|
||||
node_attrs = [] # no node attributes
|
||||
conn_attrs = ['weight']
|
||||
|
||||
@staticmethod
|
||||
def forward_transform(state, nodes, conns):
|
||||
N = nodes.shape[0]
|
||||
u_conns = jnp.zeros((N, N), dtype=jnp.float32)
|
||||
|
||||
in_keys = jnp.asarray(conns[:, 0], jnp.int32)
|
||||
out_keys = jnp.asarray(conns[:, 1], jnp.int32)
|
||||
weights = conns[:, 2]
|
||||
|
||||
u_conns = u_conns.at[in_keys, out_keys].set(weights)
|
||||
return nodes, u_conns
|
||||
|
||||
@staticmethod
|
||||
def create_forward(config):
|
||||
act = Activation.name2func[config['h_activation']]
|
||||
agg = Aggregation.name2func[config['h_aggregation']]
|
||||
|
||||
batch_act, batch_agg = vmap(act), vmap(agg)
|
||||
|
||||
def forward(inputs, transform):
|
||||
|
||||
inputs_with_bias = jnp.concatenate((inputs, jnp.ones((1,))), axis=0)
|
||||
nodes, weights = transform
|
||||
|
||||
input_idx = config['h_input_idx']
|
||||
output_idx = config['h_output_idx']
|
||||
|
||||
N = nodes.shape[0]
|
||||
vals = jnp.full((N,), 0.)
|
||||
|
||||
def body_func(i, values):
|
||||
values = values.at[input_idx].set(inputs_with_bias)
|
||||
nodes_ins = values * weights.T
|
||||
values = batch_agg(nodes_ins) # z = agg(ins)
|
||||
values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias
|
||||
values = batch_act(values) # z = act(z)
|
||||
return values
|
||||
|
||||
vals = jax.lax.fori_loop(0, config['h_activate_times'], body_func, vals)
|
||||
return vals[output_idx]
|
||||
|
||||
return forward
|
||||
2
algorithm/hyperneat/substrate/__init__.py
Normal file
2
algorithm/hyperneat/substrate/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .base import BaseSubstrate
|
||||
from .tools import analysis_substrate
|
||||
12
algorithm/hyperneat/substrate/base.py
Normal file
12
algorithm/hyperneat/substrate/base.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
class BaseSubstrate:
|
||||
|
||||
@staticmethod
|
||||
def setup(state, config):
|
||||
return state.update(
|
||||
input_coors=np.asarray(config['input_coors'], dtype=np.float32),
|
||||
output_coors=np.asarray(config['output_coors'], dtype=np.float32),
|
||||
hidden_coors=np.asarray(config['hidden_coors'], dtype=np.float32),
|
||||
)
|
||||
53
algorithm/hyperneat/substrate/tools.py
Normal file
53
algorithm/hyperneat/substrate/tools.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from typing import Type
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .base import BaseSubstrate
|
||||
|
||||
|
||||
def analysis_substrate(state):
|
||||
cd = state.input_coors.shape[1] # coordinate dimensions
|
||||
si = state.input_coors.shape[0] # input coordinate size
|
||||
so = state.output_coors.shape[0] # output coordinate size
|
||||
sh = state.hidden_coors.shape[0] # hidden coordinate size
|
||||
|
||||
input_idx = np.arange(si)
|
||||
output_idx = np.arange(si, si + so)
|
||||
hidden_idx = np.arange(si + so, si + so + sh)
|
||||
|
||||
total_conns = si * sh + sh * sh + sh * so
|
||||
query_coors = np.zeros((total_conns, cd * 2))
|
||||
correspond_keys = np.zeros((total_conns, 2))
|
||||
|
||||
# connect input to hidden
|
||||
aux_coors, aux_keys = cartesian_product(input_idx, hidden_idx, state.input_coors, state.hidden_coors)
|
||||
query_coors[0: si * sh, :] = aux_coors
|
||||
correspond_keys[0: si * sh, :] = aux_keys
|
||||
|
||||
# connect hidden to hidden
|
||||
aux_coors, aux_keys = cartesian_product(hidden_idx, hidden_idx, state.hidden_coors, state.hidden_coors)
|
||||
query_coors[si * sh: si * sh + sh * sh, :] = aux_coors
|
||||
correspond_keys[si * sh: si * sh + sh * sh, :] = aux_keys
|
||||
|
||||
# connect hidden to output
|
||||
aux_coors, aux_keys = cartesian_product(hidden_idx, output_idx, state.hidden_coors, state.output_coors)
|
||||
query_coors[si * sh + sh * sh:, :] = aux_coors
|
||||
correspond_keys[si * sh + sh * sh:, :] = aux_keys
|
||||
|
||||
return input_idx, output_idx, hidden_idx, query_coors, correspond_keys
|
||||
|
||||
|
||||
def cartesian_product(keys1, keys2, coors1, coors2):
|
||||
len1 = keys1.shape[0]
|
||||
len2 = keys2.shape[0]
|
||||
|
||||
repeated_coors1 = np.repeat(coors1, len2, axis=0)
|
||||
repeated_keys1 = np.repeat(keys1, len2)
|
||||
|
||||
tiled_coors2 = np.tile(coors2, (len1, 1))
|
||||
tiled_keys2 = np.tile(keys2, len1)
|
||||
|
||||
new_coors = np.concatenate((repeated_coors1, tiled_coors2), axis=1)
|
||||
correspond_keys = np.column_stack((repeated_keys1, tiled_keys2))
|
||||
|
||||
return new_coors, correspond_keys
|
||||
@@ -1,3 +1,2 @@
|
||||
from .neat import NEAT
|
||||
from .gene import NormalGene, RecurrentGene
|
||||
from .pipeline import Pipeline
|
||||
from .gene import BaseGene, NormalGene, RecurrentGene
|
||||
|
||||
@@ -33,12 +33,10 @@ class BaseGene:
|
||||
def distance_conn(state, conn1: Array, conn2: Array):
|
||||
return conn1
|
||||
|
||||
|
||||
@staticmethod
|
||||
def forward_transform(nodes, conns):
|
||||
def forward_transform(state, nodes, conns):
|
||||
return nodes, conns
|
||||
|
||||
|
||||
@staticmethod
|
||||
def create_forward(config):
|
||||
return None
|
||||
return None
|
||||
|
||||
@@ -4,7 +4,7 @@ from jax import Array, numpy as jnp
|
||||
from .base import BaseGene
|
||||
from .activation import Activation
|
||||
from .aggregation import Aggregation
|
||||
from ..utils import unflatten_connections, I_INT
|
||||
from algorithm.utils import unflatten_connections, I_INT
|
||||
from ..genome import topological_sort
|
||||
|
||||
|
||||
@@ -84,7 +84,7 @@ class NormalGene(BaseGene):
|
||||
return (con1[2] != con2[2]) + jnp.abs(con1[3] - con2[3]) # enable + weight
|
||||
|
||||
@staticmethod
|
||||
def forward_transform(nodes, conns):
|
||||
def forward_transform(state, nodes, conns):
|
||||
u_conns = unflatten_connections(nodes, conns)
|
||||
conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False)
|
||||
|
||||
|
||||
@@ -4,13 +4,13 @@ from jax import Array, numpy as jnp, vmap
|
||||
from .normal import NormalGene
|
||||
from .activation import Activation
|
||||
from .aggregation import Aggregation
|
||||
from ..utils import unflatten_connections, I_INT
|
||||
from algorithm.utils import unflatten_connections
|
||||
|
||||
|
||||
class RecurrentGene(NormalGene):
|
||||
|
||||
@staticmethod
|
||||
def forward_transform(nodes, conns):
|
||||
def forward_transform(state, nodes, conns):
|
||||
u_conns = unflatten_connections(nodes, conns)
|
||||
|
||||
# remove un-enable connections and remove enable attr
|
||||
|
||||
@@ -6,7 +6,7 @@ from jax import Array, numpy as jnp
|
||||
|
||||
from algorithm import State
|
||||
from ..gene import BaseGene
|
||||
from ..utils import fetch_first
|
||||
from algorithm.utils import fetch_first
|
||||
|
||||
|
||||
def initialize_genomes(state: State, gene_type: Type[BaseGene]):
|
||||
@@ -48,6 +48,7 @@ def count(nodes: Array, cons: Array):
|
||||
cons_cnt = jnp.sum(~jnp.isnan(cons[:, 0]))
|
||||
return node_cnt, cons_cnt
|
||||
|
||||
|
||||
def add_node(nodes: Array, cons: Array, new_key: int, attrs: Array) -> Tuple[Array, Array]:
|
||||
"""
|
||||
Add a new node to the genome.
|
||||
|
||||
@@ -6,7 +6,7 @@ Only used in feed-forward networks.
|
||||
import jax
|
||||
from jax import jit, Array, numpy as jnp
|
||||
|
||||
from ..utils import fetch_first, I_INT
|
||||
from algorithm.utils import fetch_first, I_INT
|
||||
|
||||
|
||||
@jit
|
||||
|
||||
@@ -4,9 +4,9 @@ import jax
|
||||
from jax import Array, numpy as jnp, vmap
|
||||
|
||||
from algorithm import State
|
||||
from .basic import add_node, add_connection, delete_node_by_idx, delete_connection_by_idx, count
|
||||
from .basic import add_node, add_connection, delete_node_by_idx, delete_connection_by_idx
|
||||
from .graph import check_cycles
|
||||
from ..utils import fetch_random, fetch_first, I_INT, unflatten_connections
|
||||
from algorithm.utils import fetch_random, fetch_first, I_INT, unflatten_connections
|
||||
from ..gene import BaseGene
|
||||
|
||||
|
||||
|
||||
@@ -3,22 +3,25 @@ from typing import Type
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from algorithm.state import State
|
||||
from algorithm import Algorithm, State
|
||||
from .gene import BaseGene
|
||||
from .genome import initialize_genomes
|
||||
from .population import create_tell
|
||||
|
||||
|
||||
class NEAT:
|
||||
class NEAT(Algorithm):
|
||||
def __init__(self, config, gene_type: Type[BaseGene]):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.gene_type = gene_type
|
||||
|
||||
self.tell_func = jax.jit(create_tell(config, self.gene_type))
|
||||
self.tell = create_tell(config, self.gene_type)
|
||||
self.ask = None
|
||||
self.forward = self.gene_type.create_forward(config)
|
||||
self.forward_transform = self.gene_type.forward_transform
|
||||
|
||||
def setup(self, randkey):
|
||||
|
||||
state = State(
|
||||
def setup(self, randkey, state=State()):
|
||||
state = state.update(
|
||||
P=self.config['pop_size'],
|
||||
N=self.config['maximum_nodes'],
|
||||
C=self.config['maximum_conns'],
|
||||
@@ -69,7 +72,4 @@ class NEAT:
|
||||
# move to device
|
||||
state = jax.device_put(state)
|
||||
|
||||
return state
|
||||
|
||||
def step(self, state, fitness):
|
||||
return self.tell_func(state, fitness)
|
||||
return state
|
||||
@@ -1,77 +0,0 @@
|
||||
import time
|
||||
from typing import Union, Callable
|
||||
|
||||
import jax
|
||||
from jax import vmap, jit
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Pipeline:
|
||||
"""
|
||||
Neat algorithm pipeline.
|
||||
"""
|
||||
|
||||
def __init__(self, config, algorithm):
|
||||
self.config = config
|
||||
self.algorithm = algorithm
|
||||
randkey = jax.random.PRNGKey(config['random_seed'])
|
||||
self.state = algorithm.setup(randkey)
|
||||
|
||||
self.best_genome = None
|
||||
self.best_fitness = float('-inf')
|
||||
self.generation_timestamp = time.time()
|
||||
|
||||
self.evaluate_time = 0
|
||||
|
||||
self.forward_func = algorithm.gene_type.create_forward(config)
|
||||
self.batch_forward_func = jit(vmap(self.forward_func, in_axes=(0, None)))
|
||||
self.pop_batch_forward_func = jit(vmap(self.batch_forward_func, in_axes=(None, 0)))
|
||||
|
||||
self.pop_transform_func = jit(vmap(algorithm.gene_type.forward_transform))
|
||||
|
||||
def ask(self):
|
||||
pop_transforms = self.pop_transform_func(self.state.pop_nodes, self.state.pop_conns)
|
||||
return lambda inputs: self.pop_batch_forward_func(inputs, pop_transforms)
|
||||
|
||||
def tell(self, fitness):
|
||||
self.state = self.algorithm.step(self.state, fitness)
|
||||
|
||||
def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"):
|
||||
for _ in range(self.config['generation_limit']):
|
||||
forward_func = self.ask()
|
||||
|
||||
fitnesses = fitness_func(forward_func)
|
||||
|
||||
if analysis is not None:
|
||||
if analysis == "default":
|
||||
self.default_analysis(fitnesses)
|
||||
else:
|
||||
assert callable(analysis), f"What the fuck you passed in? A {analysis}?"
|
||||
analysis(fitnesses)
|
||||
|
||||
if max(fitnesses) >= self.config['fitness_threshold']:
|
||||
print("Fitness limit reached!")
|
||||
return self.best_genome
|
||||
|
||||
self.tell(fitnesses)
|
||||
print("Generation limit reached!")
|
||||
return self.best_genome
|
||||
|
||||
def default_analysis(self, fitnesses):
|
||||
max_f, min_f, mean_f, std_f = max(fitnesses), min(fitnesses), np.mean(fitnesses), np.std(fitnesses)
|
||||
|
||||
new_timestamp = time.time()
|
||||
cost_time = new_timestamp - self.generation_timestamp
|
||||
self.generation_timestamp = new_timestamp
|
||||
|
||||
max_idx = np.argmax(fitnesses)
|
||||
if fitnesses[max_idx] > self.best_fitness:
|
||||
self.best_fitness = fitnesses[max_idx]
|
||||
self.best_genome = (self.state.pop_nodes[max_idx], self.state.pop_conns[max_idx])
|
||||
|
||||
member_count = jax.device_get(self.state.species_info[:, 3])
|
||||
species_sizes = [int(i) for i in member_count if i > 0]
|
||||
|
||||
print(f"Generation: {self.state.generation}",
|
||||
f"species: {len(species_sizes)}, {species_sizes}",
|
||||
f"fitness: {max_f}, {min_f}, {mean_f}, {std_f}, Cost time: {cost_time}")
|
||||
@@ -3,7 +3,7 @@ from typing import Type
|
||||
import jax
|
||||
from jax import numpy as jnp, vmap
|
||||
|
||||
from .utils import rank_elements, fetch_first
|
||||
from algorithm.utils import rank_elements, fetch_first
|
||||
from .genome import create_mutate, create_distance, crossover
|
||||
from .gene import BaseGene
|
||||
|
||||
|
||||
Reference in New Issue
Block a user