Modified code, prepare for pmap;

This commit is contained in:
wls2002
2023-07-31 15:11:39 +08:00
parent 86d085abcd
commit 85318f98f3
4 changed files with 2 additions and 89 deletions

View File

@@ -1,27 +0,0 @@
import jax
from jax import numpy as jnp
from config import Config
from core import Genome
config = Config()
from dataclasses import asdict
print(asdict(config))
pop_nodes = jnp.ones((Config.basic.pop_size, Config.neat.maximum_nodes, 3))
pop_conns = jnp.ones((Config.basic.pop_size, Config.neat.maximum_conns, 5))
pop_genomes = Genome(pop_nodes, pop_conns)
print(pop_genomes)
print(pop_genomes[0: 20])
@jax.vmap
def pop_cnts(genome):
return genome.count()
cnts = pop_cnts(pop_genomes)
print(cnts)

View File

@@ -1,23 +0,0 @@
from enum import Enum
from jax import jit
class NetworkType(Enum):
ANN = 0
SNN = 1
LSTM = 2
@jit
def func(d):
return d[0] + 1
d = {0: 1, 1: NetworkType.ANN.value}
n = None
print(n or d)
print(d)
print(func(d))

View File

@@ -27,8 +27,8 @@ if __name__ == '__main__':
pop_size=10000
),
neat=NeatConfig(
maximum_nodes=50,
maximum_conns=100,
maximum_nodes=20,
maximum_conns=50,
),
gene=NormalGeneConfig()
)