initialize methods
This commit is contained in:
BIN
tensorneat/__pycache__/pipeline.cpython-311.pyc
Normal file
BIN
tensorneat/__pycache__/pipeline.cpython-311.pyc
Normal file
Binary file not shown.
BIN
tensorneat/algorithm/__pycache__/__init__.cpython-311.pyc
Normal file
BIN
tensorneat/algorithm/__pycache__/__init__.cpython-311.pyc
Normal file
Binary file not shown.
BIN
tensorneat/algorithm/__pycache__/base.cpython-311.pyc
Normal file
BIN
tensorneat/algorithm/__pycache__/base.cpython-311.pyc
Normal file
Binary file not shown.
@@ -1,5 +1,5 @@
|
|||||||
|
from .ga import *
|
||||||
from .gene import *
|
from .gene import *
|
||||||
from .genome import *
|
from .genome import *
|
||||||
from .species import *
|
from .species import *
|
||||||
from .neat import NEAT
|
from .neat import NEAT
|
||||||
|
|
||||||
|
|||||||
BIN
tensorneat/algorithm/neat/__pycache__/__init__.cpython-311.pyc
Normal file
BIN
tensorneat/algorithm/neat/__pycache__/__init__.cpython-311.pyc
Normal file
Binary file not shown.
BIN
tensorneat/algorithm/neat/__pycache__/neat.cpython-311.pyc
Normal file
BIN
tensorneat/algorithm/neat/__pycache__/neat.cpython-311.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
tensorneat/algorithm/neat/gene/__pycache__/base.cpython-311.pyc
Normal file
BIN
tensorneat/algorithm/neat/gene/__pycache__/base.cpython-311.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -27,6 +27,17 @@ class DefaultConnGene(BaseConnGene):
|
|||||||
def new_custom_attrs(self):
|
def new_custom_attrs(self):
|
||||||
return jnp.array([self.weight_init_mean])
|
return jnp.array([self.weight_init_mean])
|
||||||
|
|
||||||
|
def new_random_attrs(self, key):
|
||||||
|
return jnp.array([mutate_float(key,
|
||||||
|
self.weight_init_mean,
|
||||||
|
self.weight_init_mean,
|
||||||
|
1.0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
1.0,
|
||||||
|
)
|
||||||
|
])
|
||||||
|
|
||||||
def mutate(self, key, conn):
|
def mutate(self, key, conn):
|
||||||
input_index = conn[0]
|
input_index = conn[0]
|
||||||
output_index = conn[1]
|
output_index = conn[1]
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -61,6 +61,16 @@ class DefaultNodeGene(BaseNodeGene):
|
|||||||
[self.bias_init_mean, self.response_init_mean, self.activation_default, self.aggregation_default]
|
[self.bias_init_mean, self.response_init_mean, self.activation_default, self.aggregation_default]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def new_random_attrs(self, key):
|
||||||
|
return jnp.array([
|
||||||
|
mutate_float(key, self.bias_init_mean, self.bias_init_mean, self.bias_init_std,
|
||||||
|
self.bias_mutate_power, self.bias_mutate_rate, self.bias_replace_rate),
|
||||||
|
mutate_float(key, self.response_init_mean, self.response_init_mean, self.response_init_std,
|
||||||
|
self.response_mutate_power, self.response_mutate_rate, self.response_replace_rate),
|
||||||
|
self.activation_default,
|
||||||
|
self.aggregation_default,
|
||||||
|
])
|
||||||
|
|
||||||
def mutate(self, key, node):
|
def mutate(self, key, node):
|
||||||
k1, k2, k3, k4 = jax.random.split(key, num=4)
|
k1, k2, k3, k4 = jax.random.split(key, num=4)
|
||||||
index = node[0]
|
index = node[0]
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -19,7 +19,8 @@ class DefaultSpecies(BaseSpecies):
|
|||||||
genome_elitism: int = 2,
|
genome_elitism: int = 2,
|
||||||
survival_threshold: float = 0.2,
|
survival_threshold: float = 0.2,
|
||||||
min_species_size: int = 1,
|
min_species_size: int = 1,
|
||||||
compatibility_threshold: float = 3.
|
compatibility_threshold: float = 3.,
|
||||||
|
initialize_method: str = 'one_hidden_node', # {'one_hidden_node', 'dense_hideen_layer', 'no_hidden_random'}
|
||||||
):
|
):
|
||||||
self.genome = genome
|
self.genome = genome
|
||||||
self.pop_size = pop_size
|
self.pop_size = pop_size
|
||||||
@@ -34,11 +35,13 @@ class DefaultSpecies(BaseSpecies):
|
|||||||
self.survival_threshold = survival_threshold
|
self.survival_threshold = survival_threshold
|
||||||
self.min_species_size = min_species_size
|
self.min_species_size = min_species_size
|
||||||
self.compatibility_threshold = compatibility_threshold
|
self.compatibility_threshold = compatibility_threshold
|
||||||
|
self.initialize_method = initialize_method
|
||||||
|
|
||||||
self.species_arange = jnp.arange(self.species_size)
|
self.species_arange = jnp.arange(self.species_size)
|
||||||
|
|
||||||
def setup(self, randkey):
|
def setup(self, randkey):
|
||||||
pop_nodes, pop_conns = initialize_population(self.pop_size, self.genome)
|
k1, k2 = jax.random.split(randkey, 2)
|
||||||
|
pop_nodes, pop_conns = initialize_population(self.pop_size, self.genome,k1, self.initialize_method)
|
||||||
|
|
||||||
species_keys = jnp.full((self.species_size,), jnp.nan) # the unique index (primary key) for each species
|
species_keys = jnp.full((self.species_size,), jnp.nan) # the unique index (primary key) for each species
|
||||||
best_fitness = jnp.full((self.species_size,), jnp.nan) # the best fitness of each species
|
best_fitness = jnp.full((self.species_size,), jnp.nan) # the best fitness of each species
|
||||||
@@ -62,7 +65,7 @@ class DefaultSpecies(BaseSpecies):
|
|||||||
pop_nodes, pop_conns = jax.device_put((pop_nodes, pop_conns))
|
pop_nodes, pop_conns = jax.device_put((pop_nodes, pop_conns))
|
||||||
|
|
||||||
return State(
|
return State(
|
||||||
randkey=randkey,
|
randkey=k2,
|
||||||
pop_nodes=pop_nodes,
|
pop_nodes=pop_nodes,
|
||||||
pop_conns=pop_conns,
|
pop_conns=pop_conns,
|
||||||
species_keys=species_keys,
|
species_keys=species_keys,
|
||||||
@@ -489,31 +492,159 @@ class DefaultSpecies(BaseSpecies):
|
|||||||
return jnp.where(max_cnt == 0, 0, val / max_cnt)
|
return jnp.where(max_cnt == 0, 0, val / max_cnt)
|
||||||
|
|
||||||
|
|
||||||
def initialize_population(pop_size, genome):
|
|
||||||
o_nodes = np.full((genome.max_nodes, genome.node_gene.length), np.nan) # original nodes
|
|
||||||
o_conns = np.full((genome.max_conns, genome.conn_gene.length), np.nan) # original connections
|
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_population(pop_size, genome, randkey, init_method='default'):
|
||||||
|
rand_keys = jax.random.split(randkey, pop_size)
|
||||||
|
|
||||||
|
if init_method == 'one_hidden_node':
|
||||||
|
init_func = init_one_hidden_node
|
||||||
|
elif init_method == 'dense_hideen_layer':
|
||||||
|
init_func = init_dense_hideen_layer
|
||||||
|
elif init_method == 'no_hidden_random':
|
||||||
|
init_func = init_no_hidden_random
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown initialization method: {}".format(init_method))
|
||||||
|
|
||||||
|
pop_nodes, pop_conns = jax.vmap(init_func, in_axes=(None, 0))(genome, rand_keys)
|
||||||
|
|
||||||
|
return pop_nodes, pop_conns
|
||||||
|
|
||||||
|
# one hidden node
|
||||||
|
def init_one_hidden_node(genome, randkey):
|
||||||
input_idx, output_idx = genome.input_idx, genome.output_idx
|
input_idx, output_idx = genome.input_idx, genome.output_idx
|
||||||
new_node_key = max([*input_idx, *output_idx]) + 1
|
new_node_key = max([*input_idx, *output_idx]) + 1
|
||||||
|
|
||||||
o_nodes[input_idx, 0] = genome.input_idx
|
nodes = jnp.full((genome.max_nodes, genome.node_gene.length), jnp.nan)
|
||||||
o_nodes[output_idx, 0] = genome.output_idx
|
conns = jnp.full((genome.max_conns, genome.conn_gene.length), jnp.nan)
|
||||||
o_nodes[new_node_key, 0] = new_node_key # one hidden node
|
|
||||||
o_nodes[np.concatenate([input_idx, output_idx]), 1:] = genome.node_gene.new_custom_attrs()
|
|
||||||
o_nodes[new_node_key, 1:] = genome.node_gene.new_custom_attrs() # one hidden node
|
|
||||||
|
|
||||||
input_conns = np.c_[input_idx, np.full_like(input_idx, new_node_key)] # input nodes to hidden
|
nodes = nodes.at[input_idx, 0].set(input_idx)
|
||||||
o_conns[input_idx, 0:2] = input_conns # in key, out key
|
nodes = nodes.at[output_idx, 0].set(output_idx)
|
||||||
o_conns[input_idx, 2] = True # enabled
|
nodes = nodes.at[new_node_key, 0].set(new_node_key)
|
||||||
o_conns[input_idx, 3:] = genome.conn_gene.new_custom_attrs()
|
|
||||||
|
|
||||||
output_conns = np.c_[np.full_like(output_idx, new_node_key), output_idx] # hidden to output nodes
|
rand_keys_nodes = jax.random.split(randkey, num=len(input_idx) + len(output_idx) + 1)
|
||||||
o_conns[output_idx, 0:2] = output_conns # in key, out key
|
input_keys, output_keys, hidden_key = rand_keys_nodes[:len(input_idx)], rand_keys_nodes[len(input_idx):len(input_idx) + len(output_idx)], rand_keys_nodes[-1]
|
||||||
o_conns[output_idx, 2] = True # enabled
|
|
||||||
o_conns[output_idx, 3:] = genome.conn_gene.new_custom_attrs()
|
|
||||||
|
|
||||||
# repeat origin genome for P times to create population
|
node_attr_func = jax.vmap(genome.node_gene.new_random_attrs, in_axes=(None,0))
|
||||||
pop_nodes = np.tile(o_nodes, (pop_size, 1, 1))
|
input_attrs = node_attr_func(input_keys)
|
||||||
pop_conns = np.tile(o_conns, (pop_size, 1, 1))
|
output_attrs = node_attr_func(output_keys)
|
||||||
|
hidden_attrs = genome.node_gene.new_custom_attrs(hidden_key)
|
||||||
|
|
||||||
return pop_nodes, pop_conns
|
nodes = nodes.at[input_idx, 1:].set(input_attrs)
|
||||||
|
nodes = nodes.at[output_idx, 1:].set(output_attrs)
|
||||||
|
nodes = nodes.at[new_node_key, 1:].set(hidden_attrs)
|
||||||
|
|
||||||
|
input_conns = jnp.c_[input_idx, jnp.full_like(input_idx, new_node_key)]
|
||||||
|
conns = conns.at[input_idx, 0:2].set(input_conns)
|
||||||
|
conns = conns.at[input_idx, 2].set(True)
|
||||||
|
|
||||||
|
output_conns = jnp.c_[jnp.full_like(output_idx, new_node_key), output_idx]
|
||||||
|
conns = conns.at[output_idx, 0:2].set(output_conns)
|
||||||
|
conns = conns.at[output_idx, 2].set(True)
|
||||||
|
|
||||||
|
rand_keys_conns = jax.random.split(randkey, num=len(input_idx) + len(output_idx))
|
||||||
|
input_conn_keys, output_conn_keys = rand_keys_conns[:len(input_idx)], rand_keys_conns[len(input_idx):]
|
||||||
|
|
||||||
|
conn_attr_func = jax.vmap(genome.conn_gene.new_random_attrs, in_axes=(None,0))
|
||||||
|
input_conn_attrs = conn_attr_func(input_conn_keys)
|
||||||
|
output_conn_attrs = conn_attr_func(output_conn_keys)
|
||||||
|
|
||||||
|
conns = conns.at[input_idx, 3:].set(input_conn_attrs)
|
||||||
|
conns = conns.at[output_idx, 3:].set(output_conn_attrs)
|
||||||
|
|
||||||
|
return nodes, conns
|
||||||
|
|
||||||
|
|
||||||
|
#random dense connections with 1 hidden layer
|
||||||
|
def init_dense_hideen_layer( genome, randkey,hiddens=20):
|
||||||
|
k1, k2, k3 = jax.random.split(randkey, num=3)
|
||||||
|
input_idx, output_idx = genome.input_idx, genome.output_idx
|
||||||
|
input_size = len(input_idx)
|
||||||
|
output_size = len(output_idx)
|
||||||
|
|
||||||
|
hidden_idx = jnp.arange(input_size + output_size, input_size + output_size + hiddens)
|
||||||
|
nodes = jnp.full((genome.max_nodes, genome.node_gene.length), jnp.nan, dtype=jnp.float32)
|
||||||
|
nodes = nodes.at[input_idx, 0].set(input_idx)
|
||||||
|
nodes = nodes.at[output_idx, 0].set(output_idx)
|
||||||
|
nodes = nodes.at[hidden_idx, 0].set(hidden_idx)
|
||||||
|
|
||||||
|
total_idx = input_size + output_size + hiddens
|
||||||
|
rand_keys_n = jax.random.split(k1, num=total_idx)
|
||||||
|
input_keys = rand_keys_n[:input_size]
|
||||||
|
output_keys = rand_keys_n[input_size:input_size + output_size]
|
||||||
|
hidden_keys = rand_keys_n[input_size + output_size:]
|
||||||
|
|
||||||
|
node_attr_func = jax.vmap(genome.conn_gene.new_random_attrs, in_axes=(0))
|
||||||
|
input_attrs = node_attr_func(input_keys)
|
||||||
|
output_attrs = node_attr_func(output_keys)
|
||||||
|
hidden_attrs = node_attr_func(hidden_keys)
|
||||||
|
|
||||||
|
nodes = nodes.at[input_idx, 1:].set(input_attrs)
|
||||||
|
nodes = nodes.at[output_idx, 1:].set(output_attrs)
|
||||||
|
nodes = nodes.at[hidden_idx, 1:].set(hidden_attrs)
|
||||||
|
|
||||||
|
total_connections = input_size * hiddens + hiddens * output_size
|
||||||
|
conns = jnp.full((genome.max_conns, genome.conn_gene.length), jnp.nan, dtype=jnp.float32)
|
||||||
|
|
||||||
|
rand_keys_c = jax.random.split(k2, num=total_connections)
|
||||||
|
conns_attr_func = jax.vmap(genome.node_gene.new_random_attrs, in_axes=(0))
|
||||||
|
conns_attrs = conns_attr_func(rand_keys_c)
|
||||||
|
|
||||||
|
input_to_hidden_ids, hidden_ids = jnp.meshgrid(input_idx, hidden_idx, indexing='ij')
|
||||||
|
hidden_to_output_ids, output_ids = jnp.meshgrid(hidden_idx, output_idx, indexing='ij')
|
||||||
|
|
||||||
|
conns = conns.at[:input_size * hiddens, 0].set(input_to_hidden_ids.flatten())
|
||||||
|
conns = conns.at[:input_size * hiddens, 1].set(hidden_ids.flatten())
|
||||||
|
conns = conns.at[input_size * hiddens: total_connections, 0].set(hidden_to_output_ids.flatten())
|
||||||
|
conns = conns.at[input_size * hiddens: total_connections, 1].set(output_ids.flatten())
|
||||||
|
conns = conns.at[:input_size * hiddens + hiddens * output_size, 2].set(True)
|
||||||
|
conns = conns.at[:input_size * hiddens + hiddens * output_size, 3:].set(conns_attrs)
|
||||||
|
|
||||||
|
return nodes, conns
|
||||||
|
|
||||||
|
# random sparse connections with no hidden nodes
|
||||||
|
def init_no_hidden_random(genome, randkey):
|
||||||
|
k1, k2, k3 = jax.random.split(randkey, num=3)
|
||||||
|
input_idx, output_idx = genome.input_idx, genome.output_idx
|
||||||
|
|
||||||
|
nodes = jnp.full((genome.max_nodes, genome.node_gene.length), jnp.nan)
|
||||||
|
nodes = nodes.at[input_idx, 0].set(input_idx)
|
||||||
|
nodes = nodes.at[output_idx, 0].set(output_idx)
|
||||||
|
|
||||||
|
total_idx = len(input_idx) + len(output_idx)
|
||||||
|
rand_keys_n = jax.random.split(k1, num=total_idx)
|
||||||
|
input_keys = rand_keys_n[:len(input_idx)]
|
||||||
|
output_keys = rand_keys_n[len(input_idx):]
|
||||||
|
|
||||||
|
node_attr_func = jax.vmap(genome.node_gene.new_random_attrs, in_axes=(0))
|
||||||
|
input_attrs = node_attr_func(input_keys)
|
||||||
|
output_attrs = node_attr_func(output_keys)
|
||||||
|
nodes = nodes.at[input_idx, 1:].set(input_attrs)
|
||||||
|
nodes = nodes.at[output_idx, 1:].set(output_attrs)
|
||||||
|
|
||||||
|
conns = jnp.full((genome.max_conns, genome.conn_gene.length), jnp.nan)
|
||||||
|
|
||||||
|
num_connections_per_output = 4
|
||||||
|
total_connections = len(output_idx) * num_connections_per_output
|
||||||
|
|
||||||
|
def create_connections_for_output(key):
|
||||||
|
permuted_inputs = jax.random.permutation(key, input_idx)
|
||||||
|
selected_inputs = permuted_inputs[:num_connections_per_output]
|
||||||
|
return selected_inputs
|
||||||
|
|
||||||
|
conn_keys = jax.random.split(k2, num=len(output_idx))
|
||||||
|
connections = jax.vmap(create_connections_for_output)(conn_keys)
|
||||||
|
connections = connections.flatten()
|
||||||
|
|
||||||
|
output_repeats = jnp.repeat(output_idx, num_connections_per_output)
|
||||||
|
|
||||||
|
rand_keys_c = jax.random.split(k3, num=total_connections)
|
||||||
|
conns_attr_func = jax.vmap(genome.conn_gene.new_random_attrs, in_axes=(0))
|
||||||
|
conns_attrs = conns_attr_func(rand_keys_c)
|
||||||
|
|
||||||
|
conns = conns.at[:total_connections, 0].set(connections)
|
||||||
|
conns = conns.at[:total_connections, 1].set(output_repeats)
|
||||||
|
conns = conns.at[:total_connections, 2].set(True) # enabled
|
||||||
|
conns = conns.at[:total_connections, 3:].set(conns_attrs)
|
||||||
|
|
||||||
|
return nodes, conns
|
||||||
|
|||||||
@@ -82,12 +82,13 @@ class Pipeline:
|
|||||||
state = ini_state
|
state = ini_state
|
||||||
compiled_step = jax.jit(self.step).lower(ini_state).compile()
|
compiled_step = jax.jit(self.step).lower(ini_state).compile()
|
||||||
|
|
||||||
for _ in range(self.generation_limit):
|
for w in range(self.generation_limit):
|
||||||
|
|
||||||
self.generation_timestamp = time.time()
|
self.generation_timestamp = time.time()
|
||||||
|
|
||||||
previous_pop = self.algorithm.ask(state.alg)
|
previous_pop = self.algorithm.ask(state.alg)
|
||||||
|
|
||||||
|
|
||||||
state, fitnesses = compiled_step(state)
|
state, fitnesses = compiled_step(state)
|
||||||
|
|
||||||
fitnesses = jax.device_get(fitnesses)
|
fitnesses = jax.device_get(fitnesses)
|
||||||
@@ -102,7 +103,13 @@ class Pipeline:
|
|||||||
if max(fitnesses) >= self.fitness_target:
|
if max(fitnesses) >= self.fitness_target:
|
||||||
print("Fitness limit reached!")
|
print("Fitness limit reached!")
|
||||||
return state, self.best_genome
|
return state, self.best_genome
|
||||||
|
node= previous_pop[0][0][:,0]
|
||||||
|
node_count = jnp.sum(~jnp.isnan(node))
|
||||||
|
conn= previous_pop[1][0][:,0]
|
||||||
|
conn_count = jnp.sum(~jnp.isnan(conn))
|
||||||
|
if(w%5==0):
|
||||||
|
print("node_count",node_count)
|
||||||
|
print("conn_count",conn_count)
|
||||||
print("Generation limit reached!")
|
print("Generation limit reached!")
|
||||||
return state, self.best_genome
|
return state, self.best_genome
|
||||||
|
|
||||||
|
|||||||
BIN
tensorneat/problem/__pycache__/__init__.cpython-311.pyc
Normal file
BIN
tensorneat/problem/__pycache__/__init__.cpython-311.pyc
Normal file
Binary file not shown.
BIN
tensorneat/problem/__pycache__/base.cpython-311.pyc
Normal file
BIN
tensorneat/problem/__pycache__/base.cpython-311.pyc
Normal file
Binary file not shown.
BIN
tensorneat/problem/rl_env/__pycache__/__init__.cpython-311.pyc
Normal file
BIN
tensorneat/problem/rl_env/__pycache__/__init__.cpython-311.pyc
Normal file
Binary file not shown.
BIN
tensorneat/problem/rl_env/__pycache__/brax_env.cpython-311.pyc
Normal file
BIN
tensorneat/problem/rl_env/__pycache__/brax_env.cpython-311.pyc
Normal file
Binary file not shown.
BIN
tensorneat/problem/rl_env/__pycache__/gymnax_env.cpython-311.pyc
Normal file
BIN
tensorneat/problem/rl_env/__pycache__/gymnax_env.cpython-311.pyc
Normal file
Binary file not shown.
BIN
tensorneat/problem/rl_env/__pycache__/rl_jit.cpython-311.pyc
Normal file
BIN
tensorneat/problem/rl_env/__pycache__/rl_jit.cpython-311.pyc
Normal file
Binary file not shown.
@@ -9,32 +9,55 @@ class RLEnv(BaseProblem):
|
|||||||
jitable = True
|
jitable = True
|
||||||
|
|
||||||
# TODO: move output transform to algorithm
|
# TODO: move output transform to algorithm
|
||||||
def __init__(self):
|
def __init__(self, max_step=1000):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.max_step = max_step
|
||||||
|
|
||||||
|
# def evaluate(self, randkey, state, act_func, params):
|
||||||
|
# rng_reset, rng_episode = jax.random.split(randkey)
|
||||||
|
# init_obs, init_env_state = self.reset(rng_reset)
|
||||||
|
|
||||||
|
# def cond_func(carry):
|
||||||
|
# _, _, _, done, _ = carry
|
||||||
|
# return ~done
|
||||||
|
|
||||||
|
# def body_func(carry):
|
||||||
|
# obs, env_state, rng, _, tr = carry # total reward
|
||||||
|
# action = act_func(obs, params)
|
||||||
|
# next_obs, next_env_state, reward, done, _ = self.step(rng, env_state, action)
|
||||||
|
# next_rng, _ = jax.random.split(rng)
|
||||||
|
# return next_obs, next_env_state, next_rng, done, tr + reward
|
||||||
|
|
||||||
|
# _, _, _, _, total_reward = jax.lax.while_loop(
|
||||||
|
# cond_func,
|
||||||
|
# body_func,
|
||||||
|
# (init_obs, init_env_state, rng_episode, False, 0.0)
|
||||||
|
# )
|
||||||
|
|
||||||
|
# return total_reward
|
||||||
|
|
||||||
def evaluate(self, randkey, state, act_func, params):
|
def evaluate(self, randkey, state, act_func, params):
|
||||||
rng_reset, rng_episode = jax.random.split(randkey)
|
rng_reset, rng_episode = jax.random.split(randkey)
|
||||||
init_obs, init_env_state = self.reset(rng_reset)
|
init_obs, init_env_state = self.reset(rng_reset)
|
||||||
|
|
||||||
def cond_func(carry):
|
def cond_func(carry):
|
||||||
_, _, _, done, _ = carry
|
_, _, _, done, _, count = carry
|
||||||
return ~done
|
return ~done & (count < self.max_step)
|
||||||
|
|
||||||
def body_func(carry):
|
def body_func(carry):
|
||||||
obs, env_state, rng, _, tr = carry # total reward
|
obs, env_state, rng, done, tr, count = carry # tr -> total reward
|
||||||
action = act_func(obs, params)
|
action = act_func(obs, params)
|
||||||
next_obs, next_env_state, reward, done, _ = self.step(rng, env_state, action)
|
next_obs, next_env_state, reward, done, _ = self.step(rng, env_state, action)
|
||||||
next_rng, _ = jax.random.split(rng)
|
next_rng, _ = jax.random.split(rng)
|
||||||
return next_obs, next_env_state, next_rng, done, tr + reward
|
return next_obs, next_env_state, next_rng, done, tr + reward, count + 1
|
||||||
|
|
||||||
_, _, _, _, total_reward = jax.lax.while_loop(
|
_, _, _, _, total_reward, _ = jax.lax.while_loop(
|
||||||
cond_func,
|
cond_func,
|
||||||
body_func,
|
body_func,
|
||||||
(init_obs, init_env_state, rng_episode, False, 0.0)
|
(init_obs, init_env_state, rng_episode, False, 0.0, 0)
|
||||||
)
|
)
|
||||||
|
|
||||||
return total_reward
|
return total_reward
|
||||||
|
|
||||||
@partial(jax.jit, static_argnums=(0,))
|
@partial(jax.jit, static_argnums=(0,))
|
||||||
def step(self, randkey, env_state, action):
|
def step(self, randkey, env_state, action):
|
||||||
return self.env_step(randkey, env_state, action)
|
return self.env_step(randkey, env_state, action)
|
||||||
|
|||||||
BIN
tensorneat/utils/__pycache__/__init__.cpython-311.pyc
Normal file
BIN
tensorneat/utils/__pycache__/__init__.cpython-311.pyc
Normal file
Binary file not shown.
BIN
tensorneat/utils/__pycache__/activation.cpython-311.pyc
Normal file
BIN
tensorneat/utils/__pycache__/activation.cpython-311.pyc
Normal file
Binary file not shown.
BIN
tensorneat/utils/__pycache__/aggregation.cpython-311.pyc
Normal file
BIN
tensorneat/utils/__pycache__/aggregation.cpython-311.pyc
Normal file
Binary file not shown.
BIN
tensorneat/utils/__pycache__/graph.cpython-311.pyc
Normal file
BIN
tensorneat/utils/__pycache__/graph.cpython-311.pyc
Normal file
Binary file not shown.
BIN
tensorneat/utils/__pycache__/state.cpython-311.pyc
Normal file
BIN
tensorneat/utils/__pycache__/state.cpython-311.pyc
Normal file
Binary file not shown.
BIN
tensorneat/utils/__pycache__/tools.cpython-311.pyc
Normal file
BIN
tensorneat/utils/__pycache__/tools.cpython-311.pyc
Normal file
Binary file not shown.
@@ -11,7 +11,7 @@ class Act:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def tanh(z):
|
def tanh(z):
|
||||||
return jnp.tanh(z)
|
return jnp.tanh(0.6 * z)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def sin(z):
|
def sin(z):
|
||||||
|
|||||||
Reference in New Issue
Block a user