add KNN
This commit is contained in:
@@ -16,8 +16,8 @@ class HyperNEAT(BaseAlgorithm):
|
||||
neat: NEAT,
|
||||
below_threshold: float = 0.3,
|
||||
max_weight: float = 5.0,
|
||||
activation=Act.sigmoid,
|
||||
aggregation=Agg.sum,
|
||||
activation=Act.sigmoid,
|
||||
activate_time: int = 10,
|
||||
output_transform: Callable = Act.sigmoid,
|
||||
):
|
||||
@@ -34,7 +34,7 @@ class HyperNEAT(BaseAlgorithm):
|
||||
num_outputs=substrate.num_outputs,
|
||||
max_nodes=substrate.nodes_cnt,
|
||||
max_conns=substrate.conns_cnt,
|
||||
node_gene=HyperNodeGene(activation, aggregation),
|
||||
node_gene=HyperNodeGene(aggregation, activation),
|
||||
conn_gene=HyperNEATConnGene(),
|
||||
activate_time=activate_time,
|
||||
output_transform=output_transform,
|
||||
@@ -57,7 +57,6 @@ class HyperNEAT(BaseAlgorithm):
|
||||
query_res = jax.vmap(self.neat.forward, in_axes=(None, 0, None))(
|
||||
state, self.substrate.query_coors, transformed
|
||||
)
|
||||
|
||||
# mute the connection with weight below threshold
|
||||
query_res = jnp.where(
|
||||
(-self.below_threshold < query_res) & (query_res < self.below_threshold),
|
||||
@@ -77,12 +76,15 @@ class HyperNEAT(BaseAlgorithm):
|
||||
h_nodes, h_conns = self.substrate.make_nodes(
|
||||
query_res
|
||||
), self.substrate.make_conn(query_res)
|
||||
|
||||
return self.hyper_genome.transform(state, h_nodes, h_conns)
|
||||
|
||||
def forward(self, state, inputs, transformed):
|
||||
# add bias
|
||||
inputs_with_bias = jnp.concatenate([inputs, jnp.array([1])])
|
||||
return self.hyper_genome.forward(state, inputs_with_bias, transformed)
|
||||
|
||||
res = self.hyper_genome.forward(state, inputs_with_bias, transformed)
|
||||
return res
|
||||
|
||||
@property
|
||||
def num_inputs(self):
|
||||
@@ -106,12 +108,12 @@ class HyperNEAT(BaseAlgorithm):
|
||||
class HyperNodeGene(BaseNodeGene):
|
||||
def __init__(
|
||||
self,
|
||||
activation=Act.sigmoid,
|
||||
aggregation=Agg.sum,
|
||||
activation=Act.sigmoid,
|
||||
):
|
||||
super().__init__()
|
||||
self.activation = activation
|
||||
self.aggregation = aggregation
|
||||
self.activation = activation
|
||||
|
||||
def forward(self, state, attrs, inputs, is_output_node=False):
|
||||
return jax.lax.cond(
|
||||
|
||||
Reference in New Issue
Block a user