From d6e9ff5d9a6ccc5ab59a8ac676ca722d836b09ae Mon Sep 17 00:00:00 2001 From: wls2002 Date: Fri, 31 May 2024 19:43:14 +0800 Subject: [PATCH] fix bug in restore genome. --- tensorneat/algorithm/neat/genome/default.py | 2 + tensorneat/algorithm/neat/genome/recurrent.py | 3 + tensorneat/pipeline.py | 19 ++++-- tensorneat/test/test_flatten.ipynb | 58 ++++++++++++++----- 4 files changed, 62 insertions(+), 20 deletions(-) diff --git a/tensorneat/algorithm/neat/genome/default.py b/tensorneat/algorithm/neat/genome/default.py index a9ed7f2..fb2eb9d 100644 --- a/tensorneat/algorithm/neat/genome/default.py +++ b/tensorneat/algorithm/neat/genome/default.py @@ -56,6 +56,8 @@ class DefaultGenome(BaseGenome): def restore(self, state, transformed): seqs, nodes, u_conns = transformed conns = flatten_conns(nodes, u_conns, C=self.max_conns) + # restore enable + conns = jnp.insert(conns, obj=2, values=1, axis=1) return nodes, conns def forward(self, state, inputs, transformed): diff --git a/tensorneat/algorithm/neat/genome/recurrent.py b/tensorneat/algorithm/neat/genome/recurrent.py index d3dae82..4670eae 100644 --- a/tensorneat/algorithm/neat/genome/recurrent.py +++ b/tensorneat/algorithm/neat/genome/recurrent.py @@ -57,6 +57,9 @@ class RecurrentGenome(BaseGenome): def restore(self, state, transformed): nodes, u_conns = transformed conns = flatten_conns(nodes, u_conns, C=self.max_conns) + + # restore enable + conns = jnp.insert(conns, obj=2, values=1, axis=1) return nodes, conns def forward(self, state, inputs, transformed): diff --git a/tensorneat/pipeline.py b/tensorneat/pipeline.py index caa5afa..edaead7 100644 --- a/tensorneat/pipeline.py +++ b/tensorneat/pipeline.py @@ -50,6 +50,11 @@ class Pipeline: self.fetch_data = lambda data: data else: raise NotImplementedError + else: + if isinstance(problem, RLEnv): + assert not problem.record_episode, "record_episode must be False" + elif isinstance(problem, FuncFit): + assert not problem.return_data, "return_data must be False" def setup(self, state=State()): print("initializing") @@ -90,6 +95,13 @@ class Pipeline: self.problem.evaluate, in_axes=(None, 0, None, 0) )(state, keys, self.algorithm.forward, pop_transformed) + # update population + pop_nodes, pop_conns = jax.vmap(self.algorithm.restore, in_axes=(None, 0))( + state, pop_transformed + ) + state = state.update(pop_nodes=pop_nodes, pop_conns=pop_conns) + + # update data for next generation data = self.fetch_data(raw_data) assert ( data.ndim == 3 @@ -119,9 +131,10 @@ class Pipeline: # replace nan with -inf fitnesses = jnp.where(jnp.isnan(fitnesses), -jnp.inf, fitnesses) + previous_pop = self.algorithm.ask(state) state = self.algorithm.tell(state, fitnesses) - return state.update(randkey=randkey), fitnesses + return state.update(randkey=randkey), previous_pop, fitnesses def auto_run(self, state): print("start compile") @@ -135,9 +148,7 @@ class Pipeline: self.generation_timestamp = time.time() - previous_pop = self.algorithm.ask(state) - - state, fitnesses = compiled_step(state) + state, previous_pop, fitnesses = compiled_step(state) fitnesses = jax.device_get(fitnesses) diff --git a/tensorneat/test/test_flatten.ipynb b/tensorneat/test/test_flatten.ipynb index de1d77e..6599aff 100644 --- a/tensorneat/test/test_flatten.ipynb +++ b/tensorneat/test/test_flatten.ipynb @@ -7,8 +7,8 @@ "metadata": { "collapsed": true, "ExecuteTime": { - "end_time": "2024-05-30T11:40:55.584592400Z", - "start_time": "2024-05-30T11:40:53.016051600Z" + "end_time": "2024-05-31T09:01:41.824974900Z", + "start_time": "2024-05-31T09:01:39.138674100Z" } }, "outputs": [], @@ -25,7 +25,7 @@ "outputs": [ { "data": { - "text/plain": "((10, 5), (10, 4))" + "text/plain": "((5, 5), (5, 4))" }, "execution_count": 2, "metadata": {}, @@ -33,7 +33,7 @@ } ], "source": [ - "genome = DefaultGenome(num_inputs=3, num_outputs=2, max_nodes=10, max_conns=10)\n", + "genome = DefaultGenome(num_inputs=3, num_outputs=1, max_nodes=5, max_conns=5)\n", "state = genome.setup()\n", "key = jax.random.PRNGKey(0)\n", "nodes, conns = genome.initialize(state, key)\n", @@ -42,8 +42,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-05-30T11:40:59.021858400Z", - "start_time": "2024-05-30T11:40:55.592593400Z" + "end_time": "2024-05-31T09:01:45.179170400Z", + "start_time": "2024-05-31T09:01:41.832976100Z" } }, "id": "89fb5cd0e77a028d" @@ -54,7 +54,7 @@ "outputs": [ { "data": { - "text/plain": "(2, 10, 10)" + "text/plain": "(Array([0, 1, 2, 4, 3], dtype=int32, weak_type=True),\n Array([[ 0. , -1.013169 , 1. , 0. , 0. ],\n [ 1. , -0.3775248 , 1. , 0. , 0. ],\n [ 2. , 0.7407059 , 1. , 0. , 0. ],\n [ 3. , -0.66817343, 1. , 0. , 0. ],\n [ 4. , 0.5336131 , 1. , 0. , 0. ]], dtype=float32, weak_type=True),\n Array([[[ nan, nan, nan, nan,\n 0.13149254],\n [ nan, nan, nan, nan,\n 0.02001922],\n [ nan, nan, nan, nan,\n -0.79229796],\n [ nan, nan, nan, nan,\n nan],\n [ nan, nan, nan, -0.57102853,\n nan]]], dtype=float32, weak_type=True))" }, "execution_count": 3, "metadata": {}, @@ -62,14 +62,14 @@ } ], "source": [ - "unflatten = unflatten_conns(nodes, conns)\n", - "unflatten.shape" + "transformed = genome.transform(state, nodes, conns)\n", + "transformed" ], "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-05-30T11:40:59.472701700Z", - "start_time": "2024-05-30T11:40:59.021858400Z" + "end_time": "2024-05-31T09:01:45.729969500Z", + "start_time": "2024-05-31T09:01:45.178173400Z" } }, "id": "aaa88227bbf29936" @@ -80,7 +80,7 @@ "outputs": [ { "data": { - "text/plain": "(Array([[ 0. , 5. , 1. , -0.41923347],\n [ 1. , 5. , 1. , -3.1815007 ],\n [ 2. , 5. , 1. , 0.5184341 ],\n [ 5. , 3. , 1. , -1.9784615 ],\n [ 5. , 4. , 1. , 0.7132204 ],\n [ nan, nan, nan, nan],\n [ nan, nan, nan, nan],\n [ nan, nan, nan, nan],\n [ nan, nan, nan, nan],\n [ nan, nan, nan, nan]], dtype=float32, weak_type=True),\n Array([[ 0. , 5. , 1. , -0.41923347],\n [ 1. , 5. , 1. , -3.1815007 ],\n [ 2. , 5. , 1. , 0.5184341 ],\n [ 5. , 3. , 1. , -1.9784615 ],\n [ 5. , 4. , 1. , 0.7132204 ],\n [ nan, nan, nan, nan],\n [ nan, nan, nan, nan],\n [ nan, nan, nan, nan],\n [ nan, nan, nan, nan],\n [ nan, nan, nan, nan]], dtype=float32))" + "text/plain": "(Array([[ 0. , -1.013169 , 1. , 0. , 0. ],\n [ 1. , -0.3775248 , 1. , 0. , 0. ],\n [ 2. , 0.7407059 , 1. , 0. , 0. ],\n [ 3. , -0.66817343, 1. , 0. , 0. ],\n [ 4. , 0.5336131 , 1. , 0. , 0. ]], dtype=float32, weak_type=True),\n Array([[ 1. , 0. , 4. , 0.13149254],\n [ 1. , 1. , 4. , 0.02001922],\n [ 1. , 2. , 4. , -0.79229796],\n [ 1. , 4. , 3. , -0.57102853],\n [ 1. , nan, nan, nan]], dtype=float32))" }, "execution_count": 4, "metadata": {}, @@ -89,18 +89,44 @@ ], "source": [ "# single flatten\n", - "flatten = flatten_conns(nodes, unflatten, C=10)\n", - "conns, flatten" + "nodes, conns = genome.restore(state, transformed)\n", + "nodes, conns" ], "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-05-30T11:41:00.308954100Z", - "start_time": "2024-05-30T11:40:59.469541500Z" + "end_time": "2024-05-31T09:01:46.660023600Z", + "start_time": "2024-05-31T09:01:45.724970700Z" } }, "id": "f2c65de38fdcff8f" }, + { + "cell_type": "code", + "execution_count": 7, + "outputs": [ + { + "data": { + "text/plain": "Array([[ 1. , 3. , 0. , 1. , 4. ,\n 0.13149254],\n [ 1. , 3. , 1. , 1. , 4. ,\n 0.02001922],\n [ 1. , 3. , 2. , 1. , 4. ,\n -0.79229796],\n [ 1. , 3. , 4. , 1. , 3. ,\n -0.57102853],\n [ 1. , 3. , nan, 1. , nan,\n nan]], dtype=float32)" + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "conns = jnp.insert(conns, obj=3, values=1, axis=1)\n", + "conns" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-31T09:03:35.665080500Z", + "start_time": "2024-05-31T09:03:35.013654700Z" + } + }, + "id": "10bcb665c32fb728" + }, { "cell_type": "code", "execution_count": 8,