This commit is contained in:
wls2002
2024-06-02 19:38:48 +08:00
parent e65200a94e
commit a07a3b1cb2
9 changed files with 673 additions and 13 deletions

View File

@@ -16,8 +16,8 @@ class HyperNEAT(BaseAlgorithm):
neat: NEAT,
below_threshold: float = 0.3,
max_weight: float = 5.0,
activation=Act.sigmoid,
aggregation=Agg.sum,
activation=Act.sigmoid,
activate_time: int = 10,
output_transform: Callable = Act.sigmoid,
):
@@ -34,7 +34,7 @@ class HyperNEAT(BaseAlgorithm):
num_outputs=substrate.num_outputs,
max_nodes=substrate.nodes_cnt,
max_conns=substrate.conns_cnt,
node_gene=HyperNodeGene(activation, aggregation),
node_gene=HyperNodeGene(aggregation, activation),
conn_gene=HyperNEATConnGene(),
activate_time=activate_time,
output_transform=output_transform,
@@ -57,7 +57,6 @@ class HyperNEAT(BaseAlgorithm):
query_res = jax.vmap(self.neat.forward, in_axes=(None, 0, None))(
state, self.substrate.query_coors, transformed
)
# mute the connection with weight below threshold
query_res = jnp.where(
(-self.below_threshold < query_res) & (query_res < self.below_threshold),
@@ -77,12 +76,15 @@ class HyperNEAT(BaseAlgorithm):
h_nodes, h_conns = self.substrate.make_nodes(
query_res
), self.substrate.make_conn(query_res)
return self.hyper_genome.transform(state, h_nodes, h_conns)
def forward(self, state, inputs, transformed):
# add bias
inputs_with_bias = jnp.concatenate([inputs, jnp.array([1])])
return self.hyper_genome.forward(state, inputs_with_bias, transformed)
res = self.hyper_genome.forward(state, inputs_with_bias, transformed)
return res
@property
def num_inputs(self):
@@ -106,12 +108,12 @@ class HyperNEAT(BaseAlgorithm):
class HyperNodeGene(BaseNodeGene):
def __init__(
self,
activation=Act.sigmoid,
aggregation=Agg.sum,
activation=Act.sigmoid,
):
super().__init__()
self.activation = activation
self.aggregation = aggregation
self.activation = activation
def forward(self, state, attrs, inputs, is_output_node=False):
return jax.lax.cond(

View File

@@ -14,7 +14,7 @@ class DefaultSubstrate(BaseSubstrate):
return self.nodes
def make_conn(self, query_res):
return self.conns.at[:, 3:].set(query_res) # change weight
return self.conns.at[:, 2:].set(query_res) # change weight
@property
def query_coors(self):

View File

@@ -56,10 +56,9 @@ def analysis_substrate(input_coors, output_coors, hidden_coors):
nodes = np.concatenate((input_idx, output_idx, hidden_idx))[..., np.newaxis]
conns = np.zeros(
(correspond_keys.shape[0], 4), dtype=np.float32
) # input_idx, output_idx, enabled, weight
conns[:, 0:2] = correspond_keys
conns[:, 2] = 1 # enabled is True
(correspond_keys.shape[0], 3), dtype=np.float32
) # input_idx, output_idx, weight
conns[:, :2] = correspond_keys
return query_coors, nodes, conns

View File

@@ -0,0 +1,27 @@
import jax.numpy as jnp
from . import BaseNodeGene
from utils import Agg
class KANNode(BaseNodeGene):
"Node gene for KAN, with only a sum aggregation."
custom_attrs = []
def __init__(self):
super().__init__()
def new_identity_attrs(self, state):
return jnp.array([])
def new_random_attrs(self, state, randkey):
return jnp.array([])
def mutate(self, state, randkey, attrs):
return jnp.array([])
def distance(self, state, attrs1, attrs2):
return 0
def forward(self, state, attrs, inputs, is_output_node=False):
return Agg.sum(inputs)

View File

@@ -39,9 +39,9 @@ if __name__ == "__main__":
),
output_transform=Act.tanh, # the activation function for output node in NEAT
),
pop_size=10000,
pop_size=1000,
species_size=10,
compatibility_threshold=3.5,
compatibility_threshold=2,
survival_threshold=0.03,
),
),

View File

@@ -0,0 +1,48 @@
from pipeline import Pipeline
from algorithm.neat import *
from algorithm.neat.gene.node.kan_node import KANNode
from algorithm.neat.gene.conn.bspline import BSplineConn
from problem.func_fit import XOR3d
from utils import Act
if __name__ == "__main__":
pipeline = Pipeline(
algorithm=NEAT(
species=DefaultSpecies(
genome=DefaultGenome(
num_inputs=3,
num_outputs=1,
max_nodes=50,
max_conns=100,
node_gene=KANNode(),
conn_gene=BSplineConn(),
output_transform=Act.sigmoid, # the activation function for output node
mutation=DefaultMutation(
node_add=0.1,
conn_add=0.1,
node_delete=0.05,
conn_delete=0.05,
),
),
pop_size=1000,
species_size=20,
compatibility_threshold=1.5,
survival_threshold=0.01, # magic
),
),
# problem=XOR3d(return_data=True),
problem=XOR3d(),
generation_limit=10000,
fitness_target=-1e-8,
# update_batch_size=8,
# pre_update=True,
)
# initialize state
state = pipeline.setup()
# print(state)
# run until terminate
state, best = pipeline.auto_run(state)
# show result
pipeline.show(state, best)

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,112 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 4,
"id": "initial_id",
"metadata": {
"collapsed": true,
"ExecuteTime": {
"end_time": "2024-06-02T08:29:04.093990Z",
"start_time": "2024-06-02T08:29:04.085992900Z"
}
},
"outputs": [],
"source": [
"import jax\n",
"from jax import vmap, jit, numpy as jnp"
]
},
{
"cell_type": "code",
"execution_count": 26,
"outputs": [],
"source": [
"def func(x, y):\n",
" return x + y\n",
"\n",
"def loop2():\n",
" s = 0\n",
" for i in range(1000):\n",
" x = jnp.full((10000, 1), i)\n",
" y = jnp.full((10000, 1), i + 1)\n",
" s = (vmap(func)(x, y)).sum()\n",
" return s\n",
"\n",
"def loop3():\n",
" s = 0\n",
" vmap_func = vmap(func)\n",
" for i in range(1000):\n",
" x = jnp.full((10000, 1), i)\n",
" y = jnp.full((10000, 1), i + 1)\n",
" s = (vmap_func(x, y)).sum()\n",
" return s"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-02T08:31:13.023886300Z",
"start_time": "2024-06-02T08:31:13.003026800Z"
}
},
"id": "39f803029127aaa8"
},
{
"cell_type": "code",
"execution_count": 27,
"outputs": [
{
"data": {
"text/plain": "Array(19990000, dtype=int32)"
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"compile_loop = jit(loop3).lower().compile()\n",
"compile_loop()"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-06-02T08:31:14.526380100Z",
"start_time": "2024-06-02T08:31:13.870916800Z"
}
},
"id": "ab9f83d0a313f51d"
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [],
"metadata": {
"collapsed": false
},
"id": "c1bd963e51aa5fd4"
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

File diff suppressed because one or more lines are too long