fix bug in restore genome.
This commit is contained in:
@@ -56,6 +56,8 @@ class DefaultGenome(BaseGenome):
|
|||||||
def restore(self, state, transformed):
|
def restore(self, state, transformed):
|
||||||
seqs, nodes, u_conns = transformed
|
seqs, nodes, u_conns = transformed
|
||||||
conns = flatten_conns(nodes, u_conns, C=self.max_conns)
|
conns = flatten_conns(nodes, u_conns, C=self.max_conns)
|
||||||
|
# restore enable
|
||||||
|
conns = jnp.insert(conns, obj=2, values=1, axis=1)
|
||||||
return nodes, conns
|
return nodes, conns
|
||||||
|
|
||||||
def forward(self, state, inputs, transformed):
|
def forward(self, state, inputs, transformed):
|
||||||
|
|||||||
@@ -57,6 +57,9 @@ class RecurrentGenome(BaseGenome):
|
|||||||
def restore(self, state, transformed):
|
def restore(self, state, transformed):
|
||||||
nodes, u_conns = transformed
|
nodes, u_conns = transformed
|
||||||
conns = flatten_conns(nodes, u_conns, C=self.max_conns)
|
conns = flatten_conns(nodes, u_conns, C=self.max_conns)
|
||||||
|
|
||||||
|
# restore enable
|
||||||
|
conns = jnp.insert(conns, obj=2, values=1, axis=1)
|
||||||
return nodes, conns
|
return nodes, conns
|
||||||
|
|
||||||
def forward(self, state, inputs, transformed):
|
def forward(self, state, inputs, transformed):
|
||||||
|
|||||||
@@ -50,6 +50,11 @@ class Pipeline:
|
|||||||
self.fetch_data = lambda data: data
|
self.fetch_data = lambda data: data
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
else:
|
||||||
|
if isinstance(problem, RLEnv):
|
||||||
|
assert not problem.record_episode, "record_episode must be False"
|
||||||
|
elif isinstance(problem, FuncFit):
|
||||||
|
assert not problem.return_data, "return_data must be False"
|
||||||
|
|
||||||
def setup(self, state=State()):
|
def setup(self, state=State()):
|
||||||
print("initializing")
|
print("initializing")
|
||||||
@@ -90,6 +95,13 @@ class Pipeline:
|
|||||||
self.problem.evaluate, in_axes=(None, 0, None, 0)
|
self.problem.evaluate, in_axes=(None, 0, None, 0)
|
||||||
)(state, keys, self.algorithm.forward, pop_transformed)
|
)(state, keys, self.algorithm.forward, pop_transformed)
|
||||||
|
|
||||||
|
# update population
|
||||||
|
pop_nodes, pop_conns = jax.vmap(self.algorithm.restore, in_axes=(None, 0))(
|
||||||
|
state, pop_transformed
|
||||||
|
)
|
||||||
|
state = state.update(pop_nodes=pop_nodes, pop_conns=pop_conns)
|
||||||
|
|
||||||
|
# update data for next generation
|
||||||
data = self.fetch_data(raw_data)
|
data = self.fetch_data(raw_data)
|
||||||
assert (
|
assert (
|
||||||
data.ndim == 3
|
data.ndim == 3
|
||||||
@@ -119,9 +131,10 @@ class Pipeline:
|
|||||||
# replace nan with -inf
|
# replace nan with -inf
|
||||||
fitnesses = jnp.where(jnp.isnan(fitnesses), -jnp.inf, fitnesses)
|
fitnesses = jnp.where(jnp.isnan(fitnesses), -jnp.inf, fitnesses)
|
||||||
|
|
||||||
|
previous_pop = self.algorithm.ask(state)
|
||||||
state = self.algorithm.tell(state, fitnesses)
|
state = self.algorithm.tell(state, fitnesses)
|
||||||
|
|
||||||
return state.update(randkey=randkey), fitnesses
|
return state.update(randkey=randkey), previous_pop, fitnesses
|
||||||
|
|
||||||
def auto_run(self, state):
|
def auto_run(self, state):
|
||||||
print("start compile")
|
print("start compile")
|
||||||
@@ -135,9 +148,7 @@ class Pipeline:
|
|||||||
|
|
||||||
self.generation_timestamp = time.time()
|
self.generation_timestamp = time.time()
|
||||||
|
|
||||||
previous_pop = self.algorithm.ask(state)
|
state, previous_pop, fitnesses = compiled_step(state)
|
||||||
|
|
||||||
state, fitnesses = compiled_step(state)
|
|
||||||
|
|
||||||
fitnesses = jax.device_get(fitnesses)
|
fitnesses = jax.device_get(fitnesses)
|
||||||
|
|
||||||
|
|||||||
@@ -7,8 +7,8 @@
|
|||||||
"metadata": {
|
"metadata": {
|
||||||
"collapsed": true,
|
"collapsed": true,
|
||||||
"ExecuteTime": {
|
"ExecuteTime": {
|
||||||
"end_time": "2024-05-30T11:40:55.584592400Z",
|
"end_time": "2024-05-31T09:01:41.824974900Z",
|
||||||
"start_time": "2024-05-30T11:40:53.016051600Z"
|
"start_time": "2024-05-31T09:01:39.138674100Z"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@@ -25,7 +25,7 @@
|
|||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": "((10, 5), (10, 4))"
|
"text/plain": "((5, 5), (5, 4))"
|
||||||
},
|
},
|
||||||
"execution_count": 2,
|
"execution_count": 2,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
@@ -33,7 +33,7 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"genome = DefaultGenome(num_inputs=3, num_outputs=2, max_nodes=10, max_conns=10)\n",
|
"genome = DefaultGenome(num_inputs=3, num_outputs=1, max_nodes=5, max_conns=5)\n",
|
||||||
"state = genome.setup()\n",
|
"state = genome.setup()\n",
|
||||||
"key = jax.random.PRNGKey(0)\n",
|
"key = jax.random.PRNGKey(0)\n",
|
||||||
"nodes, conns = genome.initialize(state, key)\n",
|
"nodes, conns = genome.initialize(state, key)\n",
|
||||||
@@ -42,8 +42,8 @@
|
|||||||
"metadata": {
|
"metadata": {
|
||||||
"collapsed": false,
|
"collapsed": false,
|
||||||
"ExecuteTime": {
|
"ExecuteTime": {
|
||||||
"end_time": "2024-05-30T11:40:59.021858400Z",
|
"end_time": "2024-05-31T09:01:45.179170400Z",
|
||||||
"start_time": "2024-05-30T11:40:55.592593400Z"
|
"start_time": "2024-05-31T09:01:41.832976100Z"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"id": "89fb5cd0e77a028d"
|
"id": "89fb5cd0e77a028d"
|
||||||
@@ -54,7 +54,7 @@
|
|||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": "(2, 10, 10)"
|
"text/plain": "(Array([0, 1, 2, 4, 3], dtype=int32, weak_type=True),\n Array([[ 0. , -1.013169 , 1. , 0. , 0. ],\n [ 1. , -0.3775248 , 1. , 0. , 0. ],\n [ 2. , 0.7407059 , 1. , 0. , 0. ],\n [ 3. , -0.66817343, 1. , 0. , 0. ],\n [ 4. , 0.5336131 , 1. , 0. , 0. ]], dtype=float32, weak_type=True),\n Array([[[ nan, nan, nan, nan,\n 0.13149254],\n [ nan, nan, nan, nan,\n 0.02001922],\n [ nan, nan, nan, nan,\n -0.79229796],\n [ nan, nan, nan, nan,\n nan],\n [ nan, nan, nan, -0.57102853,\n nan]]], dtype=float32, weak_type=True))"
|
||||||
},
|
},
|
||||||
"execution_count": 3,
|
"execution_count": 3,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
@@ -62,14 +62,14 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"unflatten = unflatten_conns(nodes, conns)\n",
|
"transformed = genome.transform(state, nodes, conns)\n",
|
||||||
"unflatten.shape"
|
"transformed"
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"collapsed": false,
|
"collapsed": false,
|
||||||
"ExecuteTime": {
|
"ExecuteTime": {
|
||||||
"end_time": "2024-05-30T11:40:59.472701700Z",
|
"end_time": "2024-05-31T09:01:45.729969500Z",
|
||||||
"start_time": "2024-05-30T11:40:59.021858400Z"
|
"start_time": "2024-05-31T09:01:45.178173400Z"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"id": "aaa88227bbf29936"
|
"id": "aaa88227bbf29936"
|
||||||
@@ -80,7 +80,7 @@
|
|||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": "(Array([[ 0. , 5. , 1. , -0.41923347],\n [ 1. , 5. , 1. , -3.1815007 ],\n [ 2. , 5. , 1. , 0.5184341 ],\n [ 5. , 3. , 1. , -1.9784615 ],\n [ 5. , 4. , 1. , 0.7132204 ],\n [ nan, nan, nan, nan],\n [ nan, nan, nan, nan],\n [ nan, nan, nan, nan],\n [ nan, nan, nan, nan],\n [ nan, nan, nan, nan]], dtype=float32, weak_type=True),\n Array([[ 0. , 5. , 1. , -0.41923347],\n [ 1. , 5. , 1. , -3.1815007 ],\n [ 2. , 5. , 1. , 0.5184341 ],\n [ 5. , 3. , 1. , -1.9784615 ],\n [ 5. , 4. , 1. , 0.7132204 ],\n [ nan, nan, nan, nan],\n [ nan, nan, nan, nan],\n [ nan, nan, nan, nan],\n [ nan, nan, nan, nan],\n [ nan, nan, nan, nan]], dtype=float32))"
|
"text/plain": "(Array([[ 0. , -1.013169 , 1. , 0. , 0. ],\n [ 1. , -0.3775248 , 1. , 0. , 0. ],\n [ 2. , 0.7407059 , 1. , 0. , 0. ],\n [ 3. , -0.66817343, 1. , 0. , 0. ],\n [ 4. , 0.5336131 , 1. , 0. , 0. ]], dtype=float32, weak_type=True),\n Array([[ 1. , 0. , 4. , 0.13149254],\n [ 1. , 1. , 4. , 0.02001922],\n [ 1. , 2. , 4. , -0.79229796],\n [ 1. , 4. , 3. , -0.57102853],\n [ 1. , nan, nan, nan]], dtype=float32))"
|
||||||
},
|
},
|
||||||
"execution_count": 4,
|
"execution_count": 4,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
@@ -89,18 +89,44 @@
|
|||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"# single flatten\n",
|
"# single flatten\n",
|
||||||
"flatten = flatten_conns(nodes, unflatten, C=10)\n",
|
"nodes, conns = genome.restore(state, transformed)\n",
|
||||||
"conns, flatten"
|
"nodes, conns"
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"collapsed": false,
|
"collapsed": false,
|
||||||
"ExecuteTime": {
|
"ExecuteTime": {
|
||||||
"end_time": "2024-05-30T11:41:00.308954100Z",
|
"end_time": "2024-05-31T09:01:46.660023600Z",
|
||||||
"start_time": "2024-05-30T11:40:59.469541500Z"
|
"start_time": "2024-05-31T09:01:45.724970700Z"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"id": "f2c65de38fdcff8f"
|
"id": "f2c65de38fdcff8f"
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": "Array([[ 1. , 3. , 0. , 1. , 4. ,\n 0.13149254],\n [ 1. , 3. , 1. , 1. , 4. ,\n 0.02001922],\n [ 1. , 3. , 2. , 1. , 4. ,\n -0.79229796],\n [ 1. , 3. , 4. , 1. , 3. ,\n -0.57102853],\n [ 1. , 3. , nan, 1. , nan,\n nan]], dtype=float32)"
|
||||||
|
},
|
||||||
|
"execution_count": 7,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"conns = jnp.insert(conns, obj=3, values=1, axis=1)\n",
|
||||||
|
"conns"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-31T09:03:35.665080500Z",
|
||||||
|
"start_time": "2024-05-31T09:03:35.013654700Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "10bcb665c32fb728"
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 8,
|
"execution_count": 8,
|
||||||
|
|||||||
Reference in New Issue
Block a user