perfect! fix bug about jax auto recompile

add task xor-3d
This commit is contained in:
wls2002
2023-07-02 22:15:26 +08:00
parent e711146f41
commit c4d34e877b
11 changed files with 234 additions and 104 deletions

View File

@@ -36,8 +36,8 @@ def initialize_genomes(N: int, C: int, config: Dict) -> Tuple[NDArray, NDArray]:
assert config['num_inputs'] * config['num_outputs'] + 1 <= C, \
f"Too small C: {C} for input_size: {config['num_inputs']} and output_size: {config['num_outputs']}!"
pop_nodes = np.full((config['pop_size'], N, 5), np.nan)
pop_cons = np.full((config['pop_size'], C, 4), np.nan)
pop_nodes = np.full((config['pop_size'], N, 5), np.nan, dtype=np.float32)
pop_cons = np.full((config['pop_size'], C, 4), np.nan, dtype=np.float32)
input_idx = config['input_idx']
output_idx = config['output_idx']
@@ -59,7 +59,7 @@ def initialize_genomes(N: int, C: int, config: Dict) -> Tuple[NDArray, NDArray]:
pop_cons[:, :p, 0] = grid_a
pop_cons[:, :p, 1] = grid_b
pop_cons[:, :p, 2] = np.random.normal(loc=config['weight_init_mean'], scale=config['weight_init_std'],
size=(config['pop_size'], p))
size=(config['pop_size'], p))
pop_cons[:, :p, 3] = 1
return pop_nodes, pop_cons