Modified code, prepare for pmap;
This commit is contained in:
@@ -1,27 +0,0 @@
|
||||
import jax
|
||||
from jax import numpy as jnp
|
||||
|
||||
from config import Config
|
||||
from core import Genome
|
||||
|
||||
config = Config()
|
||||
from dataclasses import asdict
|
||||
|
||||
print(asdict(config))
|
||||
|
||||
pop_nodes = jnp.ones((Config.basic.pop_size, Config.neat.maximum_nodes, 3))
|
||||
pop_conns = jnp.ones((Config.basic.pop_size, Config.neat.maximum_conns, 5))
|
||||
|
||||
pop_genomes = Genome(pop_nodes, pop_conns)
|
||||
|
||||
print(pop_genomes)
|
||||
print(pop_genomes[0: 20])
|
||||
|
||||
@jax.vmap
|
||||
def pop_cnts(genome):
|
||||
return genome.count()
|
||||
|
||||
cnts = pop_cnts(pop_genomes)
|
||||
|
||||
print(cnts)
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
from enum import Enum
|
||||
from jax import jit
|
||||
|
||||
class NetworkType(Enum):
|
||||
ANN = 0
|
||||
SNN = 1
|
||||
LSTM = 2
|
||||
|
||||
|
||||
|
||||
|
||||
@jit
|
||||
def func(d):
|
||||
return d[0] + 1
|
||||
|
||||
|
||||
d = {0: 1, 1: NetworkType.ANN.value}
|
||||
n = None
|
||||
|
||||
print(n or d)
|
||||
print(d)
|
||||
|
||||
print(func(d))
|
||||
@@ -27,8 +27,8 @@ if __name__ == '__main__':
|
||||
pop_size=10000
|
||||
),
|
||||
neat=NeatConfig(
|
||||
maximum_nodes=50,
|
||||
maximum_conns=100,
|
||||
maximum_nodes=20,
|
||||
maximum_conns=50,
|
||||
),
|
||||
gene=NormalGeneConfig()
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user