add gene type RNN
This commit is contained in:
@@ -1,13 +1,14 @@
|
||||
[basic]
|
||||
num_inputs = 2
|
||||
num_outputs = 1
|
||||
maximum_nodes = 100
|
||||
maximum_connections = 100
|
||||
maximum_species = 100
|
||||
maximum_nodes = 50
|
||||
maximum_conns = 100
|
||||
maximum_species = 10
|
||||
forward_way = "pop"
|
||||
batch_size = 4
|
||||
random_seed = 0
|
||||
network_type = 'feedforward'
|
||||
network_type = "feedforward"
|
||||
activate_times = 10
|
||||
|
||||
[population]
|
||||
fitness_threshold = 3.9999
|
||||
@@ -18,11 +19,11 @@ pop_size = 1000
|
||||
[genome]
|
||||
compatibility_disjoint = 1.0
|
||||
compatibility_weight = 0.5
|
||||
conn_add_prob = 0.5
|
||||
conn_add_prob = 0.4
|
||||
conn_add_trials = 1
|
||||
conn_delete_prob = 0.5
|
||||
conn_delete_prob = 0
|
||||
node_add_prob = 0.2
|
||||
node_delete_prob = 0.2
|
||||
node_delete_prob = 0
|
||||
|
||||
[species]
|
||||
compatibility_threshold = 3.0
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
from .neat import NEAT
|
||||
from .gene import NormalGene
|
||||
from .gene import NormalGene, RecurrentGene
|
||||
from .pipeline import Pipeline
|
||||
|
||||
@@ -2,4 +2,5 @@ from .base import BaseGene
|
||||
from .normal import NormalGene
|
||||
from .activation import Activation
|
||||
from .aggregation import Aggregation
|
||||
from .recurrent import RecurrentGene
|
||||
|
||||
|
||||
@@ -86,10 +86,12 @@ class NormalGene(BaseGene):
|
||||
@staticmethod
|
||||
def forward_transform(nodes, conns):
|
||||
u_conns = unflatten_connections(nodes, conns)
|
||||
u_conns = jnp.where(jnp.isnan(u_conns[0, :]), jnp.nan, u_conns) # enable is false, then the connections is nan
|
||||
u_conns = u_conns[1:, :] # remove enable attr
|
||||
conn_exist = jnp.any(~jnp.isnan(u_conns), axis=0)
|
||||
seqs = topological_sort(nodes, conn_exist)
|
||||
conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False)
|
||||
|
||||
# remove enable attr
|
||||
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
|
||||
seqs = topological_sort(nodes, conn_enable)
|
||||
|
||||
return seqs, nodes, u_conns
|
||||
|
||||
@staticmethod
|
||||
@@ -167,18 +169,8 @@ class NormalGene(BaseGene):
|
||||
# the val of input nodes is obtained by the task, not by calculation
|
||||
values = jax.lax.cond(jnp.isin(i, input_idx), miss, hit)
|
||||
|
||||
# if jnp.isin(i, input_idx):
|
||||
# values = miss()
|
||||
# else:
|
||||
# values = hit()
|
||||
|
||||
return values, idx + 1
|
||||
|
||||
# carry = (ini_vals, 0)
|
||||
# while cond_fun(carry):
|
||||
# carry = body_func(carry)
|
||||
# vals, _ = carry
|
||||
|
||||
vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0))
|
||||
|
||||
return vals[output_idx]
|
||||
@@ -216,7 +208,3 @@ class NormalGene(BaseGene):
|
||||
)
|
||||
|
||||
return val
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
90
algorithm/neat/gene/recurrent.py
Normal file
90
algorithm/neat/gene/recurrent.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import jax
|
||||
from jax import Array, numpy as jnp, vmap
|
||||
|
||||
from .normal import NormalGene
|
||||
from .activation import Activation
|
||||
from .aggregation import Aggregation
|
||||
from ..utils import unflatten_connections, I_INT
|
||||
|
||||
|
||||
class RecurrentGene(NormalGene):
|
||||
|
||||
@staticmethod
|
||||
def forward_transform(nodes, conns):
|
||||
u_conns = unflatten_connections(nodes, conns)
|
||||
|
||||
# remove un-enable connections and remove enable attr
|
||||
conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False)
|
||||
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
|
||||
|
||||
return nodes, u_conns
|
||||
|
||||
@staticmethod
|
||||
def create_forward(config):
|
||||
config['activation_funcs'] = [Activation.name2func[name] for name in config['activation_option_names']]
|
||||
config['aggregation_funcs'] = [Aggregation.name2func[name] for name in config['aggregation_option_names']]
|
||||
|
||||
def act(idx, z):
|
||||
"""
|
||||
calculate activation function for each node
|
||||
"""
|
||||
idx = jnp.asarray(idx, dtype=jnp.int32)
|
||||
# change idx from float to int
|
||||
res = jax.lax.switch(idx, config['activation_funcs'], z)
|
||||
return res
|
||||
|
||||
def agg(idx, z):
|
||||
"""
|
||||
calculate activation function for inputs of node
|
||||
"""
|
||||
idx = jnp.asarray(idx, dtype=jnp.int32)
|
||||
|
||||
def all_nan():
|
||||
return 0.
|
||||
|
||||
def not_all_nan():
|
||||
return jax.lax.switch(idx, config['aggregation_funcs'], z)
|
||||
|
||||
return jax.lax.cond(jnp.all(jnp.isnan(z)), all_nan, not_all_nan)
|
||||
|
||||
batch_act, batch_agg = vmap(act), vmap(agg)
|
||||
|
||||
def forward(inputs, transform) -> Array:
|
||||
"""
|
||||
jax forward for single input shaped (input_num, )
|
||||
nodes, connections are a single genome
|
||||
|
||||
:argument inputs: (input_num, )
|
||||
:argument cal_seqs: (N, )
|
||||
:argument nodes: (N, 5)
|
||||
:argument connections: (2, N, N)
|
||||
|
||||
:return (output_num, )
|
||||
"""
|
||||
|
||||
nodes, cons = transform
|
||||
|
||||
input_idx = config['input_idx']
|
||||
output_idx = config['output_idx']
|
||||
|
||||
N = nodes.shape[0]
|
||||
vals = jnp.full((N,), 0.)
|
||||
|
||||
weights = cons[0, :]
|
||||
|
||||
def body_func(i, values):
|
||||
values = values.at[input_idx].set(inputs)
|
||||
nodes_ins = values * weights.T
|
||||
values = batch_agg(nodes[:, 4], nodes_ins) # z = agg(ins)
|
||||
values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias
|
||||
values = batch_act(nodes[:, 3], values) # z = act(z)
|
||||
return values
|
||||
|
||||
# for i in range(config['activate_times']):
|
||||
# vals = body_func(i, vals)
|
||||
#
|
||||
# return vals[output_idx]
|
||||
vals = jax.lax.fori_loop(0, config['activate_times'], body_func, vals)
|
||||
return vals[output_idx]
|
||||
|
||||
return forward
|
||||
@@ -11,7 +11,7 @@ from ..utils import fetch_first
|
||||
|
||||
def initialize_genomes(state: State, gene_type: Type[BaseGene]):
|
||||
o_nodes = np.full((state.N, state.NL), np.nan, dtype=np.float32) # original nodes
|
||||
o_conns = np.full((state.N, state.CL), np.nan, dtype=np.float32) # original connections
|
||||
o_conns = np.full((state.C, state.CL), np.nan, dtype=np.float32) # original connections
|
||||
|
||||
input_idx = state.input_idx
|
||||
output_idx = state.output_idx
|
||||
|
||||
@@ -9,6 +9,7 @@ from jax import jit, Array, numpy as jnp
|
||||
from ..utils import fetch_first, I_INT
|
||||
|
||||
|
||||
@jit
|
||||
def topological_sort(nodes: Array, conns: Array) -> Array:
|
||||
"""
|
||||
a jit-able version of topological_sort! that's crazy!
|
||||
@@ -60,21 +61,11 @@ def topological_sort(nodes: Array, conns: Array) -> Array:
|
||||
return res
|
||||
|
||||
|
||||
def check_cycles(nodes: Array, connections: Array, from_idx: Array, to_idx: Array) -> Array:
|
||||
@jit
|
||||
def check_cycles(nodes: Array, conns: Array, from_idx, to_idx) -> Array:
|
||||
"""
|
||||
Check whether a new connection (from_idx -> to_idx) will cause a cycle.
|
||||
|
||||
:param nodes: JAX array
|
||||
The array of nodes.
|
||||
:param connections: JAX array
|
||||
The array of connections.
|
||||
:param from_idx: int
|
||||
The index of the starting node.
|
||||
:param to_idx: int
|
||||
The index of the ending node.
|
||||
:return: JAX array
|
||||
An array indicating if there is a cycle caused by the new connection.
|
||||
|
||||
Example:
|
||||
nodes = jnp.array([
|
||||
[0],
|
||||
@@ -83,28 +74,21 @@ def check_cycles(nodes: Array, connections: Array, from_idx: Array, to_idx: Arra
|
||||
[3]
|
||||
])
|
||||
connections = jnp.array([
|
||||
[
|
||||
[0, 0, 1, 0],
|
||||
[0, 0, 1, 1],
|
||||
[0, 0, 0, 1],
|
||||
[0, 0, 0, 0]
|
||||
],
|
||||
[
|
||||
[0, 0, 1, 0],
|
||||
[0, 0, 1, 1],
|
||||
[0, 0, 0, 1],
|
||||
[0, 0, 0, 0]
|
||||
]
|
||||
])
|
||||
|
||||
check_cycles(nodes, connections, 3, 2) -> True
|
||||
check_cycles(nodes, connections, 2, 3) -> False
|
||||
check_cycles(nodes, connections, 0, 3) -> False
|
||||
check_cycles(nodes, connections, 1, 0) -> False
|
||||
check_cycles(nodes, conns, 3, 2) -> True
|
||||
check_cycles(nodes, conns, 2, 3) -> False
|
||||
check_cycles(nodes, conns, 0, 3) -> False
|
||||
check_cycles(nodes, conns, 1, 0) -> False
|
||||
"""
|
||||
|
||||
connections_enable = ~jnp.isnan(connections[0, :, :])
|
||||
connections_enable = connections_enable.at[from_idx, to_idx].set(True)
|
||||
conns = conns.at[from_idx, to_idx].set(True)
|
||||
# conns_enable = ~jnp.isnan(conns[0, :, :])
|
||||
# conns_enable = conns_enable.at[from_idx, to_idx].set(True)
|
||||
|
||||
visited = jnp.full(nodes.shape[0], False)
|
||||
new_visited = visited.at[to_idx].set(True)
|
||||
@@ -117,43 +101,42 @@ def check_cycles(nodes: Array, connections: Array, from_idx: Array, to_idx: Arra
|
||||
|
||||
def body_func(carry):
|
||||
_, visited_ = carry
|
||||
new_visited_ = jnp.dot(visited_, connections_enable)
|
||||
new_visited_ = jnp.dot(visited_, conns)
|
||||
new_visited_ = jnp.logical_or(visited_, new_visited_)
|
||||
return visited_, new_visited_
|
||||
|
||||
_, visited = jax.lax.while_loop(cond_func, body_func, (visited, new_visited))
|
||||
return visited[from_idx]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
nodes = jnp.array([
|
||||
[0],
|
||||
[1],
|
||||
[2],
|
||||
[3],
|
||||
[jnp.nan]
|
||||
])
|
||||
connections = jnp.array([
|
||||
[
|
||||
[jnp.nan, jnp.nan, 1, jnp.nan, jnp.nan],
|
||||
[jnp.nan, jnp.nan, 1, 1, jnp.nan],
|
||||
[jnp.nan, jnp.nan, jnp.nan, 1, jnp.nan],
|
||||
[jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan],
|
||||
[jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan]
|
||||
],
|
||||
[
|
||||
[jnp.nan, jnp.nan, 1, jnp.nan, jnp.nan],
|
||||
[jnp.nan, jnp.nan, 1, 1, jnp.nan],
|
||||
[jnp.nan, jnp.nan, jnp.nan, 1, jnp.nan],
|
||||
[jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan],
|
||||
[jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan]
|
||||
]
|
||||
]
|
||||
)
|
||||
|
||||
print(topological_sort(nodes, connections))
|
||||
|
||||
print(check_cycles(nodes, connections, 3, 2))
|
||||
print(check_cycles(nodes, connections, 2, 3))
|
||||
print(check_cycles(nodes, connections, 0, 3))
|
||||
print(check_cycles(nodes, connections, 1, 0))
|
||||
# if __name__ == '__main__':
|
||||
# nodes = jnp.array([
|
||||
# [0],
|
||||
# [1],
|
||||
# [2],
|
||||
# [3],
|
||||
# [jnp.nan]
|
||||
# ])
|
||||
# connections = jnp.array([
|
||||
# [
|
||||
# [jnp.nan, jnp.nan, 1, jnp.nan, jnp.nan],
|
||||
# [jnp.nan, jnp.nan, 1, 1, jnp.nan],
|
||||
# [jnp.nan, jnp.nan, jnp.nan, 1, jnp.nan],
|
||||
# [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan],
|
||||
# [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan]
|
||||
# ],
|
||||
# [
|
||||
# [jnp.nan, jnp.nan, 1, jnp.nan, jnp.nan],
|
||||
# [jnp.nan, jnp.nan, 1, 1, jnp.nan],
|
||||
# [jnp.nan, jnp.nan, jnp.nan, 1, jnp.nan],
|
||||
# [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan],
|
||||
# [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan]
|
||||
# ]
|
||||
# ]
|
||||
# )
|
||||
#
|
||||
# print(topological_sort(nodes, connections))
|
||||
#
|
||||
# print(check_cycles(nodes, connections, 3, 2))
|
||||
# print(check_cycles(nodes, connections, 2, 3))
|
||||
# print(check_cycles(nodes, connections, 0, 3))
|
||||
# print(check_cycles(nodes, connections, 1, 0))
|
||||
|
||||
@@ -91,7 +91,8 @@ def create_mutate(config: Dict, gene_type: Type[BaseGene]):
|
||||
|
||||
if config['network_type'] == 'feedforward':
|
||||
u_cons = unflatten_connections(nodes_, conns_)
|
||||
is_cycle = check_cycles(nodes_, u_cons, from_idx, to_idx)
|
||||
cons_exist = jnp.where(~jnp.isnan(u_cons[0, :, :]), True, False)
|
||||
is_cycle = check_cycles(nodes_, cons_exist, from_idx, to_idx)
|
||||
|
||||
choice = jnp.where(is_already_exist, 0, jnp.where(is_cycle, 1, 2))
|
||||
return jax.lax.switch(choice, [already_exist, nothing, successful])
|
||||
|
||||
@@ -26,7 +26,7 @@ class NEAT:
|
||||
state = State(
|
||||
P=self.config['pop_size'],
|
||||
N=self.config['maximum_nodes'],
|
||||
C=self.config['maximum_connections'],
|
||||
C=self.config['maximum_conns'],
|
||||
S=self.config['maximum_species'],
|
||||
NL=1 + len(self.gene_type.node_attrs), # node length = (key) + attributes
|
||||
CL=3 + len(self.gene_type.conn_attrs), # conn length = (in, out, key) + attributes
|
||||
@@ -64,11 +64,15 @@ class NEAT:
|
||||
idx2species=idx2species,
|
||||
center_nodes=center_nodes,
|
||||
center_conns=center_conns,
|
||||
generation=generation,
|
||||
next_node_key=next_node_key,
|
||||
next_species_key=next_species_key
|
||||
# avoid jax auto cast from int to float. that would cause re-compilation.
|
||||
generation=jnp.asarray(generation, dtype=jnp.int32),
|
||||
next_node_key=jnp.asarray(next_node_key, dtype=jnp.float32),
|
||||
next_species_key=jnp.asarray(next_species_key)
|
||||
)
|
||||
|
||||
# move to device
|
||||
state = jax.device_put(state)
|
||||
|
||||
return state
|
||||
|
||||
def step(self, state, fitness):
|
||||
|
||||
@@ -34,9 +34,6 @@ class Pipeline:
|
||||
|
||||
def tell(self, fitness):
|
||||
self.state = self.algorithm.step(self.state, fitness)
|
||||
from algorithm.neat.genome.basic import count
|
||||
# print([count(self.state.pop_nodes[i], self.state.pop_conns[i]) for i in range(self.state.P)])
|
||||
|
||||
|
||||
def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"):
|
||||
for _ in range(self.config['generation_limit']):
|
||||
|
||||
@@ -7,8 +7,8 @@ from .utils import rank_elements, fetch_first
|
||||
from .genome import create_mutate, create_distance, crossover
|
||||
from .gene import BaseGene
|
||||
|
||||
def create_tell(config, gene_type: Type[BaseGene]):
|
||||
|
||||
def create_tell(config, gene_type: Type[BaseGene]):
|
||||
mutate = create_mutate(config, gene_type)
|
||||
distance = create_distance(config, gene_type)
|
||||
|
||||
@@ -36,7 +36,6 @@ def create_tell(config, gene_type: Type[BaseGene]):
|
||||
|
||||
return state, winner, loser, elite_mask
|
||||
|
||||
|
||||
def update_species_fitness(state, fitness):
|
||||
"""
|
||||
obtain the fitness of the species by the fitness of each individual.
|
||||
@@ -51,7 +50,6 @@ def create_tell(config, gene_type: Type[BaseGene]):
|
||||
|
||||
return vmap(aux_func)(jnp.arange(state.species_info.shape[0]))
|
||||
|
||||
|
||||
def stagnation(state, species_fitness):
|
||||
"""
|
||||
stagnation species.
|
||||
@@ -88,7 +86,6 @@ def create_tell(config, gene_type: Type[BaseGene]):
|
||||
|
||||
return state, species_fitness
|
||||
|
||||
|
||||
def cal_spawn_numbers(state):
|
||||
"""
|
||||
decide the number of members of each species by their fitness rank.
|
||||
@@ -106,7 +103,6 @@ def create_tell(config, gene_type: Type[BaseGene]):
|
||||
spawn_number_rate = jnp.where(is_species_valid, spawn_number_rate, 0) # set invalid species to 0
|
||||
|
||||
target_spawn_number = jnp.floor(spawn_number_rate * state.P) # calculate member
|
||||
# jax.debug.print("denominator: {}, spawn_number_rate: {}, target_spawn_number: {}", denominator, spawn_number_rate, target_spawn_number)
|
||||
|
||||
# Avoid too much variation of numbers in a species
|
||||
previous_size = state.species_info[:, 3].astype(jnp.int32)
|
||||
@@ -118,11 +114,11 @@ def create_tell(config, gene_type: Type[BaseGene]):
|
||||
|
||||
# must control the sum of spawn_number to be equal to pop_size
|
||||
error = state.P - jnp.sum(spawn_number)
|
||||
spawn_number = spawn_number.at[0].add(error) # add error to the first species to control the sum of spawn_number
|
||||
spawn_number = spawn_number.at[0].add(
|
||||
error) # add error to the first species to control the sum of spawn_number
|
||||
|
||||
return spawn_number
|
||||
|
||||
|
||||
def create_crossover_pair(state, randkey, spawn_number, fitness):
|
||||
species_size = state.species_info.shape[0]
|
||||
pop_size = fitness.shape[0]
|
||||
@@ -238,8 +234,8 @@ def create_tell(config, gene_type: Type[BaseGene]):
|
||||
return i + 1, i2s, cn, cc, o2c
|
||||
|
||||
_, idx2specie, center_nodes, center_conns, o2c_distances = \
|
||||
jax.lax.while_loop(cond_func, body_func, (0, idx2specie, state.center_nodes, state.center_conns, o2c_distances))
|
||||
|
||||
jax.lax.while_loop(cond_func, body_func,
|
||||
(0, idx2specie, state.center_nodes, state.center_conns, o2c_distances))
|
||||
|
||||
# part 2: assign members to each species
|
||||
def cond_func(carry):
|
||||
@@ -331,7 +327,7 @@ def create_tell(config, gene_type: Type[BaseGene]):
|
||||
species_info = species_info.at[:, 3].set(species_member_counts)
|
||||
|
||||
return state.update(
|
||||
idx2specie=idx2specie,
|
||||
idx2species=idx2specie,
|
||||
center_nodes=center_nodes,
|
||||
center_conns=center_conns,
|
||||
species_info=species_info,
|
||||
@@ -358,7 +354,6 @@ def create_tell(config, gene_type: Type[BaseGene]):
|
||||
|
||||
return state
|
||||
|
||||
|
||||
return tell
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
from algorithm.config import Configer
|
||||
|
||||
config = Configer.load_config()
|
||||
print(config)
|
||||
13
examples/rnn_forward_test.py
Normal file
13
examples/rnn_forward_test.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
vals = np.array([1, 2])
|
||||
weights = np.array([[0, 4], [5, 0]])
|
||||
|
||||
ins1 = vals * weights[:, 0]
|
||||
ins2 = vals * weights[:, 1]
|
||||
ins_all = vals * weights.T
|
||||
|
||||
print(ins1)
|
||||
print(ins2)
|
||||
print(ins_all)
|
||||
@@ -1,5 +1,7 @@
|
||||
[basic]
|
||||
forward_way = "common"
|
||||
network_type = "recurrent"
|
||||
activate_times = 5
|
||||
|
||||
[population]
|
||||
fitness_threshold = 4
|
||||
@@ -2,7 +2,7 @@ import jax
|
||||
import numpy as np
|
||||
|
||||
from algorithm import Configer, NEAT
|
||||
from algorithm.neat import NormalGene, Pipeline
|
||||
from algorithm.neat import NormalGene, RecurrentGene, Pipeline
|
||||
|
||||
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
|
||||
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
|
||||
@@ -15,16 +15,17 @@ def evaluate(forward_func):
|
||||
"""
|
||||
outs = forward_func(xor_inputs)
|
||||
outs = jax.device_get(outs)
|
||||
# print(outs)
|
||||
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
|
||||
return fitnesses
|
||||
|
||||
|
||||
def main():
|
||||
config = Configer.load_config("xor.ini")
|
||||
algorithm = NEAT(config, NormalGene)
|
||||
# algorithm = NEAT(config, NormalGene)
|
||||
algorithm = NEAT(config, RecurrentGene)
|
||||
pipeline = Pipeline(config, algorithm)
|
||||
pipeline.auto_run(evaluate)
|
||||
best = pipeline.auto_run(evaluate)
|
||||
print(best)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -2,31 +2,49 @@ import jax
|
||||
import numpy as np
|
||||
|
||||
from algorithm.config import Configer
|
||||
from algorithm.neat import NEAT, NormalGene, Pipeline
|
||||
from algorithm.neat import NEAT, NormalGene, RecurrentGene, Pipeline
|
||||
from algorithm.neat.genome import create_mutate
|
||||
|
||||
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
|
||||
|
||||
|
||||
def single_genome(func, nodes, conns):
|
||||
t = NormalGene.forward_transform(nodes, conns)
|
||||
t = RecurrentGene.forward_transform(nodes, conns)
|
||||
out1 = func(xor_inputs[0], t)
|
||||
out2 = func(xor_inputs[1], t)
|
||||
out3 = func(xor_inputs[2], t)
|
||||
out4 = func(xor_inputs[3], t)
|
||||
print(out1, out2, out3, out4)
|
||||
|
||||
|
||||
def batch_genome(func, nodes, conns):
|
||||
t = NormalGene.forward_transform(nodes, conns)
|
||||
out = jax.vmap(func, in_axes=(0, None))(xor_inputs, t)
|
||||
print(out)
|
||||
|
||||
|
||||
def pop_batch_genome(func, pop_nodes, pop_conns):
|
||||
t = jax.vmap(NormalGene.forward_transform)(pop_nodes, pop_conns)
|
||||
func = jax.vmap(jax.vmap(func, in_axes=(0, None)), in_axes=(None, 0))
|
||||
out = func(xor_inputs, t)
|
||||
print(out)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
config = Configer.load_config()
|
||||
neat = NEAT(config, NormalGene)
|
||||
config = Configer.load_config("xor.ini")
|
||||
# neat = NEAT(config, NormalGene)
|
||||
neat = NEAT(config, RecurrentGene)
|
||||
randkey = jax.random.PRNGKey(42)
|
||||
state = neat.setup(randkey)
|
||||
forward_func = NormalGene.create_forward(config)
|
||||
mutate_func = create_mutate(config, NormalGene)
|
||||
|
||||
forward_func = RecurrentGene.create_forward(config)
|
||||
mutate_func = create_mutate(config, RecurrentGene)
|
||||
|
||||
nodes, conns = state.pop_nodes[0], state.pop_conns[0]
|
||||
single_genome(forward_func, nodes, conns)
|
||||
# batch_genome(forward_func, nodes, conns)
|
||||
|
||||
nodes, conns = mutate_func(state, randkey, nodes, conns, 10000)
|
||||
single_genome(forward_func, nodes, conns)
|
||||
|
||||
|
||||
# batch_genome(forward_func, nodes, conns)
|
||||
#
|
||||
|
||||
32
test/unit/test_graphs.py
Normal file
32
test/unit/test_graphs.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import jax.numpy as jnp
|
||||
|
||||
from algorithm.neat.genome.graph import topological_sort, check_cycles
|
||||
from algorithm.neat.utils import I_INT
|
||||
|
||||
nodes = jnp.array([
|
||||
[0],
|
||||
[1],
|
||||
[2],
|
||||
[3],
|
||||
[jnp.nan]
|
||||
])
|
||||
|
||||
# {(0, 2), (1, 2), (1, 3), (2, 3)}
|
||||
conns = jnp.array([
|
||||
[0, 0, 1, 0, 0],
|
||||
[0, 0, 1, 1, 0],
|
||||
[0, 0, 0, 1, 0],
|
||||
[0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0]
|
||||
])
|
||||
|
||||
|
||||
def test_topological_sort():
|
||||
assert jnp.all(topological_sort(nodes, conns) == jnp.array([0, 1, 2, 3, I_INT]))
|
||||
|
||||
|
||||
def test_check_cycles():
|
||||
assert check_cycles(nodes, conns, 3, 2)
|
||||
assert ~check_cycles(nodes, conns, 2, 3)
|
||||
assert ~check_cycles(nodes, conns, 0, 3)
|
||||
assert ~check_cycles(nodes, conns, 1, 0)
|
||||
@@ -1,7 +1,5 @@
|
||||
import pytest
|
||||
import jax
|
||||
|
||||
from algorithm.neat.utils import *
|
||||
import jax.numpy as jnp
|
||||
from algorithm.neat.utils import unflatten_connections
|
||||
|
||||
|
||||
def test_unflatten():
|
||||
@@ -13,7 +11,6 @@ def test_unflatten():
|
||||
[jnp.nan, jnp.nan, jnp.nan, jnp.nan]
|
||||
])
|
||||
|
||||
|
||||
conns = jnp.array([
|
||||
[0, 1, True, 0.1, 0.11],
|
||||
[0, 2, False, 0.2, 0.22],
|
||||
|
||||
Reference in New Issue
Block a user