diff --git a/examples/tmp.py b/examples/tmp.py index e900e5a..1d08549 100644 --- a/examples/tmp.py +++ b/examples/tmp.py @@ -1,7 +1,7 @@ import jax, jax.numpy as jnp from tensorneat.algorithm import NEAT -from tensorneat.algorithm.neat import DefaultGenome, RecurrentGenome +from tensorneat.genome import DefaultGenome, RecurrentGenome key = jax.random.key(0) genome = DefaultGenome(num_inputs=5, num_outputs=3, max_nodes=100, max_conns=500, init_hidden_layers=(1, 2 ,3)) diff --git a/network.svg b/network.svg index 2989755..d86c21f 100644 --- a/network.svg +++ b/network.svg @@ -6,7 +6,7 @@ - 2024-07-10T15:27:16.806503 + 2024-07-10T16:50:19.947855 image/svg+xml @@ -32,222 +32,222 @@ z +" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p8fe09283f8)" style="fill: none; stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p8fe09283f8)" style="stroke: #000000; stroke-linecap: round"/> +" clip-path="url(#p8fe09283f8)" style="fill: #0000ff; stroke: #0000ff"/> +" clip-path="url(#p8fe09283f8)" style="fill: #0000ff; stroke: #0000ff"/> +" clip-path="url(#p8fe09283f8)" style="fill: #0000ff; stroke: #0000ff"/> +" clip-path="url(#p8fe09283f8)" style="fill: #0000ff; stroke: #0000ff"/> +" clip-path="url(#p8fe09283f8)" style="fill: #0000ff; stroke: #0000ff"/> +" clip-path="url(#p8fe09283f8)" style="fill: #0000ff; stroke: #0000ff"/> +" clip-path="url(#p8fe09283f8)" style="fill: #0000ff; stroke: #0000ff"/> +" clip-path="url(#p8fe09283f8)" style="fill: #0000ff; stroke: #0000ff"/> +" clip-path="url(#p8fe09283f8)" style="fill: #0000ff; stroke: #0000ff"/> +" clip-path="url(#p8fe09283f8)" style="fill: #0000ff; stroke: #0000ff"/> +" clip-path="url(#p8fe09283f8)" style="fill: #0000ff; stroke: #0000ff"/> +" clip-path="url(#p8fe09283f8)" style="fill: #0000ff; stroke: #0000ff"/> +" clip-path="url(#p8fe09283f8)" style="fill: #0000ff; stroke: #0000ff"/> +" clip-path="url(#p8fe09283f8)" style="fill: #0000ff; stroke: #0000ff"/> - + diff --git a/tensorneat/algorithm/neat/__init__.py b/tensorneat/algorithm/neat/__init__.py index 06f14d2..cab968b 100644 --- a/tensorneat/algorithm/neat/__init__.py +++ b/tensorneat/algorithm/neat/__init__.py @@ -1,4 +1,2 @@ -from .gene import * -from .genome import * from .species import * from .neat import NEAT diff --git a/tensorneat/algorithm/neat/species/base.py b/tensorneat/algorithm/neat/species/base.py index f6175de..cf03d4f 100644 --- a/tensorneat/algorithm/neat/species/base.py +++ b/tensorneat/algorithm/neat/species/base.py @@ -1,5 +1,5 @@ from tensorneat.common import State, StatefulBaseClass -from ..genome import BaseGenome +from tensorneat.genome import BaseGenome class BaseSpecies(StatefulBaseClass): diff --git a/tensorneat/algorithm/neat/species/default.py b/tensorneat/algorithm/neat/species/default.py index 310e070..3c5e82a 100644 --- a/tensorneat/algorithm/neat/species/default.py +++ b/tensorneat/algorithm/neat/species/default.py @@ -1,16 +1,17 @@ import jax, jax.numpy as jnp + +from .base import BaseSpecies from tensorneat.common import ( State, rank_elements, argmin_with_mask, fetch_first, ) -from ..genome.utils import ( +from tensorneat.genome.utils import ( extract_conn_attrs, extract_node_attrs, ) -from ..genome import BaseGenome -from .base import BaseSpecies +from tensorneat.genome import BaseGenome """ diff --git a/tensorneat/genome/base.py b/tensorneat/genome/base.py index 502d3c6..f3091f2 100644 --- a/tensorneat/genome/base.py +++ b/tensorneat/genome/base.py @@ -3,7 +3,7 @@ from typing import Callable, Sequence import numpy as np import jax from jax import vmap, numpy as jnp -from ..gene import BaseNodeGene, BaseConnGene +from .gene import BaseNodeGene, BaseConnGene from .operations import BaseMutation, BaseCrossover, BaseDistance from tensorneat.common import ( State, diff --git a/tensorneat/genome/default.py b/tensorneat/genome/default.py index de29eec..73fd736 100644 --- a/tensorneat/genome/default.py +++ b/tensorneat/genome/default.py @@ -5,8 +5,8 @@ from jax import vmap, numpy as jnp import numpy as np import sympy as sp -from . import BaseGenome -from ..gene import DefaultNodeGene, DefaultConnGene +from .base import BaseGenome +from .gene import DefaultNodeGene, DefaultConnGene from .operations import DefaultMutation, DefaultCrossover, DefaultDistance from .utils import unflatten_conns, extract_node_attrs, extract_conn_attrs @@ -102,7 +102,7 @@ class DefaultGenome(BaseGenome): state, nodes_attrs[i], ins, - is_output_node=jnp.isin(nodes[0], self.output_idx), # nodes[0] -> the key of nodes + is_output_node=jnp.isin(nodes[i, 0], self.output_idx), # nodes[0] -> the key of nodes ) # set new value diff --git a/tensorneat/genome/recurrent.py b/tensorneat/genome/recurrent.py index a289e8f..d833265 100644 --- a/tensorneat/genome/recurrent.py +++ b/tensorneat/genome/recurrent.py @@ -3,9 +3,9 @@ from jax import vmap, numpy as jnp from .utils import unflatten_conns from .base import BaseGenome +from .gene import DefaultNodeGene, DefaultConnGene from .operations import DefaultMutation, DefaultCrossover, DefaultDistance from .utils import unflatten_conns, extract_node_attrs, extract_conn_attrs -from ..gene import DefaultNodeGene, DefaultConnGene from tensorneat.common import attach_with_inf