hyper neat
This commit is contained in:
55
algorithm/neat/species/species_info.py
Normal file
55
algorithm/neat/species/species_info.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from jax.tree_util import register_pytree_node_class
|
||||
import numpy as np
|
||||
import jax.numpy as jnp
|
||||
|
||||
@register_pytree_node_class
|
||||
class SpeciesInfo:
|
||||
|
||||
def __init__(self, species_keys, best_fitness, last_improved, member_count):
|
||||
self.species_keys = species_keys
|
||||
self.best_fitness = best_fitness
|
||||
self.last_improved = last_improved
|
||||
self.member_count = member_count
|
||||
|
||||
@classmethod
|
||||
def initialize(cls, state):
|
||||
species_keys = np.full((state.S,), np.nan, dtype=np.float32)
|
||||
best_fitness = np.full((state.S,), np.nan, dtype=np.float32)
|
||||
last_improved = np.full((state.S,), np.nan, dtype=np.float32)
|
||||
member_count = np.full((state.S,), np.nan, dtype=np.float32)
|
||||
|
||||
species_keys[0] = 0
|
||||
best_fitness[0] = -np.inf
|
||||
last_improved[0] = 0
|
||||
member_count[0] = state.P
|
||||
|
||||
return cls(species_keys, best_fitness, last_improved, member_count)
|
||||
|
||||
def __getitem__(self, i):
|
||||
return SpeciesInfo(self.species_keys[i], self.best_fitness[i], self.last_improved[i], self.member_count[i])
|
||||
|
||||
def get(self, i):
|
||||
return self.species_keys[i], self.best_fitness[i], self.last_improved[i], self.member_count[i]
|
||||
|
||||
def set(self, idx, value):
|
||||
species_keys = self.species_keys.at[idx].set(value[0])
|
||||
best_fitness = self.best_fitness.at[idx].set(value[1])
|
||||
last_improved = self.last_improved.at[idx].set(value[2])
|
||||
member_count = self.member_count.at[idx].set(value[3])
|
||||
return SpeciesInfo(species_keys, best_fitness, last_improved, member_count)
|
||||
|
||||
def remove(self, idx):
|
||||
return self.set(idx, jnp.array([jnp.nan] * 4))
|
||||
|
||||
def size(self):
|
||||
return self.species_keys.shape[0]
|
||||
|
||||
|
||||
def tree_flatten(self):
|
||||
children = self.species_keys, self.best_fitness, self.last_improved, self.member_count
|
||||
aux_data = None
|
||||
return children, aux_data
|
||||
|
||||
@classmethod
|
||||
def tree_unflatten(cls, aux_data, children):
|
||||
return cls(*children)
|
||||
Reference in New Issue
Block a user