use black format all files;

remove "return state" for functions which will be executed in vmap;
recover randkey as args in mutation methods
This commit is contained in:
wls2002
2024-05-26 15:46:04 +08:00
parent 79d53ea7af
commit cf69b916af
38 changed files with 932 additions and 582 deletions

View File

@@ -0,0 +1,52 @@
import jax, jax.numpy as jnp
from utils import Act
from algorithm.neat import *
import numpy as np
def main():
algorithm = NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=3,
num_outputs=1,
max_nodes=100,
max_conns=100,
),
pop_size=1000,
species_size=10,
compatibility_threshold=3.5,
),
mutation=DefaultMutation(
conn_add=0.4,
conn_delete=0,
node_add=0.9,
node_delete=0,
),
)
state = algorithm.setup(jax.random.key(0))
pop_nodes, pop_conns = algorithm.species.ask(state.species)
batch_transform = jax.vmap(algorithm.genome.transform)
batch_forward = jax.vmap(algorithm.forward, in_axes=(None, 0))
for _ in range(50):
winner, losser = jax.random.randint(state.randkey, (2, 1000), 0, 1000)
elite_mask = jnp.zeros((1000,), dtype=jnp.bool_)
elite_mask = elite_mask.at[:5].set(1)
state = algorithm.create_next_generation(jax.random.key(0), state, winner, losser, elite_mask)
pop_nodes, pop_conns = algorithm.species.ask(state.species)
transforms = batch_transform(pop_nodes, pop_conns)
outputs = batch_forward(jnp.array([1, 0, 1]), transforms)
try:
assert not jnp.any(jnp.isnan(outputs))
except:
print(_)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,42 @@
import jax, jax.numpy as jnp
from utils import Act
from algorithm.neat import *
import numpy as np
def main():
node_path = "../examples/brax/nan_node.npy"
conn_path = "../examples/brax/nan_conn.npy"
nodes = np.load(node_path)
conns = np.load(conn_path)
nodes, conns = jax.device_put([nodes, conns])
genome = DefaultGenome(
num_inputs=8,
num_outputs=2,
max_nodes=20,
max_conns=20,
node_gene=DefaultNodeGene(
activation_options=(Act.tanh,),
activation_default=Act.tanh,
)
)
transformed = genome.transform(nodes, conns)
seq, nodes, conns = transformed
print(seq)
exit(0)
# print(*transformed, sep='\n')
key = jax.random.key(0)
dummy_input = jnp.zeros((8,))
output = genome.forward(dummy_input, transformed)
print(output)
if __name__ == '__main__':
a = jnp.array([1, 3, 5, 6, 8])
b = jnp.array([1, 2, 3])
print(jnp.isin(a, b))
# main()

View File

@@ -7,21 +7,25 @@ import jax, jax.numpy as jnp
def test_default():
# index, bias, response, activation, aggregation
nodes = jnp.array([
[0, 0, 1, 0, 0], # in[0]
[1, 0, 1, 0, 0], # in[1]
[2, 0.5, 1, 0, 0], # out[0],
[3, 1, 1, 0, 0], # hidden[0],
[4, -1, 1, 0, 0], # hidden[1],
])
nodes = jnp.array(
[
[0, 0, 1, 0, 0], # in[0]
[1, 0, 1, 0, 0], # in[1]
[2, 0.5, 1, 0, 0], # out[0],
[3, 1, 1, 0, 0], # hidden[0],
[4, -1, 1, 0, 0], # hidden[1],
]
)
# in_node, out_node, enable, weight
conns = jnp.array([
[0, 3, 1, 0.5], # in[0] -> hidden[0]
[1, 4, 1, 0.5], # in[1] -> hidden[1]
[3, 2, 1, 0.5], # hidden[0] -> out[0]
[4, 2, 1, 0.5], # hidden[1] -> out[0]
])
conns = jnp.array(
[
[0, 3, 1, 0.5], # in[0] -> hidden[0]
[1, 4, 1, 0.5], # in[1] -> hidden[1]
[3, 2, 1, 0.5], # hidden[0] -> out[0]
[4, 2, 1, 0.5], # hidden[1] -> out[0]
]
)
genome = DefaultGenome(
num_inputs=2,
@@ -30,34 +34,37 @@ def test_default():
max_conns=4,
node_gene=DefaultNodeGene(
activation_default=Act.identity,
activation_options=(Act.identity, ),
activation_options=(Act.identity,),
aggregation_default=Agg.sum,
aggregation_options=(Agg.sum, ),
aggregation_options=(Agg.sum,),
),
)
state = genome.setup(State(randkey=jax.random.key(0)))
state, *transformed = genome.transform(state, nodes, conns)
print(*transformed, sep='\n')
transformed = genome.transform(state, nodes, conns)
print(*transformed, sep="\n")
inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]])
state, outputs = jax.jit(jax.vmap(genome.forward,
in_axes=(None, 0, None), out_axes=(None, 0)))(state, inputs, transformed)
inputs = jnp.array([[0, 0], [0, 1], [1, 0], [1, 1]])
outputs = jax.jit(jax.vmap(genome.forward, in_axes=(None, 0, None)))(
state, inputs, transformed
)
print(outputs)
assert jnp.allclose(outputs, jnp.array([[0.5], [0.75], [0.75], [1]]))
# expected: [[0.5], [0.75], [0.75], [1]]
print('\n-------------------------------------------------------\n')
print("\n-------------------------------------------------------\n")
conns = conns.at[0, 2].set(False) # disable in[0] -> hidden[0]
print(conns)
state, *transformed = genome.transform(state, nodes, conns)
print(*transformed, sep='\n')
transformed = genome.transform(state, nodes, conns)
print(*transformed, sep="\n")
inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]])
state, outputs = jax.vmap(genome.forward, in_axes=(None, 0, None), out_axes=(None, 0))(state, inputs, transformed)
inputs = jnp.array([[0, 0], [0, 1], [1, 0], [1, 1]])
outputs = jax.vmap(genome.forward, in_axes=(None, 0, None))(
state, inputs, transformed
)
print(outputs)
assert jnp.allclose(outputs, jnp.array([[0], [0.25], [0], [0.25]]))
# expected: [[0.5], [0.75], [0.5], [0.75]]
@@ -66,21 +73,25 @@ def test_default():
def test_recurrent():
# index, bias, response, activation, aggregation
nodes = jnp.array([
[0, 0, 1, 0, 0], # in[0]
[1, 0, 1, 0, 0], # in[1]
[2, 0.5, 1, 0, 0], # out[0],
[3, 1, 1, 0, 0], # hidden[0],
[4, -1, 1, 0, 0], # hidden[1],
])
nodes = jnp.array(
[
[0, 0, 1, 0, 0], # in[0]
[1, 0, 1, 0, 0], # in[1]
[2, 0.5, 1, 0, 0], # out[0],
[3, 1, 1, 0, 0], # hidden[0],
[4, -1, 1, 0, 0], # hidden[1],
]
)
# in_node, out_node, enable, weight
conns = jnp.array([
[0, 3, 1, 0.5], # in[0] -> hidden[0]
[1, 4, 1, 0.5], # in[1] -> hidden[1]
[3, 2, 1, 0.5], # hidden[0] -> out[0]
[4, 2, 1, 0.5], # hidden[1] -> out[0]
])
conns = jnp.array(
[
[0, 3, 1, 0.5], # in[0] -> hidden[0]
[1, 4, 1, 0.5], # in[1] -> hidden[1]
[3, 2, 1, 0.5], # hidden[0] -> out[0]
[4, 2, 1, 0.5], # hidden[1] -> out[0]
]
)
genome = RecurrentGenome(
num_inputs=2,
@@ -89,35 +100,38 @@ def test_recurrent():
max_conns=4,
node_gene=DefaultNodeGene(
activation_default=Act.identity,
activation_options=(Act.identity, ),
activation_options=(Act.identity,),
aggregation_default=Agg.sum,
aggregation_options=(Agg.sum, ),
aggregation_options=(Agg.sum,),
),
activate_time=3,
)
state = genome.setup(State(randkey=jax.random.key(0)))
state, *transformed = genome.transform(state, nodes, conns)
print(*transformed, sep='\n')
transformed = genome.transform(state, nodes, conns)
print(*transformed, sep="\n")
inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]])
state, outputs = jax.jit(jax.vmap(genome.forward,
in_axes=(None, 0, None), out_axes=(None, 0)))(state, inputs, transformed)
inputs = jnp.array([[0, 0], [0, 1], [1, 0], [1, 1]])
outputs = jax.jit(jax.vmap(genome.forward, in_axes=(None, 0, None)))(
state, inputs, transformed
)
print(outputs)
assert jnp.allclose(outputs, jnp.array([[0.5], [0.75], [0.75], [1]]))
# expected: [[0.5], [0.75], [0.75], [1]]
print('\n-------------------------------------------------------\n')
print("\n-------------------------------------------------------\n")
conns = conns.at[0, 2].set(False) # disable in[0] -> hidden[0]
print(conns)
state, *transformed = genome.transform(state, nodes, conns)
print(*transformed, sep='\n')
transformed = genome.transform(state, nodes, conns)
print(*transformed, sep="\n")
inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]])
state, outputs = jax.vmap(genome.forward, in_axes=(None, 0, None), out_axes=(None, 0))(state, inputs, transformed)
inputs = jnp.array([[0, 0], [0, 1], [1, 0], [1, 1]])
outputs = jax.vmap(genome.forward, in_axes=(None, 0, None))(
state, inputs, transformed
)
print(outputs)
assert jnp.allclose(outputs, jnp.array([[0], [0.25], [0], [0.25]]))
# expected: [[0.5], [0.75], [0.5], [0.75]]
# expected: [[0.5], [0.75], [0.5], [0.75]]

View File

@@ -0,0 +1,35 @@
import jax, jax.numpy as jnp
from utils import Act
from algorithm.neat import *
import numpy as np
def main():
node_path = "../examples/brax/nan_node.npy"
conn_path = "../examples/brax/nan_conn.npy"
nodes = np.load(node_path)
conns = np.load(conn_path)
nodes, conns = jax.device_put([nodes, conns])
genome = DefaultGenome(
num_inputs=8,
num_outputs=2,
max_nodes=20,
max_conns=20,
node_gene=DefaultNodeGene(
activation_options=(Act.tanh,),
activation_default=Act.tanh,
)
)
transformed = genome.transform(nodes, conns)
print(*transformed, sep='\n')
key = jax.random.key(0)
dummy_input = jnp.zeros((8,))
output = genome.forward(dummy_input, transformed)
print(output)
if __name__ == '__main__':
main()