add KNN
This commit is contained in:
@@ -16,8 +16,8 @@ class HyperNEAT(BaseAlgorithm):
|
|||||||
neat: NEAT,
|
neat: NEAT,
|
||||||
below_threshold: float = 0.3,
|
below_threshold: float = 0.3,
|
||||||
max_weight: float = 5.0,
|
max_weight: float = 5.0,
|
||||||
activation=Act.sigmoid,
|
|
||||||
aggregation=Agg.sum,
|
aggregation=Agg.sum,
|
||||||
|
activation=Act.sigmoid,
|
||||||
activate_time: int = 10,
|
activate_time: int = 10,
|
||||||
output_transform: Callable = Act.sigmoid,
|
output_transform: Callable = Act.sigmoid,
|
||||||
):
|
):
|
||||||
@@ -34,7 +34,7 @@ class HyperNEAT(BaseAlgorithm):
|
|||||||
num_outputs=substrate.num_outputs,
|
num_outputs=substrate.num_outputs,
|
||||||
max_nodes=substrate.nodes_cnt,
|
max_nodes=substrate.nodes_cnt,
|
||||||
max_conns=substrate.conns_cnt,
|
max_conns=substrate.conns_cnt,
|
||||||
node_gene=HyperNodeGene(activation, aggregation),
|
node_gene=HyperNodeGene(aggregation, activation),
|
||||||
conn_gene=HyperNEATConnGene(),
|
conn_gene=HyperNEATConnGene(),
|
||||||
activate_time=activate_time,
|
activate_time=activate_time,
|
||||||
output_transform=output_transform,
|
output_transform=output_transform,
|
||||||
@@ -57,7 +57,6 @@ class HyperNEAT(BaseAlgorithm):
|
|||||||
query_res = jax.vmap(self.neat.forward, in_axes=(None, 0, None))(
|
query_res = jax.vmap(self.neat.forward, in_axes=(None, 0, None))(
|
||||||
state, self.substrate.query_coors, transformed
|
state, self.substrate.query_coors, transformed
|
||||||
)
|
)
|
||||||
|
|
||||||
# mute the connection with weight below threshold
|
# mute the connection with weight below threshold
|
||||||
query_res = jnp.where(
|
query_res = jnp.where(
|
||||||
(-self.below_threshold < query_res) & (query_res < self.below_threshold),
|
(-self.below_threshold < query_res) & (query_res < self.below_threshold),
|
||||||
@@ -77,12 +76,15 @@ class HyperNEAT(BaseAlgorithm):
|
|||||||
h_nodes, h_conns = self.substrate.make_nodes(
|
h_nodes, h_conns = self.substrate.make_nodes(
|
||||||
query_res
|
query_res
|
||||||
), self.substrate.make_conn(query_res)
|
), self.substrate.make_conn(query_res)
|
||||||
|
|
||||||
return self.hyper_genome.transform(state, h_nodes, h_conns)
|
return self.hyper_genome.transform(state, h_nodes, h_conns)
|
||||||
|
|
||||||
def forward(self, state, inputs, transformed):
|
def forward(self, state, inputs, transformed):
|
||||||
# add bias
|
# add bias
|
||||||
inputs_with_bias = jnp.concatenate([inputs, jnp.array([1])])
|
inputs_with_bias = jnp.concatenate([inputs, jnp.array([1])])
|
||||||
return self.hyper_genome.forward(state, inputs_with_bias, transformed)
|
|
||||||
|
res = self.hyper_genome.forward(state, inputs_with_bias, transformed)
|
||||||
|
return res
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_inputs(self):
|
def num_inputs(self):
|
||||||
@@ -106,12 +108,12 @@ class HyperNEAT(BaseAlgorithm):
|
|||||||
class HyperNodeGene(BaseNodeGene):
|
class HyperNodeGene(BaseNodeGene):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
activation=Act.sigmoid,
|
|
||||||
aggregation=Agg.sum,
|
aggregation=Agg.sum,
|
||||||
|
activation=Act.sigmoid,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.activation = activation
|
|
||||||
self.aggregation = aggregation
|
self.aggregation = aggregation
|
||||||
|
self.activation = activation
|
||||||
|
|
||||||
def forward(self, state, attrs, inputs, is_output_node=False):
|
def forward(self, state, attrs, inputs, is_output_node=False):
|
||||||
return jax.lax.cond(
|
return jax.lax.cond(
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ class DefaultSubstrate(BaseSubstrate):
|
|||||||
return self.nodes
|
return self.nodes
|
||||||
|
|
||||||
def make_conn(self, query_res):
|
def make_conn(self, query_res):
|
||||||
return self.conns.at[:, 3:].set(query_res) # change weight
|
return self.conns.at[:, 2:].set(query_res) # change weight
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def query_coors(self):
|
def query_coors(self):
|
||||||
|
|||||||
@@ -56,10 +56,9 @@ def analysis_substrate(input_coors, output_coors, hidden_coors):
|
|||||||
|
|
||||||
nodes = np.concatenate((input_idx, output_idx, hidden_idx))[..., np.newaxis]
|
nodes = np.concatenate((input_idx, output_idx, hidden_idx))[..., np.newaxis]
|
||||||
conns = np.zeros(
|
conns = np.zeros(
|
||||||
(correspond_keys.shape[0], 4), dtype=np.float32
|
(correspond_keys.shape[0], 3), dtype=np.float32
|
||||||
) # input_idx, output_idx, enabled, weight
|
) # input_idx, output_idx, weight
|
||||||
conns[:, 0:2] = correspond_keys
|
conns[:, :2] = correspond_keys
|
||||||
conns[:, 2] = 1 # enabled is True
|
|
||||||
|
|
||||||
return query_coors, nodes, conns
|
return query_coors, nodes, conns
|
||||||
|
|
||||||
|
|||||||
27
tensorneat/algorithm/neat/gene/node/kan_node.py
Normal file
27
tensorneat/algorithm/neat/gene/node/kan_node.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
import jax.numpy as jnp
|
||||||
|
from . import BaseNodeGene
|
||||||
|
from utils import Agg
|
||||||
|
|
||||||
|
|
||||||
|
class KANNode(BaseNodeGene):
|
||||||
|
"Node gene for KAN, with only a sum aggregation."
|
||||||
|
|
||||||
|
custom_attrs = []
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def new_identity_attrs(self, state):
|
||||||
|
return jnp.array([])
|
||||||
|
|
||||||
|
def new_random_attrs(self, state, randkey):
|
||||||
|
return jnp.array([])
|
||||||
|
|
||||||
|
def mutate(self, state, randkey, attrs):
|
||||||
|
return jnp.array([])
|
||||||
|
|
||||||
|
def distance(self, state, attrs1, attrs2):
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def forward(self, state, attrs, inputs, is_output_node=False):
|
||||||
|
return Agg.sum(inputs)
|
||||||
@@ -39,9 +39,9 @@ if __name__ == "__main__":
|
|||||||
),
|
),
|
||||||
output_transform=Act.tanh, # the activation function for output node in NEAT
|
output_transform=Act.tanh, # the activation function for output node in NEAT
|
||||||
),
|
),
|
||||||
pop_size=10000,
|
pop_size=1000,
|
||||||
species_size=10,
|
species_size=10,
|
||||||
compatibility_threshold=3.5,
|
compatibility_threshold=2,
|
||||||
survival_threshold=0.03,
|
survival_threshold=0.03,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
|||||||
48
tensorneat/examples/func_fit/xor_kan.py
Normal file
48
tensorneat/examples/func_fit/xor_kan.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
from pipeline import Pipeline
|
||||||
|
from algorithm.neat import *
|
||||||
|
from algorithm.neat.gene.node.kan_node import KANNode
|
||||||
|
from algorithm.neat.gene.conn.bspline import BSplineConn
|
||||||
|
|
||||||
|
from problem.func_fit import XOR3d
|
||||||
|
from utils import Act
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pipeline = Pipeline(
|
||||||
|
algorithm=NEAT(
|
||||||
|
species=DefaultSpecies(
|
||||||
|
genome=DefaultGenome(
|
||||||
|
num_inputs=3,
|
||||||
|
num_outputs=1,
|
||||||
|
max_nodes=50,
|
||||||
|
max_conns=100,
|
||||||
|
node_gene=KANNode(),
|
||||||
|
conn_gene=BSplineConn(),
|
||||||
|
output_transform=Act.sigmoid, # the activation function for output node
|
||||||
|
mutation=DefaultMutation(
|
||||||
|
node_add=0.1,
|
||||||
|
conn_add=0.1,
|
||||||
|
node_delete=0.05,
|
||||||
|
conn_delete=0.05,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
pop_size=1000,
|
||||||
|
species_size=20,
|
||||||
|
compatibility_threshold=1.5,
|
||||||
|
survival_threshold=0.01, # magic
|
||||||
|
),
|
||||||
|
),
|
||||||
|
# problem=XOR3d(return_data=True),
|
||||||
|
problem=XOR3d(),
|
||||||
|
generation_limit=10000,
|
||||||
|
fitness_target=-1e-8,
|
||||||
|
# update_batch_size=8,
|
||||||
|
# pre_update=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# initialize state
|
||||||
|
state = pipeline.setup()
|
||||||
|
# print(state)
|
||||||
|
# run until terminate
|
||||||
|
state, best = pipeline.auto_run(state)
|
||||||
|
# show result
|
||||||
|
pipeline.show(state, best)
|
||||||
308
tensorneat/test/test_b_spline_conn.ipynb
Normal file
308
tensorneat/test/test_b_spline_conn.ipynb
Normal file
File diff suppressed because one or more lines are too long
112
tensorneat/test/test_compile_time.ipynb
Normal file
112
tensorneat/test/test_compile_time.ipynb
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"id": "initial_id",
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": true,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-06-02T08:29:04.093990Z",
|
||||||
|
"start_time": "2024-06-02T08:29:04.085992900Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import jax\n",
|
||||||
|
"from jax import vmap, jit, numpy as jnp"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 26,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def func(x, y):\n",
|
||||||
|
" return x + y\n",
|
||||||
|
"\n",
|
||||||
|
"def loop2():\n",
|
||||||
|
" s = 0\n",
|
||||||
|
" for i in range(1000):\n",
|
||||||
|
" x = jnp.full((10000, 1), i)\n",
|
||||||
|
" y = jnp.full((10000, 1), i + 1)\n",
|
||||||
|
" s = (vmap(func)(x, y)).sum()\n",
|
||||||
|
" return s\n",
|
||||||
|
"\n",
|
||||||
|
"def loop3():\n",
|
||||||
|
" s = 0\n",
|
||||||
|
" vmap_func = vmap(func)\n",
|
||||||
|
" for i in range(1000):\n",
|
||||||
|
" x = jnp.full((10000, 1), i)\n",
|
||||||
|
" y = jnp.full((10000, 1), i + 1)\n",
|
||||||
|
" s = (vmap_func(x, y)).sum()\n",
|
||||||
|
" return s"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-06-02T08:31:13.023886300Z",
|
||||||
|
"start_time": "2024-06-02T08:31:13.003026800Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "39f803029127aaa8"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 27,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": "Array(19990000, dtype=int32)"
|
||||||
|
},
|
||||||
|
"execution_count": 27,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"compile_loop = jit(loop3).lower().compile()\n",
|
||||||
|
"compile_loop()"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-06-02T08:31:14.526380100Z",
|
||||||
|
"start_time": "2024-06-02T08:31:13.870916800Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "ab9f83d0a313f51d"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
|
"id": "c1bd963e51aa5fd4"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 2
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython2",
|
||||||
|
"version": "2.7.6"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
||||||
164
tensorneat/test/test_kan.ipynb
Normal file
164
tensorneat/test/test_kan.ipynb
Normal file
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user