fix bugs
This commit is contained in:
@@ -1,4 +1,2 @@
|
||||
from .gene import *
|
||||
from .genome import *
|
||||
from .species import *
|
||||
from .neat import NEAT
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from tensorneat.common import State, StatefulBaseClass
|
||||
from ..genome import BaseGenome
|
||||
from tensorneat.genome import BaseGenome
|
||||
|
||||
|
||||
class BaseSpecies(StatefulBaseClass):
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
import jax, jax.numpy as jnp
|
||||
|
||||
from .base import BaseSpecies
|
||||
from tensorneat.common import (
|
||||
State,
|
||||
rank_elements,
|
||||
argmin_with_mask,
|
||||
fetch_first,
|
||||
)
|
||||
from ..genome.utils import (
|
||||
from tensorneat.genome.utils import (
|
||||
extract_conn_attrs,
|
||||
extract_node_attrs,
|
||||
)
|
||||
from ..genome import BaseGenome
|
||||
from .base import BaseSpecies
|
||||
from tensorneat.genome import BaseGenome
|
||||
|
||||
|
||||
"""
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Callable, Sequence
|
||||
import numpy as np
|
||||
import jax
|
||||
from jax import vmap, numpy as jnp
|
||||
from ..gene import BaseNodeGene, BaseConnGene
|
||||
from .gene import BaseNodeGene, BaseConnGene
|
||||
from .operations import BaseMutation, BaseCrossover, BaseDistance
|
||||
from tensorneat.common import (
|
||||
State,
|
||||
|
||||
@@ -5,8 +5,8 @@ from jax import vmap, numpy as jnp
|
||||
import numpy as np
|
||||
import sympy as sp
|
||||
|
||||
from . import BaseGenome
|
||||
from ..gene import DefaultNodeGene, DefaultConnGene
|
||||
from .base import BaseGenome
|
||||
from .gene import DefaultNodeGene, DefaultConnGene
|
||||
from .operations import DefaultMutation, DefaultCrossover, DefaultDistance
|
||||
from .utils import unflatten_conns, extract_node_attrs, extract_conn_attrs
|
||||
|
||||
@@ -102,7 +102,7 @@ class DefaultGenome(BaseGenome):
|
||||
state,
|
||||
nodes_attrs[i],
|
||||
ins,
|
||||
is_output_node=jnp.isin(nodes[0], self.output_idx), # nodes[0] -> the key of nodes
|
||||
is_output_node=jnp.isin(nodes[i, 0], self.output_idx), # nodes[0] -> the key of nodes
|
||||
)
|
||||
|
||||
# set new value
|
||||
|
||||
@@ -3,9 +3,9 @@ from jax import vmap, numpy as jnp
|
||||
from .utils import unflatten_conns
|
||||
|
||||
from .base import BaseGenome
|
||||
from .gene import DefaultNodeGene, DefaultConnGene
|
||||
from .operations import DefaultMutation, DefaultCrossover, DefaultDistance
|
||||
from .utils import unflatten_conns, extract_node_attrs, extract_conn_attrs
|
||||
from ..gene import DefaultNodeGene, DefaultConnGene
|
||||
|
||||
from tensorneat.common import attach_with_inf
|
||||
|
||||
|
||||
Reference in New Issue
Block a user