Files
tensorneat-mend/examples/evox_test.py
2023-06-27 18:47:47 +08:00

26 lines
586 B
Python

import jax
from jax import numpy as jnp
from evox import algorithms, problems, pipelines
from evox.monitors import StdSOMonitor
monitor = StdSOMonitor()
pso = algorithms.PSO(
lb=jnp.full(shape=(2,), fill_value=-32),
ub=jnp.full(shape=(2,), fill_value=32),
pop_size=100,
)
ackley = problems.classic.Ackley()
pipeline = pipelines.StdPipeline(pso, ackley, fitness_transform=monitor.record_fit)
key = jax.random.PRNGKey(42)
state = pipeline.init(key)
# run the pipeline for 100 steps
for i in range(100):
state = pipeline.step(state)
print(monitor.get_min_fitness())