Add CustomFuncFit into problem; Add related examples
This commit is contained in:
@@ -17,6 +17,8 @@ class DefaultConn(BaseConn):
|
||||
weight_mutate_power: float = 0.15,
|
||||
weight_mutate_rate: float = 0.2,
|
||||
weight_replace_rate: float = 0.015,
|
||||
weight_lower_bound: float = -5.0,
|
||||
weight_upper_bound: float = 5.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.weight_init_mean = weight_init_mean
|
||||
@@ -24,6 +26,9 @@ class DefaultConn(BaseConn):
|
||||
self.weight_mutate_power = weight_mutate_power
|
||||
self.weight_mutate_rate = weight_mutate_rate
|
||||
self.weight_replace_rate = weight_replace_rate
|
||||
self.weight_lower_bound = weight_lower_bound
|
||||
self.weight_upper_bound = weight_upper_bound
|
||||
|
||||
|
||||
def new_zero_attrs(self, state):
|
||||
return jnp.array([0.0]) # weight = 0
|
||||
@@ -36,6 +41,7 @@ class DefaultConn(BaseConn):
|
||||
jax.random.normal(randkey, ()) * self.weight_init_std
|
||||
+ self.weight_init_mean
|
||||
)
|
||||
weight = jnp.clip(weight, self.weight_lower_bound, self.weight_upper_bound)
|
||||
return jnp.array([weight])
|
||||
|
||||
def mutate(self, state, randkey, attrs):
|
||||
@@ -49,7 +55,7 @@ class DefaultConn(BaseConn):
|
||||
self.weight_mutate_rate,
|
||||
self.weight_replace_rate,
|
||||
)
|
||||
|
||||
weight = jnp.clip(weight, self.weight_lower_bound, self.weight_upper_bound)
|
||||
return jnp.array([weight])
|
||||
|
||||
def distance(self, state, attrs1, attrs2):
|
||||
|
||||
@@ -47,9 +47,9 @@ class BiasNode(BaseNode):
|
||||
if isinstance(activation_options, Callable):
|
||||
activation_options = [activation_options]
|
||||
|
||||
if len(aggregation_options) == 1 and aggregation_default is None:
|
||||
if aggregation_default is None:
|
||||
aggregation_default = aggregation_options[0]
|
||||
if len(activation_options) == 1 and activation_default is None:
|
||||
if activation_default is None:
|
||||
activation_default = activation_options[0]
|
||||
|
||||
self.bias_init_mean = bias_init_mean
|
||||
|
||||
@@ -52,9 +52,9 @@ class DefaultNode(BaseNode):
|
||||
if isinstance(activation_options, Callable):
|
||||
activation_options = [activation_options]
|
||||
|
||||
if len(aggregation_options) == 1 and aggregation_default is None:
|
||||
if aggregation_default is None:
|
||||
aggregation_default = aggregation_options[0]
|
||||
if len(activation_options) == 1 and activation_default is None:
|
||||
if activation_default is None:
|
||||
activation_default = activation_options[0]
|
||||
|
||||
self.bias_init_mean = bias_init_mean
|
||||
|
||||
Reference in New Issue
Block a user