26 lines
586 B
Python
26 lines
586 B
Python
import jax
|
|
from jax import numpy as jnp
|
|
|
|
from evox import algorithms, problems, pipelines
|
|
from evox.monitors import StdSOMonitor
|
|
|
|
monitor = StdSOMonitor()
|
|
|
|
pso = algorithms.PSO(
|
|
lb=jnp.full(shape=(2,), fill_value=-32),
|
|
ub=jnp.full(shape=(2,), fill_value=32),
|
|
pop_size=100,
|
|
)
|
|
|
|
ackley = problems.classic.Ackley()
|
|
|
|
pipeline = pipelines.StdPipeline(pso, ackley, fitness_transform=monitor.record_fit)
|
|
|
|
key = jax.random.PRNGKey(42)
|
|
state = pipeline.init(key)
|
|
|
|
# run the pipeline for 100 steps
|
|
for i in range(100):
|
|
state = pipeline.step(state)
|
|
|
|
print(monitor.get_min_fitness()) |