This commit is contained in:
root
2024-07-10 16:50:36 +08:00
parent 4cdac932d3
commit 51cb4695af
8 changed files with 71 additions and 72 deletions

View File

@@ -1,4 +1,2 @@
from .gene import *
from .genome import *
from .species import *
from .neat import NEAT

View File

@@ -1,5 +1,5 @@
from tensorneat.common import State, StatefulBaseClass
from ..genome import BaseGenome
from tensorneat.genome import BaseGenome
class BaseSpecies(StatefulBaseClass):

View File

@@ -1,16 +1,17 @@
import jax, jax.numpy as jnp
from .base import BaseSpecies
from tensorneat.common import (
State,
rank_elements,
argmin_with_mask,
fetch_first,
)
from ..genome.utils import (
from tensorneat.genome.utils import (
extract_conn_attrs,
extract_node_attrs,
)
from ..genome import BaseGenome
from .base import BaseSpecies
from tensorneat.genome import BaseGenome
"""

View File

@@ -3,7 +3,7 @@ from typing import Callable, Sequence
import numpy as np
import jax
from jax import vmap, numpy as jnp
from ..gene import BaseNodeGene, BaseConnGene
from .gene import BaseNodeGene, BaseConnGene
from .operations import BaseMutation, BaseCrossover, BaseDistance
from tensorneat.common import (
State,

View File

@@ -5,8 +5,8 @@ from jax import vmap, numpy as jnp
import numpy as np
import sympy as sp
from . import BaseGenome
from ..gene import DefaultNodeGene, DefaultConnGene
from .base import BaseGenome
from .gene import DefaultNodeGene, DefaultConnGene
from .operations import DefaultMutation, DefaultCrossover, DefaultDistance
from .utils import unflatten_conns, extract_node_attrs, extract_conn_attrs
@@ -102,7 +102,7 @@ class DefaultGenome(BaseGenome):
state,
nodes_attrs[i],
ins,
is_output_node=jnp.isin(nodes[0], self.output_idx), # nodes[0] -> the key of nodes
is_output_node=jnp.isin(nodes[i, 0], self.output_idx), # nodes[0] -> the key of nodes
)
# set new value

View File

@@ -3,9 +3,9 @@ from jax import vmap, numpy as jnp
from .utils import unflatten_conns
from .base import BaseGenome
from .gene import DefaultNodeGene, DefaultConnGene
from .operations import DefaultMutation, DefaultCrossover, DefaultDistance
from .utils import unflatten_conns, extract_node_attrs, extract_conn_attrs
from ..gene import DefaultNodeGene, DefaultConnGene
from tensorneat.common import attach_with_inf