diff --git a/tensorneat/algorithm/neat/gene/conn/base.py b/tensorneat/algorithm/neat/gene/conn/base.py index da4d54a..4564dc2 100644 --- a/tensorneat/algorithm/neat/gene/conn/base.py +++ b/tensorneat/algorithm/neat/gene/conn/base.py @@ -16,13 +16,6 @@ class BaseConnGene(BaseGene): def forward(self, state, attrs, inputs): raise NotImplementedError - def update_by_batch(self, state, attrs, batch_inputs): - # default: do not update attrs, but to calculate batch_res - return ( - jax.vmap(self.forward, in_axes=(None, None, 0))(state, attrs, batch_inputs), - attrs, - ) - def repr(self, state, conn, precision=2, idx_width=3, func_width=8): in_idx, out_idx = conn[:2] in_idx = int(in_idx) diff --git a/tensorneat/algorithm/neat/gene/conn/default.py b/tensorneat/algorithm/neat/gene/conn/default.py index 21dbb55..0a95cf5 100644 --- a/tensorneat/algorithm/neat/gene/conn/default.py +++ b/tensorneat/algorithm/neat/gene/conn/default.py @@ -1,9 +1,8 @@ import jax.numpy as jnp import jax.random -import numpy as np import sympy as sp from tensorneat.common import mutate_float -from . import BaseConnGene +from .base import BaseConnGene class DefaultConnGene(BaseConnGene): diff --git a/tensorneat/algorithm/neat/gene/node/__init__.py b/tensorneat/algorithm/neat/gene/node/__init__.py index 3c10f2c..752cfb1 100644 --- a/tensorneat/algorithm/neat/gene/node/__init__.py +++ b/tensorneat/algorithm/neat/gene/node/__init__.py @@ -1,3 +1,3 @@ from .base import BaseNodeGene from .default import DefaultNodeGene -from .default_without_response import NodeGeneWithoutResponse +from .bias import BiasNode diff --git a/tensorneat/algorithm/neat/gene/node/base.py b/tensorneat/algorithm/neat/gene/node/base.py index 30e324d..f3275d9 100644 --- a/tensorneat/algorithm/neat/gene/node/base.py +++ b/tensorneat/algorithm/neat/gene/node/base.py @@ -12,34 +12,6 @@ class BaseNodeGene(BaseGene): def forward(self, state, attrs, inputs, is_output_node=False): raise NotImplementedError - def input_transform(self, state, attrs, inputs): - """ - make transformation in the input node. - default: do nothing - """ - return inputs - - def update_by_batch(self, state, attrs, batch_inputs, is_output_node=False): - # default: do not update attrs, but to calculate batch_res - return ( - jax.vmap(self.forward, in_axes=(None, None, 0, None))( - state, attrs, batch_inputs, is_output_node - ), - attrs, - ) - - def update_input_transform(self, state, attrs, batch_inputs): - """ - update the attrs for transformation in the input node. - default: do nothing - """ - return ( - jax.vmap(self.input_transform, in_axes=(None, None, 0))( - state, attrs, batch_inputs - ), - attrs, - ) - def repr(self, state, node, precision=2, idx_width=3, func_width=8): idx = node[0] diff --git a/tensorneat/algorithm/neat/gene/node/default_without_response.py b/tensorneat/algorithm/neat/gene/node/bias.py similarity index 98% rename from tensorneat/algorithm/neat/gene/node/default_without_response.py rename to tensorneat/algorithm/neat/gene/node/bias.py index b7f9a0b..69e88f4 100644 --- a/tensorneat/algorithm/neat/gene/node/default_without_response.py +++ b/tensorneat/algorithm/neat/gene/node/bias.py @@ -1,7 +1,6 @@ from typing import Tuple import jax, jax.numpy as jnp -import numpy as np import sympy as sp from tensorneat.common import ( Act, @@ -16,7 +15,7 @@ from tensorneat.common import ( from . import BaseNodeGene -class NodeGeneWithoutResponse(BaseNodeGene): +class BiasNode(BaseNodeGene): """ Default node gene, with the same behavior as in NEAT-python. The attribute response is removed. diff --git a/tensorneat/algorithm/neat/gene/node/kan_node.py b/tensorneat/algorithm/neat/gene/node/kan_node.py deleted file mode 100644 index 298f888..0000000 --- a/tensorneat/algorithm/neat/gene/node/kan_node.py +++ /dev/null @@ -1,27 +0,0 @@ -import jax.numpy as jnp -from . import BaseNodeGene -from tensorneat.common import Agg - - -class KANNode(BaseNodeGene): - "Node gene for KAN, with only a sum aggregation." - - custom_attrs = [] - - def __init__(self): - super().__init__() - - def new_identity_attrs(self, state): - return jnp.array([]) - - def new_random_attrs(self, state, randkey): - return jnp.array([]) - - def mutate(self, state, randkey, attrs): - return jnp.array([]) - - def distance(self, state, attrs1, attrs2): - return 0 - - def forward(self, state, attrs, inputs, is_output_node=False): - return Agg.sum(inputs) diff --git a/tensorneat/algorithm/neat/gene/node/min_max_node.py b/tensorneat/algorithm/neat/gene/node/min_max_node.py deleted file mode 100644 index 9560f6c..0000000 --- a/tensorneat/algorithm/neat/gene/node/min_max_node.py +++ /dev/null @@ -1,193 +0,0 @@ -from typing import Tuple - -import jax, jax.numpy as jnp - -from tensorneat.common import Act, Agg, act_func, agg_func, mutate_int, mutate_float -from . import BaseNodeGene - - -class MinMaxNode(BaseNodeGene): - """ - Node with normalization before activation. - """ - - # alpha and beta is used for normalization, just like BatchNorm - # norm: z = act(agg(inputs) + bias) - # z = (z - min) / (max - min) * (max_out - min_out) + min_out - custom_attrs = ["bias", "aggregation", "activation", "min", "max"] - eps = 1e-6 - - def __init__( - self, - bias_init_mean: float = 0.0, - bias_init_std: float = 1.0, - bias_mutate_power: float = 0.5, - bias_mutate_rate: float = 0.7, - bias_replace_rate: float = 0.1, - aggregation_default: callable = Agg.sum, - aggregation_options: Tuple = (Agg.sum,), - aggregation_replace_rate: float = 0.1, - activation_default: callable = Act.sigmoid, - activation_options: Tuple = (Act.sigmoid,), - activation_replace_rate: float = 0.1, - output_range: Tuple[float, float] = (-1, 1), - update_hidden_node: bool = False, - ): - super().__init__() - self.bias_init_mean = bias_init_mean - self.bias_init_std = bias_init_std - self.bias_mutate_power = bias_mutate_power - self.bias_mutate_rate = bias_mutate_rate - self.bias_replace_rate = bias_replace_rate - - self.aggregation_default = aggregation_options.index(aggregation_default) - self.aggregation_options = aggregation_options - self.aggregation_indices = jnp.arange(len(aggregation_options)) - self.aggregation_replace_rate = aggregation_replace_rate - - self.activation_default = activation_options.index(activation_default) - self.activation_options = activation_options - self.activation_indices = jnp.arange(len(activation_options)) - self.activation_replace_rate = activation_replace_rate - - self.output_range = output_range - assert ( - len(self.output_range) == 2 and self.output_range[0] < self.output_range[1] - ) - self.update_hidden_node = update_hidden_node - - def new_identity_attrs(self, state): - return jnp.array( - [0, self.aggregation_default, -1, 0, 1] - ) # activation=-1 means Act.identity; min=0, max=1 will do not influence - - def new_random_attrs(self, state, randkey): - k1, k2, k3, k4, k5 = jax.random.split(randkey, num=5) - bias = jax.random.normal(k1, ()) * self.bias_init_std + self.bias_init_mean - agg = jax.random.randint(k2, (), 0, len(self.aggregation_options)) - act = jax.random.randint(k3, (), 0, len(self.activation_options)) - return jnp.array([bias, agg, act, 0, 1]) - - def mutate(self, state, randkey, attrs): - k1, k2, k3, k4, k5 = jax.random.split(randkey, num=5) - bias, act, agg, min_, max_ = attrs - - bias = mutate_float( - k1, - bias, - self.bias_init_mean, - self.bias_init_std, - self.bias_mutate_power, - self.bias_mutate_rate, - self.bias_replace_rate, - ) - - agg = mutate_int( - k2, agg, self.aggregation_indices, self.aggregation_replace_rate - ) - - act = mutate_int(k3, act, self.activation_indices, self.activation_replace_rate) - - return jnp.array([bias, agg, act, min_, max_]) - - def distance(self, state, attrs1, attrs2): - bias1, agg1, act1, min1, max1 = attrs1 - bias2, agg2, act2, min1, max1 = attrs2 - return ( - jnp.abs(bias1 - bias2) # bias - + (agg1 != agg2) # aggregation - + (act1 != act2) # activation - ) - - def forward(self, state, attrs, inputs, is_output_node=False): - """ - post_act = (agg(inputs) + bias - mean) / std * alpha + beta - """ - bias, agg, act, min_, max_ = attrs - - z = agg_func(agg, inputs, self.aggregation_options) - z = bias + z - - # the last output node should not be activated - z = jax.lax.cond( - is_output_node, lambda: z, lambda: act_func(act, z, self.activation_options) - ) - - if self.update_hidden_node: - z = (z - min_) / (max_ - min_) # transform to 01 - z = ( - z * (self.output_range[1] - self.output_range[0]) + self.output_range[0] - ) # transform to output_range - - return z - - def input_transform(self, state, attrs, inputs): - """ - make transform in the input node. - the normalization also need be done in the first node. - """ - bias, agg, act, min_, max_ = attrs - inputs = (inputs - min_) / (max_ - min_) # transform to 01 - inputs = ( - inputs * (self.output_range[1] - self.output_range[0]) - + self.output_range[0] - ) - return inputs - - def update_by_batch(self, state, attrs, batch_inputs, is_output_node=False): - - bias, agg, act, min_, max_ = attrs - - batch_z = jax.vmap(agg_func, in_axes=(None, 0, None))( - agg, batch_inputs, self.aggregation_options - ) - - batch_z = bias + batch_z - - batch_z = jax.lax.cond( - is_output_node, - lambda: batch_z, - lambda: jax.vmap(act_func, in_axes=(None, 0, None))( - act, batch_z, self.activation_options - ), - ) - - if self.update_hidden_node: - # calculate min, max - min_ = jnp.min(jnp.where(jnp.isnan(batch_z), jnp.inf, batch_z)) - max_ = jnp.max(jnp.where(jnp.isnan(batch_z), -jnp.inf, batch_z)) - - batch_z = (batch_z - min_) / (max_ - min_) # transform to 01 - batch_z = ( - batch_z * (self.output_range[1] - self.output_range[0]) - + self.output_range[0] - ) - - # update mean and std to the attrs - attrs = attrs.at[3].set(min_) - attrs = attrs.at[4].set(max_) - - return batch_z, attrs - - def update_input_transform(self, state, attrs, batch_inputs): - """ - update the attrs for transformation in the input node. - default: do nothing - """ - bias, agg, act, min_, max_ = attrs - - # calculate min, max - min_ = jnp.min(jnp.where(jnp.isnan(batch_inputs), jnp.inf, batch_inputs)) - max_ = jnp.max(jnp.where(jnp.isnan(batch_inputs), -jnp.inf, batch_inputs)) - - batch_inputs = (batch_inputs - min_) / (max_ - min_) # transform to 01 - batch_inputs = ( - batch_inputs * (self.output_range[1] - self.output_range[0]) - + self.output_range[0] - ) - - # update mean and std to the attrs - attrs = attrs.at[3].set(min_) - attrs = attrs.at[4].set(max_) - - return batch_inputs, attrs diff --git a/tensorneat/algorithm/neat/gene/node/normalized.py b/tensorneat/algorithm/neat/gene/node/normalized.py deleted file mode 100644 index 342a14c..0000000 --- a/tensorneat/algorithm/neat/gene/node/normalized.py +++ /dev/null @@ -1,231 +0,0 @@ -from typing import Tuple - -import jax, jax.numpy as jnp - -from tensorneat.common import Act, Agg, act_func, agg_func, mutate_int, mutate_float -from . import BaseNodeGene - - -class NormalizedNode(BaseNodeGene): - """ - Node with normalization before activation. - """ - - # alpha and beta is used for normalization, just like BatchNorm - # norm: (data - mean) / (std + eps) * alpha + beta - custom_attrs = ["bias", "aggregation", "activation", "mean", "std", "alpha", "beta"] - eps = 1e-6 - - def __init__( - self, - bias_init_mean: float = 0.0, - bias_init_std: float = 1.0, - bias_mutate_power: float = 0.5, - bias_mutate_rate: float = 0.7, - bias_replace_rate: float = 0.1, - aggregation_default: callable = Agg.sum, - aggregation_options: Tuple = (Agg.sum,), - aggregation_replace_rate: float = 0.1, - activation_default: callable = Act.sigmoid, - activation_options: Tuple = (Act.sigmoid,), - activation_replace_rate: float = 0.1, - alpha_init_mean: float = 1.0, - alpha_init_std: float = 1.0, - alpha_mutate_power: float = 0.5, - alpha_mutate_rate: float = 0.7, - alpha_replace_rate: float = 0.1, - beta_init_mean: float = 0.0, - beta_init_std: float = 1.0, - beta_mutate_power: float = 0.5, - beta_mutate_rate: float = 0.7, - beta_replace_rate: float = 0.1, - ): - super().__init__() - self.bias_init_mean = bias_init_mean - self.bias_init_std = bias_init_std - self.bias_mutate_power = bias_mutate_power - self.bias_mutate_rate = bias_mutate_rate - self.bias_replace_rate = bias_replace_rate - - self.aggregation_default = aggregation_options.index(aggregation_default) - self.aggregation_options = aggregation_options - self.aggregation_indices = jnp.arange(len(aggregation_options)) - self.aggregation_replace_rate = aggregation_replace_rate - - self.activation_default = activation_options.index(activation_default) - self.activation_options = activation_options - self.activation_indices = jnp.arange(len(activation_options)) - self.activation_replace_rate = activation_replace_rate - - self.alpha_init_mean = alpha_init_mean - self.alpha_init_std = alpha_init_std - self.alpha_mutate_power = alpha_mutate_power - self.alpha_mutate_rate = alpha_mutate_rate - self.alpha_replace_rate = alpha_replace_rate - - self.beta_init_mean = beta_init_mean - self.beta_init_std = beta_init_std - self.beta_mutate_power = beta_mutate_power - self.beta_mutate_rate = beta_mutate_rate - self.beta_replace_rate = beta_replace_rate - - def new_identity_attrs(self, state): - return jnp.array( - [0, self.aggregation_default, -1, 0, 1, 1, 0] - ) # activation=-1 means Act.identity - - def new_random_attrs(self, state, randkey): - k1, k2, k3, k4, k5 = jax.random.split(randkey, num=5) - bias = jax.random.normal(k1, ()) * self.bias_init_std + self.bias_init_mean - agg = jax.random.randint(k2, (), 0, len(self.aggregation_options)) - act = jax.random.randint(k3, (), 0, len(self.activation_options)) - - mean = 0 - std = 1 - alpha = jax.random.normal(k4, ()) * self.alpha_init_std + self.alpha_init_mean - beta = jax.random.normal(k5, ()) * self.beta_init_std + self.beta_init_mean - - return jnp.array([bias, agg, act, mean, std, alpha, beta]) - - def mutate(self, state, randkey, attrs): - k1, k2, k3, k4, k5 = jax.random.split(randkey, num=5) - bias, act, agg, mean, std, alpha, beta = attrs - - bias = mutate_float( - k1, - bias, - self.bias_init_mean, - self.bias_init_std, - self.bias_mutate_power, - self.bias_mutate_rate, - self.bias_replace_rate, - ) - - agg = mutate_int( - k2, agg, self.aggregation_indices, self.aggregation_replace_rate - ) - - act = mutate_int(k3, act, self.activation_indices, self.activation_replace_rate) - - alpha = mutate_float( - k4, - alpha, - self.alpha_init_mean, - self.alpha_init_std, - self.alpha_mutate_power, - self.alpha_mutate_rate, - self.alpha_replace_rate, - ) - - beta = mutate_float( - k5, - beta, - self.beta_init_mean, - self.beta_init_std, - self.beta_mutate_power, - self.beta_mutate_rate, - self.beta_replace_rate, - ) - - return jnp.array([bias, agg, act, mean, std, alpha, beta]) - - def distance(self, state, attrs1, attrs2): - bias1, agg1, act1, mean1, std1, alpha1, beta1 = attrs1 - bias2, agg2, act2, mean2, std2, alpha2, beta2 = attrs2 - return ( - jnp.abs(bias1 - bias2) # bias - + (agg1 != agg2) # aggregation - + (act1 != act2) # activation - + jnp.abs(alpha1 - alpha2) # alpha - + jnp.abs(beta1 - beta2) # beta - ) - - def forward(self, state, attrs, inputs, is_output_node=False): - """ - post_act = (agg(inputs) + bias - mean) / std * alpha + beta - """ - bias, agg, act, mean, std, alpha, beta = attrs - - z = agg_func(agg, inputs, self.aggregation_options) - z = bias + z - z = (z - mean) / (std + self.eps) * alpha + beta # normalization - - # the last output node should not be activated - z = jax.lax.cond( - is_output_node, lambda: z, lambda: act_func(act, z, self.activation_options) - ) - - return z - - def input_transform(self, state, attrs, inputs): - """ - make transform in the input node. - the normalization also need be done in the first node. - """ - bias, agg, act, mean, std, alpha, beta = attrs - inputs = (inputs - mean) / (std + self.eps) * alpha + beta # normalization - return inputs - - def update_by_batch(self, state, attrs, batch_inputs, is_output_node=False): - - bias, agg, act, mean, std, alpha, beta = attrs - - batch_z = jax.vmap(agg_func, in_axes=(None, 0, None))( - agg, batch_inputs, self.aggregation_options - ) - - batch_z = bias + batch_z - - # calculate mean - valid_values_count = jnp.sum(~jnp.isnan(batch_z)) - valid_values_sum = jnp.sum(jnp.where(jnp.isnan(batch_z), 0, batch_z)) - mean = valid_values_sum / valid_values_count - - # calculate std - std = jnp.sqrt( - jnp.sum(jnp.where(jnp.isnan(batch_z), 0, (batch_z - mean) ** 2)) - / valid_values_count - ) - - batch_z = (batch_z - mean) / (std + self.eps) * alpha + beta # normalization - batch_z = jax.lax.cond( - is_output_node, - lambda: batch_z, - lambda: jax.vmap(act_func, in_axes=(None, 0, None))( - act, batch_z, self.activation_options - ), - ) - - # update mean and std to the attrs - attrs = attrs.at[3].set(mean) - attrs = attrs.at[4].set(std) - - return batch_z, attrs - - def update_input_transform(self, state, attrs, batch_inputs): - """ - update the attrs for transformation in the input node. - default: do nothing - """ - bias, agg, act, mean, std, alpha, beta = attrs - - # calculate mean - valid_values_count = jnp.sum(~jnp.isnan(batch_inputs)) - valid_values_sum = jnp.sum(jnp.where(jnp.isnan(batch_inputs), 0, batch_inputs)) - mean = valid_values_sum / valid_values_count - - # calculate std - std = jnp.sqrt( - jnp.sum(jnp.where(jnp.isnan(batch_inputs), 0, (batch_inputs - mean) ** 2)) - / valid_values_count - ) - - batch_inputs = (batch_inputs - mean) / ( - std + self.eps - ) * alpha + beta # normalization - - # update mean and std to the attrs - attrs = attrs.at[3].set(mean) - attrs = attrs.at[4].set(std) - - return batch_inputs, attrs