add gene type RNN

This commit is contained in:
wls2002
2023-07-19 15:43:49 +08:00
parent 0a2a9fd1be
commit a684e6584d
18 changed files with 248 additions and 129 deletions

View File

@@ -1,13 +1,14 @@
[basic] [basic]
num_inputs = 2 num_inputs = 2
num_outputs = 1 num_outputs = 1
maximum_nodes = 100 maximum_nodes = 50
maximum_connections = 100 maximum_conns = 100
maximum_species = 100 maximum_species = 10
forward_way = "pop" forward_way = "pop"
batch_size = 4 batch_size = 4
random_seed = 0 random_seed = 0
network_type = 'feedforward' network_type = "feedforward"
activate_times = 10
[population] [population]
fitness_threshold = 3.9999 fitness_threshold = 3.9999
@@ -18,11 +19,11 @@ pop_size = 1000
[genome] [genome]
compatibility_disjoint = 1.0 compatibility_disjoint = 1.0
compatibility_weight = 0.5 compatibility_weight = 0.5
conn_add_prob = 0.5 conn_add_prob = 0.4
conn_add_trials = 1 conn_add_trials = 1
conn_delete_prob = 0.5 conn_delete_prob = 0
node_add_prob = 0.2 node_add_prob = 0.2
node_delete_prob = 0.2 node_delete_prob = 0
[species] [species]
compatibility_threshold = 3.0 compatibility_threshold = 3.0

View File

@@ -1,3 +1,3 @@
from .neat import NEAT from .neat import NEAT
from .gene import NormalGene from .gene import NormalGene, RecurrentGene
from .pipeline import Pipeline from .pipeline import Pipeline

View File

@@ -2,4 +2,5 @@ from .base import BaseGene
from .normal import NormalGene from .normal import NormalGene
from .activation import Activation from .activation import Activation
from .aggregation import Aggregation from .aggregation import Aggregation
from .recurrent import RecurrentGene

View File

@@ -86,10 +86,12 @@ class NormalGene(BaseGene):
@staticmethod @staticmethod
def forward_transform(nodes, conns): def forward_transform(nodes, conns):
u_conns = unflatten_connections(nodes, conns) u_conns = unflatten_connections(nodes, conns)
u_conns = jnp.where(jnp.isnan(u_conns[0, :]), jnp.nan, u_conns) # enable is false, then the connections is nan conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False)
u_conns = u_conns[1:, :] # remove enable attr
conn_exist = jnp.any(~jnp.isnan(u_conns), axis=0) # remove enable attr
seqs = topological_sort(nodes, conn_exist) u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
seqs = topological_sort(nodes, conn_enable)
return seqs, nodes, u_conns return seqs, nodes, u_conns
@staticmethod @staticmethod
@@ -167,18 +169,8 @@ class NormalGene(BaseGene):
# the val of input nodes is obtained by the task, not by calculation # the val of input nodes is obtained by the task, not by calculation
values = jax.lax.cond(jnp.isin(i, input_idx), miss, hit) values = jax.lax.cond(jnp.isin(i, input_idx), miss, hit)
# if jnp.isin(i, input_idx):
# values = miss()
# else:
# values = hit()
return values, idx + 1 return values, idx + 1
# carry = (ini_vals, 0)
# while cond_fun(carry):
# carry = body_func(carry)
# vals, _ = carry
vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0)) vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0))
return vals[output_idx] return vals[output_idx]
@@ -216,7 +208,3 @@ class NormalGene(BaseGene):
) )
return val return val

View File

@@ -0,0 +1,90 @@
import jax
from jax import Array, numpy as jnp, vmap
from .normal import NormalGene
from .activation import Activation
from .aggregation import Aggregation
from ..utils import unflatten_connections, I_INT
class RecurrentGene(NormalGene):
@staticmethod
def forward_transform(nodes, conns):
u_conns = unflatten_connections(nodes, conns)
# remove un-enable connections and remove enable attr
conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False)
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
return nodes, u_conns
@staticmethod
def create_forward(config):
config['activation_funcs'] = [Activation.name2func[name] for name in config['activation_option_names']]
config['aggregation_funcs'] = [Aggregation.name2func[name] for name in config['aggregation_option_names']]
def act(idx, z):
"""
calculate activation function for each node
"""
idx = jnp.asarray(idx, dtype=jnp.int32)
# change idx from float to int
res = jax.lax.switch(idx, config['activation_funcs'], z)
return res
def agg(idx, z):
"""
calculate activation function for inputs of node
"""
idx = jnp.asarray(idx, dtype=jnp.int32)
def all_nan():
return 0.
def not_all_nan():
return jax.lax.switch(idx, config['aggregation_funcs'], z)
return jax.lax.cond(jnp.all(jnp.isnan(z)), all_nan, not_all_nan)
batch_act, batch_agg = vmap(act), vmap(agg)
def forward(inputs, transform) -> Array:
"""
jax forward for single input shaped (input_num, )
nodes, connections are a single genome
:argument inputs: (input_num, )
:argument cal_seqs: (N, )
:argument nodes: (N, 5)
:argument connections: (2, N, N)
:return (output_num, )
"""
nodes, cons = transform
input_idx = config['input_idx']
output_idx = config['output_idx']
N = nodes.shape[0]
vals = jnp.full((N,), 0.)
weights = cons[0, :]
def body_func(i, values):
values = values.at[input_idx].set(inputs)
nodes_ins = values * weights.T
values = batch_agg(nodes[:, 4], nodes_ins) # z = agg(ins)
values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias
values = batch_act(nodes[:, 3], values) # z = act(z)
return values
# for i in range(config['activate_times']):
# vals = body_func(i, vals)
#
# return vals[output_idx]
vals = jax.lax.fori_loop(0, config['activate_times'], body_func, vals)
return vals[output_idx]
return forward

View File

@@ -11,7 +11,7 @@ from ..utils import fetch_first
def initialize_genomes(state: State, gene_type: Type[BaseGene]): def initialize_genomes(state: State, gene_type: Type[BaseGene]):
o_nodes = np.full((state.N, state.NL), np.nan, dtype=np.float32) # original nodes o_nodes = np.full((state.N, state.NL), np.nan, dtype=np.float32) # original nodes
o_conns = np.full((state.N, state.CL), np.nan, dtype=np.float32) # original connections o_conns = np.full((state.C, state.CL), np.nan, dtype=np.float32) # original connections
input_idx = state.input_idx input_idx = state.input_idx
output_idx = state.output_idx output_idx = state.output_idx

View File

@@ -9,6 +9,7 @@ from jax import jit, Array, numpy as jnp
from ..utils import fetch_first, I_INT from ..utils import fetch_first, I_INT
@jit
def topological_sort(nodes: Array, conns: Array) -> Array: def topological_sort(nodes: Array, conns: Array) -> Array:
""" """
a jit-able version of topological_sort! that's crazy! a jit-able version of topological_sort! that's crazy!
@@ -60,21 +61,11 @@ def topological_sort(nodes: Array, conns: Array) -> Array:
return res return res
def check_cycles(nodes: Array, connections: Array, from_idx: Array, to_idx: Array) -> Array: @jit
def check_cycles(nodes: Array, conns: Array, from_idx, to_idx) -> Array:
""" """
Check whether a new connection (from_idx -> to_idx) will cause a cycle. Check whether a new connection (from_idx -> to_idx) will cause a cycle.
:param nodes: JAX array
The array of nodes.
:param connections: JAX array
The array of connections.
:param from_idx: int
The index of the starting node.
:param to_idx: int
The index of the ending node.
:return: JAX array
An array indicating if there is a cycle caused by the new connection.
Example: Example:
nodes = jnp.array([ nodes = jnp.array([
[0], [0],
@@ -83,28 +74,21 @@ def check_cycles(nodes: Array, connections: Array, from_idx: Array, to_idx: Arra
[3] [3]
]) ])
connections = jnp.array([ connections = jnp.array([
[
[0, 0, 1, 0], [0, 0, 1, 0],
[0, 0, 1, 1], [0, 0, 1, 1],
[0, 0, 0, 1], [0, 0, 0, 1],
[0, 0, 0, 0] [0, 0, 0, 0]
],
[
[0, 0, 1, 0],
[0, 0, 1, 1],
[0, 0, 0, 1],
[0, 0, 0, 0]
]
]) ])
check_cycles(nodes, connections, 3, 2) -> True check_cycles(nodes, conns, 3, 2) -> True
check_cycles(nodes, connections, 2, 3) -> False check_cycles(nodes, conns, 2, 3) -> False
check_cycles(nodes, connections, 0, 3) -> False check_cycles(nodes, conns, 0, 3) -> False
check_cycles(nodes, connections, 1, 0) -> False check_cycles(nodes, conns, 1, 0) -> False
""" """
connections_enable = ~jnp.isnan(connections[0, :, :]) conns = conns.at[from_idx, to_idx].set(True)
connections_enable = connections_enable.at[from_idx, to_idx].set(True) # conns_enable = ~jnp.isnan(conns[0, :, :])
# conns_enable = conns_enable.at[from_idx, to_idx].set(True)
visited = jnp.full(nodes.shape[0], False) visited = jnp.full(nodes.shape[0], False)
new_visited = visited.at[to_idx].set(True) new_visited = visited.at[to_idx].set(True)
@@ -117,43 +101,42 @@ def check_cycles(nodes: Array, connections: Array, from_idx: Array, to_idx: Arra
def body_func(carry): def body_func(carry):
_, visited_ = carry _, visited_ = carry
new_visited_ = jnp.dot(visited_, connections_enable) new_visited_ = jnp.dot(visited_, conns)
new_visited_ = jnp.logical_or(visited_, new_visited_) new_visited_ = jnp.logical_or(visited_, new_visited_)
return visited_, new_visited_ return visited_, new_visited_
_, visited = jax.lax.while_loop(cond_func, body_func, (visited, new_visited)) _, visited = jax.lax.while_loop(cond_func, body_func, (visited, new_visited))
return visited[from_idx] return visited[from_idx]
# if __name__ == '__main__':
if __name__ == '__main__': # nodes = jnp.array([
nodes = jnp.array([ # [0],
[0], # [1],
[1], # [2],
[2], # [3],
[3], # [jnp.nan]
[jnp.nan] # ])
]) # connections = jnp.array([
connections = jnp.array([ # [
[ # [jnp.nan, jnp.nan, 1, jnp.nan, jnp.nan],
[jnp.nan, jnp.nan, 1, jnp.nan, jnp.nan], # [jnp.nan, jnp.nan, 1, 1, jnp.nan],
[jnp.nan, jnp.nan, 1, 1, jnp.nan], # [jnp.nan, jnp.nan, jnp.nan, 1, jnp.nan],
[jnp.nan, jnp.nan, jnp.nan, 1, jnp.nan], # [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan],
[jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan], # [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan]
[jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan] # ],
], # [
[ # [jnp.nan, jnp.nan, 1, jnp.nan, jnp.nan],
[jnp.nan, jnp.nan, 1, jnp.nan, jnp.nan], # [jnp.nan, jnp.nan, 1, 1, jnp.nan],
[jnp.nan, jnp.nan, 1, 1, jnp.nan], # [jnp.nan, jnp.nan, jnp.nan, 1, jnp.nan],
[jnp.nan, jnp.nan, jnp.nan, 1, jnp.nan], # [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan],
[jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan], # [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan]
[jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan] # ]
] # ]
] # )
) #
# print(topological_sort(nodes, connections))
print(topological_sort(nodes, connections)) #
# print(check_cycles(nodes, connections, 3, 2))
print(check_cycles(nodes, connections, 3, 2)) # print(check_cycles(nodes, connections, 2, 3))
print(check_cycles(nodes, connections, 2, 3)) # print(check_cycles(nodes, connections, 0, 3))
print(check_cycles(nodes, connections, 0, 3)) # print(check_cycles(nodes, connections, 1, 0))
print(check_cycles(nodes, connections, 1, 0))

View File

@@ -91,7 +91,8 @@ def create_mutate(config: Dict, gene_type: Type[BaseGene]):
if config['network_type'] == 'feedforward': if config['network_type'] == 'feedforward':
u_cons = unflatten_connections(nodes_, conns_) u_cons = unflatten_connections(nodes_, conns_)
is_cycle = check_cycles(nodes_, u_cons, from_idx, to_idx) cons_exist = jnp.where(~jnp.isnan(u_cons[0, :, :]), True, False)
is_cycle = check_cycles(nodes_, cons_exist, from_idx, to_idx)
choice = jnp.where(is_already_exist, 0, jnp.where(is_cycle, 1, 2)) choice = jnp.where(is_already_exist, 0, jnp.where(is_cycle, 1, 2))
return jax.lax.switch(choice, [already_exist, nothing, successful]) return jax.lax.switch(choice, [already_exist, nothing, successful])

View File

@@ -26,7 +26,7 @@ class NEAT:
state = State( state = State(
P=self.config['pop_size'], P=self.config['pop_size'],
N=self.config['maximum_nodes'], N=self.config['maximum_nodes'],
C=self.config['maximum_connections'], C=self.config['maximum_conns'],
S=self.config['maximum_species'], S=self.config['maximum_species'],
NL=1 + len(self.gene_type.node_attrs), # node length = (key) + attributes NL=1 + len(self.gene_type.node_attrs), # node length = (key) + attributes
CL=3 + len(self.gene_type.conn_attrs), # conn length = (in, out, key) + attributes CL=3 + len(self.gene_type.conn_attrs), # conn length = (in, out, key) + attributes
@@ -64,11 +64,15 @@ class NEAT:
idx2species=idx2species, idx2species=idx2species,
center_nodes=center_nodes, center_nodes=center_nodes,
center_conns=center_conns, center_conns=center_conns,
generation=generation, # avoid jax auto cast from int to float. that would cause re-compilation.
next_node_key=next_node_key, generation=jnp.asarray(generation, dtype=jnp.int32),
next_species_key=next_species_key next_node_key=jnp.asarray(next_node_key, dtype=jnp.float32),
next_species_key=jnp.asarray(next_species_key)
) )
# move to device
state = jax.device_put(state)
return state return state
def step(self, state, fitness): def step(self, state, fitness):

View File

@@ -34,9 +34,6 @@ class Pipeline:
def tell(self, fitness): def tell(self, fitness):
self.state = self.algorithm.step(self.state, fitness) self.state = self.algorithm.step(self.state, fitness)
from algorithm.neat.genome.basic import count
# print([count(self.state.pop_nodes[i], self.state.pop_conns[i]) for i in range(self.state.P)])
def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"): def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"):
for _ in range(self.config['generation_limit']): for _ in range(self.config['generation_limit']):

View File

@@ -7,8 +7,8 @@ from .utils import rank_elements, fetch_first
from .genome import create_mutate, create_distance, crossover from .genome import create_mutate, create_distance, crossover
from .gene import BaseGene from .gene import BaseGene
def create_tell(config, gene_type: Type[BaseGene]):
def create_tell(config, gene_type: Type[BaseGene]):
mutate = create_mutate(config, gene_type) mutate = create_mutate(config, gene_type)
distance = create_distance(config, gene_type) distance = create_distance(config, gene_type)
@@ -36,7 +36,6 @@ def create_tell(config, gene_type: Type[BaseGene]):
return state, winner, loser, elite_mask return state, winner, loser, elite_mask
def update_species_fitness(state, fitness): def update_species_fitness(state, fitness):
""" """
obtain the fitness of the species by the fitness of each individual. obtain the fitness of the species by the fitness of each individual.
@@ -51,7 +50,6 @@ def create_tell(config, gene_type: Type[BaseGene]):
return vmap(aux_func)(jnp.arange(state.species_info.shape[0])) return vmap(aux_func)(jnp.arange(state.species_info.shape[0]))
def stagnation(state, species_fitness): def stagnation(state, species_fitness):
""" """
stagnation species. stagnation species.
@@ -88,7 +86,6 @@ def create_tell(config, gene_type: Type[BaseGene]):
return state, species_fitness return state, species_fitness
def cal_spawn_numbers(state): def cal_spawn_numbers(state):
""" """
decide the number of members of each species by their fitness rank. decide the number of members of each species by their fitness rank.
@@ -106,7 +103,6 @@ def create_tell(config, gene_type: Type[BaseGene]):
spawn_number_rate = jnp.where(is_species_valid, spawn_number_rate, 0) # set invalid species to 0 spawn_number_rate = jnp.where(is_species_valid, spawn_number_rate, 0) # set invalid species to 0
target_spawn_number = jnp.floor(spawn_number_rate * state.P) # calculate member target_spawn_number = jnp.floor(spawn_number_rate * state.P) # calculate member
# jax.debug.print("denominator: {}, spawn_number_rate: {}, target_spawn_number: {}", denominator, spawn_number_rate, target_spawn_number)
# Avoid too much variation of numbers in a species # Avoid too much variation of numbers in a species
previous_size = state.species_info[:, 3].astype(jnp.int32) previous_size = state.species_info[:, 3].astype(jnp.int32)
@@ -118,11 +114,11 @@ def create_tell(config, gene_type: Type[BaseGene]):
# must control the sum of spawn_number to be equal to pop_size # must control the sum of spawn_number to be equal to pop_size
error = state.P - jnp.sum(spawn_number) error = state.P - jnp.sum(spawn_number)
spawn_number = spawn_number.at[0].add(error) # add error to the first species to control the sum of spawn_number spawn_number = spawn_number.at[0].add(
error) # add error to the first species to control the sum of spawn_number
return spawn_number return spawn_number
def create_crossover_pair(state, randkey, spawn_number, fitness): def create_crossover_pair(state, randkey, spawn_number, fitness):
species_size = state.species_info.shape[0] species_size = state.species_info.shape[0]
pop_size = fitness.shape[0] pop_size = fitness.shape[0]
@@ -238,8 +234,8 @@ def create_tell(config, gene_type: Type[BaseGene]):
return i + 1, i2s, cn, cc, o2c return i + 1, i2s, cn, cc, o2c
_, idx2specie, center_nodes, center_conns, o2c_distances = \ _, idx2specie, center_nodes, center_conns, o2c_distances = \
jax.lax.while_loop(cond_func, body_func, (0, idx2specie, state.center_nodes, state.center_conns, o2c_distances)) jax.lax.while_loop(cond_func, body_func,
(0, idx2specie, state.center_nodes, state.center_conns, o2c_distances))
# part 2: assign members to each species # part 2: assign members to each species
def cond_func(carry): def cond_func(carry):
@@ -331,7 +327,7 @@ def create_tell(config, gene_type: Type[BaseGene]):
species_info = species_info.at[:, 3].set(species_member_counts) species_info = species_info.at[:, 3].set(species_member_counts)
return state.update( return state.update(
idx2specie=idx2specie, idx2species=idx2specie,
center_nodes=center_nodes, center_nodes=center_nodes,
center_conns=center_conns, center_conns=center_conns,
species_info=species_info, species_info=species_info,
@@ -358,11 +354,10 @@ def create_tell(config, gene_type: Type[BaseGene]):
return state return state
return tell return tell
def argmin_with_mask(arr, mask): def argmin_with_mask(arr, mask):
masked_arr = jnp.where(mask, arr, jnp.inf) masked_arr = jnp.where(mask, arr, jnp.inf)
min_idx = jnp.argmin(masked_arr) min_idx = jnp.argmin(masked_arr)
return min_idx return min_idx

View File

@@ -1,4 +0,0 @@
from algorithm.config import Configer
config = Configer.load_config()
print(config)

View File

@@ -0,0 +1,13 @@
import numpy as np
vals = np.array([1, 2])
weights = np.array([[0, 4], [5, 0]])
ins1 = vals * weights[:, 0]
ins2 = vals * weights[:, 1]
ins_all = vals * weights.T
print(ins1)
print(ins2)
print(ins_all)

View File

@@ -1,5 +1,7 @@
[basic] [basic]
forward_way = "common" forward_way = "common"
network_type = "recurrent"
activate_times = 5
[population] [population]
fitness_threshold = 4 fitness_threshold = 4

View File

@@ -2,7 +2,7 @@ import jax
import numpy as np import numpy as np
from algorithm import Configer, NEAT from algorithm import Configer, NEAT
from algorithm.neat import NormalGene, Pipeline from algorithm.neat import NormalGene, RecurrentGene, Pipeline
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32) xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32) xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
@@ -15,16 +15,17 @@ def evaluate(forward_func):
""" """
outs = forward_func(xor_inputs) outs = forward_func(xor_inputs)
outs = jax.device_get(outs) outs = jax.device_get(outs)
# print(outs)
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2)) fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
return fitnesses return fitnesses
def main(): def main():
config = Configer.load_config("xor.ini") config = Configer.load_config("xor.ini")
algorithm = NEAT(config, NormalGene) # algorithm = NEAT(config, NormalGene)
algorithm = NEAT(config, RecurrentGene)
pipeline = Pipeline(config, algorithm) pipeline = Pipeline(config, algorithm)
pipeline.auto_run(evaluate) best = pipeline.auto_run(evaluate)
print(best)
if __name__ == '__main__': if __name__ == '__main__':

View File

@@ -2,31 +2,49 @@ import jax
import numpy as np import numpy as np
from algorithm.config import Configer from algorithm.config import Configer
from algorithm.neat import NEAT, NormalGene, Pipeline from algorithm.neat import NEAT, NormalGene, RecurrentGene, Pipeline
from algorithm.neat.genome import create_mutate from algorithm.neat.genome import create_mutate
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32) xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
def single_genome(func, nodes, conns): def single_genome(func, nodes, conns):
t = NormalGene.forward_transform(nodes, conns) t = RecurrentGene.forward_transform(nodes, conns)
out1 = func(xor_inputs[0], t) out1 = func(xor_inputs[0], t)
out2 = func(xor_inputs[1], t) out2 = func(xor_inputs[1], t)
out3 = func(xor_inputs[2], t) out3 = func(xor_inputs[2], t)
out4 = func(xor_inputs[3], t) out4 = func(xor_inputs[3], t)
print(out1, out2, out3, out4) print(out1, out2, out3, out4)
def batch_genome(func, nodes, conns):
t = NormalGene.forward_transform(nodes, conns)
out = jax.vmap(func, in_axes=(0, None))(xor_inputs, t)
print(out)
def pop_batch_genome(func, pop_nodes, pop_conns):
t = jax.vmap(NormalGene.forward_transform)(pop_nodes, pop_conns)
func = jax.vmap(jax.vmap(func, in_axes=(0, None)), in_axes=(None, 0))
out = func(xor_inputs, t)
print(out)
if __name__ == '__main__': if __name__ == '__main__':
config = Configer.load_config() config = Configer.load_config("xor.ini")
neat = NEAT(config, NormalGene) # neat = NEAT(config, NormalGene)
neat = NEAT(config, RecurrentGene)
randkey = jax.random.PRNGKey(42) randkey = jax.random.PRNGKey(42)
state = neat.setup(randkey) state = neat.setup(randkey)
forward_func = NormalGene.create_forward(config) forward_func = RecurrentGene.create_forward(config)
mutate_func = create_mutate(config, NormalGene) mutate_func = create_mutate(config, RecurrentGene)
nodes, conns = state.pop_nodes[0], state.pop_conns[0] nodes, conns = state.pop_nodes[0], state.pop_conns[0]
single_genome(forward_func, nodes, conns) single_genome(forward_func, nodes, conns)
# batch_genome(forward_func, nodes, conns)
nodes, conns = mutate_func(state, randkey, nodes, conns, 10000) nodes, conns = mutate_func(state, randkey, nodes, conns, 10000)
single_genome(forward_func, nodes, conns) single_genome(forward_func, nodes, conns)
# batch_genome(forward_func, nodes, conns)
#

32
test/unit/test_graphs.py Normal file
View File

@@ -0,0 +1,32 @@
import jax.numpy as jnp
from algorithm.neat.genome.graph import topological_sort, check_cycles
from algorithm.neat.utils import I_INT
nodes = jnp.array([
[0],
[1],
[2],
[3],
[jnp.nan]
])
# {(0, 2), (1, 2), (1, 3), (2, 3)}
conns = jnp.array([
[0, 0, 1, 0, 0],
[0, 0, 1, 1, 0],
[0, 0, 0, 1, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]
])
def test_topological_sort():
assert jnp.all(topological_sort(nodes, conns) == jnp.array([0, 1, 2, 3, I_INT]))
def test_check_cycles():
assert check_cycles(nodes, conns, 3, 2)
assert ~check_cycles(nodes, conns, 2, 3)
assert ~check_cycles(nodes, conns, 0, 3)
assert ~check_cycles(nodes, conns, 1, 0)

View File

@@ -1,7 +1,5 @@
import pytest import jax.numpy as jnp
import jax from algorithm.neat.utils import unflatten_connections
from algorithm.neat.utils import *
def test_unflatten(): def test_unflatten():
@@ -13,7 +11,6 @@ def test_unflatten():
[jnp.nan, jnp.nan, jnp.nan, jnp.nan] [jnp.nan, jnp.nan, jnp.nan, jnp.nan]
]) ])
conns = jnp.array([ conns = jnp.array([
[0, 1, True, 0.1, 0.11], [0, 1, True, 0.1, 0.11],
[0, 2, False, 0.2, 0.22], [0, 2, False, 0.2, 0.22],
@@ -33,4 +30,4 @@ def test_unflatten():
mask = mask.at[:, [0, 0, 1, 1], [1, 2, 2, 3]].set(False) mask = mask.at[:, [0, 0, 1, 1], [1, 2, 2, 3]].set(False)
# Ensure all other places are jnp.nan # Ensure all other places are jnp.nan
assert jnp.all(jnp.isnan(res[mask])) assert jnp.all(jnp.isnan(res[mask]))