diff --git a/tensorneat/__pycache__/pipeline.cpython-311.pyc b/tensorneat/__pycache__/pipeline.cpython-311.pyc new file mode 100644 index 0000000..b58308c Binary files /dev/null and b/tensorneat/__pycache__/pipeline.cpython-311.pyc differ diff --git a/tensorneat/algorithm/__pycache__/__init__.cpython-311.pyc b/tensorneat/algorithm/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..f01b656 Binary files /dev/null and b/tensorneat/algorithm/__pycache__/__init__.cpython-311.pyc differ diff --git a/tensorneat/algorithm/__pycache__/base.cpython-311.pyc b/tensorneat/algorithm/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000..3613f0f Binary files /dev/null and b/tensorneat/algorithm/__pycache__/base.cpython-311.pyc differ diff --git a/tensorneat/algorithm/neat/__init__.py b/tensorneat/algorithm/neat/__init__.py index 97185ca..1af3e2b 100644 --- a/tensorneat/algorithm/neat/__init__.py +++ b/tensorneat/algorithm/neat/__init__.py @@ -1,5 +1,5 @@ +from .ga import * from .gene import * from .genome import * from .species import * -from .neat import NEAT - +from .neat import NEAT \ No newline at end of file diff --git a/tensorneat/algorithm/neat/__pycache__/__init__.cpython-311.pyc b/tensorneat/algorithm/neat/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..667df74 Binary files /dev/null and b/tensorneat/algorithm/neat/__pycache__/__init__.cpython-311.pyc differ diff --git a/tensorneat/algorithm/neat/__pycache__/neat.cpython-311.pyc b/tensorneat/algorithm/neat/__pycache__/neat.cpython-311.pyc new file mode 100644 index 0000000..44cd2d1 Binary files /dev/null and b/tensorneat/algorithm/neat/__pycache__/neat.cpython-311.pyc differ diff --git a/tensorneat/algorithm/neat/ga/__pycache__/__init__.cpython-311.pyc b/tensorneat/algorithm/neat/ga/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..7109c78 Binary files /dev/null and b/tensorneat/algorithm/neat/ga/__pycache__/__init__.cpython-311.pyc differ diff --git a/tensorneat/algorithm/neat/ga/crossover/__pycache__/__init__.cpython-311.pyc b/tensorneat/algorithm/neat/ga/crossover/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..23075fb Binary files /dev/null and b/tensorneat/algorithm/neat/ga/crossover/__pycache__/__init__.cpython-311.pyc differ diff --git a/tensorneat/algorithm/neat/ga/crossover/__pycache__/base.cpython-311.pyc b/tensorneat/algorithm/neat/ga/crossover/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000..47dc800 Binary files /dev/null and b/tensorneat/algorithm/neat/ga/crossover/__pycache__/base.cpython-311.pyc differ diff --git a/tensorneat/algorithm/neat/ga/crossover/__pycache__/default.cpython-311.pyc b/tensorneat/algorithm/neat/ga/crossover/__pycache__/default.cpython-311.pyc new file mode 100644 index 0000000..7ac87f3 Binary files /dev/null and b/tensorneat/algorithm/neat/ga/crossover/__pycache__/default.cpython-311.pyc differ diff --git a/tensorneat/algorithm/neat/ga/mutation/__pycache__/__init__.cpython-311.pyc b/tensorneat/algorithm/neat/ga/mutation/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..e6157f7 Binary files /dev/null and b/tensorneat/algorithm/neat/ga/mutation/__pycache__/__init__.cpython-311.pyc differ diff --git a/tensorneat/algorithm/neat/ga/mutation/__pycache__/base.cpython-311.pyc b/tensorneat/algorithm/neat/ga/mutation/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000..cd429aa Binary files /dev/null and b/tensorneat/algorithm/neat/ga/mutation/__pycache__/base.cpython-311.pyc differ diff --git a/tensorneat/algorithm/neat/ga/mutation/__pycache__/default.cpython-311.pyc b/tensorneat/algorithm/neat/ga/mutation/__pycache__/default.cpython-311.pyc new file mode 100644 index 0000000..a0ce422 Binary files /dev/null and b/tensorneat/algorithm/neat/ga/mutation/__pycache__/default.cpython-311.pyc differ diff --git a/tensorneat/algorithm/neat/gene/__pycache__/__init__.cpython-311.pyc b/tensorneat/algorithm/neat/gene/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..2b87fbe Binary files /dev/null and b/tensorneat/algorithm/neat/gene/__pycache__/__init__.cpython-311.pyc differ diff --git a/tensorneat/algorithm/neat/gene/__pycache__/base.cpython-311.pyc b/tensorneat/algorithm/neat/gene/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000..566f10d Binary files /dev/null and b/tensorneat/algorithm/neat/gene/__pycache__/base.cpython-311.pyc differ diff --git a/tensorneat/algorithm/neat/gene/conn/__pycache__/__init__.cpython-311.pyc b/tensorneat/algorithm/neat/gene/conn/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..3db7a8c Binary files /dev/null and b/tensorneat/algorithm/neat/gene/conn/__pycache__/__init__.cpython-311.pyc differ diff --git a/tensorneat/algorithm/neat/gene/conn/__pycache__/base.cpython-311.pyc b/tensorneat/algorithm/neat/gene/conn/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000..5fddd94 Binary files /dev/null and b/tensorneat/algorithm/neat/gene/conn/__pycache__/base.cpython-311.pyc differ diff --git a/tensorneat/algorithm/neat/gene/conn/__pycache__/default.cpython-311.pyc b/tensorneat/algorithm/neat/gene/conn/__pycache__/default.cpython-311.pyc new file mode 100644 index 0000000..43b1443 Binary files /dev/null and b/tensorneat/algorithm/neat/gene/conn/__pycache__/default.cpython-311.pyc differ diff --git a/tensorneat/algorithm/neat/gene/conn/default.py b/tensorneat/algorithm/neat/gene/conn/default.py index 39db6e2..cf1b18e 100644 --- a/tensorneat/algorithm/neat/gene/conn/default.py +++ b/tensorneat/algorithm/neat/gene/conn/default.py @@ -26,6 +26,17 @@ class DefaultConnGene(BaseConnGene): def new_custom_attrs(self): return jnp.array([self.weight_init_mean]) + + def new_random_attrs(self, key): + return jnp.array([mutate_float(key, + self.weight_init_mean, + self.weight_init_mean, + 1.0, + 0, + 0, + 1.0, + ) + ]) def mutate(self, key, conn): input_index = conn[0] diff --git a/tensorneat/algorithm/neat/gene/node/__pycache__/__init__.cpython-311.pyc b/tensorneat/algorithm/neat/gene/node/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..dbbdb01 Binary files /dev/null and b/tensorneat/algorithm/neat/gene/node/__pycache__/__init__.cpython-311.pyc differ diff --git a/tensorneat/algorithm/neat/gene/node/__pycache__/base.cpython-311.pyc b/tensorneat/algorithm/neat/gene/node/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000..4bc0237 Binary files /dev/null and b/tensorneat/algorithm/neat/gene/node/__pycache__/base.cpython-311.pyc differ diff --git a/tensorneat/algorithm/neat/gene/node/__pycache__/default.cpython-311.pyc b/tensorneat/algorithm/neat/gene/node/__pycache__/default.cpython-311.pyc new file mode 100644 index 0000000..444556b Binary files /dev/null and b/tensorneat/algorithm/neat/gene/node/__pycache__/default.cpython-311.pyc differ diff --git a/tensorneat/algorithm/neat/gene/node/default.py b/tensorneat/algorithm/neat/gene/node/default.py index 3f43d4f..a5127e0 100644 --- a/tensorneat/algorithm/neat/gene/node/default.py +++ b/tensorneat/algorithm/neat/gene/node/default.py @@ -60,6 +60,16 @@ class DefaultNodeGene(BaseNodeGene): return jnp.array( [self.bias_init_mean, self.response_init_mean, self.activation_default, self.aggregation_default] ) + + def new_random_attrs(self, key): + return jnp.array([ + mutate_float(key, self.bias_init_mean, self.bias_init_mean, self.bias_init_std, + self.bias_mutate_power, self.bias_mutate_rate, self.bias_replace_rate), + mutate_float(key, self.response_init_mean, self.response_init_mean, self.response_init_std, + self.response_mutate_power, self.response_mutate_rate, self.response_replace_rate), + self.activation_default, + self.aggregation_default, + ]) def mutate(self, key, node): k1, k2, k3, k4 = jax.random.split(key, num=4) diff --git a/tensorneat/algorithm/neat/genome/__pycache__/__init__.cpython-311.pyc b/tensorneat/algorithm/neat/genome/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..d24b973 Binary files /dev/null and b/tensorneat/algorithm/neat/genome/__pycache__/__init__.cpython-311.pyc differ diff --git a/tensorneat/algorithm/neat/genome/__pycache__/base.cpython-311.pyc b/tensorneat/algorithm/neat/genome/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000..79a1704 Binary files /dev/null and b/tensorneat/algorithm/neat/genome/__pycache__/base.cpython-311.pyc differ diff --git a/tensorneat/algorithm/neat/genome/__pycache__/default.cpython-311.pyc b/tensorneat/algorithm/neat/genome/__pycache__/default.cpython-311.pyc new file mode 100644 index 0000000..0895a5b Binary files /dev/null and b/tensorneat/algorithm/neat/genome/__pycache__/default.cpython-311.pyc differ diff --git a/tensorneat/algorithm/neat/genome/__pycache__/recurrent.cpython-311.pyc b/tensorneat/algorithm/neat/genome/__pycache__/recurrent.cpython-311.pyc new file mode 100644 index 0000000..025241b Binary files /dev/null and b/tensorneat/algorithm/neat/genome/__pycache__/recurrent.cpython-311.pyc differ diff --git a/tensorneat/algorithm/neat/species/__pycache__/__init__.cpython-311.pyc b/tensorneat/algorithm/neat/species/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..b694bfe Binary files /dev/null and b/tensorneat/algorithm/neat/species/__pycache__/__init__.cpython-311.pyc differ diff --git a/tensorneat/algorithm/neat/species/__pycache__/base.cpython-311.pyc b/tensorneat/algorithm/neat/species/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000..90e99b6 Binary files /dev/null and b/tensorneat/algorithm/neat/species/__pycache__/base.cpython-311.pyc differ diff --git a/tensorneat/algorithm/neat/species/__pycache__/default.cpython-311.pyc b/tensorneat/algorithm/neat/species/__pycache__/default.cpython-311.pyc new file mode 100644 index 0000000..400cfd5 Binary files /dev/null and b/tensorneat/algorithm/neat/species/__pycache__/default.cpython-311.pyc differ diff --git a/tensorneat/algorithm/neat/species/default.py b/tensorneat/algorithm/neat/species/default.py index 7cf3e93..323841c 100644 --- a/tensorneat/algorithm/neat/species/default.py +++ b/tensorneat/algorithm/neat/species/default.py @@ -19,7 +19,8 @@ class DefaultSpecies(BaseSpecies): genome_elitism: int = 2, survival_threshold: float = 0.2, min_species_size: int = 1, - compatibility_threshold: float = 3. + compatibility_threshold: float = 3., + initialize_method: str = 'one_hidden_node', # {'one_hidden_node', 'dense_hideen_layer', 'no_hidden_random'} ): self.genome = genome self.pop_size = pop_size @@ -34,11 +35,13 @@ class DefaultSpecies(BaseSpecies): self.survival_threshold = survival_threshold self.min_species_size = min_species_size self.compatibility_threshold = compatibility_threshold + self.initialize_method = initialize_method self.species_arange = jnp.arange(self.species_size) def setup(self, randkey): - pop_nodes, pop_conns = initialize_population(self.pop_size, self.genome) + k1, k2 = jax.random.split(randkey, 2) + pop_nodes, pop_conns = initialize_population(self.pop_size, self.genome,k1, self.initialize_method) species_keys = jnp.full((self.species_size,), jnp.nan) # the unique index (primary key) for each species best_fitness = jnp.full((self.species_size,), jnp.nan) # the best fitness of each species @@ -62,7 +65,7 @@ class DefaultSpecies(BaseSpecies): pop_nodes, pop_conns = jax.device_put((pop_nodes, pop_conns)) return State( - randkey=randkey, + randkey=k2, pop_nodes=pop_nodes, pop_conns=pop_conns, species_keys=species_keys, @@ -489,31 +492,159 @@ class DefaultSpecies(BaseSpecies): return jnp.where(max_cnt == 0, 0, val / max_cnt) -def initialize_population(pop_size, genome): - o_nodes = np.full((genome.max_nodes, genome.node_gene.length), np.nan) # original nodes - o_conns = np.full((genome.max_conns, genome.conn_gene.length), np.nan) # original connections + +def initialize_population(pop_size, genome, randkey, init_method='default'): + rand_keys = jax.random.split(randkey, pop_size) + + if init_method == 'one_hidden_node': + init_func = init_one_hidden_node + elif init_method == 'dense_hideen_layer': + init_func = init_dense_hideen_layer + elif init_method == 'no_hidden_random': + init_func = init_no_hidden_random + else: + raise ValueError("Unknown initialization method: {}".format(init_method)) + + pop_nodes, pop_conns = jax.vmap(init_func, in_axes=(None, 0))(genome, rand_keys) + + return pop_nodes, pop_conns + +# one hidden node +def init_one_hidden_node(genome, randkey): input_idx, output_idx = genome.input_idx, genome.output_idx new_node_key = max([*input_idx, *output_idx]) + 1 - o_nodes[input_idx, 0] = genome.input_idx - o_nodes[output_idx, 0] = genome.output_idx - o_nodes[new_node_key, 0] = new_node_key # one hidden node - o_nodes[np.concatenate([input_idx, output_idx]), 1:] = genome.node_gene.new_custom_attrs() - o_nodes[new_node_key, 1:] = genome.node_gene.new_custom_attrs() # one hidden node + nodes = jnp.full((genome.max_nodes, genome.node_gene.length), jnp.nan) + conns = jnp.full((genome.max_conns, genome.conn_gene.length), jnp.nan) - input_conns = np.c_[input_idx, np.full_like(input_idx, new_node_key)] # input nodes to hidden - o_conns[input_idx, 0:2] = input_conns # in key, out key - o_conns[input_idx, 2] = True # enabled - o_conns[input_idx, 3:] = genome.conn_gene.new_custom_attrs() + nodes = nodes.at[input_idx, 0].set(input_idx) + nodes = nodes.at[output_idx, 0].set(output_idx) + nodes = nodes.at[new_node_key, 0].set(new_node_key) - output_conns = np.c_[np.full_like(output_idx, new_node_key), output_idx] # hidden to output nodes - o_conns[output_idx, 0:2] = output_conns # in key, out key - o_conns[output_idx, 2] = True # enabled - o_conns[output_idx, 3:] = genome.conn_gene.new_custom_attrs() + rand_keys_nodes = jax.random.split(randkey, num=len(input_idx) + len(output_idx) + 1) + input_keys, output_keys, hidden_key = rand_keys_nodes[:len(input_idx)], rand_keys_nodes[len(input_idx):len(input_idx) + len(output_idx)], rand_keys_nodes[-1] - # repeat origin genome for P times to create population - pop_nodes = np.tile(o_nodes, (pop_size, 1, 1)) - pop_conns = np.tile(o_conns, (pop_size, 1, 1)) + node_attr_func = jax.vmap(genome.node_gene.new_random_attrs, in_axes=(None,0)) + input_attrs = node_attr_func(input_keys) + output_attrs = node_attr_func(output_keys) + hidden_attrs = genome.node_gene.new_custom_attrs(hidden_key) - return pop_nodes, pop_conns + nodes = nodes.at[input_idx, 1:].set(input_attrs) + nodes = nodes.at[output_idx, 1:].set(output_attrs) + nodes = nodes.at[new_node_key, 1:].set(hidden_attrs) + + input_conns = jnp.c_[input_idx, jnp.full_like(input_idx, new_node_key)] + conns = conns.at[input_idx, 0:2].set(input_conns) + conns = conns.at[input_idx, 2].set(True) + + output_conns = jnp.c_[jnp.full_like(output_idx, new_node_key), output_idx] + conns = conns.at[output_idx, 0:2].set(output_conns) + conns = conns.at[output_idx, 2].set(True) + + rand_keys_conns = jax.random.split(randkey, num=len(input_idx) + len(output_idx)) + input_conn_keys, output_conn_keys = rand_keys_conns[:len(input_idx)], rand_keys_conns[len(input_idx):] + + conn_attr_func = jax.vmap(genome.conn_gene.new_random_attrs, in_axes=(None,0)) + input_conn_attrs = conn_attr_func(input_conn_keys) + output_conn_attrs = conn_attr_func(output_conn_keys) + + conns = conns.at[input_idx, 3:].set(input_conn_attrs) + conns = conns.at[output_idx, 3:].set(output_conn_attrs) + + return nodes, conns + + +#random dense connections with 1 hidden layer +def init_dense_hideen_layer( genome, randkey,hiddens=20): + k1, k2, k3 = jax.random.split(randkey, num=3) + input_idx, output_idx = genome.input_idx, genome.output_idx + input_size = len(input_idx) + output_size = len(output_idx) + + hidden_idx = jnp.arange(input_size + output_size, input_size + output_size + hiddens) + nodes = jnp.full((genome.max_nodes, genome.node_gene.length), jnp.nan, dtype=jnp.float32) + nodes = nodes.at[input_idx, 0].set(input_idx) + nodes = nodes.at[output_idx, 0].set(output_idx) + nodes = nodes.at[hidden_idx, 0].set(hidden_idx) + + total_idx = input_size + output_size + hiddens + rand_keys_n = jax.random.split(k1, num=total_idx) + input_keys = rand_keys_n[:input_size] + output_keys = rand_keys_n[input_size:input_size + output_size] + hidden_keys = rand_keys_n[input_size + output_size:] + + node_attr_func = jax.vmap(genome.conn_gene.new_random_attrs, in_axes=(0)) + input_attrs = node_attr_func(input_keys) + output_attrs = node_attr_func(output_keys) + hidden_attrs = node_attr_func(hidden_keys) + + nodes = nodes.at[input_idx, 1:].set(input_attrs) + nodes = nodes.at[output_idx, 1:].set(output_attrs) + nodes = nodes.at[hidden_idx, 1:].set(hidden_attrs) + + total_connections = input_size * hiddens + hiddens * output_size + conns = jnp.full((genome.max_conns, genome.conn_gene.length), jnp.nan, dtype=jnp.float32) + + rand_keys_c = jax.random.split(k2, num=total_connections) + conns_attr_func = jax.vmap(genome.node_gene.new_random_attrs, in_axes=(0)) + conns_attrs = conns_attr_func(rand_keys_c) + + input_to_hidden_ids, hidden_ids = jnp.meshgrid(input_idx, hidden_idx, indexing='ij') + hidden_to_output_ids, output_ids = jnp.meshgrid(hidden_idx, output_idx, indexing='ij') + + conns = conns.at[:input_size * hiddens, 0].set(input_to_hidden_ids.flatten()) + conns = conns.at[:input_size * hiddens, 1].set(hidden_ids.flatten()) + conns = conns.at[input_size * hiddens: total_connections, 0].set(hidden_to_output_ids.flatten()) + conns = conns.at[input_size * hiddens: total_connections, 1].set(output_ids.flatten()) + conns = conns.at[:input_size * hiddens + hiddens * output_size, 2].set(True) + conns = conns.at[:input_size * hiddens + hiddens * output_size, 3:].set(conns_attrs) + + return nodes, conns + +# random sparse connections with no hidden nodes +def init_no_hidden_random(genome, randkey): + k1, k2, k3 = jax.random.split(randkey, num=3) + input_idx, output_idx = genome.input_idx, genome.output_idx + + nodes = jnp.full((genome.max_nodes, genome.node_gene.length), jnp.nan) + nodes = nodes.at[input_idx, 0].set(input_idx) + nodes = nodes.at[output_idx, 0].set(output_idx) + + total_idx = len(input_idx) + len(output_idx) + rand_keys_n = jax.random.split(k1, num=total_idx) + input_keys = rand_keys_n[:len(input_idx)] + output_keys = rand_keys_n[len(input_idx):] + + node_attr_func = jax.vmap(genome.node_gene.new_random_attrs, in_axes=(0)) + input_attrs = node_attr_func(input_keys) + output_attrs = node_attr_func(output_keys) + nodes = nodes.at[input_idx, 1:].set(input_attrs) + nodes = nodes.at[output_idx, 1:].set(output_attrs) + + conns = jnp.full((genome.max_conns, genome.conn_gene.length), jnp.nan) + + num_connections_per_output = 4 + total_connections = len(output_idx) * num_connections_per_output + + def create_connections_for_output(key): + permuted_inputs = jax.random.permutation(key, input_idx) + selected_inputs = permuted_inputs[:num_connections_per_output] + return selected_inputs + + conn_keys = jax.random.split(k2, num=len(output_idx)) + connections = jax.vmap(create_connections_for_output)(conn_keys) + connections = connections.flatten() + + output_repeats = jnp.repeat(output_idx, num_connections_per_output) + + rand_keys_c = jax.random.split(k3, num=total_connections) + conns_attr_func = jax.vmap(genome.conn_gene.new_random_attrs, in_axes=(0)) + conns_attrs = conns_attr_func(rand_keys_c) + + conns = conns.at[:total_connections, 0].set(connections) + conns = conns.at[:total_connections, 1].set(output_repeats) + conns = conns.at[:total_connections, 2].set(True) # enabled + conns = conns.at[:total_connections, 3:].set(conns_attrs) + + return nodes, conns diff --git a/tensorneat/pipeline.py b/tensorneat/pipeline.py index ad33945..e3c73d6 100644 --- a/tensorneat/pipeline.py +++ b/tensorneat/pipeline.py @@ -82,12 +82,13 @@ class Pipeline: state = ini_state compiled_step = jax.jit(self.step).lower(ini_state).compile() - for _ in range(self.generation_limit): + for w in range(self.generation_limit): self.generation_timestamp = time.time() previous_pop = self.algorithm.ask(state.alg) + state, fitnesses = compiled_step(state) fitnesses = jax.device_get(fitnesses) @@ -102,7 +103,13 @@ class Pipeline: if max(fitnesses) >= self.fitness_target: print("Fitness limit reached!") return state, self.best_genome - + node= previous_pop[0][0][:,0] + node_count = jnp.sum(~jnp.isnan(node)) + conn= previous_pop[1][0][:,0] + conn_count = jnp.sum(~jnp.isnan(conn)) + if(w%5==0): + print("node_count",node_count) + print("conn_count",conn_count) print("Generation limit reached!") return state, self.best_genome diff --git a/tensorneat/problem/__pycache__/__init__.cpython-311.pyc b/tensorneat/problem/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..0c65415 Binary files /dev/null and b/tensorneat/problem/__pycache__/__init__.cpython-311.pyc differ diff --git a/tensorneat/problem/__pycache__/base.cpython-311.pyc b/tensorneat/problem/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000..89ca915 Binary files /dev/null and b/tensorneat/problem/__pycache__/base.cpython-311.pyc differ diff --git a/tensorneat/problem/rl_env/__pycache__/__init__.cpython-311.pyc b/tensorneat/problem/rl_env/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..c6c77f3 Binary files /dev/null and b/tensorneat/problem/rl_env/__pycache__/__init__.cpython-311.pyc differ diff --git a/tensorneat/problem/rl_env/__pycache__/brax_env.cpython-311.pyc b/tensorneat/problem/rl_env/__pycache__/brax_env.cpython-311.pyc new file mode 100644 index 0000000..1e18481 Binary files /dev/null and b/tensorneat/problem/rl_env/__pycache__/brax_env.cpython-311.pyc differ diff --git a/tensorneat/problem/rl_env/__pycache__/gymnax_env.cpython-311.pyc b/tensorneat/problem/rl_env/__pycache__/gymnax_env.cpython-311.pyc new file mode 100644 index 0000000..a1b67a4 Binary files /dev/null and b/tensorneat/problem/rl_env/__pycache__/gymnax_env.cpython-311.pyc differ diff --git a/tensorneat/problem/rl_env/__pycache__/rl_jit.cpython-311.pyc b/tensorneat/problem/rl_env/__pycache__/rl_jit.cpython-311.pyc new file mode 100644 index 0000000..e55bf7b Binary files /dev/null and b/tensorneat/problem/rl_env/__pycache__/rl_jit.cpython-311.pyc differ diff --git a/tensorneat/problem/rl_env/rl_jit.py b/tensorneat/problem/rl_env/rl_jit.py index 128ebfb..89e1b7c 100644 --- a/tensorneat/problem/rl_env/rl_jit.py +++ b/tensorneat/problem/rl_env/rl_jit.py @@ -9,32 +9,55 @@ class RLEnv(BaseProblem): jitable = True # TODO: move output transform to algorithm - def __init__(self): + def __init__(self, max_step=1000): super().__init__() + self.max_step = max_step + + # def evaluate(self, randkey, state, act_func, params): + # rng_reset, rng_episode = jax.random.split(randkey) + # init_obs, init_env_state = self.reset(rng_reset) + + # def cond_func(carry): + # _, _, _, done, _ = carry + # return ~done + + # def body_func(carry): + # obs, env_state, rng, _, tr = carry # total reward + # action = act_func(obs, params) + # next_obs, next_env_state, reward, done, _ = self.step(rng, env_state, action) + # next_rng, _ = jax.random.split(rng) + # return next_obs, next_env_state, next_rng, done, tr + reward + + # _, _, _, _, total_reward = jax.lax.while_loop( + # cond_func, + # body_func, + # (init_obs, init_env_state, rng_episode, False, 0.0) + # ) + + # return total_reward def evaluate(self, randkey, state, act_func, params): rng_reset, rng_episode = jax.random.split(randkey) init_obs, init_env_state = self.reset(rng_reset) def cond_func(carry): - _, _, _, done, _ = carry - return ~done + _, _, _, done, _, count = carry + return ~done & (count < self.max_step) def body_func(carry): - obs, env_state, rng, _, tr = carry # total reward - action = act_func(obs, params) - next_obs, next_env_state, reward, done, _ = self.step(rng, env_state, action) - next_rng, _ = jax.random.split(rng) - return next_obs, next_env_state, next_rng, done, tr + reward + obs, env_state, rng, done, tr, count = carry # tr -> total reward + action = act_func(obs, params) + next_obs, next_env_state, reward, done, _ = self.step(rng, env_state, action) + next_rng, _ = jax.random.split(rng) + return next_obs, next_env_state, next_rng, done, tr + reward, count + 1 - _, _, _, _, total_reward = jax.lax.while_loop( + _, _, _, _, total_reward, _ = jax.lax.while_loop( cond_func, body_func, - (init_obs, init_env_state, rng_episode, False, 0.0) + (init_obs, init_env_state, rng_episode, False, 0.0, 0) ) - return total_reward - + return total_reward @partial(jax.jit, static_argnums=(0,)) def step(self, randkey, env_state, action): return self.env_step(randkey, env_state, action) diff --git a/tensorneat/utils/__pycache__/__init__.cpython-311.pyc b/tensorneat/utils/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..4077070 Binary files /dev/null and b/tensorneat/utils/__pycache__/__init__.cpython-311.pyc differ diff --git a/tensorneat/utils/__pycache__/activation.cpython-311.pyc b/tensorneat/utils/__pycache__/activation.cpython-311.pyc new file mode 100644 index 0000000..a4daa9d Binary files /dev/null and b/tensorneat/utils/__pycache__/activation.cpython-311.pyc differ diff --git a/tensorneat/utils/__pycache__/aggregation.cpython-311.pyc b/tensorneat/utils/__pycache__/aggregation.cpython-311.pyc new file mode 100644 index 0000000..af36a46 Binary files /dev/null and b/tensorneat/utils/__pycache__/aggregation.cpython-311.pyc differ diff --git a/tensorneat/utils/__pycache__/graph.cpython-311.pyc b/tensorneat/utils/__pycache__/graph.cpython-311.pyc new file mode 100644 index 0000000..720a2fc Binary files /dev/null and b/tensorneat/utils/__pycache__/graph.cpython-311.pyc differ diff --git a/tensorneat/utils/__pycache__/state.cpython-311.pyc b/tensorneat/utils/__pycache__/state.cpython-311.pyc new file mode 100644 index 0000000..0975cdc Binary files /dev/null and b/tensorneat/utils/__pycache__/state.cpython-311.pyc differ diff --git a/tensorneat/utils/__pycache__/tools.cpython-311.pyc b/tensorneat/utils/__pycache__/tools.cpython-311.pyc new file mode 100644 index 0000000..6d48676 Binary files /dev/null and b/tensorneat/utils/__pycache__/tools.cpython-311.pyc differ diff --git a/tensorneat/utils/activation.py b/tensorneat/utils/activation.py index 03a2b4d..11e34f3 100644 --- a/tensorneat/utils/activation.py +++ b/tensorneat/utils/activation.py @@ -11,7 +11,7 @@ class Act: @staticmethod def tanh(z): - return jnp.tanh(z) + return jnp.tanh(0.6 * z) @staticmethod def sin(z):