add gene type RNN
This commit is contained in:
@@ -1,13 +1,14 @@
|
|||||||
[basic]
|
[basic]
|
||||||
num_inputs = 2
|
num_inputs = 2
|
||||||
num_outputs = 1
|
num_outputs = 1
|
||||||
maximum_nodes = 100
|
maximum_nodes = 50
|
||||||
maximum_connections = 100
|
maximum_conns = 100
|
||||||
maximum_species = 100
|
maximum_species = 10
|
||||||
forward_way = "pop"
|
forward_way = "pop"
|
||||||
batch_size = 4
|
batch_size = 4
|
||||||
random_seed = 0
|
random_seed = 0
|
||||||
network_type = 'feedforward'
|
network_type = "feedforward"
|
||||||
|
activate_times = 10
|
||||||
|
|
||||||
[population]
|
[population]
|
||||||
fitness_threshold = 3.9999
|
fitness_threshold = 3.9999
|
||||||
@@ -18,11 +19,11 @@ pop_size = 1000
|
|||||||
[genome]
|
[genome]
|
||||||
compatibility_disjoint = 1.0
|
compatibility_disjoint = 1.0
|
||||||
compatibility_weight = 0.5
|
compatibility_weight = 0.5
|
||||||
conn_add_prob = 0.5
|
conn_add_prob = 0.4
|
||||||
conn_add_trials = 1
|
conn_add_trials = 1
|
||||||
conn_delete_prob = 0.5
|
conn_delete_prob = 0
|
||||||
node_add_prob = 0.2
|
node_add_prob = 0.2
|
||||||
node_delete_prob = 0.2
|
node_delete_prob = 0
|
||||||
|
|
||||||
[species]
|
[species]
|
||||||
compatibility_threshold = 3.0
|
compatibility_threshold = 3.0
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
from .neat import NEAT
|
from .neat import NEAT
|
||||||
from .gene import NormalGene
|
from .gene import NormalGene, RecurrentGene
|
||||||
from .pipeline import Pipeline
|
from .pipeline import Pipeline
|
||||||
|
|||||||
@@ -2,4 +2,5 @@ from .base import BaseGene
|
|||||||
from .normal import NormalGene
|
from .normal import NormalGene
|
||||||
from .activation import Activation
|
from .activation import Activation
|
||||||
from .aggregation import Aggregation
|
from .aggregation import Aggregation
|
||||||
|
from .recurrent import RecurrentGene
|
||||||
|
|
||||||
|
|||||||
@@ -86,10 +86,12 @@ class NormalGene(BaseGene):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def forward_transform(nodes, conns):
|
def forward_transform(nodes, conns):
|
||||||
u_conns = unflatten_connections(nodes, conns)
|
u_conns = unflatten_connections(nodes, conns)
|
||||||
u_conns = jnp.where(jnp.isnan(u_conns[0, :]), jnp.nan, u_conns) # enable is false, then the connections is nan
|
conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False)
|
||||||
u_conns = u_conns[1:, :] # remove enable attr
|
|
||||||
conn_exist = jnp.any(~jnp.isnan(u_conns), axis=0)
|
# remove enable attr
|
||||||
seqs = topological_sort(nodes, conn_exist)
|
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
|
||||||
|
seqs = topological_sort(nodes, conn_enable)
|
||||||
|
|
||||||
return seqs, nodes, u_conns
|
return seqs, nodes, u_conns
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -167,18 +169,8 @@ class NormalGene(BaseGene):
|
|||||||
# the val of input nodes is obtained by the task, not by calculation
|
# the val of input nodes is obtained by the task, not by calculation
|
||||||
values = jax.lax.cond(jnp.isin(i, input_idx), miss, hit)
|
values = jax.lax.cond(jnp.isin(i, input_idx), miss, hit)
|
||||||
|
|
||||||
# if jnp.isin(i, input_idx):
|
|
||||||
# values = miss()
|
|
||||||
# else:
|
|
||||||
# values = hit()
|
|
||||||
|
|
||||||
return values, idx + 1
|
return values, idx + 1
|
||||||
|
|
||||||
# carry = (ini_vals, 0)
|
|
||||||
# while cond_fun(carry):
|
|
||||||
# carry = body_func(carry)
|
|
||||||
# vals, _ = carry
|
|
||||||
|
|
||||||
vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0))
|
vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0))
|
||||||
|
|
||||||
return vals[output_idx]
|
return vals[output_idx]
|
||||||
@@ -216,7 +208,3 @@ class NormalGene(BaseGene):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return val
|
return val
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
90
algorithm/neat/gene/recurrent.py
Normal file
90
algorithm/neat/gene/recurrent.py
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
import jax
|
||||||
|
from jax import Array, numpy as jnp, vmap
|
||||||
|
|
||||||
|
from .normal import NormalGene
|
||||||
|
from .activation import Activation
|
||||||
|
from .aggregation import Aggregation
|
||||||
|
from ..utils import unflatten_connections, I_INT
|
||||||
|
|
||||||
|
|
||||||
|
class RecurrentGene(NormalGene):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward_transform(nodes, conns):
|
||||||
|
u_conns = unflatten_connections(nodes, conns)
|
||||||
|
|
||||||
|
# remove un-enable connections and remove enable attr
|
||||||
|
conn_enable = jnp.where(~jnp.isnan(u_conns[0]), True, False)
|
||||||
|
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
|
||||||
|
|
||||||
|
return nodes, u_conns
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_forward(config):
|
||||||
|
config['activation_funcs'] = [Activation.name2func[name] for name in config['activation_option_names']]
|
||||||
|
config['aggregation_funcs'] = [Aggregation.name2func[name] for name in config['aggregation_option_names']]
|
||||||
|
|
||||||
|
def act(idx, z):
|
||||||
|
"""
|
||||||
|
calculate activation function for each node
|
||||||
|
"""
|
||||||
|
idx = jnp.asarray(idx, dtype=jnp.int32)
|
||||||
|
# change idx from float to int
|
||||||
|
res = jax.lax.switch(idx, config['activation_funcs'], z)
|
||||||
|
return res
|
||||||
|
|
||||||
|
def agg(idx, z):
|
||||||
|
"""
|
||||||
|
calculate activation function for inputs of node
|
||||||
|
"""
|
||||||
|
idx = jnp.asarray(idx, dtype=jnp.int32)
|
||||||
|
|
||||||
|
def all_nan():
|
||||||
|
return 0.
|
||||||
|
|
||||||
|
def not_all_nan():
|
||||||
|
return jax.lax.switch(idx, config['aggregation_funcs'], z)
|
||||||
|
|
||||||
|
return jax.lax.cond(jnp.all(jnp.isnan(z)), all_nan, not_all_nan)
|
||||||
|
|
||||||
|
batch_act, batch_agg = vmap(act), vmap(agg)
|
||||||
|
|
||||||
|
def forward(inputs, transform) -> Array:
|
||||||
|
"""
|
||||||
|
jax forward for single input shaped (input_num, )
|
||||||
|
nodes, connections are a single genome
|
||||||
|
|
||||||
|
:argument inputs: (input_num, )
|
||||||
|
:argument cal_seqs: (N, )
|
||||||
|
:argument nodes: (N, 5)
|
||||||
|
:argument connections: (2, N, N)
|
||||||
|
|
||||||
|
:return (output_num, )
|
||||||
|
"""
|
||||||
|
|
||||||
|
nodes, cons = transform
|
||||||
|
|
||||||
|
input_idx = config['input_idx']
|
||||||
|
output_idx = config['output_idx']
|
||||||
|
|
||||||
|
N = nodes.shape[0]
|
||||||
|
vals = jnp.full((N,), 0.)
|
||||||
|
|
||||||
|
weights = cons[0, :]
|
||||||
|
|
||||||
|
def body_func(i, values):
|
||||||
|
values = values.at[input_idx].set(inputs)
|
||||||
|
nodes_ins = values * weights.T
|
||||||
|
values = batch_agg(nodes[:, 4], nodes_ins) # z = agg(ins)
|
||||||
|
values = values * nodes[:, 2] + nodes[:, 1] # z = z * response + bias
|
||||||
|
values = batch_act(nodes[:, 3], values) # z = act(z)
|
||||||
|
return values
|
||||||
|
|
||||||
|
# for i in range(config['activate_times']):
|
||||||
|
# vals = body_func(i, vals)
|
||||||
|
#
|
||||||
|
# return vals[output_idx]
|
||||||
|
vals = jax.lax.fori_loop(0, config['activate_times'], body_func, vals)
|
||||||
|
return vals[output_idx]
|
||||||
|
|
||||||
|
return forward
|
||||||
@@ -11,7 +11,7 @@ from ..utils import fetch_first
|
|||||||
|
|
||||||
def initialize_genomes(state: State, gene_type: Type[BaseGene]):
|
def initialize_genomes(state: State, gene_type: Type[BaseGene]):
|
||||||
o_nodes = np.full((state.N, state.NL), np.nan, dtype=np.float32) # original nodes
|
o_nodes = np.full((state.N, state.NL), np.nan, dtype=np.float32) # original nodes
|
||||||
o_conns = np.full((state.N, state.CL), np.nan, dtype=np.float32) # original connections
|
o_conns = np.full((state.C, state.CL), np.nan, dtype=np.float32) # original connections
|
||||||
|
|
||||||
input_idx = state.input_idx
|
input_idx = state.input_idx
|
||||||
output_idx = state.output_idx
|
output_idx = state.output_idx
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from jax import jit, Array, numpy as jnp
|
|||||||
from ..utils import fetch_first, I_INT
|
from ..utils import fetch_first, I_INT
|
||||||
|
|
||||||
|
|
||||||
|
@jit
|
||||||
def topological_sort(nodes: Array, conns: Array) -> Array:
|
def topological_sort(nodes: Array, conns: Array) -> Array:
|
||||||
"""
|
"""
|
||||||
a jit-able version of topological_sort! that's crazy!
|
a jit-able version of topological_sort! that's crazy!
|
||||||
@@ -60,21 +61,11 @@ def topological_sort(nodes: Array, conns: Array) -> Array:
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
def check_cycles(nodes: Array, connections: Array, from_idx: Array, to_idx: Array) -> Array:
|
@jit
|
||||||
|
def check_cycles(nodes: Array, conns: Array, from_idx, to_idx) -> Array:
|
||||||
"""
|
"""
|
||||||
Check whether a new connection (from_idx -> to_idx) will cause a cycle.
|
Check whether a new connection (from_idx -> to_idx) will cause a cycle.
|
||||||
|
|
||||||
:param nodes: JAX array
|
|
||||||
The array of nodes.
|
|
||||||
:param connections: JAX array
|
|
||||||
The array of connections.
|
|
||||||
:param from_idx: int
|
|
||||||
The index of the starting node.
|
|
||||||
:param to_idx: int
|
|
||||||
The index of the ending node.
|
|
||||||
:return: JAX array
|
|
||||||
An array indicating if there is a cycle caused by the new connection.
|
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
nodes = jnp.array([
|
nodes = jnp.array([
|
||||||
[0],
|
[0],
|
||||||
@@ -83,28 +74,21 @@ def check_cycles(nodes: Array, connections: Array, from_idx: Array, to_idx: Arra
|
|||||||
[3]
|
[3]
|
||||||
])
|
])
|
||||||
connections = jnp.array([
|
connections = jnp.array([
|
||||||
[
|
|
||||||
[0, 0, 1, 0],
|
[0, 0, 1, 0],
|
||||||
[0, 0, 1, 1],
|
[0, 0, 1, 1],
|
||||||
[0, 0, 0, 1],
|
[0, 0, 0, 1],
|
||||||
[0, 0, 0, 0]
|
[0, 0, 0, 0]
|
||||||
],
|
|
||||||
[
|
|
||||||
[0, 0, 1, 0],
|
|
||||||
[0, 0, 1, 1],
|
|
||||||
[0, 0, 0, 1],
|
|
||||||
[0, 0, 0, 0]
|
|
||||||
]
|
|
||||||
])
|
])
|
||||||
|
|
||||||
check_cycles(nodes, connections, 3, 2) -> True
|
check_cycles(nodes, conns, 3, 2) -> True
|
||||||
check_cycles(nodes, connections, 2, 3) -> False
|
check_cycles(nodes, conns, 2, 3) -> False
|
||||||
check_cycles(nodes, connections, 0, 3) -> False
|
check_cycles(nodes, conns, 0, 3) -> False
|
||||||
check_cycles(nodes, connections, 1, 0) -> False
|
check_cycles(nodes, conns, 1, 0) -> False
|
||||||
"""
|
"""
|
||||||
|
|
||||||
connections_enable = ~jnp.isnan(connections[0, :, :])
|
conns = conns.at[from_idx, to_idx].set(True)
|
||||||
connections_enable = connections_enable.at[from_idx, to_idx].set(True)
|
# conns_enable = ~jnp.isnan(conns[0, :, :])
|
||||||
|
# conns_enable = conns_enable.at[from_idx, to_idx].set(True)
|
||||||
|
|
||||||
visited = jnp.full(nodes.shape[0], False)
|
visited = jnp.full(nodes.shape[0], False)
|
||||||
new_visited = visited.at[to_idx].set(True)
|
new_visited = visited.at[to_idx].set(True)
|
||||||
@@ -117,43 +101,42 @@ def check_cycles(nodes: Array, connections: Array, from_idx: Array, to_idx: Arra
|
|||||||
|
|
||||||
def body_func(carry):
|
def body_func(carry):
|
||||||
_, visited_ = carry
|
_, visited_ = carry
|
||||||
new_visited_ = jnp.dot(visited_, connections_enable)
|
new_visited_ = jnp.dot(visited_, conns)
|
||||||
new_visited_ = jnp.logical_or(visited_, new_visited_)
|
new_visited_ = jnp.logical_or(visited_, new_visited_)
|
||||||
return visited_, new_visited_
|
return visited_, new_visited_
|
||||||
|
|
||||||
_, visited = jax.lax.while_loop(cond_func, body_func, (visited, new_visited))
|
_, visited = jax.lax.while_loop(cond_func, body_func, (visited, new_visited))
|
||||||
return visited[from_idx]
|
return visited[from_idx]
|
||||||
|
|
||||||
|
# if __name__ == '__main__':
|
||||||
if __name__ == '__main__':
|
# nodes = jnp.array([
|
||||||
nodes = jnp.array([
|
# [0],
|
||||||
[0],
|
# [1],
|
||||||
[1],
|
# [2],
|
||||||
[2],
|
# [3],
|
||||||
[3],
|
# [jnp.nan]
|
||||||
[jnp.nan]
|
# ])
|
||||||
])
|
# connections = jnp.array([
|
||||||
connections = jnp.array([
|
# [
|
||||||
[
|
# [jnp.nan, jnp.nan, 1, jnp.nan, jnp.nan],
|
||||||
[jnp.nan, jnp.nan, 1, jnp.nan, jnp.nan],
|
# [jnp.nan, jnp.nan, 1, 1, jnp.nan],
|
||||||
[jnp.nan, jnp.nan, 1, 1, jnp.nan],
|
# [jnp.nan, jnp.nan, jnp.nan, 1, jnp.nan],
|
||||||
[jnp.nan, jnp.nan, jnp.nan, 1, jnp.nan],
|
# [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan],
|
||||||
[jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan],
|
# [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan]
|
||||||
[jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan]
|
# ],
|
||||||
],
|
# [
|
||||||
[
|
# [jnp.nan, jnp.nan, 1, jnp.nan, jnp.nan],
|
||||||
[jnp.nan, jnp.nan, 1, jnp.nan, jnp.nan],
|
# [jnp.nan, jnp.nan, 1, 1, jnp.nan],
|
||||||
[jnp.nan, jnp.nan, 1, 1, jnp.nan],
|
# [jnp.nan, jnp.nan, jnp.nan, 1, jnp.nan],
|
||||||
[jnp.nan, jnp.nan, jnp.nan, 1, jnp.nan],
|
# [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan],
|
||||||
[jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan],
|
# [jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan]
|
||||||
[jnp.nan, jnp.nan, jnp.nan, jnp.nan, jnp.nan]
|
# ]
|
||||||
]
|
# ]
|
||||||
]
|
# )
|
||||||
)
|
#
|
||||||
|
# print(topological_sort(nodes, connections))
|
||||||
print(topological_sort(nodes, connections))
|
#
|
||||||
|
# print(check_cycles(nodes, connections, 3, 2))
|
||||||
print(check_cycles(nodes, connections, 3, 2))
|
# print(check_cycles(nodes, connections, 2, 3))
|
||||||
print(check_cycles(nodes, connections, 2, 3))
|
# print(check_cycles(nodes, connections, 0, 3))
|
||||||
print(check_cycles(nodes, connections, 0, 3))
|
# print(check_cycles(nodes, connections, 1, 0))
|
||||||
print(check_cycles(nodes, connections, 1, 0))
|
|
||||||
|
|||||||
@@ -91,7 +91,8 @@ def create_mutate(config: Dict, gene_type: Type[BaseGene]):
|
|||||||
|
|
||||||
if config['network_type'] == 'feedforward':
|
if config['network_type'] == 'feedforward':
|
||||||
u_cons = unflatten_connections(nodes_, conns_)
|
u_cons = unflatten_connections(nodes_, conns_)
|
||||||
is_cycle = check_cycles(nodes_, u_cons, from_idx, to_idx)
|
cons_exist = jnp.where(~jnp.isnan(u_cons[0, :, :]), True, False)
|
||||||
|
is_cycle = check_cycles(nodes_, cons_exist, from_idx, to_idx)
|
||||||
|
|
||||||
choice = jnp.where(is_already_exist, 0, jnp.where(is_cycle, 1, 2))
|
choice = jnp.where(is_already_exist, 0, jnp.where(is_cycle, 1, 2))
|
||||||
return jax.lax.switch(choice, [already_exist, nothing, successful])
|
return jax.lax.switch(choice, [already_exist, nothing, successful])
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ class NEAT:
|
|||||||
state = State(
|
state = State(
|
||||||
P=self.config['pop_size'],
|
P=self.config['pop_size'],
|
||||||
N=self.config['maximum_nodes'],
|
N=self.config['maximum_nodes'],
|
||||||
C=self.config['maximum_connections'],
|
C=self.config['maximum_conns'],
|
||||||
S=self.config['maximum_species'],
|
S=self.config['maximum_species'],
|
||||||
NL=1 + len(self.gene_type.node_attrs), # node length = (key) + attributes
|
NL=1 + len(self.gene_type.node_attrs), # node length = (key) + attributes
|
||||||
CL=3 + len(self.gene_type.conn_attrs), # conn length = (in, out, key) + attributes
|
CL=3 + len(self.gene_type.conn_attrs), # conn length = (in, out, key) + attributes
|
||||||
@@ -64,11 +64,15 @@ class NEAT:
|
|||||||
idx2species=idx2species,
|
idx2species=idx2species,
|
||||||
center_nodes=center_nodes,
|
center_nodes=center_nodes,
|
||||||
center_conns=center_conns,
|
center_conns=center_conns,
|
||||||
generation=generation,
|
# avoid jax auto cast from int to float. that would cause re-compilation.
|
||||||
next_node_key=next_node_key,
|
generation=jnp.asarray(generation, dtype=jnp.int32),
|
||||||
next_species_key=next_species_key
|
next_node_key=jnp.asarray(next_node_key, dtype=jnp.float32),
|
||||||
|
next_species_key=jnp.asarray(next_species_key)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# move to device
|
||||||
|
state = jax.device_put(state)
|
||||||
|
|
||||||
return state
|
return state
|
||||||
|
|
||||||
def step(self, state, fitness):
|
def step(self, state, fitness):
|
||||||
|
|||||||
@@ -34,9 +34,6 @@ class Pipeline:
|
|||||||
|
|
||||||
def tell(self, fitness):
|
def tell(self, fitness):
|
||||||
self.state = self.algorithm.step(self.state, fitness)
|
self.state = self.algorithm.step(self.state, fitness)
|
||||||
from algorithm.neat.genome.basic import count
|
|
||||||
# print([count(self.state.pop_nodes[i], self.state.pop_conns[i]) for i in range(self.state.P)])
|
|
||||||
|
|
||||||
|
|
||||||
def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"):
|
def auto_run(self, fitness_func, analysis: Union[Callable, str] = "default"):
|
||||||
for _ in range(self.config['generation_limit']):
|
for _ in range(self.config['generation_limit']):
|
||||||
|
|||||||
@@ -7,8 +7,8 @@ from .utils import rank_elements, fetch_first
|
|||||||
from .genome import create_mutate, create_distance, crossover
|
from .genome import create_mutate, create_distance, crossover
|
||||||
from .gene import BaseGene
|
from .gene import BaseGene
|
||||||
|
|
||||||
def create_tell(config, gene_type: Type[BaseGene]):
|
|
||||||
|
|
||||||
|
def create_tell(config, gene_type: Type[BaseGene]):
|
||||||
mutate = create_mutate(config, gene_type)
|
mutate = create_mutate(config, gene_type)
|
||||||
distance = create_distance(config, gene_type)
|
distance = create_distance(config, gene_type)
|
||||||
|
|
||||||
@@ -36,7 +36,6 @@ def create_tell(config, gene_type: Type[BaseGene]):
|
|||||||
|
|
||||||
return state, winner, loser, elite_mask
|
return state, winner, loser, elite_mask
|
||||||
|
|
||||||
|
|
||||||
def update_species_fitness(state, fitness):
|
def update_species_fitness(state, fitness):
|
||||||
"""
|
"""
|
||||||
obtain the fitness of the species by the fitness of each individual.
|
obtain the fitness of the species by the fitness of each individual.
|
||||||
@@ -51,7 +50,6 @@ def create_tell(config, gene_type: Type[BaseGene]):
|
|||||||
|
|
||||||
return vmap(aux_func)(jnp.arange(state.species_info.shape[0]))
|
return vmap(aux_func)(jnp.arange(state.species_info.shape[0]))
|
||||||
|
|
||||||
|
|
||||||
def stagnation(state, species_fitness):
|
def stagnation(state, species_fitness):
|
||||||
"""
|
"""
|
||||||
stagnation species.
|
stagnation species.
|
||||||
@@ -88,7 +86,6 @@ def create_tell(config, gene_type: Type[BaseGene]):
|
|||||||
|
|
||||||
return state, species_fitness
|
return state, species_fitness
|
||||||
|
|
||||||
|
|
||||||
def cal_spawn_numbers(state):
|
def cal_spawn_numbers(state):
|
||||||
"""
|
"""
|
||||||
decide the number of members of each species by their fitness rank.
|
decide the number of members of each species by their fitness rank.
|
||||||
@@ -106,7 +103,6 @@ def create_tell(config, gene_type: Type[BaseGene]):
|
|||||||
spawn_number_rate = jnp.where(is_species_valid, spawn_number_rate, 0) # set invalid species to 0
|
spawn_number_rate = jnp.where(is_species_valid, spawn_number_rate, 0) # set invalid species to 0
|
||||||
|
|
||||||
target_spawn_number = jnp.floor(spawn_number_rate * state.P) # calculate member
|
target_spawn_number = jnp.floor(spawn_number_rate * state.P) # calculate member
|
||||||
# jax.debug.print("denominator: {}, spawn_number_rate: {}, target_spawn_number: {}", denominator, spawn_number_rate, target_spawn_number)
|
|
||||||
|
|
||||||
# Avoid too much variation of numbers in a species
|
# Avoid too much variation of numbers in a species
|
||||||
previous_size = state.species_info[:, 3].astype(jnp.int32)
|
previous_size = state.species_info[:, 3].astype(jnp.int32)
|
||||||
@@ -118,11 +114,11 @@ def create_tell(config, gene_type: Type[BaseGene]):
|
|||||||
|
|
||||||
# must control the sum of spawn_number to be equal to pop_size
|
# must control the sum of spawn_number to be equal to pop_size
|
||||||
error = state.P - jnp.sum(spawn_number)
|
error = state.P - jnp.sum(spawn_number)
|
||||||
spawn_number = spawn_number.at[0].add(error) # add error to the first species to control the sum of spawn_number
|
spawn_number = spawn_number.at[0].add(
|
||||||
|
error) # add error to the first species to control the sum of spawn_number
|
||||||
|
|
||||||
return spawn_number
|
return spawn_number
|
||||||
|
|
||||||
|
|
||||||
def create_crossover_pair(state, randkey, spawn_number, fitness):
|
def create_crossover_pair(state, randkey, spawn_number, fitness):
|
||||||
species_size = state.species_info.shape[0]
|
species_size = state.species_info.shape[0]
|
||||||
pop_size = fitness.shape[0]
|
pop_size = fitness.shape[0]
|
||||||
@@ -238,8 +234,8 @@ def create_tell(config, gene_type: Type[BaseGene]):
|
|||||||
return i + 1, i2s, cn, cc, o2c
|
return i + 1, i2s, cn, cc, o2c
|
||||||
|
|
||||||
_, idx2specie, center_nodes, center_conns, o2c_distances = \
|
_, idx2specie, center_nodes, center_conns, o2c_distances = \
|
||||||
jax.lax.while_loop(cond_func, body_func, (0, idx2specie, state.center_nodes, state.center_conns, o2c_distances))
|
jax.lax.while_loop(cond_func, body_func,
|
||||||
|
(0, idx2specie, state.center_nodes, state.center_conns, o2c_distances))
|
||||||
|
|
||||||
# part 2: assign members to each species
|
# part 2: assign members to each species
|
||||||
def cond_func(carry):
|
def cond_func(carry):
|
||||||
@@ -331,7 +327,7 @@ def create_tell(config, gene_type: Type[BaseGene]):
|
|||||||
species_info = species_info.at[:, 3].set(species_member_counts)
|
species_info = species_info.at[:, 3].set(species_member_counts)
|
||||||
|
|
||||||
return state.update(
|
return state.update(
|
||||||
idx2specie=idx2specie,
|
idx2species=idx2specie,
|
||||||
center_nodes=center_nodes,
|
center_nodes=center_nodes,
|
||||||
center_conns=center_conns,
|
center_conns=center_conns,
|
||||||
species_info=species_info,
|
species_info=species_info,
|
||||||
@@ -358,11 +354,10 @@ def create_tell(config, gene_type: Type[BaseGene]):
|
|||||||
|
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
|
||||||
return tell
|
return tell
|
||||||
|
|
||||||
|
|
||||||
def argmin_with_mask(arr, mask):
|
def argmin_with_mask(arr, mask):
|
||||||
masked_arr = jnp.where(mask, arr, jnp.inf)
|
masked_arr = jnp.where(mask, arr, jnp.inf)
|
||||||
min_idx = jnp.argmin(masked_arr)
|
min_idx = jnp.argmin(masked_arr)
|
||||||
return min_idx
|
return min_idx
|
||||||
|
|||||||
@@ -1,4 +0,0 @@
|
|||||||
from algorithm.config import Configer
|
|
||||||
|
|
||||||
config = Configer.load_config()
|
|
||||||
print(config)
|
|
||||||
13
examples/rnn_forward_test.py
Normal file
13
examples/rnn_forward_test.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
vals = np.array([1, 2])
|
||||||
|
weights = np.array([[0, 4], [5, 0]])
|
||||||
|
|
||||||
|
ins1 = vals * weights[:, 0]
|
||||||
|
ins2 = vals * weights[:, 1]
|
||||||
|
ins_all = vals * weights.T
|
||||||
|
|
||||||
|
print(ins1)
|
||||||
|
print(ins2)
|
||||||
|
print(ins_all)
|
||||||
@@ -1,5 +1,7 @@
|
|||||||
[basic]
|
[basic]
|
||||||
forward_way = "common"
|
forward_way = "common"
|
||||||
|
network_type = "recurrent"
|
||||||
|
activate_times = 5
|
||||||
|
|
||||||
[population]
|
[population]
|
||||||
fitness_threshold = 4
|
fitness_threshold = 4
|
||||||
@@ -2,7 +2,7 @@ import jax
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from algorithm import Configer, NEAT
|
from algorithm import Configer, NEAT
|
||||||
from algorithm.neat import NormalGene, Pipeline
|
from algorithm.neat import NormalGene, RecurrentGene, Pipeline
|
||||||
|
|
||||||
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
|
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
|
||||||
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
|
xor_outputs = np.array([[0], [1], [1], [0]], dtype=np.float32)
|
||||||
@@ -15,16 +15,17 @@ def evaluate(forward_func):
|
|||||||
"""
|
"""
|
||||||
outs = forward_func(xor_inputs)
|
outs = forward_func(xor_inputs)
|
||||||
outs = jax.device_get(outs)
|
outs = jax.device_get(outs)
|
||||||
# print(outs)
|
|
||||||
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
|
fitnesses = 4 - np.sum((outs - xor_outputs) ** 2, axis=(1, 2))
|
||||||
return fitnesses
|
return fitnesses
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
config = Configer.load_config("xor.ini")
|
config = Configer.load_config("xor.ini")
|
||||||
algorithm = NEAT(config, NormalGene)
|
# algorithm = NEAT(config, NormalGene)
|
||||||
|
algorithm = NEAT(config, RecurrentGene)
|
||||||
pipeline = Pipeline(config, algorithm)
|
pipeline = Pipeline(config, algorithm)
|
||||||
pipeline.auto_run(evaluate)
|
best = pipeline.auto_run(evaluate)
|
||||||
|
print(best)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
@@ -2,31 +2,49 @@ import jax
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from algorithm.config import Configer
|
from algorithm.config import Configer
|
||||||
from algorithm.neat import NEAT, NormalGene, Pipeline
|
from algorithm.neat import NEAT, NormalGene, RecurrentGene, Pipeline
|
||||||
from algorithm.neat.genome import create_mutate
|
from algorithm.neat.genome import create_mutate
|
||||||
|
|
||||||
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
|
xor_inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
|
||||||
|
|
||||||
|
|
||||||
def single_genome(func, nodes, conns):
|
def single_genome(func, nodes, conns):
|
||||||
t = NormalGene.forward_transform(nodes, conns)
|
t = RecurrentGene.forward_transform(nodes, conns)
|
||||||
out1 = func(xor_inputs[0], t)
|
out1 = func(xor_inputs[0], t)
|
||||||
out2 = func(xor_inputs[1], t)
|
out2 = func(xor_inputs[1], t)
|
||||||
out3 = func(xor_inputs[2], t)
|
out3 = func(xor_inputs[2], t)
|
||||||
out4 = func(xor_inputs[3], t)
|
out4 = func(xor_inputs[3], t)
|
||||||
print(out1, out2, out3, out4)
|
print(out1, out2, out3, out4)
|
||||||
|
|
||||||
|
|
||||||
|
def batch_genome(func, nodes, conns):
|
||||||
|
t = NormalGene.forward_transform(nodes, conns)
|
||||||
|
out = jax.vmap(func, in_axes=(0, None))(xor_inputs, t)
|
||||||
|
print(out)
|
||||||
|
|
||||||
|
|
||||||
|
def pop_batch_genome(func, pop_nodes, pop_conns):
|
||||||
|
t = jax.vmap(NormalGene.forward_transform)(pop_nodes, pop_conns)
|
||||||
|
func = jax.vmap(jax.vmap(func, in_axes=(0, None)), in_axes=(None, 0))
|
||||||
|
out = func(xor_inputs, t)
|
||||||
|
print(out)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
config = Configer.load_config()
|
config = Configer.load_config("xor.ini")
|
||||||
neat = NEAT(config, NormalGene)
|
# neat = NEAT(config, NormalGene)
|
||||||
|
neat = NEAT(config, RecurrentGene)
|
||||||
randkey = jax.random.PRNGKey(42)
|
randkey = jax.random.PRNGKey(42)
|
||||||
state = neat.setup(randkey)
|
state = neat.setup(randkey)
|
||||||
forward_func = NormalGene.create_forward(config)
|
forward_func = RecurrentGene.create_forward(config)
|
||||||
mutate_func = create_mutate(config, NormalGene)
|
mutate_func = create_mutate(config, RecurrentGene)
|
||||||
|
|
||||||
|
|
||||||
nodes, conns = state.pop_nodes[0], state.pop_conns[0]
|
nodes, conns = state.pop_nodes[0], state.pop_conns[0]
|
||||||
single_genome(forward_func, nodes, conns)
|
single_genome(forward_func, nodes, conns)
|
||||||
|
# batch_genome(forward_func, nodes, conns)
|
||||||
|
|
||||||
nodes, conns = mutate_func(state, randkey, nodes, conns, 10000)
|
nodes, conns = mutate_func(state, randkey, nodes, conns, 10000)
|
||||||
single_genome(forward_func, nodes, conns)
|
single_genome(forward_func, nodes, conns)
|
||||||
|
|
||||||
|
# batch_genome(forward_func, nodes, conns)
|
||||||
|
#
|
||||||
|
|||||||
32
test/unit/test_graphs.py
Normal file
32
test/unit/test_graphs.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
import jax.numpy as jnp
|
||||||
|
|
||||||
|
from algorithm.neat.genome.graph import topological_sort, check_cycles
|
||||||
|
from algorithm.neat.utils import I_INT
|
||||||
|
|
||||||
|
nodes = jnp.array([
|
||||||
|
[0],
|
||||||
|
[1],
|
||||||
|
[2],
|
||||||
|
[3],
|
||||||
|
[jnp.nan]
|
||||||
|
])
|
||||||
|
|
||||||
|
# {(0, 2), (1, 2), (1, 3), (2, 3)}
|
||||||
|
conns = jnp.array([
|
||||||
|
[0, 0, 1, 0, 0],
|
||||||
|
[0, 0, 1, 1, 0],
|
||||||
|
[0, 0, 0, 1, 0],
|
||||||
|
[0, 0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0]
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
def test_topological_sort():
|
||||||
|
assert jnp.all(topological_sort(nodes, conns) == jnp.array([0, 1, 2, 3, I_INT]))
|
||||||
|
|
||||||
|
|
||||||
|
def test_check_cycles():
|
||||||
|
assert check_cycles(nodes, conns, 3, 2)
|
||||||
|
assert ~check_cycles(nodes, conns, 2, 3)
|
||||||
|
assert ~check_cycles(nodes, conns, 0, 3)
|
||||||
|
assert ~check_cycles(nodes, conns, 1, 0)
|
||||||
@@ -1,7 +1,5 @@
|
|||||||
import pytest
|
import jax.numpy as jnp
|
||||||
import jax
|
from algorithm.neat.utils import unflatten_connections
|
||||||
|
|
||||||
from algorithm.neat.utils import *
|
|
||||||
|
|
||||||
|
|
||||||
def test_unflatten():
|
def test_unflatten():
|
||||||
@@ -13,7 +11,6 @@ def test_unflatten():
|
|||||||
[jnp.nan, jnp.nan, jnp.nan, jnp.nan]
|
[jnp.nan, jnp.nan, jnp.nan, jnp.nan]
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|
||||||
conns = jnp.array([
|
conns = jnp.array([
|
||||||
[0, 1, True, 0.1, 0.11],
|
[0, 1, True, 0.1, 0.11],
|
||||||
[0, 2, False, 0.2, 0.22],
|
[0, 2, False, 0.2, 0.22],
|
||||||
@@ -33,4 +30,4 @@ def test_unflatten():
|
|||||||
mask = mask.at[:, [0, 0, 1, 1], [1, 2, 2, 3]].set(False)
|
mask = mask.at[:, [0, 0, 1, 1], [1, 2, 2, 3]].set(False)
|
||||||
|
|
||||||
# Ensure all other places are jnp.nan
|
# Ensure all other places are jnp.nan
|
||||||
assert jnp.all(jnp.isnan(res[mask]))
|
assert jnp.all(jnp.isnan(res[mask]))
|
||||||
|
|||||||
Reference in New Issue
Block a user