create state
This commit is contained in:
@@ -1,115 +0,0 @@
|
||||
import pickle
|
||||
|
||||
import jax
|
||||
from jax import numpy as jnp, jit, vmap
|
||||
|
||||
import numpy as np
|
||||
|
||||
from configs import Configer
|
||||
from algorithms.neat import initialize_genomes
|
||||
from algorithms.neat import tell
|
||||
from algorithms.neat import unflatten_connections, topological_sort, create_forward_function
|
||||
|
||||
jax.config.update("jax_disable_jit", True)
|
||||
|
||||
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
|
||||
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
|
||||
|
||||
def evaluate(forward_func):
|
||||
"""
|
||||
:param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size)
|
||||
:return:
|
||||
"""
|
||||
outs = forward_func(xor_inputs)
|
||||
outs = jax.device_get(outs)
|
||||
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
|
||||
return fitnesses
|
||||
|
||||
|
||||
def get_fitnesses(pop_nodes, pop_cons, pop_unflatten_connections, pop_topological_sort, forward_func):
|
||||
u_pop_cons = pop_unflatten_connections(pop_nodes, pop_cons)
|
||||
pop_seqs = pop_topological_sort(pop_nodes, u_pop_cons)
|
||||
func = lambda x: forward_func(x, pop_seqs, pop_nodes, u_pop_cons)
|
||||
|
||||
return evaluate(func)
|
||||
|
||||
|
||||
def equal(ar1, ar2):
|
||||
if ar1.shape != ar2.shape:
|
||||
return False
|
||||
|
||||
nan_mask1 = jnp.isnan(ar1)
|
||||
nan_mask2 = jnp.isnan(ar2)
|
||||
|
||||
return jnp.all((ar1 == ar2) | (nan_mask1 & nan_mask2))
|
||||
|
||||
def main():
|
||||
# initialize
|
||||
config = Configer.load_config("xor.ini")
|
||||
jit_config = Configer.create_jit_config(config) # config used in jit-able functions
|
||||
|
||||
P = config['pop_size']
|
||||
N = config['init_maximum_nodes']
|
||||
C = config['init_maximum_connections']
|
||||
S = config['init_maximum_species']
|
||||
randkey = jax.random.PRNGKey(6)
|
||||
np.random.seed(6)
|
||||
|
||||
pop_nodes, pop_cons = initialize_genomes(N, C, config)
|
||||
species_info = np.full((S, 4), np.nan)
|
||||
species_info[0, :] = 0, -np.inf, 0, P
|
||||
idx2species = np.zeros(P, dtype=np.float32)
|
||||
center_nodes = np.full((S, N, 5), np.nan)
|
||||
center_cons = np.full((S, C, 4), np.nan)
|
||||
center_nodes[0, :, :] = pop_nodes[0, :, :]
|
||||
center_cons[0, :, :] = pop_cons[0, :, :]
|
||||
generation = 0
|
||||
|
||||
pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons = jax.device_put(
|
||||
[pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons])
|
||||
|
||||
pop_unflatten_connections = jit(vmap(unflatten_connections))
|
||||
pop_topological_sort = jit(vmap(topological_sort))
|
||||
forward = create_forward_function(config)
|
||||
|
||||
|
||||
while True:
|
||||
fitness = get_fitnesses(pop_nodes, pop_cons, pop_unflatten_connections, pop_topological_sort, forward)
|
||||
|
||||
last_max = np.max(fitness)
|
||||
|
||||
info = [fitness, randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, generation,
|
||||
jit_config]
|
||||
|
||||
with open('list.pkl', 'wb') as f:
|
||||
# 使用pickle模块的dump函数来保存list
|
||||
pickle.dump(info, f)
|
||||
|
||||
randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, generation = tell(
|
||||
fitness, randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, generation,
|
||||
jit_config)
|
||||
|
||||
fitness = get_fitnesses(pop_nodes, pop_cons, pop_unflatten_connections, pop_topological_sort, forward)
|
||||
current_max = np.max(fitness)
|
||||
print(last_max, current_max)
|
||||
assert current_max >= last_max, f"current_max: {current_max}, last_max: {last_max}"
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# main()
|
||||
config = Configer.load_config("xor.ini")
|
||||
pop_unflatten_connections = jit(vmap(unflatten_connections))
|
||||
pop_topological_sort = jit(vmap(topological_sort))
|
||||
forward = create_forward_function(config)
|
||||
|
||||
with open('list.pkl', 'rb') as f:
|
||||
# 使用pickle模块的dump函数来保存list
|
||||
fitness, randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, i, jit_config = pickle.load(
|
||||
f)
|
||||
|
||||
print(np.max(fitness))
|
||||
randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, _ = tell(
|
||||
fitness, randkey, pop_nodes, pop_cons, species_info, idx2species, center_nodes, center_cons, i,
|
||||
jit_config)
|
||||
fitness = get_fitnesses(pop_nodes, pop_cons, pop_unflatten_connections, pop_topological_sort, forward)
|
||||
print(np.max(fitness))
|
||||
@@ -1,22 +0,0 @@
|
||||
[basic]
|
||||
num_inputs = 6
|
||||
num_outputs = 3
|
||||
maximum_nodes = 50
|
||||
maximum_connections = 50
|
||||
maximum_species = 10
|
||||
forward_way = "single"
|
||||
random_seed = 42
|
||||
|
||||
[population]
|
||||
pop_size = 100
|
||||
|
||||
[gene-activation]
|
||||
activation_default = "sigmoid"
|
||||
activation_option_names = ['sigmoid', 'tanh', 'sin', 'gauss', 'relu', 'identity', 'inv', 'log', 'exp', 'abs', 'hat', 'square']
|
||||
activation_replace_rate = 0.1
|
||||
|
||||
[gene-aggregation]
|
||||
aggregation_default = "sum"
|
||||
aggregation_option_names = ['sum', 'product', 'max', 'min', 'maxabs', 'median', 'mean']
|
||||
aggregation_replace_rate = 0.1
|
||||
|
||||
@@ -1,63 +0,0 @@
|
||||
import evox
|
||||
import jax
|
||||
from jax import jit, vmap, numpy as jnp
|
||||
|
||||
from configs import Configer
|
||||
from algorithms.neat import create_forward_function, topological_sort, unflatten_connections
|
||||
from evox_adaptor import NEAT, Gym
|
||||
|
||||
if __name__ == '__main__':
|
||||
batch_policy = True
|
||||
key = jax.random.PRNGKey(42)
|
||||
|
||||
monitor = evox.monitors.StdSOMonitor()
|
||||
neat_config = Configer.load_config('acrobot.ini')
|
||||
origin_forward_func = create_forward_function(neat_config)
|
||||
|
||||
|
||||
def neat_transform(pop):
|
||||
P = neat_config['pop_size']
|
||||
N = neat_config['maximum_nodes']
|
||||
C = neat_config['maximum_connections']
|
||||
|
||||
pop_nodes = pop[:P * N * 5].reshape((P, N, 5))
|
||||
pop_cons = pop[P * N * 5:].reshape((P, C, 4))
|
||||
|
||||
u_pop_cons = vmap(unflatten_connections)(pop_nodes, pop_cons)
|
||||
pop_seqs = vmap(topological_sort)(pop_nodes, u_pop_cons)
|
||||
return pop_seqs, pop_nodes, u_pop_cons
|
||||
|
||||
# special policy for mountain car
|
||||
def neat_forward(genome, x):
|
||||
res = origin_forward_func(x, *genome)
|
||||
out = jnp.argmax(res) # {0, 1, 2}
|
||||
return out
|
||||
|
||||
|
||||
forward_func = lambda pop, x: origin_forward_func(x, *pop)
|
||||
|
||||
problem = Gym(
|
||||
policy=jit(vmap(neat_forward)),
|
||||
env_name="Acrobot-v1",
|
||||
env_options={"new_step_api": True},
|
||||
pop_size=100,
|
||||
)
|
||||
|
||||
# create a pipeline
|
||||
pipeline = evox.pipelines.StdPipeline(
|
||||
algorithm=NEAT(neat_config),
|
||||
problem=problem,
|
||||
pop_transform=jit(neat_transform),
|
||||
fitness_transform=monitor.record_fit,
|
||||
)
|
||||
# init the pipeline
|
||||
state = pipeline.init(key)
|
||||
|
||||
# run the pipeline for 10 steps
|
||||
for i in range(30):
|
||||
state = pipeline.step(state)
|
||||
print(i, monitor.get_min_fitness())
|
||||
|
||||
# obtain -62.0
|
||||
min_fitness = monitor.get_min_fitness()
|
||||
print(min_fitness)
|
||||
@@ -1,22 +0,0 @@
|
||||
[basic]
|
||||
num_inputs = 24
|
||||
num_outputs = 4
|
||||
maximum_nodes = 100
|
||||
maximum_connections = 200
|
||||
maximum_species = 10
|
||||
forward_way = "single"
|
||||
random_seed = 42
|
||||
|
||||
[population]
|
||||
pop_size = 100
|
||||
|
||||
[gene-activation]
|
||||
activation_default = "sigmoid"
|
||||
activation_option_names = ['sigmoid', 'tanh', 'sin', 'gauss', 'relu', 'identity', 'inv', 'log', 'exp', 'abs', 'hat', 'square']
|
||||
activation_replace_rate = 0.1
|
||||
|
||||
[gene-aggregation]
|
||||
aggregation_default = "sum"
|
||||
aggregation_option_names = ['sum', 'product', 'max', 'min', 'maxabs', 'median', 'mean']
|
||||
aggregation_replace_rate = 0.1
|
||||
|
||||
@@ -1,62 +0,0 @@
|
||||
import evox
|
||||
import jax
|
||||
from jax import jit, vmap, numpy as jnp
|
||||
|
||||
from configs import Configer
|
||||
from algorithms.neat import create_forward_function, topological_sort, unflatten_connections
|
||||
from evox_adaptor import NEAT, Gym
|
||||
|
||||
if __name__ == '__main__':
|
||||
batch_policy = True
|
||||
key = jax.random.PRNGKey(42)
|
||||
|
||||
monitor = evox.monitors.StdSOMonitor()
|
||||
neat_config = Configer.load_config('bipedalwalker.ini')
|
||||
origin_forward_func = create_forward_function(neat_config)
|
||||
|
||||
|
||||
def neat_transform(pop):
|
||||
P = neat_config['pop_size']
|
||||
N = neat_config['maximum_nodes']
|
||||
C = neat_config['maximum_connections']
|
||||
|
||||
pop_nodes = pop[:P * N * 5].reshape((P, N, 5))
|
||||
pop_cons = pop[P * N * 5:].reshape((P, C, 4))
|
||||
|
||||
u_pop_cons = vmap(unflatten_connections)(pop_nodes, pop_cons)
|
||||
pop_seqs = vmap(topological_sort)(pop_nodes, u_pop_cons)
|
||||
return pop_seqs, pop_nodes, u_pop_cons
|
||||
|
||||
# special policy for mountain car
|
||||
def neat_forward(genome, x):
|
||||
res = origin_forward_func(x, *genome)
|
||||
out = jnp.tanh(res) # (-1, 1)
|
||||
return out
|
||||
|
||||
|
||||
forward_func = lambda pop, x: origin_forward_func(x, *pop)
|
||||
|
||||
problem = Gym(
|
||||
policy=jit(vmap(neat_forward)),
|
||||
env_name="BipedalWalker-v3",
|
||||
pop_size=100,
|
||||
)
|
||||
|
||||
# create a pipeline
|
||||
pipeline = evox.pipelines.StdPipeline(
|
||||
algorithm=NEAT(neat_config),
|
||||
problem=problem,
|
||||
pop_transform=jit(neat_transform),
|
||||
fitness_transform=monitor.record_fit,
|
||||
)
|
||||
# init the pipeline
|
||||
state = pipeline.init(key)
|
||||
|
||||
# run the pipeline for 10 steps
|
||||
for i in range(30):
|
||||
state = pipeline.step(state)
|
||||
print(i, monitor.get_min_fitness())
|
||||
|
||||
# obtain 98.91529684268514
|
||||
min_fitness = monitor.get_min_fitness()
|
||||
print(min_fitness)
|
||||
@@ -1,11 +0,0 @@
|
||||
[basic]
|
||||
num_inputs = 4
|
||||
num_outputs = 1
|
||||
maximum_nodes = 50
|
||||
maximum_connections = 50
|
||||
maximum_species = 10
|
||||
forward_way = "single"
|
||||
random_seed = 42
|
||||
|
||||
[population]
|
||||
pop_size = 40
|
||||
@@ -1,63 +0,0 @@
|
||||
import evox
|
||||
import jax
|
||||
from jax import jit, vmap, numpy as jnp
|
||||
|
||||
from configs import Configer
|
||||
from algorithms.neat import create_forward_function, topological_sort, unflatten_connections
|
||||
from evox_adaptor import NEAT, Gym
|
||||
|
||||
if __name__ == '__main__':
|
||||
batch_policy = True
|
||||
key = jax.random.PRNGKey(42)
|
||||
|
||||
monitor = evox.monitors.StdSOMonitor()
|
||||
neat_config = Configer.load_config('cartpole.ini')
|
||||
origin_forward_func = create_forward_function(neat_config)
|
||||
|
||||
|
||||
def neat_transform(pop):
|
||||
P = neat_config['pop_size']
|
||||
N = neat_config['maximum_nodes']
|
||||
C = neat_config['maximum_connections']
|
||||
|
||||
pop_nodes = pop[:P * N * 5].reshape((P, N, 5))
|
||||
pop_cons = pop[P * N * 5:].reshape((P, C, 4))
|
||||
|
||||
u_pop_cons = vmap(unflatten_connections)(pop_nodes, pop_cons)
|
||||
pop_seqs = vmap(topological_sort)(pop_nodes, u_pop_cons)
|
||||
return pop_seqs, pop_nodes, u_pop_cons
|
||||
|
||||
# special policy for cartpole
|
||||
def neat_forward(genome, x):
|
||||
res = origin_forward_func(x, *genome)[0]
|
||||
out = jnp.where(res > 0.5, 1, 0)
|
||||
return out
|
||||
|
||||
|
||||
forward_func = lambda pop, x: origin_forward_func(x, *pop)
|
||||
|
||||
problem = Gym(
|
||||
policy=jit(vmap(neat_forward)),
|
||||
env_name="CartPole-v1",
|
||||
env_options={"new_step_api": True},
|
||||
pop_size=40,
|
||||
)
|
||||
|
||||
# create a pipeline
|
||||
pipeline = evox.pipelines.StdPipeline(
|
||||
algorithm=NEAT(neat_config),
|
||||
problem=problem,
|
||||
pop_transform=jit(neat_transform),
|
||||
fitness_transform=monitor.record_fit,
|
||||
)
|
||||
# init the pipeline
|
||||
state = pipeline.init(key)
|
||||
|
||||
# run the pipeline for 10 steps
|
||||
for i in range(10):
|
||||
state = pipeline.step(state)
|
||||
print(monitor.get_min_fitness())
|
||||
|
||||
# obtain 500
|
||||
min_fitness = monitor.get_min_fitness()
|
||||
print(min_fitness)
|
||||
@@ -1,14 +0,0 @@
|
||||
import gym
|
||||
|
||||
env = gym.make("CartPole-v1", new_step_api=True)
|
||||
print(env.reset())
|
||||
obs = env.reset()
|
||||
|
||||
print(obs)
|
||||
while True:
|
||||
action = env.action_space.sample()
|
||||
obs, reward, terminate, truncate, info = env.step(action)
|
||||
print(obs, info)
|
||||
if terminate | truncate:
|
||||
break
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
[basic]
|
||||
num_inputs = 2
|
||||
num_outputs = 1
|
||||
maximum_nodes = 50
|
||||
maximum_connections = 50
|
||||
maximum_species = 10
|
||||
forward_way = "single"
|
||||
random_seed = 42
|
||||
|
||||
[population]
|
||||
pop_size = 100
|
||||
|
||||
[gene-activation]
|
||||
activation_default = "sigmoid"
|
||||
activation_option_names = ['sigmoid', 'tanh', 'sin', 'gauss', 'relu', 'identity', 'inv', 'log', 'exp', 'abs', 'hat', 'square']
|
||||
activation_replace_rate = 0.1
|
||||
|
||||
[gene-aggregation]
|
||||
aggregation_default = "sum"
|
||||
aggregation_option_names = ['sum', 'product', 'max', 'min', 'maxabs', 'median', 'mean']
|
||||
aggregation_replace_rate = 0.1
|
||||
|
||||
@@ -1,63 +0,0 @@
|
||||
import evox
|
||||
import jax
|
||||
from jax import jit, vmap, numpy as jnp
|
||||
|
||||
from configs import Configer
|
||||
from algorithms.neat import create_forward_function, topological_sort, unflatten_connections
|
||||
from evox_adaptor import NEAT, Gym
|
||||
|
||||
if __name__ == '__main__':
|
||||
batch_policy = True
|
||||
key = jax.random.PRNGKey(42)
|
||||
|
||||
monitor = evox.monitors.StdSOMonitor()
|
||||
neat_config = Configer.load_config('mountain_car.ini')
|
||||
origin_forward_func = create_forward_function(neat_config)
|
||||
|
||||
|
||||
def neat_transform(pop):
|
||||
P = neat_config['pop_size']
|
||||
N = neat_config['maximum_nodes']
|
||||
C = neat_config['maximum_connections']
|
||||
|
||||
pop_nodes = pop[:P * N * 5].reshape((P, N, 5))
|
||||
pop_cons = pop[P * N * 5:].reshape((P, C, 4))
|
||||
|
||||
u_pop_cons = vmap(unflatten_connections)(pop_nodes, pop_cons)
|
||||
pop_seqs = vmap(topological_sort)(pop_nodes, u_pop_cons)
|
||||
return pop_seqs, pop_nodes, u_pop_cons
|
||||
|
||||
# special policy for mountain car
|
||||
def neat_forward(genome, x):
|
||||
res = origin_forward_func(x, *genome)
|
||||
out = jnp.tanh(res) # (-1, 1)
|
||||
return out
|
||||
|
||||
|
||||
forward_func = lambda pop, x: origin_forward_func(x, *pop)
|
||||
|
||||
problem = Gym(
|
||||
policy=jit(vmap(neat_forward)),
|
||||
env_name="MountainCarContinuous-v0",
|
||||
env_options={"new_step_api": True},
|
||||
pop_size=100,
|
||||
)
|
||||
|
||||
# create a pipeline
|
||||
pipeline = evox.pipelines.StdPipeline(
|
||||
algorithm=NEAT(neat_config),
|
||||
problem=problem,
|
||||
pop_transform=jit(neat_transform),
|
||||
fitness_transform=monitor.record_fit,
|
||||
)
|
||||
# init the pipeline
|
||||
state = pipeline.init(key)
|
||||
|
||||
# run the pipeline for 10 steps
|
||||
for i in range(30):
|
||||
state = pipeline.step(state)
|
||||
print(i, monitor.get_min_fitness())
|
||||
|
||||
# obtain 98.91529684268514
|
||||
min_fitness = monitor.get_min_fitness()
|
||||
print(min_fitness)
|
||||
@@ -1,18 +0,0 @@
|
||||
from functools import partial
|
||||
|
||||
from jax import numpy as jnp, jit
|
||||
|
||||
|
||||
@partial(jit, static_argnames=['reverse'])
|
||||
def rank_element(array, reverse=False):
|
||||
"""
|
||||
rank the element in the array.
|
||||
if reverse is True, the rank is from large to small.
|
||||
"""
|
||||
if reverse:
|
||||
array = -array
|
||||
return jnp.argsort(jnp.argsort(array))
|
||||
|
||||
|
||||
a = jnp.array([1, 5, 3, 5, 2, 1, 0])
|
||||
print(rank_element(a, reverse=True))
|
||||
14
examples/state_test.py
Normal file
14
examples/state_test.py
Normal file
@@ -0,0 +1,14 @@
|
||||
import jax
|
||||
from algorithm.state import State
|
||||
|
||||
@jax.jit
|
||||
def func(state: State, a):
|
||||
return state.update(a=a)
|
||||
|
||||
|
||||
state = State(c=1, b=2)
|
||||
print(state)
|
||||
|
||||
state = func(state, 1111111)
|
||||
|
||||
print(state)
|
||||
@@ -1,5 +0,0 @@
|
||||
[basic]
|
||||
forward_way = "common"
|
||||
|
||||
[population]
|
||||
fitness_threshold = 4
|
||||
@@ -1,31 +0,0 @@
|
||||
import jax
|
||||
import numpy as np
|
||||
|
||||
from configs import Configer
|
||||
from pipeline import Pipeline
|
||||
|
||||
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
|
||||
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
|
||||
|
||||
|
||||
def evaluate(forward_func):
|
||||
"""
|
||||
:param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size)
|
||||
:return:
|
||||
"""
|
||||
outs = forward_func(xor_inputs)
|
||||
outs = jax.device_get(outs)
|
||||
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
|
||||
return fitnesses
|
||||
|
||||
|
||||
def main():
|
||||
config = Configer.load_config("xor.ini")
|
||||
pipeline = Pipeline(config)
|
||||
nodes, cons = pipeline.auto_run(evaluate)
|
||||
# g = Genome(nodes, cons, config)
|
||||
# print(g)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -1,47 +0,0 @@
|
||||
[basic]
|
||||
num_inputs = 3
|
||||
num_outputs = 1
|
||||
maximum_nodes = 50
|
||||
maximum_connections = 50
|
||||
maximum_species = 10
|
||||
forward_way = "common"
|
||||
batch_size = 4
|
||||
random_seed = 42
|
||||
|
||||
[population]
|
||||
fitness_threshold = 8
|
||||
generation_limit = 1000
|
||||
fitness_criterion = "max"
|
||||
pop_size = 10000
|
||||
|
||||
[genome]
|
||||
compatibility_disjoint = 1.0
|
||||
compatibility_weight = 0.5
|
||||
conn_add_prob = 0.4
|
||||
conn_add_trials = 1
|
||||
conn_delete_prob = 0
|
||||
node_add_prob = 0.2
|
||||
node_delete_prob = 0
|
||||
|
||||
[species]
|
||||
compatibility_threshold = 3
|
||||
species_elitism = 1
|
||||
max_stagnation = 15
|
||||
genome_elitism = 2
|
||||
survival_threshold = 0.2
|
||||
min_species_size = 1
|
||||
spawn_number_move_rate = 0.5
|
||||
|
||||
[gene-bias]
|
||||
bias_init_mean = 0.0
|
||||
bias_init_std = 1.0
|
||||
bias_mutate_power = 0.5
|
||||
bias_mutate_rate = 0.7
|
||||
bias_replace_rate = 0.1
|
||||
|
||||
[gene-weight]
|
||||
weight_init_mean = 0.0
|
||||
weight_init_std = 1.0
|
||||
weight_mutate_power = 0.5
|
||||
weight_mutate_rate = 0.8
|
||||
weight_replace_rate = 0.1
|
||||
@@ -1,31 +0,0 @@
|
||||
import jax
|
||||
import numpy as np
|
||||
|
||||
from configs import Configer
|
||||
from pipeline import Pipeline
|
||||
|
||||
xor_inputs = np.array([[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1], [1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1]], dtype=np.float32)
|
||||
xor_outputs = np.array([[0], [1], [1], [0], [1], [0], [0], [1]], dtype=np.float32)
|
||||
|
||||
|
||||
def evaluate(forward_func):
|
||||
"""
|
||||
:param forward_func: (4: batch, 2: input size) -> (pop_size, 4: batch, 1: output size)
|
||||
:return:
|
||||
"""
|
||||
outs = forward_func(xor_inputs)
|
||||
outs = jax.device_get(outs)
|
||||
fitnesses = 8 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
|
||||
return fitnesses
|
||||
|
||||
|
||||
def main():
|
||||
config = Configer.load_config("xor3d.ini")
|
||||
pipeline = Pipeline(config)
|
||||
nodes, cons = pipeline.auto_run(evaluate)
|
||||
# g = Genome(nodes, cons, config)
|
||||
# print(g)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user