add gene type RNN

This commit is contained in:
wls2002
2023-07-19 15:43:49 +08:00
parent 0a2a9fd1be
commit a684e6584d
18 changed files with 248 additions and 129 deletions

View File

@@ -1,13 +1,14 @@
[basic]
num_inputs = 2
num_outputs = 1
maximum_nodes = 100
maximum_connections = 100
maximum_species = 100
maximum_nodes = 50
maximum_conns = 100
maximum_species = 10
forward_way = "pop"
batch_size = 4
random_seed = 0
network_type = 'feedforward'
network_type = "feedforward"
activate_times = 10
[population]
fitness_threshold = 3.9999
@@ -18,11 +19,11 @@ pop_size = 1000
[genome]
compatibility_disjoint = 1.0
compatibility_weight = 0.5
conn_add_prob = 0.5
conn_add_prob = 0.4
conn_add_trials = 1
conn_delete_prob = 0.5
conn_delete_prob = 0
node_add_prob = 0.2
node_delete_prob = 0.2
node_delete_prob = 0
[species]
compatibility_threshold = 3.0

View File

@@ -1,3 +1,3 @@
from .neat import NEAT
from .gene import NormalGene
from .gene import NormalGene, RecurrentGene
from .pipeline import Pipeline

View File

@@ -2,4 +2,5 @@ from .base import BaseGene
from .normal import NormalGene
from .activation import Activation
from .aggregation import Aggregation
from .recurrent import RecurrentGene

View File

@@ -86,10 +86,12 @@ class NormalGene(BaseGene):
@staticmethod
def forward_transform(nodes, conns):
u_conns = unflatten_connections(nodes, conns)
u_conns = jnp.where(jnp.isnan(u_conns[0, :]), jnp.nan, u_conns) # enable is false, then the connections is nan
u_conns = u_conns[1:, :] # remove enable attr
conn_exist = jnp.any(~jnp.isnan(u_conns), axis=0)
seqs = topological_sort(nodes, conn_exist)
conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False)
# remove enable attr
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
seqs = topological_sort(nodes, conn_enable)
return seqs, nodes, u_conns
@staticmethod
@@ -167,18 +169,8 @@ class NormalGene(BaseGene):
# the val of input nodes is obtained by the task, not by calculation
values = jax.lax.cond(jnp.isin(i, input_idx), miss, hit)
# if jnp.isin(i, input_idx):
# values = miss()
# else:
# values = hit()
return values, idx + 1
# carry = (ini_vals, 0)
# while cond_fun(carry):
# carry = body_func(carry)
# vals, _ = carry
vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0))
return vals[output_idx]
@@ -216,7 +208,3 @@ class NormalGene(BaseGene):
)
return val

View File

@@ -0,0 +1,90 @@
import jax
from jax import Array, numpy as jnp, vmap
from .normal import NormalGene
from .activation import Activation
from .aggregation import Aggregation
from ..utils import unflatten_connections, I_INT
class RecurrentGene(NormalGene):
@staticmethod
def forward_transform(nodes, conns):
u_conns = unflatten_connections(nodes, conns)
# remove un-enable connections and remove enable attr
conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False)
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
return nodes, u_conns
@staticmethod
def create_forward(config):
config['activation_funcs'] = [Activation.name2func[name] for name in config['activation_option_names']]
config['aggregation_funcs'] = [Aggregation.name2func[name] for name in config['aggregation_option_names']]
def act(idx, z):
"""
calculate activation function for each node
"""
idx = jnp.asarray(idx, dtype=jnp.int32)
# change idx from float to int
res = jax.lax.switch(idx, config['activation_funcs'], z)
return res
def agg(idx, z):
"""
calculate activation function for inputs of node
"""
idx = jnp.asarray(idx, dtype=jnp.int32)
def all_nan():
return 0.
def not_all_nan():
return jax.lax.switch(idx, config['aggregation_funcs'], z)
return jax.lax.cond(jnp.all(jnp.isnan(z)), all_nan, not_all_nan)
batch_act, batch_agg = vmap(act), vmap(agg)
def forward(inputs, transform) -> Array:
"""
jax forward for single input shaped (input_num, )
nodes, connections are a single genome
:argument inputs: (input_num, )
:argument cal_seqs: (N, )
:argument nodes: (N, 5)
:argument connections: (2, N, N)
:return (output_num, )
"""
nodes, cons = transform
input_idx = config['input_idx']
output_idx = config['output_idx']
N = nodes.shape[0]
vals = jnp.full((N,), 0.)
weights = cons[0, :]
def body_func(i, values):
values = values.at[input_idx].set(inputs)
nodes_ins = values * weights.T
values = batch_agg(nodes[:, 4], nodes_ins) # z = agg(ins)
values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias
values = batch_act(nodes[:, 3], values) # z = act(z)
return values
# for i in range(config['activate_times']):
# vals = body_func(i, vals)
#
# return vals[output_idx]
vals = jax.lax.fori_loop(0, config['activate_times'], body_func, vals)
return vals[output_idx]
return forward

View File

@@ -11,7 +11,7 @@ from ..utils import fetch_first
def initialize_genomes(state: State, gene_type: Type[BaseGene]):
o_nodes = np.full((state.N, state.NL), np.nan, dtype=np.float32) # original nodes
o_conns = np.full((state.N, state.CL), np.nan, dtype=np.float32) # original connections
o_conns = np.full((state.C, state.CL), np.nan, dtype=np.float32) # original connections
input_idx = state.input_idx
output_idx = state.output_idx

View File

@@ -9,6 +9,7 @@ from jax import jit, Array, numpy as jnp
from ..utils import fetch_first, I_INT
@jit
def topological_sort(nodes: Array, conns: Array) -> Array:
"""
a jit-able version of topological_sort! that's crazy!
@@ -60,21 +61,11 @@ def topological_sort(nodes: Array, conns: Array) -> Array:
return res
def check_cycles(nodes: Array, connections: Array, from_idx: Array, to_idx: Array) -> Array:
@jit
def check_cycles(nodes: Array, conns: Array, from_idx, to_idx) -> Array:
"""
Check whether a new connection (from_idx -> to_idx) will cause a cycle.
:param nodes: JAX array
The array of nodes.
:param connections: JAX array
The array of connections.
:param from_idx: int
The index of the starting node.
:param to_idx: int
The index of the ending node.
:return: JAX array
An array indicating if there is a cycle caused by the new connection.
Example:
nodes = jnp.array([
[0],
@@ -83,28 +74,21 @@ def check_cycles(nodes: Array, connections: Array, from_idx: Array, to_idx: Arra
[3]
])
connections = jnp.array([
[
[0, 0, 1, 0],
[0, 0, 1, 1],
[0, 0, 0, 1],
[0, 0, 0, 0]
],
[
[0, 0, 1, 0],
[0, 0, 1, 1],
[0, 0, 0, 1],
[0, 0, 0, 0]
]
])
check_cycles(nodes, connections, 3, 2) -> True
check_cycles(nodes, connections, 2, 3) -> False
check_cycles(nodes, connections, 0, 3) -> False
check_cycles(nodes, connections, 1, 0) -> False
check_cycles(nodes, conns, 3, 2) -> True
check_cycles(nodes, conns, 2, 3) -> False
check_cycles(nodes, conns, 0, 3) -> False
check_cycles(nodes, conns, 1, 0) -> False
"""
connections_enable = ~jnp.isnan(connections[0, :, :])
connections_enable = connections_enable.at[from_idx, to_idx].set(True)
conns = conns.at[from_idx, to_idx].set(True)
# conns_enable = ~jnp.isnan(conns[0, :, :])
# conns_enable = conns_enable.at[from_idx, to_idx].set(True)
visited = jnp.full(nodes.shape[0], False)
new_visited = visited.at[to_idx].set(True)
@@ -117,43 +101,42 @@ def check_cycles(nodes: Array, connections: Array, from_idx: Array, to_idx: Arra
def body_func(carry):
_, visited_ = carry
new_visited_ = jnp.dot(visited_, connections_enable)
new_visited_ = jnp.dot(visited_, conns)
new_visited_ = jnp.logical_or(visited_, new_visited_)
return visited_, new_visited_
_, visited = jax.lax.while_loop(cond_func, body_func, (visited, new_visited))
return visited[from_idx]
if __name__ == '__main__':
nodes = jnp.array([
[0],
[1],
[2],
[3],
[jnp.nan]
])
connections = jnp.array([
[
[jnp.nan, jnp.nan, 1, jnp.nan, jnp.nan],
[jnp.nan, jnp.nan, 1, 1, jnp.nan],
[jnp.nan, jnp.nan, jnp.nan, 1, jnp.nan],
[jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan],
[jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan]
],
[
[jnp.nan, jnp.nan, 1, jnp.nan, jnp.nan],
[jnp.nan, jnp.nan, 1, 1, jnp.nan],
[jnp.nan, jnp.nan, jnp.nan, 1, jnp.nan],
[jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan],
[jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan]
]
]
)
print(topological_sort(nodes, connections))
print(check_cycles(nodes, connections, 3, 2))
print(check_cycles(nodes, connections, 2, 3))
print(check_cycles(nodes, connections, 0, 3))
print(check_cycles(nodes, connections, 1, 0))
# if __name__ == '__main__':
# nodes = jnp.array([
# [0],
# [1],
# [2],
# [3],
# [jnp.nan]
# ])
# connections = jnp.array([
# [
# [jnp.nan, jnp.nan, 1, jnp.nan, jnp.nan],
# [jnp.nan, jnp.nan, 1, 1, jnp.nan],
# [jnp.nan, jnp.nan, jnp.nan, 1, jnp.nan],
# [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan],
# [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan]
# ],
# [
# [jnp.nan, jnp.nan, 1, jnp.nan, jnp.nan],
# [jnp.nan, jnp.nan, 1, 1, jnp.nan],
# [jnp.nan, jnp.nan, jnp.nan, 1, jnp.nan],
# [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan],
# [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan]
# ]
# ]
# )
#
# print(topological_sort(nodes, connections))
#
# print(check_cycles(nodes, connections, 3, 2))
# print(check_cycles(nodes, connections, 2, 3))
# print(check_cycles(nodes, connections, 0, 3))
# print(check_cycles(nodes, connections, 1, 0))

View File

@@ -91,7 +91,8 @@ def create_mutate(config: Dict, gene_type: Type[BaseGene]):
if config['network_type'] == 'feedforward':
u_cons = unflatten_connections(nodes_, conns_)
is_cycle = check_cycles(nodes_, u_cons, from_idx, to_idx)
cons_exist = jnp.where(~jnp.isnan(u_cons[0, :, :]), True, False)
is_cycle = check_cycles(nodes_, cons_exist, from_idx, to_idx)
choice = jnp.where(is_already_exist, 0, jnp.where(is_cycle, 1, 2))
return jax.lax.switch(choice, [already_exist, nothing, successful])

View File

@@ -26,7 +26,7 @@ class NEAT:
state = State(
P=self.config['pop_size'],
N=self.config['maximum_nodes'],
C=self.config['maximum_connections'],
C=self.config['maximum_conns'],
S=self.config['maximum_species'],
NL=1 + len(self.gene_type.node_attrs), # node length = (key) + attributes
CL=3 + len(self.gene_type.conn_attrs), # conn length = (in, out, key) + attributes
@@ -64,11 +64,15 @@ class NEAT:
idx2species=idx2species,
center_nodes=center_nodes,
center_conns=center_conns,
generation=generation,
next_node_key=next_node_key,
next_species_key=next_species_key
# avoid jax auto cast from int to float. that would cause re-compilation.
generation=jnp.asarray(generation, dtype=jnp.int32),
next_node_key=jnp.asarray(next_node_key, dtype=jnp.float32),
next_species_key=jnp.asarray(next_species_key)
)
# move to device
state = jax.device_put(state)
return state
def step(self, state, fitness):

View File

@@ -34,9 +34,6 @@ class Pipeline:
def tell(self, fitness):
self.state = self.algorithm.step(self.state, fitness)
from algorithm.neat.genome.basic import count
# print([count(self.state.pop_nodes[i], self.state.pop_conns[i]) for i in range(self.state.P)])
def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"):
for _ in range(self.config['generation_limit']):

View File

@@ -7,8 +7,8 @@ from .utils import rank_elements, fetch_first
from .genome import create_mutate, create_distance, crossover
from .gene import BaseGene
def create_tell(config, gene_type: Type[BaseGene]):
def create_tell(config, gene_type: Type[BaseGene]):
mutate = create_mutate(config, gene_type)
distance = create_distance(config, gene_type)
@@ -36,7 +36,6 @@ def create_tell(config, gene_type: Type[BaseGene]):
return state, winner, loser, elite_mask
def update_species_fitness(state, fitness):
"""
obtain the fitness of the species by the fitness of each individual.
@@ -51,7 +50,6 @@ def create_tell(config, gene_type: Type[BaseGene]):
return vmap(aux_func)(jnp.arange(state.species_info.shape[0]))
def stagnation(state, species_fitness):
"""
stagnation species.
@@ -88,7 +86,6 @@ def create_tell(config, gene_type: Type[BaseGene]):
return state, species_fitness
def cal_spawn_numbers(state):
"""
decide the number of members of each species by their fitness rank.
@@ -106,7 +103,6 @@ def create_tell(config, gene_type: Type[BaseGene]):
spawn_number_rate = jnp.where(is_species_valid, spawn_number_rate, 0) # set invalid species to 0
target_spawn_number = jnp.floor(spawn_number_rate * state.P) # calculate member
# jax.debug.print("denominator: {}, spawn_number_rate: {}, target_spawn_number: {}", denominator, spawn_number_rate, target_spawn_number)
# Avoid too much variation of numbers in a species
previous_size = state.species_info[:, 3].astype(jnp.int32)
@@ -118,11 +114,11 @@ def create_tell(config, gene_type: Type[BaseGene]):
# must control the sum of spawn_number to be equal to pop_size
error = state.P - jnp.sum(spawn_number)
spawn_number = spawn_number.at[0].add(error) # add error to the first species to control the sum of spawn_number
spawn_number = spawn_number.at[0].add(
error) # add error to the first species to control the sum of spawn_number
return spawn_number
def create_crossover_pair(state, randkey, spawn_number, fitness):
species_size = state.species_info.shape[0]
pop_size = fitness.shape[0]
@@ -238,8 +234,8 @@ def create_tell(config, gene_type: Type[BaseGene]):
return i + 1, i2s, cn, cc, o2c
_, idx2specie, center_nodes, center_conns, o2c_distances = \
jax.lax.while_loop(cond_func, body_func, (0, idx2specie, state.center_nodes, state.center_conns, o2c_distances))
jax.lax.while_loop(cond_func, body_func,
(0, idx2specie, state.center_nodes, state.center_conns, o2c_distances))
# part 2: assign members to each species
def cond_func(carry):
@@ -331,7 +327,7 @@ def create_tell(config, gene_type: Type[BaseGene]):
species_info = species_info.at[:, 3].set(species_member_counts)
return state.update(
idx2specie=idx2specie,
idx2species=idx2specie,
center_nodes=center_nodes,
center_conns=center_conns,
species_info=species_info,
@@ -358,11 +354,10 @@ def create_tell(config, gene_type: Type[BaseGene]):
return state
return tell
def argmin_with_mask(arr, mask):
masked_arr = jnp.where(mask, arr, jnp.inf)
min_idx = jnp.argmin(masked_arr)
return min_idx
return min_idx

View File

@@ -1,4 +0,0 @@
from algorithm.config import Configer
config = Configer.load_config()
print(config)

View File

@@ -0,0 +1,13 @@
import numpy as np
vals = np.array([1, 2])
weights = np.array([[0, 4], [5, 0]])
ins1 = vals * weights[:, 0]
ins2 = vals * weights[:, 1]
ins_all = vals * weights.T
print(ins1)
print(ins2)
print(ins_all)

View File

@@ -1,5 +1,7 @@
[basic]
forward_way = "common"
network_type = "recurrent"
activate_times = 5
[population]
fitness_threshold = 4

View File

@@ -2,7 +2,7 @@ import jax
import numpy as np
from algorithm import Configer, NEAT
from algorithm.neat import NormalGene, Pipeline
from algorithm.neat import NormalGene, RecurrentGene, Pipeline
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
@@ -15,16 +15,17 @@ def evaluate(forward_func):
"""
outs = forward_func(xor_inputs)
outs = jax.device_get(outs)
# print(outs)
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
return fitnesses
def main():
config = Configer.load_config("xor.ini")
algorithm = NEAT(config, NormalGene)
# algorithm = NEAT(config, NormalGene)
algorithm = NEAT(config, RecurrentGene)
pipeline = Pipeline(config, algorithm)
pipeline.auto_run(evaluate)
best = pipeline.auto_run(evaluate)
print(best)
if __name__ == '__main__':

View File

@@ -2,31 +2,49 @@ import jax
import numpy as np
from algorithm.config import Configer
from algorithm.neat import NEAT, NormalGene, Pipeline
from algorithm.neat import NEAT, NormalGene, RecurrentGene, Pipeline
from algorithm.neat.genome import create_mutate
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
def single_genome(func, nodes, conns):
t = NormalGene.forward_transform(nodes, conns)
t = RecurrentGene.forward_transform(nodes, conns)
out1 = func(xor_inputs[0], t)
out2 = func(xor_inputs[1], t)
out3 = func(xor_inputs[2], t)
out4 = func(xor_inputs[3], t)
print(out1, out2, out3, out4)
def batch_genome(func, nodes, conns):
t = NormalGene.forward_transform(nodes, conns)
out = jax.vmap(func, in_axes=(0, None))(xor_inputs, t)
print(out)
def pop_batch_genome(func, pop_nodes, pop_conns):
t = jax.vmap(NormalGene.forward_transform)(pop_nodes, pop_conns)
func = jax.vmap(jax.vmap(func, in_axes=(0, None)), in_axes=(None, 0))
out = func(xor_inputs, t)
print(out)
if __name__ == '__main__':
config = Configer.load_config()
neat = NEAT(config, NormalGene)
config = Configer.load_config("xor.ini")
# neat = NEAT(config, NormalGene)
neat = NEAT(config, RecurrentGene)
randkey = jax.random.PRNGKey(42)
state = neat.setup(randkey)
forward_func = NormalGene.create_forward(config)
mutate_func = create_mutate(config, NormalGene)
forward_func = RecurrentGene.create_forward(config)
mutate_func = create_mutate(config, RecurrentGene)
nodes, conns = state.pop_nodes[0], state.pop_conns[0]
single_genome(forward_func, nodes, conns)
# batch_genome(forward_func, nodes, conns)
nodes, conns = mutate_func(state, randkey, nodes, conns, 10000)
single_genome(forward_func, nodes, conns)
# batch_genome(forward_func, nodes, conns)
#

32
test/unit/test_graphs.py Normal file
View File

@@ -0,0 +1,32 @@
import jax.numpy as jnp
from algorithm.neat.genome.graph import topological_sort, check_cycles
from algorithm.neat.utils import I_INT
nodes = jnp.array([
[0],
[1],
[2],
[3],
[jnp.nan]
])
# {(0, 2), (1, 2), (1, 3), (2, 3)}
conns = jnp.array([
[0, 0, 1, 0, 0],
[0, 0, 1, 1, 0],
[0, 0, 0, 1, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]
])
def test_topological_sort():
assert jnp.all(topological_sort(nodes, conns) == jnp.array([0, 1, 2, 3, I_INT]))
def test_check_cycles():
assert check_cycles(nodes, conns, 3, 2)
assert ~check_cycles(nodes, conns, 2, 3)
assert ~check_cycles(nodes, conns, 0, 3)
assert ~check_cycles(nodes, conns, 1, 0)

View File

@@ -1,7 +1,5 @@
import pytest
import jax
from algorithm.neat.utils import *
import jax.numpy as jnp
from algorithm.neat.utils import unflatten_connections
def test_unflatten():
@@ -13,7 +11,6 @@ def test_unflatten():
[jnp.nan, jnp.nan, jnp.nan, jnp.nan]
])
conns = jnp.array([
[0, 1, True, 0.1, 0.11],
[0, 2, False, 0.2, 0.22],
@@ -33,4 +30,4 @@ def test_unflatten():
mask = mask.at[:, [0, 0, 1, 1], [1, 2, 2, 3]].set(False)
# Ensure all other places are jnp.nan
assert jnp.all(jnp.isnan(res[mask]))
assert jnp.all(jnp.isnan(res[mask]))