Modified code, prepare for pmap;

This commit is contained in:
wls2002
2023-07-31 15:11:39 +08:00
parent 86d085abcd
commit 85318f98f3
4 changed files with 2 additions and 89 deletions

View File

@@ -1,27 +0,0 @@
import jax
from jax import numpy as jnp
from config import Config
from core import Genome
config = Config()
from dataclasses import asdict
print(asdict(config))
pop_nodes = jnp.ones((Config.basic.pop_size, Config.neat.maximum_nodes, 3))
pop_conns = jnp.ones((Config.basic.pop_size, Config.neat.maximum_conns, 5))
pop_genomes = Genome(pop_nodes, pop_conns)
print(pop_genomes)
print(pop_genomes[0: 20])
@jax.vmap
def pop_cnts(genome):
return genome.count()
cnts = pop_cnts(pop_genomes)
print(cnts)