This commit is related to issue: https://github.com/EMI-Group/tensorneat/issues/11
1. Add origin_node and origin_conn. 2. Change the behavior of crossover and mutation. Now, TensorNEAT will use all fix_attrs(include historical marker if it has one) as identifier for gene in crossover and distance calculation. 3. Other slightly change. 4. Add two related examples: xor_origin and hopper_origin 5. Add related test file.
This commit is contained in:
247
test/origin_operations_test.py
Normal file
247
test/origin_operations_test.py
Normal file
@@ -0,0 +1,247 @@
|
||||
import jax, jax.numpy as jnp
|
||||
from tensorneat.genome.operations import (
|
||||
DefaultMutation,
|
||||
DefaultDistance,
|
||||
DefaultCrossover,
|
||||
)
|
||||
from tensorneat.genome import (
|
||||
DefaultGenome,
|
||||
DefaultNode,
|
||||
DefaultConn,
|
||||
OriginNode,
|
||||
OriginConn,
|
||||
)
|
||||
from tensorneat.genome.utils import add_node, add_conn
|
||||
|
||||
origin_genome = DefaultGenome(
|
||||
node_gene=OriginNode(response_init_std=1),
|
||||
conn_gene=OriginConn(),
|
||||
mutation=DefaultMutation(conn_add=1, node_add=1, conn_delete=0, node_delete=0),
|
||||
crossover=DefaultCrossover(),
|
||||
distance=DefaultDistance(),
|
||||
num_inputs=3,
|
||||
num_outputs=1,
|
||||
max_nodes=6,
|
||||
max_conns=6,
|
||||
)
|
||||
|
||||
default_genome = DefaultGenome(
|
||||
node_gene=DefaultNode(response_init_std=1),
|
||||
conn_gene=DefaultConn(),
|
||||
mutation=DefaultMutation(conn_add=1, node_add=1, conn_delete=0, node_delete=0),
|
||||
crossover=DefaultCrossover(),
|
||||
distance=DefaultDistance(),
|
||||
num_inputs=3,
|
||||
num_outputs=1,
|
||||
max_nodes=6,
|
||||
max_conns=6,
|
||||
)
|
||||
|
||||
state = default_genome.setup()
|
||||
state = origin_genome.setup(state)
|
||||
|
||||
randkey = jax.random.PRNGKey(42)
|
||||
|
||||
|
||||
def mutation_default():
|
||||
nodes, conns = default_genome.initialize(state, randkey)
|
||||
print("old genome:\n", default_genome.repr(state, nodes, conns))
|
||||
|
||||
nodes, conns = default_genome.execute_mutation(
|
||||
state,
|
||||
randkey,
|
||||
nodes,
|
||||
conns,
|
||||
new_node_key=jnp.asarray(10),
|
||||
new_conn_keys=jnp.array([20, 21, 22]),
|
||||
)
|
||||
|
||||
# new_conn_keys is not used in default genome
|
||||
print("new genome:\n", default_genome.repr(state, nodes, conns))
|
||||
|
||||
|
||||
def mutation_origin():
|
||||
nodes, conns = origin_genome.initialize(state, randkey)
|
||||
print(conns)
|
||||
print("old genome:\n", origin_genome.repr(state, nodes, conns))
|
||||
|
||||
nodes, conns = origin_genome.execute_mutation(
|
||||
state,
|
||||
randkey,
|
||||
nodes,
|
||||
conns,
|
||||
new_node_key=jnp.asarray(10),
|
||||
new_conn_keys=jnp.array([20, 21, 22]),
|
||||
)
|
||||
print(conns)
|
||||
# new_conn_keys is used in origin genome
|
||||
print("new genome:\n", origin_genome.repr(state, nodes, conns))
|
||||
|
||||
def distance_default():
|
||||
nodes, conns = default_genome.initialize(state, randkey)
|
||||
nodes = add_node(
|
||||
nodes,
|
||||
fix_attrs=jnp.asarray([10]),
|
||||
custom_attrs=default_genome.node_gene.new_identity_attrs(state)
|
||||
)
|
||||
conns1 = add_conn(
|
||||
conns,
|
||||
fix_attrs=jnp.array([0, 10]), # in-idx, out-idx
|
||||
custom_attrs=default_genome.conn_gene.new_zero_attrs(state)
|
||||
)
|
||||
conns2 = add_conn(
|
||||
conns,
|
||||
fix_attrs=jnp.array([0, 10]), # in-idx, out-idx
|
||||
custom_attrs=default_genome.conn_gene.new_random_attrs(state, randkey)
|
||||
)
|
||||
print("genome1:\n", default_genome.repr(state, nodes, conns1))
|
||||
print("genome2:\n", default_genome.repr(state, nodes, conns2))
|
||||
|
||||
distance = default_genome.execute_distance(state, nodes, conns1, nodes, conns2)
|
||||
print("distance: ", distance)
|
||||
|
||||
def distance_origin_case1():
|
||||
"""
|
||||
distance with different historical marker
|
||||
"""
|
||||
nodes, conns = origin_genome.initialize(state, randkey)
|
||||
nodes = add_node(
|
||||
nodes,
|
||||
fix_attrs=jnp.asarray([10]),
|
||||
custom_attrs=origin_genome.node_gene.new_identity_attrs(state)
|
||||
)
|
||||
conns1 = add_conn(
|
||||
conns,
|
||||
fix_attrs=jnp.array([0, 10, 99]), # in-idx, out-idx, historical mark
|
||||
custom_attrs=origin_genome.conn_gene.new_zero_attrs(state)
|
||||
)
|
||||
conns2 = add_conn(
|
||||
conns,
|
||||
fix_attrs=jnp.array([0, 10, 88]), # in-idx, out-idx, historical mark
|
||||
custom_attrs=origin_genome.conn_gene.new_random_attrs(state, randkey)
|
||||
)
|
||||
print("genome1:\n", origin_genome.repr(state, nodes, conns1))
|
||||
print("genome2:\n", origin_genome.repr(state, nodes, conns2))
|
||||
|
||||
distance = origin_genome.execute_distance(state, nodes, conns1, nodes, conns2)
|
||||
print("distance: ", distance)
|
||||
|
||||
def distance_origin_case2():
|
||||
"""
|
||||
distance with same historical marker
|
||||
"""
|
||||
nodes, conns = origin_genome.initialize(state, randkey)
|
||||
nodes = add_node(
|
||||
nodes,
|
||||
fix_attrs=jnp.asarray([10]),
|
||||
custom_attrs=origin_genome.node_gene.new_identity_attrs(state)
|
||||
)
|
||||
conns1 = add_conn(
|
||||
conns,
|
||||
fix_attrs=jnp.array([0, 10, 99]), # in-idx, out-idx, historical mark
|
||||
custom_attrs=origin_genome.conn_gene.new_zero_attrs(state)
|
||||
)
|
||||
conns2 = add_conn(
|
||||
conns,
|
||||
fix_attrs=jnp.array([0, 10, 99]), # in-idx, out-idx, historical mark
|
||||
custom_attrs=origin_genome.conn_gene.new_random_attrs(state, randkey)
|
||||
)
|
||||
print("genome1:\n", origin_genome.repr(state, nodes, conns1))
|
||||
print("genome2:\n", origin_genome.repr(state, nodes, conns2))
|
||||
|
||||
distance = origin_genome.execute_distance(state, nodes, conns1, nodes, conns2)
|
||||
print("distance: ", distance)
|
||||
|
||||
def crossover_origin_case1():
|
||||
"""
|
||||
crossover with different historical marker
|
||||
"""
|
||||
nodes, conns = origin_genome.initialize(state, randkey)
|
||||
nodes = add_node(
|
||||
nodes,
|
||||
fix_attrs=jnp.asarray([10]),
|
||||
custom_attrs=origin_genome.node_gene.new_identity_attrs(state)
|
||||
)
|
||||
conns1 = add_conn(
|
||||
conns,
|
||||
fix_attrs=jnp.array([0, 10, 99]), # in-idx, out-idx, historical mark
|
||||
custom_attrs=origin_genome.conn_gene.new_zero_attrs(state)
|
||||
)
|
||||
conns2 = add_conn(
|
||||
conns,
|
||||
fix_attrs=jnp.array([0, 10, 88]), # in-idx, out-idx, historical mark
|
||||
custom_attrs=origin_genome.conn_gene.new_random_attrs(state, randkey)
|
||||
)
|
||||
print("genome1:\n", origin_genome.repr(state, nodes, conns1))
|
||||
print("genome2:\n", origin_genome.repr(state, nodes, conns2))
|
||||
|
||||
# (0, 10)'s weight must be 0 (disjoint gene, use fitter)
|
||||
child_nodes, child_conns = origin_genome.execute_crossover(state, randkey, nodes, conns1, nodes, conns2)
|
||||
print("child:\n", origin_genome.repr(state, child_nodes, child_conns))
|
||||
|
||||
def crossover_origin_case2():
|
||||
"""
|
||||
crossover with same historical marker
|
||||
"""
|
||||
nodes, conns = origin_genome.initialize(state, randkey)
|
||||
nodes = add_node(
|
||||
nodes,
|
||||
fix_attrs=jnp.asarray([10]),
|
||||
custom_attrs=origin_genome.node_gene.new_identity_attrs(state)
|
||||
)
|
||||
conns1 = add_conn(
|
||||
conns,
|
||||
fix_attrs=jnp.array([0, 10, 99]), # in-idx, out-idx, historical mark
|
||||
custom_attrs=origin_genome.conn_gene.new_zero_attrs(state)
|
||||
)
|
||||
conns2 = add_conn(
|
||||
conns,
|
||||
fix_attrs=jnp.array([0, 10, 99]), # in-idx, out-idx, historical mark
|
||||
custom_attrs=origin_genome.conn_gene.new_random_attrs(state, randkey)
|
||||
)
|
||||
print("genome1:\n", origin_genome.repr(state, nodes, conns1))
|
||||
print("genome2:\n", origin_genome.repr(state, nodes, conns2))
|
||||
|
||||
# (0, 10)'s weight might be random or zero (homologous gene)
|
||||
|
||||
# zero case:
|
||||
child_nodes, child_conns = origin_genome.execute_crossover(state, jax.random.key(99), nodes, conns1, nodes, conns2)
|
||||
print("child_zero:\n", origin_genome.repr(state, child_nodes, child_conns))
|
||||
|
||||
# random case:
|
||||
child_nodes, child_conns = origin_genome.execute_crossover(state, jax.random.key(0), nodes, conns1, nodes, conns2)
|
||||
print("child_random:\n", origin_genome.repr(state, child_nodes, child_conns))
|
||||
|
||||
def crossover_origin_case3():
|
||||
"""
|
||||
test examine it use random gene rather than attribute exchange
|
||||
"""
|
||||
nodes, conns = origin_genome.initialize(state, randkey)
|
||||
nodes1 = add_node(
|
||||
nodes,
|
||||
fix_attrs=jnp.asarray([10]),
|
||||
custom_attrs=jnp.array([1, 2, 0, 0])
|
||||
)
|
||||
nodes2 = add_node(
|
||||
nodes,
|
||||
fix_attrs=jnp.asarray([10]),
|
||||
custom_attrs=jnp.array([100, 200, 0, 0])
|
||||
)
|
||||
|
||||
# [1, 2] case
|
||||
child_nodes, child_conns = origin_genome.execute_crossover(state, jax.random.key(99), nodes1, conns, nodes2, conns)
|
||||
print("child1:\n", origin_genome.repr(state, child_nodes, child_conns))
|
||||
|
||||
# [100, 200] case
|
||||
child_nodes, child_conns = origin_genome.execute_crossover(state, jax.random.key(1), nodes1, conns, nodes2, conns)
|
||||
print("child2:\n", origin_genome.repr(state, child_nodes, child_conns))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# mutation_origin()
|
||||
# distance_default()
|
||||
# distance_origin_case1()
|
||||
# distance_origin_case2()
|
||||
# crossover_origin_case1()
|
||||
# crossover_origin_case2()
|
||||
crossover_origin_case3()
|
||||
Reference in New Issue
Block a user