Add CustomFuncFit into problem; Add related examples

This commit is contained in:
root
2024-07-11 18:32:08 +08:00
parent 3cb5fbf581
commit be6a67d7e2
15 changed files with 241 additions and 437 deletions

View File

@@ -17,6 +17,8 @@ class DefaultConn(BaseConn):
weight_mutate_power: float = 0.15,
weight_mutate_rate: float = 0.2,
weight_replace_rate: float = 0.015,
weight_lower_bound: float = -5.0,
weight_upper_bound: float = 5.0,
):
super().__init__()
self.weight_init_mean = weight_init_mean
@@ -24,6 +26,9 @@ class DefaultConn(BaseConn):
self.weight_mutate_power = weight_mutate_power
self.weight_mutate_rate = weight_mutate_rate
self.weight_replace_rate = weight_replace_rate
self.weight_lower_bound = weight_lower_bound
self.weight_upper_bound = weight_upper_bound
def new_zero_attrs(self, state):
return jnp.array([0.0]) # weight = 0
@@ -36,6 +41,7 @@ class DefaultConn(BaseConn):
jax.random.normal(randkey, ()) * self.weight_init_std
+ self.weight_init_mean
)
weight = jnp.clip(weight, self.weight_lower_bound, self.weight_upper_bound)
return jnp.array([weight])
def mutate(self, state, randkey, attrs):
@@ -49,7 +55,7 @@ class DefaultConn(BaseConn):
self.weight_mutate_rate,
self.weight_replace_rate,
)
weight = jnp.clip(weight, self.weight_lower_bound, self.weight_upper_bound)
return jnp.array([weight])
def distance(self, state, attrs1, attrs2):