change a lot a lot a lot!!!!!!!

This commit is contained in:
wls2002
2023-07-24 02:16:02 +08:00
parent 48f90c7eef
commit ac295c1921
49 changed files with 1138 additions and 1460 deletions

View File

@@ -1,11 +1,28 @@
import numpy as np
import jax.numpy as jnp
import jax
from jax import numpy as jnp
a = jnp.zeros((5, 5))
k1 = jnp.array([1, 2, 3])
k2 = jnp.array([2, 3, 4])
v = jnp.array([1, 1, 1])
from config import Config
from core import Genome
a = a.at[k1, k2].set(v)
config = Config()
from dataclasses import asdict
print(asdict(config))
pop_nodes = jnp.ones((Config.basic.pop_size, Config.neat.maximum_nodes, 3))
pop_conns = jnp.ones((Config.basic.pop_size, Config.neat.maximum_conns, 5))
pop_genomes1 = jax.vmap(Genome)(pop_nodes, pop_conns)
pop_genomes2 = Genome(pop_nodes, pop_conns)
print(pop_genomes)
print(pop_genomes[0])
@jax.vmap
def pop_cnts(genome):
return genome.count()
cnts = pop_cnts(pop_genomes)
print(cnts)
print(a)

19
examples/b.py Normal file
View File

@@ -0,0 +1,19 @@
from enum import Enum
from jax import jit
class NetworkType(Enum):
ANN = 0
SNN = 1
LSTM = 2
@jit
def func(d):
return d[0] + 1
d = {0: 1, 1: NetworkType.ANN.value}
print(func(d))

View File

@@ -1,44 +0,0 @@
import jax
import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class
@register_pytree_node_class
class Genome:
def __init__(self, nodes, conns):
self.nodes = nodes
self.conns = conns
def update_nodes(self, nodes):
return Genome(nodes, self.conns)
def update_conns(self, conns):
return Genome(self.nodes, conns)
def tree_flatten(self):
children = self.nodes, self.conns
aux_data = None
return children, aux_data
@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(*children)
def __repr__(self):
return f"Genome ({self.nodes}, \n\t{self.conns})"
@jax.jit
def add_node(self, a: int):
nodes = self.nodes.at[0, :].set(a)
return self.update_nodes(nodes)
nodes, conns = jnp.array([[1, 2, 3, 4, 5]]), jnp.array([[1, 2, 3, 4]])
g = Genome(nodes, conns)
print(g)
g = g.add_node(1)
print(g)
g = jax.jit(g.add_node)(2)
print(g)

View File

@@ -1,12 +0,0 @@
[basic]
activate_times = 5
fitness_threshold = 4
[population]
pop_size = 1000
[neat]
network_type = "recurrent"
num_inputs = 4
num_outputs = 1

View File

@@ -1,10 +1,10 @@
import jax
import numpy as np
from config import Config, BasicConfig
from pipeline import Pipeline
from config import Configer
from algorithm import NEAT
from algorithm.neat import RecurrentGene
from algorithm.neat.gene import NormalGene, NormalGeneConfig
from algorithm.neat.neat import NEAT
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
@@ -21,13 +21,11 @@ def evaluate(forward_func):
return fitnesses
def main():
config = Configer.load_config("xor.ini")
algorithm = NEAT(config, RecurrentGene)
pipeline = Pipeline(config, algorithm)
best = pipeline.auto_run(evaluate)
print(best)
if __name__ == '__main__':
main()
config = Config(
basic=BasicConfig(fitness_target=4),
gene=NormalGeneConfig()
)
algorithm = NEAT(config, NormalGene)
pipeline = Pipeline(config, algorithm)
pipeline.auto_run(evaluate)

View File

@@ -1,33 +0,0 @@
import jax
import numpy as np
from pipeline import Pipeline
from config import Configer
from algorithm import NEAT, HyperNEAT
from algorithm.neat import RecurrentGene
from algorithm.hyperneat import BaseSubstrate
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
def evaluate(forward_func):
"""
:param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size)
:return:
"""
outs = forward_func(xor_inputs)
outs = jax.device_get(outs)
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
return fitnesses
def main():
config = Configer.load_config("xor.ini")
algorithm = HyperNEAT(config, RecurrentGene, BaseSubstrate)
pipeline = Pipeline(config, algorithm)
pipeline.auto_run(evaluate)
if __name__ == '__main__':
main()