Files
tensorneat-mend/examples/a.py
2023-07-24 02:16:02 +08:00

29 lines
561 B
Python

import jax
from jax import numpy as jnp
from config import Config
from core import Genome
config = Config()
from dataclasses import asdict
print(asdict(config))
pop_nodes = jnp.ones((Config.basic.pop_size, Config.neat.maximum_nodes, 3))
pop_conns = jnp.ones((Config.basic.pop_size, Config.neat.maximum_conns, 5))
pop_genomes1 = jax.vmap(Genome)(pop_nodes, pop_conns)
pop_genomes2 = Genome(pop_nodes, pop_conns)
print(pop_genomes)
print(pop_genomes[0])
@jax.vmap
def pop_cnts(genome):
return genome.count()
cnts = pop_cnts(pop_genomes)
print(cnts)