add KNN
This commit is contained in:
@@ -16,8 +16,8 @@ class HyperNEAT(BaseAlgorithm):
|
||||
neat: NEAT,
|
||||
below_threshold: float = 0.3,
|
||||
max_weight: float = 5.0,
|
||||
activation=Act.sigmoid,
|
||||
aggregation=Agg.sum,
|
||||
activation=Act.sigmoid,
|
||||
activate_time: int = 10,
|
||||
output_transform: Callable = Act.sigmoid,
|
||||
):
|
||||
@@ -34,7 +34,7 @@ class HyperNEAT(BaseAlgorithm):
|
||||
num_outputs=substrate.num_outputs,
|
||||
max_nodes=substrate.nodes_cnt,
|
||||
max_conns=substrate.conns_cnt,
|
||||
node_gene=HyperNodeGene(activation, aggregation),
|
||||
node_gene=HyperNodeGene(aggregation, activation),
|
||||
conn_gene=HyperNEATConnGene(),
|
||||
activate_time=activate_time,
|
||||
output_transform=output_transform,
|
||||
@@ -57,7 +57,6 @@ class HyperNEAT(BaseAlgorithm):
|
||||
query_res = jax.vmap(self.neat.forward, in_axes=(None, 0, None))(
|
||||
state, self.substrate.query_coors, transformed
|
||||
)
|
||||
|
||||
# mute the connection with weight below threshold
|
||||
query_res = jnp.where(
|
||||
(-self.below_threshold < query_res) & (query_res < self.below_threshold),
|
||||
@@ -77,12 +76,15 @@ class HyperNEAT(BaseAlgorithm):
|
||||
h_nodes, h_conns = self.substrate.make_nodes(
|
||||
query_res
|
||||
), self.substrate.make_conn(query_res)
|
||||
|
||||
return self.hyper_genome.transform(state, h_nodes, h_conns)
|
||||
|
||||
def forward(self, state, inputs, transformed):
|
||||
# add bias
|
||||
inputs_with_bias = jnp.concatenate([inputs, jnp.array([1])])
|
||||
return self.hyper_genome.forward(state, inputs_with_bias, transformed)
|
||||
|
||||
res = self.hyper_genome.forward(state, inputs_with_bias, transformed)
|
||||
return res
|
||||
|
||||
@property
|
||||
def num_inputs(self):
|
||||
@@ -106,12 +108,12 @@ class HyperNEAT(BaseAlgorithm):
|
||||
class HyperNodeGene(BaseNodeGene):
|
||||
def __init__(
|
||||
self,
|
||||
activation=Act.sigmoid,
|
||||
aggregation=Agg.sum,
|
||||
activation=Act.sigmoid,
|
||||
):
|
||||
super().__init__()
|
||||
self.activation = activation
|
||||
self.aggregation = aggregation
|
||||
self.activation = activation
|
||||
|
||||
def forward(self, state, attrs, inputs, is_output_node=False):
|
||||
return jax.lax.cond(
|
||||
|
||||
@@ -14,7 +14,7 @@ class DefaultSubstrate(BaseSubstrate):
|
||||
return self.nodes
|
||||
|
||||
def make_conn(self, query_res):
|
||||
return self.conns.at[:, 3:].set(query_res) # change weight
|
||||
return self.conns.at[:, 2:].set(query_res) # change weight
|
||||
|
||||
@property
|
||||
def query_coors(self):
|
||||
|
||||
@@ -56,10 +56,9 @@ def analysis_substrate(input_coors, output_coors, hidden_coors):
|
||||
|
||||
nodes = np.concatenate((input_idx, output_idx, hidden_idx))[..., np.newaxis]
|
||||
conns = np.zeros(
|
||||
(correspond_keys.shape[0], 4), dtype=np.float32
|
||||
) # input_idx, output_idx, enabled, weight
|
||||
conns[:, 0:2] = correspond_keys
|
||||
conns[:, 2] = 1 # enabled is True
|
||||
(correspond_keys.shape[0], 3), dtype=np.float32
|
||||
) # input_idx, output_idx, weight
|
||||
conns[:, :2] = correspond_keys
|
||||
|
||||
return query_coors, nodes, conns
|
||||
|
||||
|
||||
27
tensorneat/algorithm/neat/gene/node/kan_node.py
Normal file
27
tensorneat/algorithm/neat/gene/node/kan_node.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import jax.numpy as jnp
|
||||
from . import BaseNodeGene
|
||||
from utils import Agg
|
||||
|
||||
|
||||
class KANNode(BaseNodeGene):
|
||||
"Node gene for KAN, with only a sum aggregation."
|
||||
|
||||
custom_attrs = []
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def new_identity_attrs(self, state):
|
||||
return jnp.array([])
|
||||
|
||||
def new_random_attrs(self, state, randkey):
|
||||
return jnp.array([])
|
||||
|
||||
def mutate(self, state, randkey, attrs):
|
||||
return jnp.array([])
|
||||
|
||||
def distance(self, state, attrs1, attrs2):
|
||||
return 0
|
||||
|
||||
def forward(self, state, attrs, inputs, is_output_node=False):
|
||||
return Agg.sum(inputs)
|
||||
@@ -39,9 +39,9 @@ if __name__ == "__main__":
|
||||
),
|
||||
output_transform=Act.tanh, # the activation function for output node in NEAT
|
||||
),
|
||||
pop_size=10000,
|
||||
pop_size=1000,
|
||||
species_size=10,
|
||||
compatibility_threshold=3.5,
|
||||
compatibility_threshold=2,
|
||||
survival_threshold=0.03,
|
||||
),
|
||||
),
|
||||
|
||||
48
tensorneat/examples/func_fit/xor_kan.py
Normal file
48
tensorneat/examples/func_fit/xor_kan.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from pipeline import Pipeline
|
||||
from algorithm.neat import *
|
||||
from algorithm.neat.gene.node.kan_node import KANNode
|
||||
from algorithm.neat.gene.conn.bspline import BSplineConn
|
||||
|
||||
from problem.func_fit import XOR3d
|
||||
from utils import Act
|
||||
|
||||
if __name__ == "__main__":
|
||||
pipeline = Pipeline(
|
||||
algorithm=NEAT(
|
||||
species=DefaultSpecies(
|
||||
genome=DefaultGenome(
|
||||
num_inputs=3,
|
||||
num_outputs=1,
|
||||
max_nodes=50,
|
||||
max_conns=100,
|
||||
node_gene=KANNode(),
|
||||
conn_gene=BSplineConn(),
|
||||
output_transform=Act.sigmoid, # the activation function for output node
|
||||
mutation=DefaultMutation(
|
||||
node_add=0.1,
|
||||
conn_add=0.1,
|
||||
node_delete=0.05,
|
||||
conn_delete=0.05,
|
||||
),
|
||||
),
|
||||
pop_size=1000,
|
||||
species_size=20,
|
||||
compatibility_threshold=1.5,
|
||||
survival_threshold=0.01, # magic
|
||||
),
|
||||
),
|
||||
# problem=XOR3d(return_data=True),
|
||||
problem=XOR3d(),
|
||||
generation_limit=10000,
|
||||
fitness_target=-1e-8,
|
||||
# update_batch_size=8,
|
||||
# pre_update=True,
|
||||
)
|
||||
|
||||
# initialize state
|
||||
state = pipeline.setup()
|
||||
# print(state)
|
||||
# run until terminate
|
||||
state, best = pipeline.auto_run(state)
|
||||
# show result
|
||||
pipeline.show(state, best)
|
||||
308
tensorneat/test/test_b_spline_conn.ipynb
Normal file
308
tensorneat/test/test_b_spline_conn.ipynb
Normal file
File diff suppressed because one or more lines are too long
112
tensorneat/test/test_compile_time.ipynb
Normal file
112
tensorneat/test/test_compile_time.ipynb
Normal file
@@ -0,0 +1,112 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "initial_id",
|
||||
"metadata": {
|
||||
"collapsed": true,
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-06-02T08:29:04.093990Z",
|
||||
"start_time": "2024-06-02T08:29:04.085992900Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import jax\n",
|
||||
"from jax import vmap, jit, numpy as jnp"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 26,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def func(x, y):\n",
|
||||
" return x + y\n",
|
||||
"\n",
|
||||
"def loop2():\n",
|
||||
" s = 0\n",
|
||||
" for i in range(1000):\n",
|
||||
" x = jnp.full((10000, 1), i)\n",
|
||||
" y = jnp.full((10000, 1), i + 1)\n",
|
||||
" s = (vmap(func)(x, y)).sum()\n",
|
||||
" return s\n",
|
||||
"\n",
|
||||
"def loop3():\n",
|
||||
" s = 0\n",
|
||||
" vmap_func = vmap(func)\n",
|
||||
" for i in range(1000):\n",
|
||||
" x = jnp.full((10000, 1), i)\n",
|
||||
" y = jnp.full((10000, 1), i + 1)\n",
|
||||
" s = (vmap_func(x, y)).sum()\n",
|
||||
" return s"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-06-02T08:31:13.023886300Z",
|
||||
"start_time": "2024-06-02T08:31:13.003026800Z"
|
||||
}
|
||||
},
|
||||
"id": "39f803029127aaa8"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": "Array(19990000, dtype=int32)"
|
||||
},
|
||||
"execution_count": 27,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"compile_loop = jit(loop3).lower().compile()\n",
|
||||
"compile_loop()"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"ExecuteTime": {
|
||||
"end_time": "2024-06-02T08:31:14.526380100Z",
|
||||
"start_time": "2024-06-02T08:31:13.870916800Z"
|
||||
}
|
||||
},
|
||||
"id": "ab9f83d0a313f51d"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"outputs": [],
|
||||
"source": [],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"id": "c1bd963e51aa5fd4"
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 2
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython2",
|
||||
"version": "2.7.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
164
tensorneat/test/test_kan.ipynb
Normal file
164
tensorneat/test/test_kan.ipynb
Normal file
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user