This commit is contained in:
wls2002
2024-06-02 19:38:48 +08:00
parent e65200a94e
commit a07a3b1cb2
9 changed files with 673 additions and 13 deletions

View File

@@ -16,8 +16,8 @@ class HyperNEAT(BaseAlgorithm):
neat: NEAT,
below_threshold: float = 0.3,
max_weight: float = 5.0,
activation=Act.sigmoid,
aggregation=Agg.sum,
activation=Act.sigmoid,
activate_time: int = 10,
output_transform: Callable = Act.sigmoid,
):
@@ -34,7 +34,7 @@ class HyperNEAT(BaseAlgorithm):
num_outputs=substrate.num_outputs,
max_nodes=substrate.nodes_cnt,
max_conns=substrate.conns_cnt,
node_gene=HyperNodeGene(activation, aggregation),
node_gene=HyperNodeGene(aggregation, activation),
conn_gene=HyperNEATConnGene(),
activate_time=activate_time,
output_transform=output_transform,
@@ -57,7 +57,6 @@ class HyperNEAT(BaseAlgorithm):
query_res = jax.vmap(self.neat.forward, in_axes=(None, 0, None))(
state, self.substrate.query_coors, transformed
)
# mute the connection with weight below threshold
query_res = jnp.where(
(-self.below_threshold < query_res) & (query_res < self.below_threshold),
@@ -77,12 +76,15 @@ class HyperNEAT(BaseAlgorithm):
h_nodes, h_conns = self.substrate.make_nodes(
query_res
), self.substrate.make_conn(query_res)
return self.hyper_genome.transform(state, h_nodes, h_conns)
def forward(self, state, inputs, transformed):
# add bias
inputs_with_bias = jnp.concatenate([inputs, jnp.array([1])])
return self.hyper_genome.forward(state, inputs_with_bias, transformed)
res = self.hyper_genome.forward(state, inputs_with_bias, transformed)
return res
@property
def num_inputs(self):
@@ -106,12 +108,12 @@ class HyperNEAT(BaseAlgorithm):
class HyperNodeGene(BaseNodeGene):
def __init__(
self,
activation=Act.sigmoid,
aggregation=Agg.sum,
activation=Act.sigmoid,
):
super().__init__()
self.activation = activation
self.aggregation = aggregation
self.activation = activation
def forward(self, state, attrs, inputs, is_output_node=False):
return jax.lax.cond(

View File

@@ -14,7 +14,7 @@ class DefaultSubstrate(BaseSubstrate):
return self.nodes
def make_conn(self, query_res):
return self.conns.at[:, 3:].set(query_res) # change weight
return self.conns.at[:, 2:].set(query_res) # change weight
@property
def query_coors(self):

View File

@@ -56,10 +56,9 @@ def analysis_substrate(input_coors, output_coors, hidden_coors):
nodes = np.concatenate((input_idx, output_idx, hidden_idx))[..., np.newaxis]
conns = np.zeros(
(correspond_keys.shape[0], 4), dtype=np.float32
) # input_idx, output_idx, enabled, weight
conns[:, 0:2] = correspond_keys
conns[:, 2] = 1 # enabled is True
(correspond_keys.shape[0], 3), dtype=np.float32
) # input_idx, output_idx, weight
conns[:, :2] = correspond_keys
return query_coors, nodes, conns