All function with state will update the state and return it.
Remove randkey args in functions with state, since it can attach the randkey by states.
This commit is contained in:
@@ -3,8 +3,8 @@ from utils import State
|
|||||||
|
|
||||||
class BaseCrossover:
|
class BaseCrossover:
|
||||||
|
|
||||||
def setup(self, key, state=State()):
|
def setup(self, state=State()):
|
||||||
return state
|
return state
|
||||||
|
|
||||||
def __call__(self, state, key, genome, nodes1, nodes2, conns1, conns2):
|
def __call__(self, state, genome, nodes1, nodes2, conns1, conns2):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
@@ -5,12 +5,12 @@ from .base import BaseCrossover
|
|||||||
|
|
||||||
class DefaultCrossover(BaseCrossover):
|
class DefaultCrossover(BaseCrossover):
|
||||||
|
|
||||||
def __call__(self, state, key, genome, nodes1, conns1, nodes2, conns2):
|
def __call__(self, state, genome, nodes1, conns1, nodes2, conns2):
|
||||||
"""
|
"""
|
||||||
use genome1 and genome2 to generate a new genome
|
use genome1 and genome2 to generate a new genome
|
||||||
notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
|
notice that genome1 should have higher fitness than genome2 (genome1 is winner!)
|
||||||
"""
|
"""
|
||||||
randkey_1, randkey_2, key = jax.random.split(key, 3)
|
randkey1, randkey2, randkey = jax.random.split(state.randkey, 3)
|
||||||
|
|
||||||
# crossover nodes
|
# crossover nodes
|
||||||
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
|
keys1, keys2 = nodes1[:, 0], nodes2[:, 0]
|
||||||
@@ -20,16 +20,16 @@ class DefaultCrossover(BaseCrossover):
|
|||||||
# For not homologous genes, use the value of nodes1(winner)
|
# For not homologous genes, use the value of nodes1(winner)
|
||||||
# For homologous genes, use the crossover result between nodes1 and nodes2
|
# For homologous genes, use the crossover result between nodes1 and nodes2
|
||||||
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1,
|
new_nodes = jnp.where(jnp.isnan(nodes1) | jnp.isnan(nodes2), nodes1,
|
||||||
self.crossover_gene(randkey_1, nodes1, nodes2, is_conn=False))
|
self.crossover_gene(randkey1, nodes1, nodes2, is_conn=False))
|
||||||
|
|
||||||
# crossover connections
|
# crossover connections
|
||||||
con_keys1, con_keys2 = conns1[:, :2], conns2[:, :2]
|
con_keys1, con_keys2 = conns1[:, :2], conns2[:, :2]
|
||||||
conns2 = self.align_array(con_keys1, con_keys2, conns2, is_conn=True)
|
conns2 = self.align_array(con_keys1, con_keys2, conns2, is_conn=True)
|
||||||
|
|
||||||
new_conns = jnp.where(jnp.isnan(conns1) | jnp.isnan(conns2), conns1,
|
new_conns = jnp.where(jnp.isnan(conns1) | jnp.isnan(conns2), conns1,
|
||||||
self.crossover_gene(randkey_2, conns1, conns2, is_conn=True))
|
self.crossover_gene(randkey2, conns1, conns2, is_conn=True))
|
||||||
|
|
||||||
return new_nodes, new_conns
|
return state.update(randkey=randkey), new_nodes, new_conns
|
||||||
|
|
||||||
def align_array(self, seq1, seq2, ar2, is_conn: bool):
|
def align_array(self, seq1, seq2, ar2, is_conn: bool):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -3,8 +3,8 @@ from utils import State
|
|||||||
|
|
||||||
class BaseMutation:
|
class BaseMutation:
|
||||||
|
|
||||||
def setup(self, key, state=State()):
|
def setup(self, state=State()):
|
||||||
return state
|
return state
|
||||||
|
|
||||||
def __call__(self, state, key, genome, nodes, conns, new_node_key):
|
def __call__(self, state, genome, nodes, conns, new_node_key):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
@@ -17,13 +17,13 @@ class DefaultMutation(BaseMutation):
|
|||||||
self.node_add = node_add
|
self.node_add = node_add
|
||||||
self.node_delete = node_delete
|
self.node_delete = node_delete
|
||||||
|
|
||||||
def __call__(self, state, key, genome, nodes, conns, new_node_key):
|
def __call__(self, state, genome, nodes, conns, new_node_key):
|
||||||
k1, k2 = jax.random.split(key)
|
k1, k2, randkey = jax.random.split(state.randkey)
|
||||||
|
|
||||||
nodes, conns = self.mutate_structure(k1, genome, nodes, conns, new_node_key)
|
nodes, conns = self.mutate_structure(k1, genome, nodes, conns, new_node_key)
|
||||||
nodes, conns = self.mutate_values(k2, genome, nodes, conns)
|
nodes, conns = self.mutate_values(k2, genome, nodes, conns)
|
||||||
|
|
||||||
return nodes, conns
|
return state.update(randkey=randkey), nodes, conns
|
||||||
|
|
||||||
def mutate_structure(self, key, genome, nodes, conns, new_node_key):
|
def mutate_structure(self, key, genome, nodes, conns, new_node_key):
|
||||||
def mutate_add_node(key_, nodes_, conns_):
|
def mutate_add_node(key_, nodes_, conns_):
|
||||||
|
|||||||
@@ -9,13 +9,13 @@ class BaseGene:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def setup(self, key, state=State()):
|
def setup(self, state=State()):
|
||||||
return state
|
return state
|
||||||
|
|
||||||
def new_attrs(self, state, key):
|
def new_attrs(self, state):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def mutate(self, state, key, gene):
|
def mutate(self, state, gene):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def distance(self, state, gene1, gene2):
|
def distance(self, state, gene1, gene2):
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
import jax.random
|
||||||
|
|
||||||
from utils import mutate_float
|
from utils import mutate_float
|
||||||
from . import BaseConnGene
|
from . import BaseConnGene
|
||||||
@@ -24,14 +25,15 @@ class DefaultConnGene(BaseConnGene):
|
|||||||
self.weight_mutate_rate = weight_mutate_rate
|
self.weight_mutate_rate = weight_mutate_rate
|
||||||
self.weight_replace_rate = weight_replace_rate
|
self.weight_replace_rate = weight_replace_rate
|
||||||
|
|
||||||
def new_attrs(self, state, key):
|
def new_attrs(self, state):
|
||||||
return jnp.array([self.weight_init_mean])
|
return state, jnp.array([self.weight_init_mean])
|
||||||
|
|
||||||
def mutate(self, state, key, conn):
|
def mutate(self, state, conn):
|
||||||
|
randkey_, randkey = jax.random.split(state.randkey, 2)
|
||||||
input_index = conn[0]
|
input_index = conn[0]
|
||||||
output_index = conn[1]
|
output_index = conn[1]
|
||||||
enabled = conn[2]
|
enabled = conn[2]
|
||||||
weight = mutate_float(key,
|
weight = mutate_float(randkey_,
|
||||||
conn[3],
|
conn[3],
|
||||||
self.weight_init_mean,
|
self.weight_init_mean,
|
||||||
self.weight_init_std,
|
self.weight_init_std,
|
||||||
@@ -40,11 +42,11 @@ class DefaultConnGene(BaseConnGene):
|
|||||||
self.weight_replace_rate
|
self.weight_replace_rate
|
||||||
)
|
)
|
||||||
|
|
||||||
return jnp.array([input_index, output_index, enabled, weight])
|
return state.update(randkey=randkey), jnp.array([input_index, output_index, enabled, weight])
|
||||||
|
|
||||||
def distance(self, state, attrs1, attrs2):
|
def distance(self, state, attrs1, attrs2):
|
||||||
return (attrs1[2] != attrs2[2]) + jnp.abs(attrs1[3] - attrs2[3]) # enable + weight
|
return state, (attrs1[2] != attrs2[2]) + jnp.abs(attrs1[3] - attrs2[3]) # enable + weight
|
||||||
|
|
||||||
def forward(self, state, attrs, inputs):
|
def forward(self, state, attrs, inputs):
|
||||||
weight = attrs[0]
|
weight = attrs[0]
|
||||||
return inputs * weight
|
return state, inputs * weight
|
||||||
|
|||||||
@@ -56,13 +56,13 @@ class DefaultNodeGene(BaseNodeGene):
|
|||||||
self.aggregation_indices = jnp.arange(len(aggregation_options))
|
self.aggregation_indices = jnp.arange(len(aggregation_options))
|
||||||
self.aggregation_replace_rate = aggregation_replace_rate
|
self.aggregation_replace_rate = aggregation_replace_rate
|
||||||
|
|
||||||
def new_attrs(self, state, key):
|
def new_attrs(self, state):
|
||||||
return jnp.array(
|
return state, jnp.array(
|
||||||
[self.bias_init_mean, self.response_init_mean, self.activation_default, self.aggregation_default]
|
[self.bias_init_mean, self.response_init_mean, self.activation_default, self.aggregation_default]
|
||||||
)
|
)
|
||||||
|
|
||||||
def mutate(self, state, key, node):
|
def mutate(self, state, node):
|
||||||
k1, k2, k3, k4 = jax.random.split(key, num=4)
|
k1, k2, k3, k4, randkey = jax.random.split(state.randkey, num=5)
|
||||||
index = node[0]
|
index = node[0]
|
||||||
|
|
||||||
bias = mutate_float(k1, node[1], self.bias_init_mean, self.bias_init_std,
|
bias = mutate_float(k1, node[1], self.bias_init_mean, self.bias_init_std,
|
||||||
@@ -75,10 +75,10 @@ class DefaultNodeGene(BaseNodeGene):
|
|||||||
|
|
||||||
agg = mutate_int(k4, node[4], self.aggregation_indices, self.aggregation_replace_rate)
|
agg = mutate_int(k4, node[4], self.aggregation_indices, self.aggregation_replace_rate)
|
||||||
|
|
||||||
return jnp.array([index, bias, res, act, agg])
|
return state.update(randkey=randkey), jnp.array([index, bias, res, act, agg])
|
||||||
|
|
||||||
def distance(self, state, node1, node2):
|
def distance(self, state, node1, node2):
|
||||||
return (
|
return state, (
|
||||||
jnp.abs(node1[1] - node2[1]) +
|
jnp.abs(node1[1] - node2[1]) +
|
||||||
jnp.abs(node1[2] - node2[2]) +
|
jnp.abs(node1[2] - node2[2]) +
|
||||||
(node1[3] != node2[3]) +
|
(node1[3] != node2[3]) +
|
||||||
@@ -98,4 +98,4 @@ class DefaultNodeGene(BaseNodeGene):
|
|||||||
lambda: act(act_idx, z, self.activation_options)
|
lambda: act(act_idx, z, self.activation_options)
|
||||||
)
|
)
|
||||||
|
|
||||||
return z
|
return state, z
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ class BaseGenome:
|
|||||||
self.node_gene = node_gene
|
self.node_gene = node_gene
|
||||||
self.conn_gene = conn_gene
|
self.conn_gene = conn_gene
|
||||||
|
|
||||||
def setup(self, key, state=State()):
|
def setup(self, state=State()):
|
||||||
return state
|
return state
|
||||||
|
|
||||||
def transform(self, state, nodes, conns):
|
def transform(self, state, nodes, conns):
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ class DefaultGenome(BaseGenome):
|
|||||||
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
|
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
|
||||||
seqs = topological_sort(nodes, conn_enable)
|
seqs = topological_sort(nodes, conn_enable)
|
||||||
|
|
||||||
return seqs, nodes, u_conns
|
return state, seqs, nodes, u_conns
|
||||||
|
|
||||||
def forward(self, state, inputs, transformed):
|
def forward(self, state, inputs, transformed):
|
||||||
cal_seqs, nodes, conns = transformed
|
cal_seqs, nodes, conns = transformed
|
||||||
@@ -49,34 +49,32 @@ class DefaultGenome(BaseGenome):
|
|||||||
nodes_attrs = nodes[:, 1:]
|
nodes_attrs = nodes[:, 1:]
|
||||||
|
|
||||||
def cond_fun(carry):
|
def cond_fun(carry):
|
||||||
values, idx = carry
|
state_, values, idx = carry
|
||||||
return (idx < N) & (cal_seqs[idx] != I_INT)
|
return (idx < N) & (cal_seqs[idx] != I_INT)
|
||||||
|
|
||||||
def body_func(carry):
|
def body_func(carry):
|
||||||
values, idx = carry
|
state_, values, idx = carry
|
||||||
i = cal_seqs[idx]
|
i = cal_seqs[idx]
|
||||||
|
|
||||||
def hit():
|
def hit():
|
||||||
ins = jax.vmap(self.conn_gene.forward, in_axes=(None, 1, 0))(conns[:, :, i], values)
|
s, ins = jax.vmap(self.conn_gene.forward,
|
||||||
z = self.node_gene.forward(state, nodes_attrs[i], ins, is_output_node=jnp.isin(i, self.output_idx))
|
in_axes=(None, 1, 0), out_axes=(None, 0))(state_, conns[:, :, i], values)
|
||||||
|
s, z = self.node_gene.forward(s, nodes_attrs[i], ins, is_output_node=jnp.isin(i, self.output_idx))
|
||||||
new_values = values.at[i].set(z)
|
new_values = values.at[i].set(z)
|
||||||
return new_values
|
return s, new_values
|
||||||
|
|
||||||
def miss():
|
|
||||||
return values
|
|
||||||
|
|
||||||
# the val of input nodes is obtained by the task, not by calculation
|
# the val of input nodes is obtained by the task, not by calculation
|
||||||
values = jax.lax.cond(
|
state_, values = jax.lax.cond(
|
||||||
jnp.isin(i, self.input_idx),
|
jnp.isin(i, self.input_idx),
|
||||||
miss,
|
lambda: (state_, values),
|
||||||
hit
|
hit
|
||||||
)
|
)
|
||||||
|
|
||||||
return values, idx + 1
|
return state_, values, idx + 1
|
||||||
|
|
||||||
vals, _ = jax.lax.while_loop(cond_fun, body_func, (ini_vals, 0))
|
state, vals, _ = jax.lax.while_loop(cond_fun, body_func, (state, ini_vals, 0))
|
||||||
|
|
||||||
if self.output_transform is None:
|
if self.output_transform is None:
|
||||||
return vals[self.output_idx]
|
return state, vals[self.output_idx]
|
||||||
else:
|
else:
|
||||||
return self.output_transform(vals[self.output_idx])
|
return state, self.output_transform(vals[self.output_idx])
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ class RecurrentGenome(BaseGenome):
|
|||||||
conn_enable = u_conns[0] == 1
|
conn_enable = u_conns[0] == 1
|
||||||
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
|
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
|
||||||
|
|
||||||
return nodes, u_conns
|
return state, nodes, u_conns
|
||||||
|
|
||||||
def forward(self, state, inputs, transformed):
|
def forward(self, state, inputs, transformed):
|
||||||
nodes, conns = transformed
|
nodes, conns = transformed
|
||||||
@@ -48,27 +48,36 @@ class RecurrentGenome(BaseGenome):
|
|||||||
vals = jnp.full((N,), jnp.nan)
|
vals = jnp.full((N,), jnp.nan)
|
||||||
nodes_attrs = nodes[:, 1:]
|
nodes_attrs = nodes[:, 1:]
|
||||||
|
|
||||||
def body_func(_, values):
|
def body_func(_, carry):
|
||||||
|
state_, values = carry
|
||||||
|
|
||||||
# set input values
|
# set input values
|
||||||
values = values.at[self.input_idx].set(inputs)
|
values = values.at[self.input_idx].set(inputs)
|
||||||
|
|
||||||
# calculate connections
|
# calculate connections
|
||||||
node_ins = jax.vmap(
|
state_, node_ins = jax.vmap(
|
||||||
jax.vmap(
|
jax.vmap(
|
||||||
self.conn_gene.forward,
|
self.conn_gene.forward,
|
||||||
in_axes=(None, 1, None)
|
in_axes=(None, 1, None),
|
||||||
|
out_axes=(None, 0)
|
||||||
),
|
),
|
||||||
in_axes=(None, 1, 0)
|
in_axes=(None, 1, 0),
|
||||||
)(state, conns, values)
|
out_axes=(None, 0)
|
||||||
|
)(state_, conns, values)
|
||||||
|
|
||||||
# calculate nodes
|
# calculate nodes
|
||||||
is_output_nodes = jnp.isin(
|
is_output_nodes = jnp.isin(
|
||||||
jnp.arange(N),
|
jnp.arange(N),
|
||||||
self.output_idx
|
self.output_idx
|
||||||
)
|
)
|
||||||
values = jax.vmap(self.node_gene.forward, in_axes=(None, 0, 0))(nodes_attrs, node_ins.T, is_output_nodes)
|
state_, values = jax.vmap(
|
||||||
return values
|
self.node_gene.forward,
|
||||||
|
in_axes=(None, 0, 0, 0),
|
||||||
|
out_axes=(None, 0)
|
||||||
|
)(state_, nodes_attrs, node_ins.T, is_output_nodes)
|
||||||
|
|
||||||
vals = jax.lax.fori_loop(0, self.activate_time, body_func, vals)
|
return state_, values
|
||||||
|
|
||||||
return vals[self.output_idx]
|
state, vals = jax.lax.fori_loop(0, self.activate_time, body_func, (state, vals))
|
||||||
|
|
||||||
|
return state, vals[self.output_idx]
|
||||||
|
|||||||
@@ -40,8 +40,8 @@ class DefaultSpecies(BaseSpecies):
|
|||||||
|
|
||||||
self.species_arange = jnp.arange(self.species_size)
|
self.species_arange = jnp.arange(self.species_size)
|
||||||
|
|
||||||
def setup(self, randkey):
|
def setup(self, key, state=State()):
|
||||||
k1, k2 = jax.random.split(randkey, 2)
|
k1, k2 = jax.random.split(key, 2)
|
||||||
pop_nodes, pop_conns = initialize_population(self.pop_size, self.genome, k1, self.initialize_method)
|
pop_nodes, pop_conns = initialize_population(self.pop_size, self.genome, k1, self.initialize_method)
|
||||||
|
|
||||||
species_keys = jnp.full((self.species_size,), jnp.nan) # the unique index (primary key) for each species
|
species_keys = jnp.full((self.species_size,), jnp.nan) # the unique index (primary key) for each species
|
||||||
@@ -65,8 +65,8 @@ class DefaultSpecies(BaseSpecies):
|
|||||||
|
|
||||||
pop_nodes, pop_conns = jax.device_put((pop_nodes, pop_conns))
|
pop_nodes, pop_conns = jax.device_put((pop_nodes, pop_conns))
|
||||||
|
|
||||||
return State(
|
return state.register(
|
||||||
randkey=k2,
|
species_randkey=k2,
|
||||||
pop_nodes=pop_nodes,
|
pop_nodes=pop_nodes,
|
||||||
pop_conns=pop_conns,
|
pop_conns=pop_conns,
|
||||||
species_keys=species_keys,
|
species_keys=species_keys,
|
||||||
@@ -134,8 +134,7 @@ class DefaultSpecies(BaseSpecies):
|
|||||||
def check_stagnation(idx):
|
def check_stagnation(idx):
|
||||||
# determine whether the species stagnation
|
# determine whether the species stagnation
|
||||||
st = (
|
st = (
|
||||||
(species_fitness[idx] <= state.best_fitness[
|
(species_fitness[idx] <= state.best_fitness[idx]) & # not better than the best fitness of the species
|
||||||
idx]) & # not better than the best fitness of the species
|
|
||||||
(generation - state.last_improved[idx] > self.max_stagnation) # for a long time
|
(generation - state.last_improved[idx] > self.max_stagnation) # for a long time
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from algorithm.neat import *
|
from algorithm.neat import *
|
||||||
from utils import Act, Agg
|
from utils import Act, Agg, State
|
||||||
|
|
||||||
import jax, jax.numpy as jnp
|
import jax, jax.numpy as jnp
|
||||||
|
|
||||||
@@ -36,11 +36,14 @@ def test_default():
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
transformed = genome.transform(nodes, conns)
|
state = genome.setup(State(randkey=jax.random.key(0)))
|
||||||
|
|
||||||
|
state, *transformed = genome.transform(state, nodes, conns)
|
||||||
print(*transformed, sep='\n')
|
print(*transformed, sep='\n')
|
||||||
|
|
||||||
inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]])
|
inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]])
|
||||||
outputs = jax.jit(jax.vmap(genome.forward, in_axes=(0, None)))(inputs, transformed)
|
state, outputs = jax.jit(jax.vmap(genome.forward,
|
||||||
|
in_axes=(None, 0, None), out_axes=(None, 0)))(state, inputs, transformed)
|
||||||
print(outputs)
|
print(outputs)
|
||||||
assert jnp.allclose(outputs, jnp.array([[0.5], [0.75], [0.75], [1]]))
|
assert jnp.allclose(outputs, jnp.array([[0.5], [0.75], [0.75], [1]]))
|
||||||
# expected: [[0.5], [0.75], [0.75], [1]]
|
# expected: [[0.5], [0.75], [0.75], [1]]
|
||||||
@@ -50,11 +53,11 @@ def test_default():
|
|||||||
conns = conns.at[0, 2].set(False) # disable in[0] -> hidden[0]
|
conns = conns.at[0, 2].set(False) # disable in[0] -> hidden[0]
|
||||||
print(conns)
|
print(conns)
|
||||||
|
|
||||||
transformed = genome.transform(nodes, conns)
|
state, *transformed = genome.transform(state, nodes, conns)
|
||||||
print(*transformed, sep='\n')
|
print(*transformed, sep='\n')
|
||||||
|
|
||||||
inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]])
|
inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]])
|
||||||
outputs = jax.vmap(genome.forward, in_axes=(0, None))(inputs, transformed)
|
state, outputs = jax.vmap(genome.forward, in_axes=(None, 0, None), out_axes=(None, 0))(state, inputs, transformed)
|
||||||
print(outputs)
|
print(outputs)
|
||||||
assert jnp.allclose(outputs, jnp.array([[0], [0.25], [0], [0.25]]))
|
assert jnp.allclose(outputs, jnp.array([[0], [0.25], [0], [0.25]]))
|
||||||
# expected: [[0.5], [0.75], [0.5], [0.75]]
|
# expected: [[0.5], [0.75], [0.5], [0.75]]
|
||||||
@@ -93,11 +96,14 @@ def test_recurrent():
|
|||||||
activate_time=3,
|
activate_time=3,
|
||||||
)
|
)
|
||||||
|
|
||||||
transformed = genome.transform(nodes, conns)
|
state = genome.setup(State(randkey=jax.random.key(0)))
|
||||||
|
|
||||||
|
state, *transformed = genome.transform(state, nodes, conns)
|
||||||
print(*transformed, sep='\n')
|
print(*transformed, sep='\n')
|
||||||
|
|
||||||
inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]])
|
inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]])
|
||||||
outputs = jax.jit(jax.vmap(genome.forward, in_axes=(0, None)))(inputs, transformed)
|
state, outputs = jax.jit(jax.vmap(genome.forward,
|
||||||
|
in_axes=(None, 0, None), out_axes=(None, 0)))(state, inputs, transformed)
|
||||||
print(outputs)
|
print(outputs)
|
||||||
assert jnp.allclose(outputs, jnp.array([[0.5], [0.75], [0.75], [1]]))
|
assert jnp.allclose(outputs, jnp.array([[0.5], [0.75], [0.75], [1]]))
|
||||||
# expected: [[0.5], [0.75], [0.75], [1]]
|
# expected: [[0.5], [0.75], [0.75], [1]]
|
||||||
@@ -107,11 +113,11 @@ def test_recurrent():
|
|||||||
conns = conns.at[0, 2].set(False) # disable in[0] -> hidden[0]
|
conns = conns.at[0, 2].set(False) # disable in[0] -> hidden[0]
|
||||||
print(conns)
|
print(conns)
|
||||||
|
|
||||||
transformed = genome.transform(nodes, conns)
|
state, *transformed = genome.transform(state, nodes, conns)
|
||||||
print(*transformed, sep='\n')
|
print(*transformed, sep='\n')
|
||||||
|
|
||||||
inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]])
|
inputs = jnp.array([[0, 0],[0, 1], [1, 0], [1, 1]])
|
||||||
outputs = jax.vmap(genome.forward, in_axes=(0, None))(inputs, transformed)
|
state, outputs = jax.vmap(genome.forward, in_axes=(None, 0, None), out_axes=(None, 0))(state, inputs, transformed)
|
||||||
print(outputs)
|
print(outputs)
|
||||||
assert jnp.allclose(outputs, jnp.array([[0], [0.25], [0], [0.25]]))
|
assert jnp.allclose(outputs, jnp.array([[0], [0.25], [0], [0.25]]))
|
||||||
# expected: [[0.5], [0.75], [0.5], [0.75]]
|
# expected: [[0.5], [0.75], [0.5], [0.75]]
|
||||||
Reference in New Issue
Block a user