From 3ea9986bd45c37822251156ca7af9bfd7dc72ce8 Mon Sep 17 00:00:00 2001 From: wls2002 Date: Thu, 30 May 2024 23:12:11 +0800 Subject: [PATCH] add "update_by_batch" in genome; add "normalized" gene, which can do normalization before activation func. add related test. --- .../algorithm/neat/gene/node/normalized.py | 12 +- tensorneat/algorithm/neat/genome/base.py | 5 +- tensorneat/algorithm/neat/genome/default.py | 77 ++++- tensorneat/algorithm/neat/genome/recurrent.py | 12 +- tensorneat/test/test_record_episode.py | 0 tensorneat/test/test_update_by_batch.ipynb | 317 ++++++++++++++++++ 6 files changed, 403 insertions(+), 20 deletions(-) create mode 100644 tensorneat/test/test_record_episode.py create mode 100644 tensorneat/test/test_update_by_batch.ipynb diff --git a/tensorneat/algorithm/neat/gene/node/normalized.py b/tensorneat/algorithm/neat/gene/node/normalized.py index 94f2558..62a48df 100644 --- a/tensorneat/algorithm/neat/gene/node/normalized.py +++ b/tensorneat/algorithm/neat/gene/node/normalized.py @@ -29,12 +29,12 @@ class NormalizedNode(BaseNodeGene): aggregation_default: callable = Agg.sum, aggregation_options: Tuple = (Agg.sum,), aggregation_replace_rate: float = 0.1, - alpha_init_mean: float = 0.0, + alpha_init_mean: float = 1.0, alpha_init_std: float = 1.0, alpha_mutate_power: float = 0.5, alpha_mutate_rate: float = 0.7, alpha_replace_rate: float = 0.1, - beta_init_mean: float = 1.0, + beta_init_mean: float = 0.0, beta_init_std: float = 1.0, beta_mutate_power: float = 0.5, beta_mutate_rate: float = 0.7, @@ -92,7 +92,7 @@ class NormalizedNode(BaseNodeGene): alpha = jax.random.normal(k5, ()) * self.alpha_init_std + self.alpha_init_mean beta = jax.random.normal(k6, ()) * self.beta_init_std + self.beta_init_mean - return jnp.array([bias, act, agg, 0, 1, alpha, beta]) + return jnp.array([bias, act, agg, mean, std, alpha, beta]) def mutate(self, state, randkey, node): k1, k2, k3, k4, k5, k6 = jax.random.split(state.randkey, num=6) @@ -178,13 +178,13 @@ class NormalizedNode(BaseNodeGene): batch_z = bias + batch_z # calculate mean - valid_values_count = jnp.sum(~jnp.isnan(batch_inputs)) - valid_values_sum = jnp.sum(jnp.where(jnp.isnan(batch_inputs), 0, batch_inputs)) + valid_values_count = jnp.sum(~jnp.isnan(batch_z)) + valid_values_sum = jnp.sum(jnp.where(jnp.isnan(batch_z), 0, batch_z)) mean = valid_values_sum / valid_values_count # calculate std std = jnp.sqrt( - jnp.sum(jnp.where(jnp.isnan(batch_inputs), 0, (batch_inputs - mean) ** 2)) + jnp.sum(jnp.where(jnp.isnan(batch_z), 0, (batch_z - mean) ** 2)) / valid_values_count ) diff --git a/tensorneat/algorithm/neat/genome/base.py b/tensorneat/algorithm/neat/genome/base.py index c78c7f1..f52a210 100644 --- a/tensorneat/algorithm/neat/genome/base.py +++ b/tensorneat/algorithm/neat/genome/base.py @@ -39,6 +39,9 @@ class BaseGenome: def transform(self, state, nodes, conns): raise NotImplementedError + def restore(self, state, transformed): + raise NotImplementedError + def forward(self, state, inputs, transformed): raise NotImplementedError @@ -121,7 +124,7 @@ class BaseGenome: return nodes, conns - def update_by_batch(self, state, batch_input, nodes, conns): + def update_by_batch(self, state, batch_input, transformed): """ Update the genome by a batch of data. """ diff --git a/tensorneat/algorithm/neat/genome/default.py b/tensorneat/algorithm/neat/genome/default.py index b0ff770..b880fc9 100644 --- a/tensorneat/algorithm/neat/genome/default.py +++ b/tensorneat/algorithm/neat/genome/default.py @@ -1,7 +1,7 @@ from typing import Callable import jax, jax.numpy as jnp -from utils import unflatten_conns, topological_sort, I_INF +from utils import unflatten_conns, flatten_conns, topological_sort, I_INF from . import BaseGenome from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene @@ -53,17 +53,21 @@ class DefaultGenome(BaseGenome): return seqs, nodes, u_conns - def forward(self, state, inputs, transformed): - cal_seqs, nodes, conns = transformed + def restore(self, state, transformed): + seqs, nodes, u_conns = transformed + conns = flatten_conns(nodes, u_conns, C=self.max_conns) + return nodes, conns - N = nodes.shape[0] - ini_vals = jnp.full((N,), jnp.nan) + def forward(self, state, inputs, transformed): + cal_seqs, nodes, u_conns = transformed + + ini_vals = jnp.full((self.max_nodes,), jnp.nan) ini_vals = ini_vals.at[self.input_idx].set(inputs) nodes_attrs = nodes[:, 1:] def cond_fun(carry): values, idx = carry - return (idx < N) & (cal_seqs[idx] != I_INF) + return (idx < self.max_nodes) & (cal_seqs[idx] != I_INF) def body_func(carry): values, idx = carry @@ -71,7 +75,7 @@ class DefaultGenome(BaseGenome): def hit(): ins = jax.vmap(self.conn_gene.forward, in_axes=(None, 1, 0))( - state, conns[:, :, i], values + state, u_conns[:, :, i], values ) z = self.node_gene.forward( state, @@ -80,6 +84,7 @@ class DefaultGenome(BaseGenome): is_output_node=jnp.isin(i, self.output_idx), ) new_values = values.at[i].set(z) + return new_values # the val of input nodes is obtained by the task, not by calculation @@ -94,5 +99,59 @@ class DefaultGenome(BaseGenome): else: return self.output_transform(vals[self.output_idx]) - def update_by_batch(self, state, batch_input, nodes, conns): - pass + def update_by_batch(self, state, batch_input, transformed): + cal_seqs, nodes, u_conns = transformed + + batch_size = batch_input.shape[0] + batch_ini_vals = jnp.full((batch_size, self.max_nodes), jnp.nan) + batch_ini_vals = batch_ini_vals.at[:, self.input_idx].set(batch_input) + nodes_attrs = nodes[:, 1:] + + def cond_fun(carry): + batch_values, nodes_attrs_, u_conns_, idx = carry + return (idx < self.max_nodes) & (cal_seqs[idx] != I_INF) + + def body_func(carry): + batch_values, nodes_attrs_, u_conns_, idx = carry + i = cal_seqs[idx] + + def hit(): + batch_ins, new_conn_attrs = jax.vmap( + self.conn_gene.update_by_batch, in_axes=(None, 1, 1), out_axes=(1, 1) + )(state, u_conns_[:, :, i], batch_values) + batch_z, new_node_attrs = self.node_gene.update_by_batch( + state, + nodes_attrs[i], + batch_ins, + is_output_node=jnp.isin(i, self.output_idx), + ) + new_batch_values = batch_values.at[:, i].set(batch_z) + return ( + new_batch_values, + nodes_attrs_.at[i].set(new_node_attrs), + u_conns_.at[:, :, i].set(new_conn_attrs), + ) + + (batch_values, nodes_attrs_, u_conns_) = jax.lax.cond( + jnp.isin(i, self.input_idx), + lambda: (batch_values, nodes_attrs_, u_conns_), + hit, + ) + # the val of input nodes is obtained by the task, not by calculation + + return batch_values, nodes_attrs_, u_conns_, idx + 1 + + batch_vals, nodes_attrs, u_conns, _ = jax.lax.while_loop( + cond_fun, body_func, (batch_ini_vals, nodes_attrs, u_conns, 0) + ) + + nodes = nodes.at[:, 1:].set(nodes_attrs) + new_transformed = (cal_seqs, nodes, u_conns) + + if self.output_transform is None: + return batch_vals[:, self.output_idx], new_transformed + else: + return ( + jax.vmap(self.output_transform)(batch_vals[:, self.output_idx]), + new_transformed, + ) diff --git a/tensorneat/algorithm/neat/genome/recurrent.py b/tensorneat/algorithm/neat/genome/recurrent.py index 88d88e8..d3dae82 100644 --- a/tensorneat/algorithm/neat/genome/recurrent.py +++ b/tensorneat/algorithm/neat/genome/recurrent.py @@ -1,7 +1,7 @@ from typing import Callable import jax, jax.numpy as jnp -from utils import unflatten_conns +from utils import unflatten_conns, flatten_conns from . import BaseGenome from ..gene import BaseNodeGene, BaseConnGene, DefaultNodeGene, DefaultConnGene @@ -54,11 +54,15 @@ class RecurrentGenome(BaseGenome): return nodes, u_conns + def restore(self, state, transformed): + nodes, u_conns = transformed + conns = flatten_conns(nodes, u_conns, C=self.max_conns) + return nodes, conns + def forward(self, state, inputs, transformed): nodes, conns = transformed - N = nodes.shape[0] - vals = jnp.full((N,), jnp.nan) + vals = jnp.full((self.max_nodes,), jnp.nan) nodes_attrs = nodes[:, 1:] # remove index def body_func(_, values): @@ -73,7 +77,7 @@ class RecurrentGenome(BaseGenome): )(state, conns, values) # calculate nodes - is_output_nodes = jnp.isin(jnp.arange(N), self.output_idx) + is_output_nodes = jnp.isin(jnp.arange(self.max_nodes), self.output_idx) values = jax.vmap(self.node_gene.forward, in_axes=(None, 0, 0, 0))( state, nodes_attrs, node_ins.T, is_output_nodes ) diff --git a/tensorneat/test/test_record_episode.py b/tensorneat/test/test_record_episode.py new file mode 100644 index 0000000..e69de29 diff --git a/tensorneat/test/test_update_by_batch.ipynb b/tensorneat/test/test_update_by_batch.ipynb new file mode 100644 index 0000000..7c427a3 --- /dev/null +++ b/tensorneat/test/test_update_by_batch.ipynb @@ -0,0 +1,317 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "initial_id", + "metadata": { + "collapsed": true, + "ExecuteTime": { + "end_time": "2024-05-30T15:07:59.805322900Z", + "start_time": "2024-05-30T15:07:57.075364700Z" + } + }, + "outputs": [], + "source": [ + "import jax, jax.numpy as jnp\n", + "from algorithm.neat.genome import *\n", + "from algorithm.neat.gene import *\n", + "\n", + "jnp.set_printoptions(precision=2, linewidth=150)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "outputs": [], + "source": [ + "# genome = DefaultGenome(num_inputs=3, num_outputs=2, max_nodes=10, max_conns=10)\n", + "# state = genome.setup()\n", + "# randkey = jax.random.key(0)\n", + "# genome_key, input_key = jax.random.split(randkey)\n", + "# nodes, conns = genome.initialize(state, genome_key)\n", + "# inputs = jax.random.normal(input_key, (10, 3)) * 2 + 1 # std: 2, mean: 1\n", + "# print(nodes, conns, sep='\\n')\n", + "# print(inputs)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-30T15:07:59.817325200Z", + "start_time": "2024-05-30T15:07:59.809324300Z" + } + }, + "id": "c81fa2df52f01d93" + }, + { + "cell_type": "code", + "execution_count": 3, + "outputs": [], + "source": [ + "# transformed = genome.transform(state, nodes, conns)\n", + "# batch_output = jax.vmap(genome.forward, in_axes=(None, 0, None))(state, inputs, transformed)\n", + "# batch_output, transformed" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-30T15:07:59.817950Z", + "start_time": "2024-05-30T15:07:59.812323Z" + } + }, + "id": "d4b9aa0449c8d706" + }, + { + "cell_type": "code", + "execution_count": 4, + "outputs": [], + "source": [ + "# batch_output2, new_transformed = genome.update_by_batch(state, inputs, transformed)\n", + "# batch_output2, new_transformed" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-30T15:07:59.831323800Z", + "start_time": "2024-05-30T15:07:59.821324100Z" + } + }, + "id": "d32986470dad3229" + }, + { + "cell_type": "code", + "execution_count": 5, + "outputs": [], + "source": [ + "# assert jnp.allclose(new_transformed[0], transformed[0], equal_nan=True)\n", + "# assert jnp.allclose(new_transformed[1], transformed[1], equal_nan=True)\n", + "# assert jnp.allclose(new_transformed[2], transformed[2], equal_nan=True)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-30T15:07:59.832325200Z", + "start_time": "2024-05-30T15:07:59.826324400Z" + } + }, + "id": "3c4007dfd6770faf" + }, + { + "cell_type": "code", + "execution_count": 6, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[ 0. 0. 0. 0. 0. 1. 1. 0.]\n", + " [ 1. 0. 0. 0. 0. 1. 1. 0.]\n", + " [ 2. 0. 0. 0. 0. 1. 1. 0.]\n", + " [ 3. 0. 0. 0. 0. 1. 1. 0.]\n", + " [ 4. 0. 0. 0. 0. 1. 1. 0.]\n", + " [ 5. 0. 0. 0. 0. 1. 1. 0.]\n", + " [nan 0. 0. 0. 0. 1. 1. 0.]\n", + " [nan 0. 0. 0. 0. 1. 1. 0.]\n", + " [nan 0. 0. 0. 0. 1. 1. 0.]\n", + " [nan 0. 0. 0. 0. 1. 1. 0.]]\n", + "[[ 0. 5. 1. 1.]\n", + " [ 1. 5. 1. 1.]\n", + " [ 2. 5. 1. 1.]\n", + " [ 5. 3. 1. 1.]\n", + " [ 5. 4. 1. 1.]\n", + " [nan nan nan 1.]\n", + " [nan nan nan 1.]\n", + " [nan nan nan 1.]\n", + " [nan nan nan 1.]\n", + " [nan nan nan 1.]]\n", + "[[-1.9 -3.53 0.94]\n", + " [ 2.92 0.06 3.44]\n", + " [-0.9 -0.06 2.94]\n", + " ...\n", + " [ 2.07 -1.43 1.55]\n", + " [ 1.93 2.85 0.19]\n", + " [ 0.91 -0.65 1.86]]\n" + ] + }, + { + "data": { + "text/plain": "(Array([ 0, 1, 2, 5, 3, 4, 2147483647, 2147483647, 2147483647, 2147483647], dtype=int32, weak_type=True),\n Array([[ 0., 0., 0., 0., 0., 1., 1., 0.],\n [ 1., 0., 0., 0., 0., 1., 1., 0.],\n [ 2., 0., 0., 0., 0., 1., 1., 0.],\n [ 3., 0., 0., 0., 0., 1., 1., 0.],\n [ 4., 0., 0., 0., 0., 1., 1., 0.],\n [ 5., 0., 0., 0., 0., 1., 1., 0.],\n [nan, 0., 0., 0., 0., 1., 1., 0.],\n [nan, 0., 0., 0., 0., 1., 1., 0.],\n [nan, 0., 0., 0., 0., 1., 1., 0.],\n [nan, 0., 0., 0., 0., 1., 1., 0.]], dtype=float32, weak_type=True),\n Array([[[nan, nan, nan, nan, nan, 1., nan, nan, nan, nan],\n [nan, nan, nan, nan, nan, 1., nan, nan, nan, nan],\n [nan, nan, nan, nan, nan, 1., nan, nan, nan, nan],\n [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],\n [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],\n [nan, nan, nan, 1., 1., nan, nan, nan, nan, nan],\n [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],\n [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],\n [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],\n [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan]]], dtype=float32, weak_type=True))" + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from algorithm.neat.gene.node.normalized import NormalizedNode\n", + "from algorithm.neat.gene.conn import DefaultConnGene\n", + "from utils import Act\n", + "\n", + "genome = DefaultGenome(num_inputs=3, num_outputs=2, max_nodes=10, max_conns=10,\n", + " node_gene=NormalizedNode(activation_default=Act.identity, activation_options=(Act.identity,)),\n", + " conn_gene=DefaultConnGene(weight_init_mean=1))\n", + "state = genome.setup()\n", + "randkey = jax.random.key(0)\n", + "genome_key, input_key = jax.random.split(randkey)\n", + "nodes, conns = genome.initialize(state, genome_key)\n", + "nodes = nodes.at[:, 1:].set(genome.node_gene.new_custom_attrs(state))\n", + "conns = conns.at[:, 3:].set(genome.conn_gene.new_custom_attrs(state))\n", + "\n", + "inputs = jax.random.normal(input_key, (10000, 3)) * 2 + 1 # std: 2, mean: 1\n", + "print(nodes, conns, sep='\\n')\n", + "print(inputs)\n", + "transformed = genome.transform(state, nodes, conns)\n", + "transformed" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-30T15:08:04.532243100Z", + "start_time": "2024-05-30T15:07:59.832325200Z" + } + }, + "id": "da73909c3414366e" + }, + { + "cell_type": "code", + "execution_count": 7, + "outputs": [ + { + "data": { + "text/plain": "Array([[-4.49, -4.49],\n [ 6.42, 6.42],\n [ 1.98, 1.98],\n ...,\n [ 2.19, 2.19],\n [ 4.97, 4.97],\n [ 2.12, 2.12]], dtype=float32, weak_type=True)" + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "batch_output2 = jax.vmap(genome.forward, in_axes=(None, 0, None))(state, inputs, transformed)\n", + "batch_output2" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-30T15:08:04.901593900Z", + "start_time": "2024-05-30T15:08:04.527245300Z" + } + }, + "id": "8ef2402bc4c7908d" + }, + { + "cell_type": "code", + "execution_count": 8, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "batch_z: [-4.49 6.42 1.98 ... 2.19 4.97 2.12]\n", + "batch_z_mean: 2.9496588706970215\n", + "batch_z: [-2.15 1. -0.28 ... -0.22 0.58 -0.24]\n", + "batch_z_mean: -2.1362303925798187e-08\n", + "batch_z: [-2.15 1. -0.28 ... -0.22 0.58 -0.24]\n", + "batch_z_mean: -2.1362303925798187e-08\n" + ] + } + ], + "source": [ + "batch_output, new_transformed = genome.update_by_batch(state, inputs, transformed)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-30T15:08:05.269935400Z", + "start_time": "2024-05-30T15:08:04.899594200Z" + } + }, + "id": "b3c085c7ca28f127" + }, + { + "cell_type": "code", + "execution_count": 9, + "outputs": [ + { + "data": { + "text/plain": "(Array([[-2.15, -2.15],\n [ 1. , 1. ],\n [-0.28, -0.28],\n ...,\n [-0.22, -0.22],\n [ 0.58, 0.58],\n [-0.24, -0.24]], dtype=float32, weak_type=True),\n (Array([ 0, 1, 2, 5, 3, 4, 2147483647, 2147483647, 2147483647, 2147483647], dtype=int32, weak_type=True),\n Array([[ 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e+00, 1.00e+00, 0.00e+00],\n [ 1.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e+00, 1.00e+00, 0.00e+00],\n [ 2.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e+00, 1.00e+00, 0.00e+00],\n [ 3.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, -2.14e-08, 1.00e+00, 1.00e+00, 0.00e+00],\n [ 4.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, -2.14e-08, 1.00e+00, 1.00e+00, 0.00e+00],\n [ 5.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 2.95e+00, 3.46e+00, 1.00e+00, 0.00e+00],\n [ nan, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e+00, 1.00e+00, 0.00e+00],\n [ nan, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e+00, 1.00e+00, 0.00e+00],\n [ nan, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e+00, 1.00e+00, 0.00e+00],\n [ nan, 0.00e+00, 0.00e+00, 0.00e+00, 0.00e+00, 1.00e+00, 1.00e+00, 0.00e+00]], dtype=float32, weak_type=True),\n Array([[[nan, nan, nan, nan, nan, 1., nan, nan, nan, nan],\n [nan, nan, nan, nan, nan, 1., nan, nan, nan, nan],\n [nan, nan, nan, nan, nan, 1., nan, nan, nan, nan],\n [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],\n [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],\n [nan, nan, nan, 1., 1., nan, nan, nan, nan, nan],\n [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],\n [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],\n [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],\n [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan]]], dtype=float32, weak_type=True)))" + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "batch_output, new_transformed" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-30T15:08:05.270935800Z", + "start_time": "2024-05-30T15:08:05.261936200Z" + } + }, + "id": "60ce6747ebd95e10" + }, + { + "cell_type": "code", + "execution_count": 10, + "outputs": [ + { + "data": { + "text/plain": "Array([[-2.15, -2.15],\n [ 1. , 1. ],\n [-0.28, -0.28],\n ...,\n [-0.22, -0.22],\n [ 0.58, 0.58],\n [-0.24, -0.24]], dtype=float32, weak_type=True)" + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "batch_output2 = jax.vmap(genome.forward, in_axes=(None, 0, None))(state, inputs, new_transformed)\n", + "batch_output2" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-30T15:08:05.415934Z", + "start_time": "2024-05-30T15:08:05.269935400Z" + } + }, + "id": "7b092224d8f33b7" + }, + { + "cell_type": "code", + "execution_count": 10, + "outputs": [], + "source": [], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-05-30T15:08:05.416935400Z", + "start_time": "2024-05-30T15:08:05.405934300Z" + } + }, + "id": "eec974242eb3867e" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}