This commit is contained in:
root
2024-07-10 16:50:36 +08:00
parent 4cdac932d3
commit 51cb4695af
8 changed files with 71 additions and 72 deletions

View File

@@ -3,7 +3,7 @@ from typing import Callable, Sequence
import numpy as np
import jax
from jax import vmap, numpy as jnp
from ..gene import BaseNodeGene, BaseConnGene
from .gene import BaseNodeGene, BaseConnGene
from .operations import BaseMutation, BaseCrossover, BaseDistance
from tensorneat.common import (
State,

View File

@@ -5,8 +5,8 @@ from jax import vmap, numpy as jnp
import numpy as np
import sympy as sp
from . import BaseGenome
from ..gene import DefaultNodeGene, DefaultConnGene
from .base import BaseGenome
from .gene import DefaultNodeGene, DefaultConnGene
from .operations import DefaultMutation, DefaultCrossover, DefaultDistance
from .utils import unflatten_conns, extract_node_attrs, extract_conn_attrs
@@ -102,7 +102,7 @@ class DefaultGenome(BaseGenome):
state,
nodes_attrs[i],
ins,
is_output_node=jnp.isin(nodes[0], self.output_idx), # nodes[0] -> the key of nodes
is_output_node=jnp.isin(nodes[i, 0], self.output_idx), # nodes[0] -> the key of nodes
)
# set new value

View File

@@ -3,9 +3,9 @@ from jax import vmap, numpy as jnp
from .utils import unflatten_conns
from .base import BaseGenome
from .gene import DefaultNodeGene, DefaultConnGene
from .operations import DefaultMutation, DefaultCrossover, DefaultDistance
from .utils import unflatten_conns, extract_node_attrs, extract_conn_attrs
from ..gene import DefaultNodeGene, DefaultConnGene
from tensorneat.common import attach_with_inf