remove attr enable for conn
This commit is contained in:
@@ -45,8 +45,8 @@ class DefaultMutation(BaseMutation):
|
|||||||
i_key, o_key, idx = self.choice_connection_key(key_, conns_)
|
i_key, o_key, idx = self.choice_connection_key(key_, conns_)
|
||||||
|
|
||||||
def successful_add_node():
|
def successful_add_node():
|
||||||
# disable the connection
|
# remove the original connection
|
||||||
new_conns = conns_.at[idx, 2].set(False)
|
new_conns = delete_conn_by_pos(conns_, idx)
|
||||||
|
|
||||||
# add a new node
|
# add a new node
|
||||||
new_nodes = add_node(
|
new_nodes = add_node(
|
||||||
@@ -58,14 +58,12 @@ class DefaultMutation(BaseMutation):
|
|||||||
new_conns,
|
new_conns,
|
||||||
i_key,
|
i_key,
|
||||||
new_node_key,
|
new_node_key,
|
||||||
True,
|
|
||||||
genome.conn_gene.new_custom_attrs(state),
|
genome.conn_gene.new_custom_attrs(state),
|
||||||
)
|
)
|
||||||
new_conns = add_conn(
|
new_conns = add_conn(
|
||||||
new_conns,
|
new_conns,
|
||||||
new_node_key,
|
new_node_key,
|
||||||
o_key,
|
o_key,
|
||||||
True,
|
|
||||||
genome.conn_gene.new_custom_attrs(state),
|
genome.conn_gene.new_custom_attrs(state),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -140,27 +138,26 @@ class DefaultMutation(BaseMutation):
|
|||||||
|
|
||||||
def successful():
|
def successful():
|
||||||
return nodes_, add_conn(
|
return nodes_, add_conn(
|
||||||
conns_, i_key, o_key, True, genome.conn_gene.new_custom_attrs(state)
|
conns_, i_key, o_key, genome.conn_gene.new_custom_attrs(state)
|
||||||
)
|
)
|
||||||
|
|
||||||
def already_exist():
|
|
||||||
return nodes_, conns_.at[conn_pos, 2].set(True)
|
|
||||||
|
|
||||||
if genome.network_type == "feedforward":
|
if genome.network_type == "feedforward":
|
||||||
u_cons = unflatten_conns(nodes_, conns_)
|
u_cons = unflatten_conns(nodes_, conns_)
|
||||||
cons_exist = ~jnp.isnan(u_cons[0, :, :])
|
conns_exist = ~jnp.isnan(u_cons[0, :, :])
|
||||||
is_cycle = check_cycles(nodes_, cons_exist, from_idx, to_idx)
|
is_cycle = check_cycles(nodes_, conns_exist, from_idx, to_idx)
|
||||||
|
|
||||||
return jax.lax.cond(
|
return jax.lax.cond(
|
||||||
is_already_exist,
|
is_already_exist | is_cycle | (remain_conn_space < 1),
|
||||||
already_exist,
|
nothing,
|
||||||
lambda: jax.lax.cond(
|
successful,
|
||||||
is_cycle & (remain_conn_space < 1), nothing, successful
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
elif genome.network_type == "recurrent":
|
elif genome.network_type == "recurrent":
|
||||||
return jax.lax.cond(is_already_exist, already_exist, successful)
|
return jax.lax.cond(
|
||||||
|
is_already_exist | (remain_conn_space < 1),
|
||||||
|
nothing,
|
||||||
|
successful,
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid network type: {genome.network_type}")
|
raise ValueError(f"Invalid network type: {genome.network_type}")
|
||||||
@@ -169,19 +166,16 @@ class DefaultMutation(BaseMutation):
|
|||||||
# randomly choose a connection
|
# randomly choose a connection
|
||||||
i_key, o_key, idx = self.choice_connection_key(key_, conns_)
|
i_key, o_key, idx = self.choice_connection_key(key_, conns_)
|
||||||
|
|
||||||
def successfully_delete_connection():
|
|
||||||
return nodes_, delete_conn_by_pos(conns_, idx)
|
|
||||||
|
|
||||||
return jax.lax.cond(
|
return jax.lax.cond(
|
||||||
idx == I_INF,
|
idx == I_INF,
|
||||||
lambda: (nodes_, conns_), # nothing
|
lambda: (nodes_, conns_), # nothing
|
||||||
successfully_delete_connection,
|
lambda: (nodes_, delete_conn_by_pos(conns_, idx)), # success
|
||||||
)
|
)
|
||||||
|
|
||||||
k1, k2, k3, k4 = jax.random.split(randkey, num=4)
|
k1, k2, k3, k4 = jax.random.split(randkey, num=4)
|
||||||
r1, r2, r3, r4 = jax.random.uniform(k1, shape=(4,))
|
r1, r2, r3, r4 = jax.random.uniform(k1, shape=(4,))
|
||||||
|
|
||||||
def no(key_, nodes_, conns_):
|
def no(_, nodes_, conns_):
|
||||||
return nodes_, conns_
|
return nodes_, conns_
|
||||||
|
|
||||||
if self.node_add > 0:
|
if self.node_add > 0:
|
||||||
|
|||||||
@@ -4,29 +4,18 @@ from .. import BaseGene
|
|||||||
|
|
||||||
class BaseConnGene(BaseGene):
|
class BaseConnGene(BaseGene):
|
||||||
"Base class for connection genes."
|
"Base class for connection genes."
|
||||||
fixed_attrs = ["input_index", "output_index", "enabled"]
|
fixed_attrs = ["input_index", "output_index"]
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def crossover(self, state, randkey, gene1, gene2):
|
def crossover(self, state, randkey, gene1, gene2):
|
||||||
def crossover_attr():
|
|
||||||
return jnp.where(
|
return jnp.where(
|
||||||
jax.random.normal(randkey, gene1.shape) > 0,
|
jax.random.normal(randkey, gene1.shape) > 0,
|
||||||
gene1,
|
gene1,
|
||||||
gene2,
|
gene2,
|
||||||
)
|
)
|
||||||
|
|
||||||
return jax.lax.cond(
|
|
||||||
gene1[2] == gene2[2], # if both genes are enabled or disabled
|
|
||||||
crossover_attr, # then randomly pick attributes from gene1 or gene2
|
|
||||||
lambda: jnp.where( # one gene is enabled and the other is disabled
|
|
||||||
gene1[2], # if gene1 is enabled
|
|
||||||
gene1, # then return gene1
|
|
||||||
gene2, # else return gene2
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, state, attrs, inputs):
|
def forward(self, state, attrs, inputs):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|||||||
@@ -38,10 +38,9 @@ class DefaultConnGene(BaseConnGene):
|
|||||||
def mutate(self, state, randkey, conn):
|
def mutate(self, state, randkey, conn):
|
||||||
input_index = conn[0]
|
input_index = conn[0]
|
||||||
output_index = conn[1]
|
output_index = conn[1]
|
||||||
enabled = conn[2]
|
|
||||||
weight = mutate_float(
|
weight = mutate_float(
|
||||||
randkey,
|
randkey,
|
||||||
conn[3],
|
conn[2],
|
||||||
self.weight_init_mean,
|
self.weight_init_mean,
|
||||||
self.weight_init_std,
|
self.weight_init_std,
|
||||||
self.weight_mutate_power,
|
self.weight_mutate_power,
|
||||||
@@ -49,12 +48,10 @@ class DefaultConnGene(BaseConnGene):
|
|||||||
self.weight_replace_rate,
|
self.weight_replace_rate,
|
||||||
)
|
)
|
||||||
|
|
||||||
return jnp.array([input_index, output_index, enabled, weight])
|
return jnp.array([input_index, output_index, weight])
|
||||||
|
|
||||||
def distance(self, state, attrs1, attrs2):
|
def distance(self, state, attrs1, attrs2):
|
||||||
return (attrs1[2] != attrs2[2]) + jnp.abs(
|
return jnp.abs(attrs1[0] - attrs2[0])
|
||||||
attrs1[3] - attrs2[3]
|
|
||||||
) # enable + weight
|
|
||||||
|
|
||||||
def forward(self, state, attrs, inputs):
|
def forward(self, state, attrs, inputs):
|
||||||
weight = attrs[0]
|
weight = attrs[0]
|
||||||
|
|||||||
@@ -106,21 +106,19 @@ class BaseGenome:
|
|||||||
self.input_idx, jnp.full_like(self.input_idx, new_node_key)
|
self.input_idx, jnp.full_like(self.input_idx, new_node_key)
|
||||||
]
|
]
|
||||||
conns = conns.at[self.input_idx, :2].set(input_conns) # in-keys, out-keys
|
conns = conns.at[self.input_idx, :2].set(input_conns) # in-keys, out-keys
|
||||||
conns = conns.at[self.input_idx, 2].set(True) # enable
|
|
||||||
|
|
||||||
# output-hidden connections
|
# output-hidden connections
|
||||||
output_conns = jnp.c_[
|
output_conns = jnp.c_[
|
||||||
jnp.full_like(self.output_idx, new_node_key), self.output_idx
|
jnp.full_like(self.output_idx, new_node_key), self.output_idx
|
||||||
]
|
]
|
||||||
conns = conns.at[self.output_idx, :2].set(output_conns) # in-keys, out-keys
|
conns = conns.at[self.output_idx, :2].set(output_conns) # in-keys, out-keys
|
||||||
conns = conns.at[self.output_idx, 2].set(True) # enable
|
|
||||||
|
|
||||||
conn_keys = jax.random.split(k2, num=len(self.input_idx) + len(self.output_idx))
|
conn_keys = jax.random.split(k2, num=len(self.input_idx) + len(self.output_idx))
|
||||||
# generate random attributes for conns
|
# generate random attributes for conns
|
||||||
random_conn_attrs = jax.vmap(
|
random_conn_attrs = jax.vmap(
|
||||||
self.conn_gene.new_random_attrs, in_axes=(None, 0)
|
self.conn_gene.new_random_attrs, in_axes=(None, 0)
|
||||||
)(state, conn_keys)
|
)(state, conn_keys)
|
||||||
conns = conns.at[: len(conn_keys), 3:].set(random_conn_attrs)
|
conns = conns.at[: len(conn_keys), 2:].set(random_conn_attrs)
|
||||||
|
|
||||||
return nodes, conns
|
return nodes, conns
|
||||||
|
|
||||||
|
|||||||
@@ -45,19 +45,15 @@ class DefaultGenome(BaseGenome):
|
|||||||
|
|
||||||
def transform(self, state, nodes, conns):
|
def transform(self, state, nodes, conns):
|
||||||
u_conns = unflatten_conns(nodes, conns)
|
u_conns = unflatten_conns(nodes, conns)
|
||||||
conn_enable = u_conns[0] == 1
|
conn_exist = ~jnp.isnan(u_conns[0])
|
||||||
|
|
||||||
# remove enable attr
|
seqs = topological_sort(nodes, conn_exist)
|
||||||
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
|
|
||||||
seqs = topological_sort(nodes, conn_enable)
|
|
||||||
|
|
||||||
return seqs, nodes, u_conns
|
return seqs, nodes, u_conns
|
||||||
|
|
||||||
def restore(self, state, transformed):
|
def restore(self, state, transformed):
|
||||||
seqs, nodes, u_conns = transformed
|
seqs, nodes, u_conns = transformed
|
||||||
conns = flatten_conns(nodes, u_conns, C=self.max_conns)
|
conns = flatten_conns(nodes, u_conns, C=self.max_conns)
|
||||||
# restore enable
|
|
||||||
conns = jnp.insert(conns, obj=2, values=1, axis=1)
|
|
||||||
return nodes, conns
|
return nodes, conns
|
||||||
|
|
||||||
def forward(self, state, inputs, transformed):
|
def forward(self, state, inputs, transformed):
|
||||||
@@ -79,14 +75,15 @@ class DefaultGenome(BaseGenome):
|
|||||||
ins = jax.vmap(self.conn_gene.forward, in_axes=(None, 1, 0))(
|
ins = jax.vmap(self.conn_gene.forward, in_axes=(None, 1, 0))(
|
||||||
state, u_conns[:, :, i], values
|
state, u_conns[:, :, i], values
|
||||||
)
|
)
|
||||||
|
|
||||||
z = self.node_gene.forward(
|
z = self.node_gene.forward(
|
||||||
state,
|
state,
|
||||||
nodes_attrs[i],
|
nodes_attrs[i],
|
||||||
ins,
|
ins,
|
||||||
is_output_node=jnp.isin(i, self.output_idx),
|
is_output_node=jnp.isin(i, self.output_idx),
|
||||||
)
|
)
|
||||||
new_values = values.at[i].set(z)
|
|
||||||
|
|
||||||
|
new_values = values.at[i].set(z)
|
||||||
return new_values
|
return new_values
|
||||||
|
|
||||||
# the val of input nodes is obtained by the task, not by calculation
|
# the val of input nodes is obtained by the task, not by calculation
|
||||||
|
|||||||
@@ -47,19 +47,11 @@ class RecurrentGenome(BaseGenome):
|
|||||||
|
|
||||||
def transform(self, state, nodes, conns):
|
def transform(self, state, nodes, conns):
|
||||||
u_conns = unflatten_conns(nodes, conns)
|
u_conns = unflatten_conns(nodes, conns)
|
||||||
|
|
||||||
# remove un-enable connections and remove enable attr
|
|
||||||
conn_enable = u_conns[0] == 1
|
|
||||||
u_conns = jnp.where(conn_enable, u_conns[1:, :], jnp.nan)
|
|
||||||
|
|
||||||
return nodes, u_conns
|
return nodes, u_conns
|
||||||
|
|
||||||
def restore(self, state, transformed):
|
def restore(self, state, transformed):
|
||||||
nodes, u_conns = transformed
|
nodes, u_conns = transformed
|
||||||
conns = flatten_conns(nodes, u_conns, C=self.max_conns)
|
conns = flatten_conns(nodes, u_conns, C=self.max_conns)
|
||||||
|
|
||||||
# restore enable
|
|
||||||
conns = jnp.insert(conns, obj=2, values=1, axis=1)
|
|
||||||
return nodes, conns
|
return nodes, conns
|
||||||
|
|
||||||
def forward(self, state, inputs, transformed):
|
def forward(self, state, inputs, transformed):
|
||||||
|
|||||||
@@ -11,8 +11,8 @@ if __name__ == "__main__":
|
|||||||
genome=DefaultGenome(
|
genome=DefaultGenome(
|
||||||
num_inputs=3,
|
num_inputs=3,
|
||||||
num_outputs=1,
|
num_outputs=1,
|
||||||
max_nodes=5,
|
max_nodes=50,
|
||||||
max_conns=10,
|
max_conns=100,
|
||||||
node_gene=DefaultNodeGene(
|
node_gene=DefaultNodeGene(
|
||||||
activation_default=Act.tanh,
|
activation_default=Act.tanh,
|
||||||
activation_options=(Act.tanh,),
|
activation_options=(Act.tanh,),
|
||||||
@@ -21,8 +21,8 @@ if __name__ == "__main__":
|
|||||||
mutation=DefaultMutation(
|
mutation=DefaultMutation(
|
||||||
node_add=0.1,
|
node_add=0.1,
|
||||||
conn_add=0.1,
|
conn_add=0.1,
|
||||||
node_delete=0.1,
|
node_delete=0.05,
|
||||||
conn_delete=0.1,
|
conn_delete=0.05,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
pop_size=1000,
|
pop_size=1000,
|
||||||
|
|||||||
@@ -21,10 +21,10 @@ def test_default():
|
|||||||
# in_node, out_node, enable, weight
|
# in_node, out_node, enable, weight
|
||||||
conns = jnp.array(
|
conns = jnp.array(
|
||||||
[
|
[
|
||||||
[0, 3, 1, 0.5], # in[0] -> hidden[0]
|
[0, 3, 0.5], # in[0] -> hidden[0]
|
||||||
[1, 4, 1, 0.5], # in[1] -> hidden[1]
|
[1, 4, 0.5], # in[1] -> hidden[1]
|
||||||
[3, 2, 1, 0.5], # hidden[0] -> out[0]
|
[3, 2, 0.5], # hidden[0] -> out[0]
|
||||||
[4, 2, 1, 0.5], # hidden[1] -> out[0]
|
[4, 2, 0.5], # hidden[1] -> out[0]
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -54,22 +54,6 @@ def test_default():
|
|||||||
assert jnp.allclose(outputs, jnp.array([[0.5], [0.75], [0.75], [1]]))
|
assert jnp.allclose(outputs, jnp.array([[0.5], [0.75], [0.75], [1]]))
|
||||||
# expected: [[0.5], [0.75], [0.75], [1]]
|
# expected: [[0.5], [0.75], [0.75], [1]]
|
||||||
|
|
||||||
print("\n-------------------------------------------------------\n")
|
|
||||||
|
|
||||||
conns = conns.at[0, 2].set(False) # disable in[0] -> hidden[0]
|
|
||||||
print(conns)
|
|
||||||
|
|
||||||
transformed = genome.transform(state, nodes, conns)
|
|
||||||
print(*transformed, sep="\n")
|
|
||||||
|
|
||||||
inputs = jnp.array([[0, 0], [0, 1], [1, 0], [1, 1]])
|
|
||||||
outputs = jax.vmap(genome.forward, in_axes=(None, 0, None))(
|
|
||||||
state, inputs, transformed
|
|
||||||
)
|
|
||||||
print(outputs)
|
|
||||||
assert jnp.allclose(outputs, jnp.array([[0], [0.25], [0], [0.25]]))
|
|
||||||
# expected: [[0.5], [0.75], [0.5], [0.75]]
|
|
||||||
|
|
||||||
|
|
||||||
def test_recurrent():
|
def test_recurrent():
|
||||||
|
|
||||||
@@ -87,10 +71,10 @@ def test_recurrent():
|
|||||||
# in_node, out_node, enable, weight
|
# in_node, out_node, enable, weight
|
||||||
conns = jnp.array(
|
conns = jnp.array(
|
||||||
[
|
[
|
||||||
[0, 3, 1, 0.5], # in[0] -> hidden[0]
|
[0, 3, 0.5], # in[0] -> hidden[0]
|
||||||
[1, 4, 1, 0.5], # in[1] -> hidden[1]
|
[1, 4, 0.5], # in[1] -> hidden[1]
|
||||||
[3, 2, 1, 0.5], # hidden[0] -> out[0]
|
[3, 2, 0.5], # hidden[0] -> out[0]
|
||||||
[4, 2, 1, 0.5], # hidden[1] -> out[0]
|
[4, 2, 0.5], # hidden[1] -> out[0]
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -121,22 +105,6 @@ def test_recurrent():
|
|||||||
assert jnp.allclose(outputs, jnp.array([[0.5], [0.75], [0.75], [1]]))
|
assert jnp.allclose(outputs, jnp.array([[0.5], [0.75], [0.75], [1]]))
|
||||||
# expected: [[0.5], [0.75], [0.75], [1]]
|
# expected: [[0.5], [0.75], [0.75], [1]]
|
||||||
|
|
||||||
print("\n-------------------------------------------------------\n")
|
|
||||||
|
|
||||||
conns = conns.at[0, 2].set(False) # disable in[0] -> hidden[0]
|
|
||||||
print(conns)
|
|
||||||
|
|
||||||
transformed = genome.transform(state, nodes, conns)
|
|
||||||
print(*transformed, sep="\n")
|
|
||||||
|
|
||||||
inputs = jnp.array([[0, 0], [0, 1], [1, 0], [1, 1]])
|
|
||||||
outputs = jax.vmap(genome.forward, in_axes=(None, 0, None))(
|
|
||||||
state, inputs, transformed
|
|
||||||
)
|
|
||||||
print(outputs)
|
|
||||||
assert jnp.allclose(outputs, jnp.array([[0], [0.25], [0], [0.25]]))
|
|
||||||
# expected: [[0.5], [0.75], [0.5], [0.75]]
|
|
||||||
|
|
||||||
|
|
||||||
def test_random_initialize():
|
def test_random_initialize():
|
||||||
genome = DefaultGenome(
|
genome = DefaultGenome(
|
||||||
|
|||||||
@@ -168,15 +168,15 @@ def delete_node_by_pos(nodes, pos):
|
|||||||
return nodes.at[pos].set(jnp.nan)
|
return nodes.at[pos].set(jnp.nan)
|
||||||
|
|
||||||
|
|
||||||
def add_conn(conns, i_key, o_key, enable: bool, attrs):
|
def add_conn(conns, i_key, o_key, attrs):
|
||||||
"""
|
"""
|
||||||
Add a new connection to the genome.
|
Add a new connection to the genome.
|
||||||
The new connection will place at the first NaN row.
|
The new connection will place at the first NaN row.
|
||||||
"""
|
"""
|
||||||
con_keys = conns[:, 0]
|
con_keys = conns[:, 0]
|
||||||
pos = fetch_first(jnp.isnan(con_keys))
|
pos = fetch_first(jnp.isnan(con_keys))
|
||||||
new_conns = conns.at[pos, 0:3].set(jnp.array([i_key, o_key, enable]))
|
new_conns = conns.at[pos, 0:2].set(jnp.array([i_key, o_key]))
|
||||||
return new_conns.at[pos, 3:].set(attrs)
|
return new_conns.at[pos, 2:].set(attrs)
|
||||||
|
|
||||||
|
|
||||||
def delete_conn_by_pos(conns, pos):
|
def delete_conn_by_pos(conns, pos):
|
||||||
|
|||||||
Reference in New Issue
Block a user