{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from IPython.display import Image" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# CNTK 105: Basic autoencoder (AE) with MNIST data\n", "\n", "**Prerequisites**: We assume that you have successfully downloaded the MNIST data by completing the tutorial titled CNTK_103A_MNIST_DataLoader.ipynb.\n", "\n", "\n", "## Introduction\n", "\n", "In this tutorial we introduce you to the basics of [Autoencoders](https://en.wikipedia.org/wiki/Autoencoder). An autoencoder is an artificial neural network used for unsupervised learning of efficient encodings. In other words, they are used for lossy data-specific compression that is learnt automatically instead of relying on human engineered features. The aim of an autoencoder is to learn a representation (encoding) for a set of data, typically for the purpose of dimensionality reduction. \n", "\n", "The autoencoders are very specific to the data-set on hand and are different from standard codecs such as JPEG, MPEG standard based encodings. Once the information is encoded and decoded back to original dimensions some amount of information is lost in the process. Given these encodings are specific to data, autoencoders are not used for compression. However, there are two areas where autoencoders have been found very effective: denoising and dimensionality reduction.\n", "\n", "Autoencoders have attracted attention since they have long been thought to be a potential approach for unsupervised learning. Truly unsupervised approaches involve learning useful representations without the need for labels. Autoencoders fall under self-supervised learning, a specific instance of supervised learning where the targets are generated from the input data. \n", "\n", "**Goal** \n", "\n", "Our goal is to train an autoencoder that compresses MNIST digits image to a vector of smaller dimension and then restores the image. The MNIST data comprises of hand-written digits with little background noise." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Figure 1\n", "Image(url=\"http://cntk.ai/jup/MNIST-image.jpg\", width=300, height=300)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this tutorial, we will use the [MNIST hand-written digits data](https://en.wikipedia.org/wiki/MNIST_database) to show how images can be encoded and decoded (restored) using feed-forward networks. We will visualize the original and the restored images. We illustrate feed forward network based on two autoencoders: simple and deep autoencoder. More advanced autoencoders will be covered in future 200 series tutorials." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# Import the relevant modules\n", "from __future__ import print_function # Use a function definition from future version (say 3.x from 2.7 interpreter)\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import os\n", "import sys\n", "\n", "# Import CNTK \n", "import cntk as C\n", "import cntk.tests.test_utils\n", "cntk.tests.test_utils.set_device_from_pytest_env() # (only needed for our build system)\n", "C.cntk_py.set_fixed_random_seed(1) # fix a random seed for CNTK components\n", "\n", "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "There are two run modes:\n", "- *Fast mode*: `isFast` is set to `True`. This is the default mode for the notebooks, which means we train for fewer iterations or train / test on limited data. This ensures functional correctness of the notebook though the models produced are far from what a completed training would produce.\n", "\n", "- *Slow mode*: We recommend the user to set this flag to `False` once the user has gained familiarity with the notebook content and wants to gain insight from running the notebooks for a longer period with different parameters for training. " ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": true }, "outputs": [], "source": [ "isFast = True" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data reading\n", "\n", "In this section, we will read the data generated in CNTK 103 Part A.\n", "\n", "The data is in the following format:\n", "\n", " |labels 0 0 0 0 0 0 0 1 0 0 |features 0 0 0 0 ... \n", " (784 integers each representing a pixel)\n", " \n", " In this tutorial we are going to use the image pixels corresponding the integer stream named \"features\". We define a `create_reader` function to read the training and test data using the [CTF deserializer](https://cntk.ai/pythondocs/cntk.io.html?highlight=ctfdeserializer#cntk.io.CTFDeserializer). The labels are [1-hot encoded](https://en.wikipedia.org/wiki/One-hot). We ignore them in this tutorial. \n", "\n", "We also check if the training and test data file has been downloaded and available for reading by the `create_reader` function. In this tutorial we are using the MNIST data you have downloaded using CNTK_103A_MNIST_DataLoader notebook. The dataset has 60,000 training images and 10,000 test images with each image being 28 x 28 pixels." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# Read a CTF formatted text (as mentioned above) using the CTF deserializer from a file\n", "def create_reader(path, is_training, input_dim, num_label_classes):\n", " return C.io.MinibatchSource(C.io.CTFDeserializer(path, C.io.StreamDefs(\n", " labels_viz = C.io.StreamDef(field='labels', shape=num_label_classes, is_sparse=False),\n", " features = C.io.StreamDef(field='features', shape=input_dim, is_sparse=False)\n", " )), randomize = is_training, max_sweeps = C.io.INFINITELY_REPEAT if is_training else 1)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Data directory is ..\\Examples\\Image\\DataSets\\MNIST\n" ] } ], "source": [ "# Ensure the training and test data is generated and available for this tutorial.\n", "# We search in two locations in the toolkit for the cached MNIST data set.\n", "data_found = False\n", "for data_dir in [os.path.join(\"..\", \"Examples\", \"Image\", \"DataSets\", \"MNIST\"),\n", " os.path.join(\"data\", \"MNIST\")]:\n", " train_file = os.path.join(data_dir, \"Train-28x28_cntk_text.txt\")\n", " test_file = os.path.join(data_dir, \"Test-28x28_cntk_text.txt\")\n", " if os.path.isfile(train_file) and os.path.isfile(test_file):\n", " data_found = True\n", " break\n", " \n", "if not data_found:\n", " raise ValueError(\"Please generate the data by completing CNTK 103 Part A\")\n", "print(\"Data directory is {0}\".format(data_dir))\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model Creation (Simple AE)\n", "\n", "We start with a simple single fully-connected feedforward network as encoder and as decoder (as shown in the figure below):" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Figure 2\n", "Image(url=\"http://cntk.ai/jup/SimpleAEfig.jpg\", width=200, height=200)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The input data is a set of hand written digits images each of 28 x 28 pixels. In this tutorial, we will consider each image as a linear array of 784 pixel values. These pixels are considered as an input having 784 dimensions, one per pixel. Since the goal of the autoencoder is to compress the data and reconstruct the original image, the output dimension is same as the input dimension. We will compress the input to mere 32 dimensions (referred to as the `encoding_dim`). Additionally, since the maximum input value is 255, we normalize the input between 0 and 1. " ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": true }, "outputs": [], "source": [ "input_dim = 784\n", "encoding_dim = 32\n", "output_dim = input_dim\n", "\n", "def create_model(features):\n", " with C.layers.default_options(init = C.glorot_uniform()):\n", " # We scale the input pixels to 0-1 range\n", " encode = C.layers.Dense(encoding_dim, activation = C.relu)(features/255.0)\n", " decode = C.layers.Dense(input_dim, activation = C.sigmoid)(encode)\n", "\n", " return decode" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Train and test the model\n", "\n", "In previous tutorials, we have defined each of the training and testing phases separately. In this tutorial, we combine the two components in one place such that this template could be used as a recipe for your usage. \n", "\n", "The `train_and_test` function performs two major tasks:\n", "- Train the model\n", "- Evaluate the accuracy of the model on test data\n", "\n", "For training:\n", "\n", "> The function takes a reader (`reader_train`), a model function (`model_func`) and the target (a.k.a `label`) as input. In this tutorial, we show how to create and pass your **own** loss function. We normalize the `label` function to emit value between 0 and 1 for us to compute the label error using `C.classification_error` function.\n", "\n", "> We use Adam optimizer in this tutorial from a range of [learners](https://www.cntk.ai/pythondocs/cntk.learner.html#module-cntk.learner) (optimizers) available in the toolkit. \n", "\n", "For testing:\n", "\n", "> The function additionally takes a reader (`reader_test`) and evaluates the predicted pixel values made by the model against reference data, in this case the original pixel values for each image.\n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def train_and_test(reader_train, reader_test, model_func):\n", " \n", " ###############################################\n", " # Training the model\n", " ###############################################\n", " \n", " # Instantiate the input and the label variables\n", " input = C.input_variable(input_dim)\n", " label = C.input_variable(input_dim)\n", " \n", " # Create the model function\n", " model = model_func(input)\n", " \n", " # The labels for this network is same as the input MNIST image.\n", " # Note: Inside the model we are scaling the input to 0-1 range\n", " # Hence we rescale the label to the same range\n", " # We show how one can use their custom loss function\n", " # loss = -(y* log(p)+ (1-y) * log(1-p)) where p = model output and y = target\n", " # We have normalized the input between 0-1. Hence we scale the target to same range\n", " \n", " target = label/255.0 \n", " loss = -(target * C.log(model) + (1 - target) * C.log(1 - model))\n", " label_error = C.classification_error(model, target)\n", " \n", " # training config\n", " epoch_size = 30000 # 30000 samples is half the dataset size \n", " minibatch_size = 64\n", " num_sweeps_to_train_with = 5 if isFast else 100\n", " num_samples_per_sweep = 60000\n", " num_minibatches_to_train = (num_samples_per_sweep * num_sweeps_to_train_with) // minibatch_size\n", " \n", " \n", " # Instantiate the trainer object to drive the model training\n", " lr_per_sample = [0.00003]\n", " lr_schedule = C.learning_parameter_schedule_per_sample(lr_per_sample, epoch_size)\n", " \n", " # Momentum which is applied on every minibatch_size = 64 samples\n", " momentum_schedule = C.momentum_schedule(0.9126265014311797, minibatch_size)\n", " \n", " # We use a variant of the Adam optimizer which is known to work well on this dataset\n", " # Feel free to try other optimizers from \n", " # https://www.cntk.ai/pythondocs/cntk.learner.html#module-cntk.learner\n", " learner = C.fsadagrad(model.parameters,\n", " lr=lr_schedule, momentum=momentum_schedule) \n", " \n", " # Instantiate the trainer\n", " progress_printer = C.logging.ProgressPrinter(0)\n", " trainer = C.Trainer(model, (loss, label_error), learner, progress_printer)\n", " \n", " # Map the data streams to the input and labels.\n", " # Note: for autoencoders input == label\n", " input_map = {\n", " input : reader_train.streams.features,\n", " label : reader_train.streams.features\n", " } \n", " \n", " aggregate_metric = 0\n", " for i in range(num_minibatches_to_train):\n", " # Read a mini batch from the training data file\n", " data = reader_train.next_minibatch(minibatch_size, input_map = input_map)\n", " \n", " # Run the trainer on and perform model training\n", " trainer.train_minibatch(data)\n", " samples = trainer.previous_minibatch_sample_count\n", " aggregate_metric += trainer.previous_minibatch_evaluation_average * samples\n", " \n", " train_error = (aggregate_metric*100.0) / (trainer.total_number_of_samples_seen)\n", " print(\"Average training error: {0:0.2f}%\".format(train_error))\n", " \n", " #############################################################################\n", " # Testing the model\n", " # Note: we use a test file reader to read data different from a training data\n", " #############################################################################\n", " \n", " # Test data for trained model\n", " test_minibatch_size = 32\n", " num_samples = 10000\n", " num_minibatches_to_test = num_samples / test_minibatch_size\n", " test_result = 0.0\n", " \n", " # Test error metric calculation\n", " metric_numer = 0\n", " metric_denom = 0\n", "\n", " test_input_map = {\n", " input : reader_test.streams.features,\n", " label : reader_test.streams.features\n", " }\n", "\n", " for i in range(0, int(num_minibatches_to_test)):\n", " \n", " # We are loading test data in batches specified by test_minibatch_size\n", " # Each data point in the minibatch is a MNIST digit image of 784 dimensions \n", " # with one pixel per dimension that we will encode / decode with the \n", " # trained model.\n", " data = reader_test.next_minibatch(test_minibatch_size,\n", " input_map = test_input_map)\n", "\n", " # Specify the mapping of input variables in the model to actual\n", " # minibatch data to be tested with\n", " eval_error = trainer.test_minibatch(data)\n", " \n", " # minibatch data to be trained with\n", " metric_numer += np.abs(eval_error * test_minibatch_size)\n", " metric_denom += test_minibatch_size\n", "\n", " # Average of evaluation errors of all test minibatches\n", " test_error = (metric_numer*100.0) / (metric_denom) \n", " print(\"Average test error: {0:0.2f}%\".format(test_error))\n", " \n", " return model, train_error, test_error" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let us train the simple autoencoder. We create a training and a test reader" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "f:\\projects\\cntk\\CNTK\\bindings\\python\\cntk\\learners\\__init__.py:340: RuntimeWarning: When providing the schedule as a number, epoch_size is ignored\n", " warnings.warn('When providing the schedule as a number, epoch_size is ignored', RuntimeWarning)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " average since average since examples\n", " loss last metric last \n", " ------------------------------------------------------\n", "Learning rate per 1 samples: 3e-05\n", " 544 544 0.846 0.846 64\n", " 544 544 0.848 0.85 192\n", " 544 543 0.868 0.883 448\n", " 542 541 0.859 0.852 960\n", " 538 533 0.848 0.837 1984\n", " 496 456 0.754 0.662 4032\n", " 385 275 0.584 0.417 8128\n", " 303 221 0.442 0.301 16320\n", " 250 197 0.339 0.236 32704\n", " 208 167 0.257 0.176 65472\n", " 173 138 0.182 0.108 131008\n", " 142 111 0.116 0.0496 262080\n", "Average training error: 10.57%\n", "Average test error: 2.98%\n" ] } ], "source": [ "num_label_classes = 10\n", "reader_train = create_reader(train_file, True, input_dim, num_label_classes)\n", "reader_test = create_reader(test_file, False, input_dim, num_label_classes)\n", "model, simple_ae_train_error, simple_ae_test_error = train_and_test(reader_train, \n", " reader_test, \n", " model_func = create_model )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Visualize simple AE results" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Original image statistics:\n", "Max: 255.00, Median: 0.00, Mean: 24.07, Min: 0.00\n", "Decoded image statistics:\n", "Max: 252.06, Median: 0.44, Mean: 26.61, Min: 0.00\n" ] } ], "source": [ "# Read some data to run the eval\n", "num_label_classes = 10\n", "reader_eval = create_reader(test_file, False, input_dim, num_label_classes)\n", "\n", "eval_minibatch_size = 50\n", "eval_input_map = { input : reader_eval.streams.features } \n", " \n", "eval_data = reader_eval.next_minibatch(eval_minibatch_size,\n", " input_map = eval_input_map)\n", "\n", "img_data = eval_data[input].asarray()\n", "\n", "# Select a random image\n", "np.random.seed(0) \n", "idx = np.random.choice(eval_minibatch_size)\n", "\n", "orig_image = img_data[idx,:,:]\n", "decoded_image = model.eval(orig_image)[0]*255\n", "\n", "# Print image statistics\n", "def print_image_stats(img, text):\n", " print(text)\n", " print(\"Max: {0:.2f}, Median: {1:.2f}, Mean: {2:.2f}, Min: {3:.2f}\".format(np.max(img),\n", " np.median(img),\n", " np.mean(img),\n", " np.min(img))) \n", " \n", "# Print original image\n", "print_image_stats(orig_image, \"Original image statistics:\")\n", "\n", "# Print decoded image\n", "print_image_stats(decoded_image, \"Decoded image statistics:\")\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let us plot the original and the decoded image. They should look visually similar." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# Define a helper function to plot a pair of images\n", "def plot_image_pair(img1, text1, img2, text2):\n", " fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(6, 6))\n", "\n", " axes[0].imshow(img1, cmap=\"gray\")\n", " axes[0].set_title(text1)\n", " axes[0].axis(\"off\")\n", "\n", " axes[1].imshow(img2, cmap=\"gray\")\n", " axes[1].set_title(text2)\n", " axes[1].axis(\"off\")" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAADHCAYAAAAJSqg8AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4wLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvpW3flQAAEu9JREFUeJzt3XuwVeV5x/HfA4druF9EAeVQiYDx\nEhCLSKvOQIt4GW2n2tZWhUErySSpNjrWRA0qJjGjUadptCVO1WKSKi0z6GgNKGmLRR2jNrZeAAHl\n5jFyv8nNt3+sdeKa8z4L9j6c63u+nxlmNs963rXXPmftZ79nve96t4UQBABo/zq19gEAAJoGBR0A\nEkFBB4BEUNABIBEUdABIBAUdABJBQa+QmX3LzH7S1LkV7CuY2aiSbc+Z2dVN8TxAY5nZeWa2vjna\nmtkuM/udxh9dx1LT2gfQGsxshqRvSjpR0g5JCyXdEkLYVtYmhPDdSvdfTe7RCCFMb4nnQdtkZmsl\nDZF0UNIhSW9LelzSP4YQPmvFQ2syIYRerX0M7UmH66Gb2Tcl3SPpJkl9JZ0laYSkxWbWtaRNh/zg\nQ7twcQiht7Jz+PuSbpb0SOseElpLhyroZtZH0h2Svh5C+PcQwoEQwlpJlyt7Q/xlnjfHzBaY2Xwz\n2yFpRh6bX9jXVWb2gZltNrPbzGytmU0ttJ+fP67NL5tcbWYfmtknZvbtwn5+18yWm9k2M9tkZj8q\n+2BxXs8vzeya/PEMM3vJzO7P97XazM7O4+vM7OPi5Rkzu9DM3jCzHfn2OQ32fbjX18nM/tbM3s+3\nP2lmA6r/jaCphBC2hxAWSfpTSVeb2SmSZGbdzOze/NyrM7OHzaxHfTszu8TM3szPg/fN7Pw8PtTM\nFpnZFjNbZWbXFtr0MLNHzWyrmb0t6cziseRt/9XMfmNma8zsG5W2bah4yTFv9+P8UuOu/Hw/1swe\nyPf3rpmNK7StP0d3mtnbZvZHhW2dzey+/P24xsy+lj9XTb69r5k9kr8nN5jZXDPr3JjfTUvqUAVd\n0tmSukv6t2IwhLBL0nOS/qAQvkTSAkn9JD1RzDezkyX9WNJfSDpOWU9/2BGe+/ckjZY0RdLtZjY2\njx+SdIOkQZIm5du/WuXrqjdR0q8lDZT0U0k/V/aGGaXsw+pHZlb/J+xuSVflr+9CSV8xs0srfH3f\nkHSppHMlDZW0VdLfN/KY0YRCCK9KWi/p9/PQPZJOkvRlZefBMEm3S1lnQtklmpuUnQfnSFqbt/tZ\nvp+hkv5E0nfNbEq+7TvKLleeKGmapGJHoZOkpyX9T/5cUyRdb2bTjtS2QpdLulXZ+2WfpOWSXs//\nv0DSDwu57+c/h77KOnLzzey4fNu1kqbnP5fxys7noseUXcoaJWmcpD+UdE2Vx9ryQggd5p+yovZR\nybbvS1qcP54j6T8bbJ8jaX7++HZJPyts6ylpv6SpTm6tpCBpeCH/VUl/VnIc10taWPh/kDSqJPeX\nkq7JH8+QtLKw7dS87ZBCbLOkL5fs6wFJ91f4+t6RNKWw/ThJByTVtPbvuCP9U1Z8pzrxlyV9W5Ip\n++A+sbBtkqQ1+eN/qP+dN2h/vLKORu9C7HuSHs0fr5Z0fmHbX0lanz+eKOnDBvu7RdI/HaltyWv8\n7fkv6VFJ8wrbvi7pncL/T5W07TD7elPSJfnjFyVdV9g2NX+uGmXjEvsk9Shs/3NJS1v7d36kfx3t\n2vAnkgaZWU0I4WCDbcfl2+utO8x+hha3hxD2mNnmIzz3R4XHeyT1kiQzO0lZr2KCssJZI+lXR9hX\nmbrC4735sTWM1T/vRGUfYqdI6iqpm6Sn8rwjvb4RkhaaWXHg7ZCyN8KGRh47ms4wSVskDVZ2Tv3K\nzOq3maT6SwfHS3rWaT9U0pYQws5C7ANl52j99nUNttUbIWmomRUnGHSW9F8VtK1Ew/PZPb+l7LKh\npL9R1qlSvm1QyXEUH4+Q1EXSpsLPrZMOXxPahI52yWW5sk/ePy4GzewLyv78eqEQPtwylJskDS+0\n76HsMkdjPCTpXUlfDCH0kfQtZW+65vZTSYskHR9C6Cvp4cLzHun1rZM0PYTQr/CvewiBYt7KzOxM\nZQV9mbIOyl5JXyr8nvqGz2eOrFN26aOhjZIGmFnvQuwEff5hvUnZh0FxW711yv4CKJ4bvUMIF1TQ\ntsmY2QhJ8yR9TdLAEEI/Sf+rknO8wTGtU1YnBhVeQ58Qwpea41ibUocq6CGE7cqupf2dmZ1vZl3M\nrFZZz3S9pH+ucFcLJF2cDzp2zffZ2CLcW9nUyV1mNkbSVxq5n8Y875YQwqf5tdQrCtuO9PoelnR3\n/qaRmQ02s0ta6LjhMLM+ZnaRsnGT+SGEt0I2dXGepPvN7Jg8b1jhevYjkmaa2ZR8oHuYmY0JIayT\n9N+Svmdm3c3sNEmz9PlY0pOSbjGz/mY2XNmlj3qvStphZjfnA6CdzeyU/IPmSG2b0heUdcp+k7/u\nmcr+Gq33pKS/zl9zP2WzgyRJIYRNkn4h6b7859rJzE40s3Ob6VibTIcq6JIUQviBsl7wvcoK6SvK\nPpGnhBD2VbiP/1N2Iv5c2Sf9TkkfK/tUr9aNyorpTmVvvn9pxD4a46uS7jSzncqumT9Zv6GC1/eg\nst79L/L2Lyu7doqW93T+O1in7Lr5DyXNLGy/WdIqSS9bNmNribLBeYVsAHWmpPslbZf0H8ouN0jZ\nNeNaZb31hZK+E0JYnG+7Q9mlkjXKCt9vO0IhhEOSLlY22LhG2V8JP1E2MHnYtk0phPC2pPuU/VVe\np+z6+kuFlHn58/9a0hvKLj3Vz+eXsgkDXZXN7d+qrJNznNo4yy/44yjkM0e2Kbtssqa1j6eppf76\nADObLunhEMKIIya3YR2uh95UzOxiM+uZX3+/V9Jb+nzKV7uX+utDx5ZfDrrAzGrMbJiy6ZQLW/u4\njhYFvfEuUfbn6EZJX1Q2DTGlP3dSf33o2EzZ5Z+tyi65vKN8fn57xiUXAEgEPXQASAQFHQAS0aJ3\nipoZ13fQrEIILXFTVoRzG82tknObHjoAJIKCDgCJoKADQCIo6ACQCAo6ACSCgg4AiaCgA0AiKOgA\nkAgKOgAkgoIOAImgoANAIijoAJAICjoAJIKCDgCJaNHlcwG0DWbxSqxeTJK6dOkSxXr37u3m7t27\nN4p17tw5iu3bt89t/9lnn0WxQ4cOubnet6119G9go4cOAImgoANAIijoAJAICjoAJIKCDgCJYJZL\nlYYMGeLG58yZE8Vmz57t5noj8U888UQUu+2229z2a9euLT9AJK9sNoqnZ8+ebnzAgAFRbODAgW7u\npZdeWlF7SaqtrY1i+/fvj2Lbt2932z/33HNRbOnSpW6ut4+DBw+6uR0FPXQASAQFHQASQUEHgERQ\n0AEgEQyKHoY3ALpkyRI39+STT45i3m3MZa644ooo9thjj7m5DIp2bDU1/tu2a9euFcUkacyYMVFs\n8uTJbu7EiROj2AknnODmeoOw3jFs27bNbV9XVxfFVq5c6ebu3LkzipUNGDfXkgDe87Xm8gP00AEg\nERR0AEgEBR0AEkFBB4BEUNABIBHMcjmMuXPnRrGy0f158+ZFsa1bt7q5N9xwQxTzvkTgpptuctuX\nzbRBeqq5zb9bt25RrH///m5uv379oljZbfOLFi2KYqNGjap4v94yAd5yAJL/ZRZlPwNvts+BAwfc\n3KNVdgzel3eUzW6rZtZbY9FDB4BEUNABIBEUdABIBAUdABLBoOhheOstz5o1y81dsGBBxfsdNmxY\nFLvsssuimDfIJfm3UpcNMqF9q+bWcm9AsWyQcMWKFVHs9ddfr3i/3bt3d3O9QVFvmYGRI0e67dev\nXx/FjjnmGDf3gw8+iGLVDLaW6dQp7ud6MallBjqrQQ8dABJBQQeARFDQASARFHQASAQFHQASYS25\nGLuZtd7K722IN3tl8eLFUazsCwfOPvvsKPbKK68c/YElIIRQ+b3yTaglz+2yGRferfDerelluWUz\nQbwlAcqOoVevXlFs0qRJUWzcuHFu+z179kSxZcuWubmrVq2KYmXLbXizUcpu5/deW1mut99qZtRU\no5Jzmx46ACSCgg4AiaCgA0AiKOgAkAhu/W8F3oBQ2QAo0FDZ7ebe4GXZpIdqlorw9uENqkr+2udD\nhgyJYoMHD3bbf/zxx1GsbJ323bt3u3GPNzhczVrzZbyfTdl+W2ICCj10AEgEBR0AEkFBB4BEUNAB\nIBEUdABIBLNcWkHZgv0NvfXWW27cu+UZ8GZRlM0QqeaLM7zcslv/Tz311Cg2bdq0KDZw4EC3/caN\nG6PYhg0b3FzvFvvmmrlSNrPIy23J5VQaoocOAImgoANAIijoAJAICjoAJIJB0VZw7bXXVpRXV1fn\nxjdv3tyUh4NENNdt6N4+Ro0a5eaed955UWz8+PFRrOwcXr16dRTbu3evm+sNzJYN1lYzeHm0ua2J\nHjoAJIKCDgCJoKADQCIo6ACQCAo6ACSCWS7NqGwmwJlnnllR+759+7rxs846K4qVLRNQzZcAoGMr\nmyHSr1+/KHbllVe6uRdccEEU69OnTxTbsWNHxcd17LHHunFvH2VLHXjLBJQdg3ebf1ubzVKGHjoA\nJIKCDgCJoKADQCIo6ACQCAZFm5E3GCSVf+N5Q2WDpy+99FIUe/zxx93ce+65J4q9++67FT0/2r9q\n1jivqfHLwdixY6PYaaed5uZ27do1inmDrXv27HHbDxkyJIpNmjTJzfWsX7/ejb/55ptRrL3czl8N\neugAkAgKOgAkgoIOAImgoANAIhgUbUZld6ItWbIkinlrRg8YMKDi57rqqqvc+IgRI6LYRRdd5OaW\nDVQhPd7AX9mdosOHD49i+/fvd3NXrFgRxby7NA8cOOC2nzhxYkXPL/mTDp555hk3d+XKlVFs69at\nbm57Rg8dABJBQQeARFDQASARFHQASAQFHQASwSyXZrRq1So3Pm3atCg2ZsyYKDZu3Di3/fXXXx/F\nJkyY4Oaee+65UWzZsmVurreW9UcffeTmIj3ebBRJeuONN6LYvn373NxPP/00innri48cOdJt752D\ntbW1bm737t2jWNl3EHjfLbBp0yY3t+zn0B7QQweARFDQASARFHQASAQFHQASwaBoG+GtUV62bvmz\nzz4bxZYvX+7mjh49Ooqdfvrpbq63ljU6jrLb8T/88MMotnbt2or326VLlyi2a9cuN/ecc86JYmVL\nEnjxsi9W95bR8NaEb+/ooQNAIijoAJAICjoAJIKCDgCJoKADQCKY5dIObd++PYrt3bu3FY4EKSn7\ntnvvdv6j1bNnTzfuzYjxlg4oi3vLAUj+DK7OnTu7udXMfin7mbUWeugAkAgKOgAkgoIOAImgoANA\nIhgUbcOGDh3qxq+77rooNnbs2Ir3u3r1aje+e/fuiveB1tPSt6xX83ze7fiDBg2KYpMnT3bbn3TS\nSVGsR48ebu7Bgwej2HvvvefmbtiwIYq153XPy9BDB4BEUNABIBEUdABIBAUdABJBQQeARDDLpY2Y\nPn16FLvjjjvc3DPOOKPi/XozWrznkqTNmzdXvF+0DG+GSdkXPlTavimOoVu3bm5ubW1tFLvxxhuj\n2LRp09z2vXr1qvi4vC/ZWLp0qZu7cePGKObNkmnv6KEDQCIo6ACQCAo6ACSCgg4AiWBQtBnNnDnT\njd99991RrH///lHMW8O5zFNPPeXGb7311ii2atWqiveLtqdsoNNbS7ws1xvULBsk7NOnTxS78MIL\n3dxZs2ZFsfHjx0exmhq/9Hi342/ZssXNff7556PYiy++6Obu2LGjoueS2t4a59Wghw4AiaCgA0Ai\nKOgAkAgKOgAkgoIOAIlglksTmTFjRhR76KGH3FxvNkI15s6dG8XuuusuNzfF25s7Ou/b7iV/dkbZ\nLfp9+/at+PkmTJgQxWbPnu3mjh49OoqVzWjx7Nq1K4q98MILbu4DDzwQxbZt2+bmejNa2vNsljL0\n0AEgERR0AEgEBR0AEkFBB4BEMCjaRLz1nY928NNbIkCS7rzzziiW4jeYo7qBuwMHDlSc6w0eDho0\nyM0dMGBARe3LjsEb6Pzkk0/c9k8//XQUe/DBB93curq6KJbi7fzVoIcOAImgoANAIijoAJAICjoA\nJIKCDgCJYJZLE1m2bFkUu/zyy93cDRs2RLGpU6dGsbIvoii79RvpKfuCCo93Xuzfv9/N9ZaEKFsm\nYuHChVFs9erVbu7gwYMrOq7XXnvNbb9x48YoVvYamNkVo4cOAImgoANAIijoAJAICjoAJMJa8pZY\nM+sY99+i1YQQKh9FbEIpnNvVDMCW5R5tPekot+g3RiXnNj10AEgEBR0AEkFBB4BEUNABIBEUdABI\nBLf+A5BU3QwTZqO0TfTQASARFHQASAQFHQASQUEHgERQ0AEgERR0AEgEBR0AEkFBB4BEUNABIBEU\ndABIBAUdABJBQQeARFDQASARFHQASAQFHQASYaxrDABpoIcOAImgoANAIijoAJAICjoAJIKCDgCJ\noKADQCIo6ACQCAo6ACSCgg4AiaCgA0AiKOgAkAgKOgAkgoIOAImgoANAIijoAJAICjoAJIKCDgCJ\noKADQCIo6ACQCAo6ACSCgg4AiaCgA0AiKOgAkIj/BwoORkOPbfC6AAAAAElFTkSuQmCC\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Plot the original and the decoded image\n", "img1 = orig_image.reshape(28,28)\n", "text1 = 'Original image'\n", "\n", "img2 = decoded_image.reshape(28,28)\n", "text2 = 'Decoded image'\n", "\n", "plot_image_pair(img1, text1, img2, text2)" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "## Model Creation (Deep AE)\n", "\n", "We do not have to limit ourselves to a single layer as encoder or decoder, we could instead use a stack of dense layers. Let us create a deep autoencoder." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Figure 3\n", "Image(url=\"http://cntk.ai/jup/DeepAEfig.jpg\", width=500, height=300)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The encoding dimensions are 128, 64 and 32 while the decoding dimensions are symmetrically opposite 64, 128 and 784. This increases the number of parameters used to model the transformation and achieves lower error rates at the cost of longer training duration and memory footprint. If we train this deep encoder for larger number iterations by turning the `isFast` flag to be `False`, we get a lower error and the reconstructed images are also marginally better. " ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "collapsed": true }, "outputs": [], "source": [ "input_dim = 784\n", "encoding_dims = [128,64,32]\n", "decoding_dims = [64,128]\n", "\n", "encoded_model = None\n", "\n", "def create_deep_model(features):\n", " with C.layers.default_options(init = C.layers.glorot_uniform()):\n", " encode = C.element_times(C.constant(1.0/255.0), features)\n", "\n", " for encoding_dim in encoding_dims:\n", " encode = C.layers.Dense(encoding_dim, activation = C.relu)(encode)\n", "\n", " global encoded_model\n", " encoded_model= encode\n", " \n", " decode = encode\n", " for decoding_dim in decoding_dims:\n", " decode = C.layers.Dense(decoding_dim, activation = C.relu)(decode)\n", "\n", " decode = C.layers.Dense(input_dim, activation = C.sigmoid)(decode)\n", " return decode " ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "f:\\projects\\cntk\\CNTK\\bindings\\python\\cntk\\learners\\__init__.py:340: RuntimeWarning: When providing the schedule as a number, epoch_size is ignored\n", " warnings.warn('When providing the schedule as a number, epoch_size is ignored', RuntimeWarning)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " average since average since examples\n", " loss last metric last \n", " ------------------------------------------------------\n", "Learning rate per 1 samples: 3e-05\n", " 544 544 0.739 0.739 64\n", " 544 544 0.794 0.822 192\n", " 544 543 0.801 0.805 448\n", " 543 542 0.817 0.831 960\n", " 530 518 0.876 0.931 1984\n", " 415 304 0.743 0.615 4032\n", " 315 216 0.594 0.448 8128\n", " 259 204 0.493 0.392 16320\n", " 215 172 0.366 0.24 32704\n", " 177 138 0.254 0.141 65472\n", " 145 113 0.165 0.0759 131008\n", " 120 95.9 0.104 0.0431 262080\n", "Average training error: 9.52%\n", "Average test error: 2.87%\n" ] } ], "source": [ "num_label_classes = 10\n", "reader_train = create_reader(train_file, True, input_dim, num_label_classes)\n", "reader_test = create_reader(test_file, False, input_dim, num_label_classes)\n", "\n", "model, deep_ae_train_error, deep_ae_test_error = train_and_test(reader_train, \n", " reader_test, \n", " model_func = create_deep_model) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Visualize deep AE results" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Original image statistics:\n", "Max: 255.00, Median: 0.00, Mean: 24.07, Min: 0.00\n", "Decoded image statistics:\n", "Max: 248.16, Median: 0.02, Mean: 22.87, Min: 0.00\n" ] } ], "source": [ "# Run the same image as the simple autoencoder through the deep encoder\n", "orig_image = img_data[idx,:,:]\n", "decoded_image = model.eval(orig_image)[0]*255\n", "\n", "# Print image statistics\n", "def print_image_stats(img, text):\n", " print(text)\n", " print(\"Max: {0:.2f}, Median: {1:.2f}, Mean: {2:.2f}, Min: {3:.2f}\".format(np.max(img),\n", " np.median(img),\n", " np.mean(img),\n", " np.min(img))) \n", " \n", "# Print original image\n", "print_image_stats(orig_image, \"Original image statistics:\")\n", "\n", "# Print decoded image\n", "print_image_stats(decoded_image, \"Decoded image statistics:\")\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let us plot the original and the decoded image with the deep autoencoder. They should look visually similar." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAADHCAYAAAAJSqg8AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4wLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvpW3flQAAEghJREFUeJzt3XuQVvV9x/HPl4XVVZc7WHfBRSBB\njdHgBS/YriNUwQimHU1bWxUGqSYmUZo61kQMIiba0RBtGm2JU6SYpIh1JI7pCLXpxaKONCHWaxGI\nYBDlIle5//rHORvP7O/7sA/LXtjf837NMPPwPb/fOefZPc/nOXt+52IhBAEAur5unb0CAIC2QaAD\nQCIIdABIBIEOAIkg0AEgEQQ6ACSCQC+TmX3DzH7Y1m3LmFcws+Elpv3MzK5ri+UArWVmF5nZ2vbo\na2bbzWxo69eusnTv7BXoDGY2SdLXJQ2TtFXSU5JuDyF8VKpPCOHb5c7/UNoejhDC+I5YDo5MZrZa\n0vGS9knaL+l1SfMk/X0I4UAnrlqbCSEc19nr0JVU3B66mX1d0n2SbpXUS9J5khokLTaz6hJ9KvKL\nD13ChBBCrbJt+F5Jt0l6tHNXCZ2logLdzHpKukvSV0MI/xJC2BtCWC3pi8o+EH+Wt5thZgvNbL6Z\nbZU0Ka/NL8zrWjP7tZltNLPpZrbazMYW+s/PXw/JD5tcZ2bvmtkGM/tmYT6jzGypmX1kZuvM7Pul\nvlic9/NzM7s+fz3JzF4ws9n5vFaa2QV5fY2ZfVA8PGNmnzezX5jZ1nz6jGbzPtj762Zmf2Vm7+TT\nF5hZ30P/jaCthBC2hBAWSfojSdeZ2WmSZGZHmdn9+ba33sweMbOapn5mdoWZ/TLfDt4xs3F5vc7M\nFpnZJjNbYWZTC31qzGyumW02s9clnVNcl7zvk2b2oZmtMrOvldu3ueIhx7zfD/JDjdvz7f13zOx7\n+fzeNLORhb5N2+g2M3vdzP6gMK3KzB7IP4+rzOwr+bK659N7mdmj+WfyPTObZWZVrfnddKSKCnRJ\nF0g6WtI/F4shhO2Sfibp9wvlKyQtlNRb0uPF9mZ2qqQfSPpTSSco29Ovb2HZF0oaIWmMpDvN7JS8\nvl/SNEn9JZ2fT//yIb6vJudK+pWkfpJ+JOknyj4ww5V9WX3fzJr+hN0h6dr8/X1e0pfM7Atlvr+v\nSfqCpEZJdZI2S/rbVq4z2lAI4WVJayX9bl66T9KnJX1O2XZQL+lOKduZUHaI5lZl28HvSVqd9/tx\nPp86SVdK+raZjcmnfUvZ4cphki6VVNxR6Cbpp5KW58saI+kWM7u0pb5l+qKkO5R9XnZLWirpf/L/\nL5T03ULbd/KfQy9lO3LzzeyEfNpUSePzn8uZyrbnoseUHcoaLmmkpEskXX+I69rxQggV809ZqL1f\nYtq9khbnr2dI+o9m02dImp+/vlPSjwvTjpG0R9JYp+0QSUHSoEL7lyX9cYn1uEXSU4X/B0nDS7T9\nuaTr89eTJP1fYdpn877HF2obJX2uxLy+J2l2me/vDUljCtNPkLRXUvfO/h1X0j9l4TvWqb8o6ZuS\nTNkX97DCtPMlrcpf/13T77xZ/8HKdjRqC7XvSJqbv14paVxh2p9LWpu/PlfSu83md7ukf2ipb4n3\n+NvtX9JcSXMK074q6Y3C/z8r6aODzOuXkq7IXz8v6YbCtLH5srorG5fYLammMP1PJP1bZ//OW/pX\naceGN0jqb2bdQwj7mk07IZ/eZM1B5lNXnB5C2GlmG1tY9vuF1zslHSdJZvZpZXsVZysLzu6SlrUw\nr1LWF15/nK9b81rTcs9V9iV2mqRqSUdJeiJv19L7a5D0lJkVB972K/sgvNfKdUfbqZe0SdIAZdvU\nMjNrmmaSmg4dDJb0rNO/TtKmEMK2Qu3XyrbRpulrmk1r0iCpzsyKJxhUSfrPMvqWo/n27G7fUnbY\nUNJfKNupUj6tf4n1KL5ukNRD0rrCz62bDp4JR4RKO+SyVNk37x8Wi2Z2rLI/v/61UD7YbSjXSRpU\n6F+j7DBHazws6U1Jnwoh9JT0DWUfuvb2I0mLJA0OIfSS9EhhuS29vzWSxocQehf+HR1CIMw7mZmd\noyzQ/0vZDsrHkj5T+D31Cp+cObJG2aGP5n4jqa+Z1RZqJ+qTL+t1yr4MitOarFH2F0Bx26gNIVxW\nRt82Y2YNkuZI+oqkfiGE3pL+VyW28WbrtEZZTvQvvIeeIYTPtMe6tqWKCvQQwhZlx9L+xszGmVkP\nMxuibM90raR/LHNWCyVNyAcdq/N5tjaEa5WdOrndzE6W9KVWzqc1y90UQtiVH0u9ujCtpff3iKR7\n8g+NzGyAmV3RQesNh5n1NLPLlY2bzA8hvBqyUxfnSJptZgPzdvWF49mPSppsZmPyge56Mzs5hLBG\n0n9L+o6ZHW1mp0uaok/GkhZIut3M+pjZIGWHPpq8LGmrmd2WD4BWmdlp+RdNS33b0rHKdso+zN/3\nZGV/jTZZIOnm/D33VnZ2kCQphLBO0nOSHsh/rt3MbJiZNbbTuraZigp0SQoh/LWyveD7lQXpS8q+\nkceEEHaXOY/XlG2IP1H2Tb9N0gfKvtUP1V8qC9Ntyj58/9SKebTGlyXNNLNtyo6ZL2iaUMb7e1DZ\n3v1zef8XlR07Rcf7af47WKPsuPl3JU0uTL9N0gpJL1p2xtYSZYPzCtkA6mRJsyVtkfTvyg43SNkx\n4yHK9tafkvStEMLifNpdyg6VrFIWfL/dEQoh7Jc0Qdlg4yplfyX8UNnA5EH7tqUQwuuSHlD2V/l6\nZcfXXyg0mZMv/1eSfqHs0FPT+fxSdsJAtbJz+zcr28k5QUc4yw/44zDkZ458pOywyarOXp+2lvr7\nA8xsvKRHQggNLTY+glXcHnpbMbMJZnZMfvz9fkmv6pNTvrq81N8fKlt+OOgyM+tuZvXKTqd8qrPX\n63AR6K13hbI/R38j6VPKTkNM6c+d1N8fKpspO/yzWdkhlzeUn5/flXHIBQASwR46ACSCQAeARHTo\nlaJmxvEdtKsQQkdclBVh20Z7K2fbZg8dABJBoANAIgh0AEgEgQ4AiSDQASARBDoAJIJAB4BEEOgA\nkAgCHQASQaADQCIIdABIBIEOAIkg0AEgEQQ6ACSiQ2+fC+DIZebfnbVbt3i/r6qqym27d+/eqFZT\nUxPV9uzZ4/bft2/fwVYRLWAPHQASQaADQCIIdABIBIEOAIkg0AEgEZzlcoiOP/54tz5jxoyoduON\nN7ptQ4ifJ/z4449HtenTp7v9V69eXXoFgQLvDBVJOuqoo6Janz593LaDBw+OahMnTnTbnnjiiVGt\ntrY2qj3//PNu/wULFkS19evXu229z1GlYw8dABJBoANAIgh0AEgEgQ4AiWBQ9CC8AdAlS5a4bU89\n9dSoduDAgbKXdfXVV0e1xx57zG3LoCg83qX7pQZF+/btG9VGjRrltp0wYUJUa2xsdNvW19dHNW/w\n8vTTT3f7r1u3LqotWrTIbbt79263XsnYQweARBDoAJAIAh0AEkGgA0AiCHQASARnuRzErFmzopp3\nabMkzZkzJ6pt3rzZbTtt2rSo1qNHj6h26623uv1LnWmD9HhnrpR6uIR3Nskxxxzjth00aFBUGzJk\niNvW2443bNjgtvXOqunXr19U27Ztm9vfexiG99mQOMvFwx46ACSCQAeARBDoAJAIAh0AEsGg6EFs\n2bIlqk2ZMsVtu3DhwrLn610efdVVV0U1757VklRdXR3VSj1FHV1DqUv0vdtHlLqlhDd4WGq72Lhx\nY1R77rnn3LYDBgyIamvXrnXbXnDBBVHt2GOPLbv/ypUro9qh3EKj0rGHDgCJINABIBEEOgAkgkAH\ngEQQ6ACQCOvIJ2ebGY/pln/2yuLFi6Pa6NGj3f7emQQvvfTS4a9YAkII8bXyHaAjt23vdgBtodQl\n9t4ZON6ZK5LUv3//qOY9/OXCCy90+z/55JNRbdmyZW7bSrv0v5xtmz10AEgEgQ4AiSDQASARBDoA\nJIJL/zvByJEjo1qpAVCgufY6keFQbh9RakBy586dUc0bQC1173Wvf6n7vyPGHjoAJIJAB4BEEOgA\nkAgCHQASQaADQCI4y6UTDBw4sKx2r776qltfsWJFW64OcMhK3X6grq4uqjU2Nka1ESNGuP29bX7V\nqlVuW++WBJX+MAz20AEgEQQ6ACSCQAeARBDoAJAIBkU7wdSpU8tqt379erfuPbEd6Eg1NTVufdy4\ncVHt5ptvjmr79u1z+7/11ltR7ZVXXnHbeoOipVTKYCl76ACQCAIdABJBoANAIgh0AEgEgQ4AieAs\nl3Y0fPhwt37OOeeU1b9Xr15u/bzzzotqpW4TsGPHjrKWBZTinU0yceJEt+306dOjWr9+/aLa9u3b\n3f7V1dVRbdiwYW7bd999N6pt2rTJbVsptwlgDx0AEkGgA0AiCHQASASBDgCJYFC0HfXs2dOtDxgw\noKz+pQZPX3jhhag2b948t+19990X1d58882ylg9IUo8ePaLa+eef77b1Bh+7d49jpqqqyu1/ySWX\nRLVSA/vecwVK3Tvdq+/evdtt25Wxhw4AiSDQASARBDoAJIJAB4BEMCjajrZu3erWlyxZEtXOPPPM\nqNa3b9+yl3Xttde69YaGhqh2+eWXu2137txZ9vJQ2ZYtW+bWd+3aFdVGjx4d1YYOHer2954BcMYZ\nZ7htL7744qi2ePFit+3mzZuj2gcffOC2DSG49a6APXQASASBDgCJINABIBEEOgAkgkAHgERYR47o\nmlnXHT5uZyeffHJUGzlypNv2lltuiWpnn3122ctavny5W7/sssui2vvvv1/2fI8EIQTrjOVW2rZ9\n3HHHuXXvzKyzzjorqnmX+EvStm3bolrv3r3dtiNGjIhqpbbXmTNnRrW3337bbbt371633tnK2bbZ\nQweARBDoAJAIAh0AEkGgA0AiGBTtgryHRy9dutRt6w0clXLSSSdFNe9BvEcyBkU7hpn/Y/bqNTU1\nUc3b1iRpyJAhUe3SSy912zY2NkY17x7pkjR79uyoNmfOHLetd5uAI+F2AAyKAkAFIdABIBEEOgAk\ngkAHgEQQ6ACQCB5w0QVt2bIlqn388cedsCaoVKXO+vDq3kMvPvzwQ7e/dwuMiy66yG3rnRGzf/9+\nt+1rr73m1lPDHjoAJIJAB4BEEOgAkAgCHQASwaDoEayurs6t33DDDVHtlFNOKXu+K1eudOs7duwo\nex6Ax7v0v7q6OqrV19e7/b3L+UvdD72qqiqqbd261W1bW1sb1bx7r0tHxmX+rcUeOgAkgkAHgEQQ\n6ACQCAIdABJBoANAIjjL5Qgxfvz4qHbXXXe5bb2nqJfindHiLUuSNm7cWPZ8Udm6d/ejY9CgQVHN\ne0DF1KlT3f6DBw+Oaj179nTb7tu3L6rNmzfPbfvMM89Etb1797ptuzL20AEgEQQ6ACSCQAeARBDo\nAJAI68jLXCvtyeiTJ0926/fcc09U69OnT1TzLpku5YknnnDrd9xxR1RbsWJF2fPtasp5Mnp7SGHb\n9i7bl/xL7L17kUvSlVdeGdWmTZsW1fr37+/279Yt3sfcs2eP2/bZZ5+Natdcc43bdvv27W69Kyln\n22YPHQASQaADQCIIdABIBIEOAIkg0AEgEVz630YmTZoU1R5++GG3bY8ePQ5rWbNmzYpqd999t9vW\nuzwa8M5cKXU5f0NDQ1SbMmWK29Y7y2TgwIFR7cCBA27/TZs2RbW5c+e6be+9996olsLZLIeDPXQA\nSASBDgCJINABIBEEOgAkgkHRNuLd8/lwBz+9WwRI0syZM6Pa/v37D2tZSFOpy/m9QUnvsnvJH9Qc\nNWqU27a2tjaqeZfuL1u2zO3/4IMPRrWnn37abbtr1y63XsnYQweARBDoAJAIAh0AEkGgA0AiCHQA\nSAQPuGgjN910U1R76KGH3LbvvfdeVBs7dmxUK/UgilKXTYMHXByOUmfE1NTURLWhQ4e6bb3bBHhn\nuSxfvtztv2HDhqjG9p7hARcAUEEIdABIBIEOAIkg0AEgEQyKIikMiiJVDIoCQAUh0AEgEQQ6ACSC\nQAeARBDoAJAIAh0AEkGgA0AiCHQASASBDgCJINABIBEEOgAkgkAHgEQQ6ACQCAIdABJBoANAIgh0\nAEgEgQ4AiSDQASARBDoAJIJAB4BEEOgAkAgLgYeVA0AK2EMHgEQQ6ACQCAIdABJBoANAIgh0AEgE\ngQ4AiSDQASARBDoAJIJAB4BEEOgAkAgCHQASQaADQCIIdABIBIEOAIkg0AEgEQQ6ACSCQAeARBDo\nAJAIAh0AEkGgA0AiCHQASASBDgCJINABIBH/Dzd5oszXSBIoAAAAAElFTkSuQmCC\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Plot the original and the decoded image\n", "img1 = orig_image.reshape(28,28)\n", "text1 = 'Original image'\n", "\n", "img2 = decoded_image.reshape(28,28)\n", "text2 = 'Decoded image'\n", "\n", "plot_image_pair(img1, text1, img2, text2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We have shown how to encode and decode an input. In this section we will explore how we can compare one to another and also show how to extract an encoded input for a given input. For visualizing high dimension data in 2D, [t-SNE](http://scikit-learn.org/stable/modules/generated/sklearn.manifold.TSNE.html) is probably one of the best methods. However, it typically requires relatively low-dimensional data. So a good strategy for visualizing similarity relationships in high-dimensional data is to encode data into a low-dimensional space (e.g. 32 dimensional) using an autoencoder first, extract the encoding of the input data followed by using t-SNE for mapping the compressed data to a 2D plane. \n", "\n", "We will use the deep autoencoder outputs to:\n", "- Compare two images and\n", "- Show how we can retrieve an encoded (compressed) data. \n", "\n", "First we need to read some image data along with their labels. " ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# Read some data to run get the image data and the corresponding labels\n", "num_label_classes = 10\n", "reader_viz = create_reader(test_file, False, input_dim, num_label_classes)\n", "\n", "image = C.input_variable(input_dim)\n", "image_label = C.input_variable(num_label_classes)\n", "\n", "viz_minibatch_size = 50\n", "\n", "viz_input_map = { \n", " image : reader_viz.streams.features, \n", " image_label : reader_viz.streams.labels_viz \n", "} \n", " \n", "viz_data = reader_eval.next_minibatch(viz_minibatch_size,\n", " input_map = viz_input_map)\n", "\n", "img_data = viz_data[image].asarray()\n", "imglabel_raw = viz_data[image_label].asarray()" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1: [7, 24, 39, 44, 46]\n", "3: [1, 13, 18, 26, 37, 40, 43]\n", "9: [8, 12, 23, 28, 42, 49]\n" ] } ], "source": [ "# Map the image labels into indices in minibatch array\n", "img_labels = [np.argmax(imglabel_raw[i,:,:]) for i in range(0, imglabel_raw.shape[0])] \n", " \n", "from collections import defaultdict\n", "label_dict=defaultdict(list)\n", "for img_idx, img_label, in enumerate(img_labels):\n", " label_dict[img_label].append(img_idx) \n", " \n", "# Print indices corresponding to 3 digits\n", "randIdx = [1, 3, 9]\n", "for i in randIdx:\n", " print(\"{0}: {1}\".format(i, label_dict[i]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will [compute cosine distance](https://en.wikipedia.org/wiki/Cosine_similarity) between two images using `scipy`. " ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from scipy import spatial\n", "\n", "def image_pair_cosine_distance(img1, img2):\n", " if img1.size != img2.size:\n", " raise ValueError(\"Two images need to be of same dimension\")\n", " return 1 - spatial.distance.cosine(img1, img2)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Distance between two original image: 0.294\n", "Distance between two decoded image: 0.351\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAADHCAYAAAAJSqg8AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4wLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvpW3flQAADvNJREFUeJzt3X2MVcUdxvFnKi5bxLK0EEOxvi5i\nVZBU5cUYa1OqiBJAGzS2KVU3KF3BNIWsoG15SSOt1bUWdPuimNTWCg0V321XJIJSSRoo2ETEvqBY\nlAhVdy2C7E7/OJf0lpnD3ve7+9vvJ9kEfnfOOXPYuQ+zO+ec67z3AgD0fJ+odgcAAKVBoAOAEQQ6\nABhBoAOAEQQ6ABhBoAOAEQT6YZxz851zvyx12xz25Z1z9SmvPe2cm16K46D3Ymzb5yxfh+6c+6ak\n70g6VdIHkn4vaZ73/r1q9ivGOeclDfPev17tvhTCOVcj6TeSzpV0oqQvee/XVrVThjG2K8c5N1bS\nYknnSOqQtFbSbO/9rmr2K8bsDN059x1JP5Q0V9IASWOVBM0fM+ET26ZP5Xpo0npJX5f0drU7Yhlj\nu+IGSvq5pJOU/Du3SVpezQ6l8t6b+5L0KUntkqYdVu8vabek6zJ/XyDpd5IeUjLLacjUHsra5huS\ndkjaI+m7kv4paXzW9g9l/nySJC9puqQ3JL0r6das/YyWtEHSe5J2SVoqqSbrdS+pPuV81kpqyPz5\nm5JelNSc2dffJZ2fqb+ZOb/pWdteJmlT5vzelLTgsH0f6fw+IekWSX/LvL5C0qdz+PffKemiao8D\ni1+M7eqO7cy2X5DUVu2xEPuyOkM/X1KtpFXZRe99u6SnJX0lqzxZycCvk/Tr7PbOuTMk3Svpa5KG\nKJkNDe3i2BdIGi7py5K+55z7fKbeIenbkgZJGpd5/Vt5ntchYyRtkfQZJb/m+K2k8yTVK5khL3XO\n9c+0/VDJwK5T8gaY6ZybkuP5zZY0RdIXJX1W0r8lLSuwzygNxnb1x/aFkv5a2OmVl9VAHyTpXe/9\nwchruzKvH7LBe/+o977Te7/vsLZflfS493699/6ApO8pmW0cyULv/T7v/V8k/UXS2ZLkvf+z9/5P\n3vuD3vt/SvqZksFUiH9475d77zskPSLpc5IWee/3e+//IOmAkjeAvPdrvfdbM+e3RdLDWcft6vxu\nUDIT2+m9369k1vZVfnyvKsZ2Fce2c25kZl9zCzy/srL6xnxX0iDnXJ/IwB+Sef2QN4+wn89mv+69\n/49zbk8Xx87+/fF/lPwoLOfcaZLuUrJo2E/Jv/2fu9hXmney/rwv07fDa4eOO0bSEklnSaqR1FfS\nyky7rs7vREm/d851ZtU6JB0n6a0C+47iMLarNLYzV+o8Lelm7/26vM+sAqzO0DdI2i/piuyic+4Y\nSZdKei6rfKRZyS5Jx2dt/0klPwoW4j5JrypZ7f+UpPmSXIH7ysdvJD0m6XPe+wGSWrKO29X5vSnp\nUu99XdZXrfeeMK8exvb/VGxsO+dOlNQqabH3/ldlOJeSMBno3vv3JS2U9FPn3ATn3NHOuZOU/O+9\nU1Ku35DfSZrknDs/c/XAQhU+UI9VsnjT7pw7XdLMAvdTyHH3eu8/cs6NlnRN1mtdnV+LpB9kBrOc\nc4Odc5PTDuSc6+ucq838tcY5V+ucq8Qbu9dgbAfHLfvYds4NlbRG0jLvfUs5TqRUTAa6JHnvf6Rk\npvBjJYPtZSX/K3858zuzXPbxV0mzlCzM7FJyudJuJTOkfM1RMuDaJP1Cye8HK+FbkhY559qU/O5v\nxaEXcji/nyiZAf0hs/2flCxapdmm5EfioZKezfz5xFKeDBjbWSo1thsknSLp+8659kNfZTifopm+\nsajUMqvr7yn50fIf1e5PqVk/P6Sz/r23fn6HmJ2hl4pzbpJzrl/md5Q/lrRVyfWsJlg/P6Sz/r23\nfn4xBHrXJkv6V+ZrmKSrva0fa6yfH9JZ/95bP78Av3IBACOYoQOAEQQ6ABhR0TtFM4/RBMrGe1+V\n694Z2yi3XMY2M3QAMIJABwAjCHQAMIJABwAjCHQAMIJABwAjCHQAMIJABwAjCHQAMIJABwAjCHQA\nMIJABwAjCHQAMIJABwAjKvr4XJTPqaeeGq3PmzcvqF1zzTXRtuPHjw9qL730UnEdQ49RV1cXra9Z\nsyaoHXPMMdG2w4cPL2mfkB9m6ABgBIEOAEYQ6ABgBIEOAEYQ6ABgBFe59EDHH398UHvqqaeibevr\n64NaR0dHtO3BgweL6xh6jIEDBwa11tbWaNuzzz47qG3fvr3kfULxmKEDgBEEOgAYQaADgBEEOgAY\nwaJoD3T99dcHtdjiZ5rly5dH6xs3biy4T+ieYoufUnwBdNSoUdG2nZ2dQe3xxx8vrmMoC2boAGAE\ngQ4ARhDoAGAEgQ4ARhDoAGCE895X7mDOVe5gBpx77rnR+gsvvBDU+vbtG20b+4CKiy++ONp23759\nefSue/Leu2oct7uO7SVLlkTrc+fOzXkfLS0tQa2xsbHgPqEwuYxtZugAYASBDgBGEOgAYASBDgBG\ncOt/N3bllVdG67W1tUEt7bb9yZMnBzULi58IDRo0KKhNmDAh5+3ff//9aP2ee+4puE+oLGboAGAE\ngQ4ARhDoAGAEgQ4ARhDoAGAEV7l0Ew0NDUGtqakp2ratrS2oTZs2Ldp27969xXUMPcZzzz0X1M46\n66yct3/44Yej9W3bthXcJ1QWM3QAMIJABwAjCHQAMIJABwAjWBStgtizy2O3+ac9q/6WW24Jam+8\n8UbxHUOPNmLEiKCWNoba29uDWnNzc8n7hMpihg4ARhDoAGAEgQ4ARhDoAGAEgQ4ARri0VfCyHKyb\nfjJ6pcU+dGLVqlVBrbW1Nbr9JZdcUvI+WZHLJ6OXQ3cY27H3cmdnZ7Rt7JEQgwcPLnmfSuG0006L\n1vv371+xPmzdujVa//jjjyvWh1zGNjN0ADCCQAcAIwh0ADCCQAcAI7j1v4yef/75aH3Dhg1Bbfv2\n7UFt5syZJe8T7KrkBQ7FGj9+fLR+8803B7Vx48ZF2w4cOLCkfTqSNWvWROvr1q0Lag8++GC0bSUe\nz8EMHQCMINABwAgCHQCMINABwAgCHQCM4Nb/Ehk5cmRQe/HFF6Nt+/XrF9SuuOKKoLZ69eriO9bL\n9OZb/2O3+ae9vyt56/+xxx4b1J599tlo2zFjxuS8302bNgW1tra2aNtXXnklqO3ZsyfadtSoUUEt\n7XEbNTU1QW3Hjh3RtrGrXxYtWhRtG8Ot/wDQixDoAGAEgQ4ARhDoAGAEi6Ilsm3btqA2bNiwaNvY\ngtDUqVOD2kcffVR0v04//fSglrZw9NZbbxV9vGpjUfT/pb2/m5ubg9qcOXNK3idJuvvuu4ParFmz\nct5+5cqV0fqMGTOC2gcffJB7x/Jw0003ReuNjY1BLe357TFHHXVUzm1ZFAWAXoRABwAjCHQAMIJA\nBwAjeB56icQWQNMWpO67776gFlsArauri25/2223BbWJEydG2w4dOjSovf3229G2sWdRP/PMM9G2\n6NnSnkdeDieccEJR2y9dujRaL9cCaD596NMnjNA777yz3N1JxQwdAIwg0AHACAIdAIwg0AHACAId\nAIzgKpc8XXDBBTm3PXDgQLSedpXJ4ZqamqL1/v37B7XNmzdH2w4fPjyo1dfXR9vGrr45+eSTj9RF\n9FCxq59QXrHHg5QaM3QAMIJABwAjCHQAMIJABwAjWBTN0+23355z29bW1mh948aNOW0/b968nI8V\nWyiVpHHjxgW1tFuxY/XLL7882vaJJ57IuW/ofmIfbpy2WP7666+XuztHdMMNN0Tr69evr3BPirNw\n4cKyH4MZOgAYQaADgBEEOgAYQaADgBEEOgAYwVUuZfToo49W7Fi1tbXRej4fLvDaa68FNa5m6Tlm\nzpwZ1O69995o29hVUWlXjcyaNSuorVy5Mud+NTY2BrVNmzZF2954441BbcqUKdG2Tz75ZFC74447\nom3Xrl17hB4WbsSIETm3HTJkSFn6kI0ZOgAYQaADgBEEOgAYQaADgBEsipaIcy6oDRs2rAo9+X+x\nfqVZtWpVGXuCcrv//vuDWmyRUZJGjhwZ1AYPHhxte+uttwa1fBZFd+3aFdQWL14cbdve3h7U5syZ\nE207YcKEoHbhhRdG2zY0NAS1Rx55JNo25owzzojWp02blvM+duzYkXPbQjFDBwAjCHQAMIJABwAj\nCHQAMIJABwAjuMqlRLz3QW306NHRtldffXVQW7FiRVDr7OyMbn/00UcHtbFjx+bcr46Ojmjb1atX\nR+voGQ4ePBjUJk6cGG27c+fOnPcbu8Ij7ZECd911V1DL5wMympubg9q6deuibadPnx7UTjnllGjb\nBx54IKhde+210bax98H8+fOjbfv16xfUrrvuumjbSjwKhBk6ABhBoAOAEQQ6ABhBoAOAES62aFa2\ngzlXuYOVyYIFC6L12AJNPs8ij92GnHZr8qRJk4Ja2gJPzLJly6L12bNn57yP7sp7n/uzDkqou47t\ntEc/XHXVVUGtqakp2jb2mIA0H374YVCL3fIee0xBKdTU1ETrsQsUpk6dmvN+t2zZEq1fdtllQS32\nqAMpfoFCPnIZ28zQAcAIAh0AjCDQAcAIAh0AjCDQAcAIrnLJU21tbbR+0UUXBbVFixZF255zzjlF\n9SF25ULa9zF2i3faIwneeeedovrVHXCVS+GmTJkSrY8ZMyaoxa7ukKQzzzyzpH0qlVdffTWoPfbY\nY9G2mzdvDmppt+3v37+/uI7lgatcAKAXIdABwAgCHQCMINABwAgWRcso9txySTrvvPOCWuw50gMG\nDIhuv3v37qC2ZMmSaNuXX345qO3duzfa1gIWRSujT5/4Rykcd9xxQW3GjBnl7k6XWlpaglraLfrd\nFYuiANCLEOgAYASBDgBGEOgAYASBDgBGcJULTOEqF1jFVS4A0IsQ6ABgBIEOAEYQ6ABgBIEOAEYQ\n6ABgBIEOAEYQ6ABgBIEOAEYQ6ABgBIEOAEYQ6ABgBIEOAEYQ6ABgBIEOAEYQ6ABgBIEOAEYQ6ABg\nBIEOAEYQ6ABgBIEOAEYQ6ABgBIEOAEYQ6ABgBIEOAEYQ6ABghPPeV7sPAIASYIYOAEYQ6ABgBIEO\nAEYQ6ABgBIEOAEYQ6ABgBIEOAEYQ6ABgBIEOAEYQ6ABgBIEOAEYQ6ABgBIEOAEYQ6ABgBIEOAEYQ\n6ABgBIEOAEYQ6ABgBIEOAEYQ6ABgBIEOAEYQ6ABgBIEOAEb8F78K5L/Q0oaqAAAAAElFTkSuQmCC\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAADHCAYAAAAJSqg8AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4wLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvpW3flQAAE1hJREFUeJzt3XuM1WV+x/HPl8vIZQYQBrkzhJuI\nAl5Q1GqKiqyUNYtoY8ofuyFN2y223Vo1drNpN002URupq3/sJRZLlI0G1C6srrqtNOuy60ZEbl4q\nt1EBQW4iKAMCPv3jd6Y5y/N94AycufDwfiUkw/d8f+f8zpnnfOd3znOzEIIAAGe/Tu19AgCA6qCg\nA0AmKOgAkAkKOgBkgoIOAJmgoANAJijorcDMpprZttY41sw+N7ORp392wOmjbXdsWRV0M/vAzJrM\n7KCZ7Tez35nZt80sm+cZQqgNIWxp7/M4GTO7xMxeMbM9ZsZEhyqgbXcMZvYtM1tlZgfMbJuZ/auZ\ndWnv82qWTWMoc2sIoU5Sg6QHJd0vaUH7ntI556ikxZL+vL1PJDO07fbXQ9LfS6qXNEXSTZLubdcz\nKpNjQZckhRA+CyEsk3SnpG+Z2SWSZGbnmdnDZvaRmX1iZj8xs+7Nx5nZN8xsTekv8GYzu6UUH2xm\ny8xsn5ltMrO/KDumu5ktNLNPzexdSVeWn0vp2OfMbLeZNZrZ31V67InMLJjZ6NLPC83sR2b2Uunj\n6m/NbKCZ/bB0f/9rZpeVHfuPped00MzeNbPbym7rbGbzS1fVjWb2N6XH6lK6vbeZLTCzHWa23cx+\nYGadE6/9+yGEBZLeOeUvCi1G227Xtv3jEMJvQghfhhC2S/qZpD861e+szYQQsvkn6QNJ05z4R5L+\nuvTzDyUtk9RXUp2kX0h6oHTbVZI+k3Szij92QySNK932a0k/ktRN0qWSdku6qXTbg5J+U7rPYZLe\nlrStdFsnSask/bOkGkkjJW2R9LVTHZt4jkHS6NLPCyXtkXRF6byWS2qU9E1JnSX9QNL/lB37p5IG\nl87pTklfSBpUuu3bkt6VNFTS+ZL+u/RYXUq3/1zSTyX1lHSBpDck/dUpfh+jiybW/m3jbP9H2+5Y\nbbvscX8u6cH2bh//fz7tfQJt1Oh/L+l7kqz0ix5Vdts1khpLP/9U0iPO8cMkHZdUVxZ7QNLC0s9b\nJN1SdttfljX6KZI+OuH+vivpP051bOI5ntjoHy+77W8lvVf2/wmS9p/kvtZI+kbp5+XljVjStOZG\nL2mApCOSupfd/mflb6jE/VPQq/SPtt2x2nYpb66kbZLq27t9NP/rMF/mt7IhkvZJ6q/iO7BVZtZ8\nm6n4iy8VjfuXzvGDJe0LIRwsi30oaXLZ7VtPuK1Zg6TBZra/LNZZxZXLqY6txCdlPzc5/69t/o+Z\nfVPSP0gaUQrVqvgu0DuP8p8bJHWVtKPsdet0Qg7aB21bbd+2zWyWik8g00IIe079VNpG9gXdzK5U\n0ehXqPgI1yTp4lB8/3WirZJGOfGPJfU1s7qyhj9cUvN97FDxhnmn7Lby+2wMIYxJnOLJjq0aM2uQ\n9LiKTpzXQwjHzWyNijd983kMLTtkWNnPW1VcxdSHEI61xvmh5WjbhbZu26W+h8clzQwhrD/T86+m\nbDtFzayXmX1d0jOSFoUQ1ocQvlLxi3jEzC4o5Q0xs6+VDlsgaa6Z3WRmnUq3jQshbJX0O0kPmFk3\nM5uoYgTHz0rHLZb0XTM738yGqvh42OwNSQfM7P5SJ1FnK4b1XVnBsdXUU8XHzN2l5z1X0iVlty+W\n9J3Sc+6jYgSFJCmEsEPSryTNL72uncxslJn9sfdAVuim4ntVlV6z81rlWZ2DaNuRtmzbN6p4bW4P\nIbzROk/n9OVY0H9hZgdV/OX9nqR/U/FdV7P7JW2S9HszO6Cig+RCSSr9guZKekRFB9KvVXwkk4rv\n1UaouKL5T0nfDyH8V+m2f1HxcbJRReN4qvnBQgjHJd2qorOpUcWV1L9L6n2qY6sphPCupPmSXlfx\n0XWCpN+WpTxeevx1klar+Hh+TMX3q1LRGVWjonPpU0nPShqUeLgGFVeLzVdmTZLer9JTOZfRth1t\n3Lb/ScXz+2Vp9M3nZvZSVZ/QGbDSl/vAHzCzGZJ+EkJoOGUycBbJuW3neIWO01D6yPwnZtbFzIZI\n+r6KqzXgrHYutW2u0CFJMrMeKj6Gj1PxFcmLkr4TQjjQricGnKFzqW1T0AEgE3zlAgCZoKADQCba\ndGKRsZQqWlkIwU6dVX20bbS2Sto2V+gAkAkKOgBkgoIOAJnIfnEuAJXp1Mm/vvOGNjPcuWPiCh0A\nMkFBB4BMUNABIBMUdADIBAUdADLBKJdMlO2H+Ae6dIl/xd27d3dzm5qaotjRo0fP7MTQrjp37uzG\nu3XrFsVS7cIb/ZJqF5W2odQoma+++sqNozJcoQNAJijoAJAJCjoAZIKCDgCZoFO0SrzOx1RHpdfx\n4+Wmjq+pqYliEyZMcHNvvfXWKFZfX+/mNjY2RrFFixa5uTt37oxiqQ4tOrqqK9UuvI7OlOuuuy6K\njRkzxs1taIj3Uu7bt6+bu3jx4ii2cePGKLZ9+3b3+CNHjrhxVIYrdADIBAUdADJBQQeATFDQASAT\nFHQAyASjXFqoJSNXvJEvqfvwRq6MHDnSPX7evHlR7IorrnBz+/fvH8XWrVvn5r711ltRLDV13HsO\njGZpG6nRLMeOHYtiqZEr3kinm2++2c31pv577VWSunbtGsVoK22HK3QAyAQFHQAyQUEHgExQ0AEg\nE3SKtlBqHWcv/uWXX57RY/Xu3duNjxs3LoqNHz/ezd29e3cU27Nnj5u7dOnSKLZ37143l3XS24bX\noZhqVz169Kg4d8iQIVEsNZ3fWyf9ww8/dHP79OkTxbzp/KnBBV489Z5DjCt0AMgEBR0AMkFBB4BM\nUNABIBMUdADIBKNcOgivJ3/EiBFurhdPjTo5cOBAFHv22Wfd3H379lV8v6iu1KgPb+mF2tpaN3fo\n0KFR7LbbbnNz586dG8VSo0m8UVGvvPKKm7t8+fKKjveWKUhJvTaeloyISd2vdx+tdQ7VxhU6AGSC\ngg4AmaCgA0AmKOgAkAk6RTsIb9r2lClT3FxvzekvvvjCzfU6r1auXOnmnulSBTh9LelI69mzpxv3\nOkBnz57t5nrT/Ddu3OjmLly4MIq9+OKLbu4nn3wSxVqy9rnX+ei1d6llr5m3pvvx48crvt/UvgDe\nfaTOK/V41cQVOgBkgoIOAJmgoANAJijoAJAJCjoAZIJRLu3A27X9nnvuiWLTp093j/c2rXjiiSfc\n3KeeeiqKecsBoGPyRn1MnDjRzZ01a1YU69Wrl5t76NChKPbyyy+7uUuWLIli3jIRUuUjWlJT6c87\n77wo1qWLX6a8jTdSSwrU1dVFsdRoFO+1aWpqcnNramrOKLfaI8u4QgeATFDQASATFHQAyAQFHQAy\nQadoK0pNWb766quj2O233x7FNm/e7B7/zDPPRLHnnnvOzfU6aNhF/ezhtaGRI0e6uQMHDoxiqU5R\nr22tX7/ezfU60Vsynd/jTcVPxQcMGODmevHUFP3zzz8/iqVem127dkWxrVu3urlePPW+b4ulNbhC\nB4BMUNABIBMUdADIBAUdADJBQQeATDDKpUq83nlv5Iok3XfffRXd58MPP+zGV6xYEcWOHj1a0X2i\nY0pNhffiw4YNc3O90RXeVPpUfO3atW5uajp9pVqyaYU3GsVb0kDyN4Xp16+fm+ttAJN6Xt7mH0eO\nHHFzvfvYtm2bm8sGFwCAilHQASATFHQAyAQFHQAyQadolXjTsefNm+fmelO0vTWn161b5x5/pp1U\n6HhSyzF4ne3eevqS36GYul9vGnprdax765mPGDHCzZ0xY0YUSw0u8DpbV69e7eZ667dv2rTJzfVe\nx969e7u5DQ0NUeyll15yc7111quNK3QAyAQFHQAyQUEHgExQ0AEgExR0AMgEo1xaKLUw/8yZM6PY\npZde6uZ+9tlnUWzRokVR7ODBg+7xqWnintT5erwREWe6kQHOjDddfOLEiRUf7015l/wRVKkNGLz2\n5o1ckfwRON4IsEcffdQ9fuzYsVEsNfrG2+jlhRdecHP3798fxVKjxaZOnRrFUqNyvJErNTU1bq73\nOlZ7sxmu0AEgExR0AMgEBR0AMkFBB4BM0CnaQrW1tW58zpw5USzVcdTY2BjFdu7cGcVSO5h7nS6p\n9aUHDx5cca7X+bRhw4aKc1F9XqdoU1OTm+t1oqc6RXfs2BHFUm3bm/Y+fPhwN/euu+6KYjfeeGMU\nSy1f4LXN559/3s31BhJs2bLFzfU6QHv16uXmbt26teLcSZMmRbHx48e7uZ9++mkUS/1+ThdX6ACQ\nCQo6AGSCgg4AmaCgA0Am6BRtoXHjxrlxb+PeVMeh1yHlbdrbv39/93iv8+ryyy93c++8884oNnTo\nUDfX65i9++673dz3338/ilV71hv8zrzUOvlXXXVVFEt1oHqbKV900UVu7pQpU6LYDTfc4OZ665nX\n19dHscOHD7vHv/3221Fs4cKFbu7mzZujWGozZ69tejO2Jemdd96JYqlZ3977/vrrr3dzV65cGcVS\nv5/TnaHNFToAZIKCDgCZoKADQCYo6ACQCQo6AGSCUS4n4a0lfsstt7i53bt3j2IHDhxwc73eeW+n\n8dTSAd7omdQa6QMGDIhiqVEuXo/9rFmz3NyHHnrIjeP0pH5/3uiMDz74wM312ltqtITXhvr27evm\nfvzxx1Fs165dbq43+slbwiI1Rf/JJ5+MYps2bXJzvfXbWzI6JJXrxb1RaJK/DEdqWQMvt9r7DXCF\nDgCZoKADQCYo6ACQCQo6AGSCTtGT8DpzUusipzbY9XhrIHvT/Ddu3Ogev3v37oruU5L27t0bxS67\n7DI315uOnZqa7HUYe2t3ozKpZRO813nNmjVurjdlva6uzs312nHPnj3dXG/9fm9tb0l67733oliq\nk9CzfPnyKJbaF6AlS014nc6pDdS9eGoZDu+9uG/fPjfXW5YgdQ5M/QeAcxwFHQAyQUEHgExQ0AEg\nExR0AMgEo1xOwuudv+CCC9zcQ4cORTFvir8krV+/PoqtWrUqiqV6yz179uxx497yAanNBbwRLd5m\nHFL1pyyjct6u9JK0YcOGKJbagd77XU+YMMHN9abpe6NZJH+TFO894y1JIUkjRoyIYqnRU97yBakN\nLrp27RrFvOU2JOnCCy+MYpMnT3ZzvSU/UqPTUss7VBNX6ACQCQo6AGSCgg4AmaCgA0Am6BQ9Ca9T\ndO3atW7uoEGDopg3jVnyO6+8qdSpziCvg+eaa65xc73dylNrO3udba+99pqb25Jp1zh93ut88OBB\nN9frWO/Ro4eb662J73XwSdKMGTOiWGoqfGoJihOlltDwHstbY12Samtro9jq1avdXG8tci8mSf36\n9YtiQ4YMcXO9AQ6p9+3nn3/uxj2n24HKFToAZIKCDgCZoKADQCYo6ACQCQo6AGSCUS4n4W1a4e2s\nLvmbA4waNcrN9XrRvSn23tRmSRo9enQUe+yxx9xcb3pz6n6ffvrpKJba3R1tw1tiwVtmQvKn4197\n7bVurjfKZfDgwW6u145To1S8EVTeqI9JkyZVfF7Tpk1zc72RYVOnTnVzven43uYUkjRw4MAolhp1\n4r0/li1b5ua2ZFOY0x1FxhU6AGSCgg4AmaCgA0AmKOgAkAk6RU/CW1s5NbXYm2I/ceJEN/fee++N\nYt5O7hdffLF7/PTp06NY79693dxjx45FsSVLlri58+fPj2Kse97xpDrSvCUlFi1a5OZ6nZep5SO8\nNpDqFO3bt28U86bopzpgvdxUh2RdXV0US3Xie8tlvPnmm26uNxhi6dKlbq73vvM6diV/zwLvsc4E\nV+gAkAkKOgBkgoIOAJmgoANAJijoAJAJa8uNCszsrNoVwetd93Ywl6TZs2dHsTlz5ri5DQ0NUaxz\n585RrE+fPu7x3sYbqen8r7/+ehSbOXOmm5uaUn42CSG0/tbqjo7atr3RHZLfhlIjT7zlI7zlKyR/\n4wtvtNbYsWPd473c1MYbO3fujGKppTm8jShWrFjh5nqbdKRGFnmjarxlBiT//ZW6X08lbZsrdADI\nBAUdADJBQQeATFDQASATdIq2kLemseRPhZ48ebKbe8cdd0SxMWPGRLHUDuredP4FCxa4ua+++moU\nO3z4sJubAzpFq88bHJB6H3TpEq8m4u0VUFNT4x4/fPjwKDZo0CA3t76+Popt2rTJzd22bVsUS027\n9zoqvceSpO3bt0expqYmN9eLt6T+0ikKAOcQCjoAZIKCDgCZoKADQCYo6ACQCUa5tKLUwvzeSABv\n6n9qWrAXZyOKAqNc8uO9NyR/WYPUUgctec94y2ikRuV49TO1DIf3eIxyAQC4KOgAkAkKOgBkgoIO\nAJmgUxRZoVP03JFafqC9pTpbvUESdIoCAFwUdADIBAUdADJBQQeATFDQASAT8Rx0ADgLnG3LXbTF\niEKu0AEgExR0AMgEBR0AMkFBB4BMUNABIBMUdADIBAUdADJBQQeATFDQASATFHQAyAQFHQAyQUEH\ngExQ0AEgExR0AMgEBR0AMmFtsUYvAKD1cYUOAJmgoANAJijoAJAJCjoAZIKCDgCZoKADQCYo6ACQ\nCQo6AGSCgg4AmaCgA0AmKOgAkAkKOgBkgoIOAJmgoANAJijoAJAJCjoAZIKCDgCZoKADQCYo6ACQ\nCQo6AGSCgg4AmaCgA0AmKOgAkIn/A3kJUavAVR8iAAAAAElFTkSuQmCC\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Let s compute the distance between two images of the same number\n", "digit_of_interest = 6\n", "\n", "digit_index_list = label_dict[digit_of_interest]\n", "\n", "if len(digit_index_list) < 2:\n", " print(\"Need at least two images to compare\")\n", "else:\n", " imgA = img_data[digit_index_list[0],:,:][0] \n", " imgB = img_data[digit_index_list[1],:,:][0]\n", " \n", " # Print distance between original image\n", " imgA_B_dist = image_pair_cosine_distance(imgA, imgB)\n", " print(\"Distance between two original image: {0:.3f}\".format(imgA_B_dist))\n", " \n", " # Plot the two images\n", " img1 = imgA.reshape(28,28)\n", " text1 = 'Original image 1'\n", "\n", " img2 = imgB.reshape(28,28)\n", " text2 = 'Original image 2'\n", "\n", " plot_image_pair(img1, text1, img2, text2)\n", " \n", " # Decode the encoded stream \n", " imgA_decoded = model.eval([imgA])[0]\n", " imgB_decoded = model.eval([imgB]) [0] \n", " imgA_B_decoded_dist = image_pair_cosine_distance(imgA_decoded, imgB_decoded)\n", "\n", " # Print distance between original image\n", " print(\"Distance between two decoded image: {0:.3f}\".format(imgA_B_decoded_dist))\n", " \n", " # Plot the two images\n", " # Plot the original and the decoded image\n", " img1 = imgA_decoded.reshape(28,28)\n", " text1 = 'Decoded image 1'\n", "\n", " img2 = imgB_decoded.reshape(28,28)\n", " text2 = 'Decoded image 2'\n", "\n", " plot_image_pair(img1, text1, img2, text2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note: The cosine distance between the original images comparable to the distance between the corresponding decoded images. A value of 1 indicates high similarity between the images and 0 indicates no similarity.\n", "\n", "Let us now see how to get the encoded vector corresponding to an input image. This should have the dimension of the choke point in the network shown in the figure with the box labeled `E`." ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Length of the original image is 784 and the encoded image is 32\n", "\n", "The encoded image: \n", "[ 14.24417496 11.13341045 11.24246407 4.64616632 0. 6.89158678\n", " 23.79421425 18.19504166 17.70633888 0. 0. 28.18136215\n", " 13.94447613 17.40437126 16.58884048 7.5404644 14.78264236\n", " 20.94945335 5.16527224 19.49497986 12.03796673 19.87505722\n", " 13.01367664 8.0799036 6.24639368 0. 14.11477566\n", " 20.0975914 4.01841021 10.9685421 16.97727776 13.98702526]\n" ] } ], "source": [ "imgA = img_data[digit_index_list[0],:,:][0] \n", "imgA_encoded = encoded_model.eval([imgA])\n", "\n", "print(\"Length of the original image is {0:3d} and the encoded image is {1:3d}\".format(len(imgA), \n", " len(imgA_encoded[0])))\n", "print(\"\\nThe encoded image: \")\n", "print(imgA_encoded[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let us compare the distance between different digits." ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Distance between two original image: 0.376\n", "Distance between two decoded image: 0.424\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAADHCAYAAAAJSqg8AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4wLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvpW3flQAADvdJREFUeJzt3X+QlVUdx/HPQQJdfrT+GllWdGVR\nS7P8kUJMoyWKIqKZIE650SiawEgm4WiNlKRONojjCGr+mJxBSQVDdBhxxQaVNBmYZmNcCVIgoC0V\nwVakReD0x3OpG+fc3fv7wve+XzM7A997nuc5z92zn332nnOf67z3AgAc+LpVugMAgOIg0AHACAId\nAIwg0AHACAIdAIwg0AHACAJ9H865nzjnHi122yz25Z1zgzI89qJzblwxjoPqxdi2z1leh+6c+76k\nKZIaJf1L0gJJt3rvt1WyXzHOOS/peO/9Xyvdl3w453pImivpq5KOlfRN7/3SinbKMMZ2+Tjnhkj6\nhaQzJO2WtFTSZO99WyX7FWP2Ct05N0XS3ZKmSvq8pCFKgublVPjEtulevh6atEzSVZL+UemOWMbY\nLrtDJT0sqUHJ89wu6TeV7FBG3ntzX5L6SvpE0hX71HtLel/S1an//1zSfElPKLnKGZ+qPZG2zfck\nbZC0RdJtktZLOi9t+ydS/26Q5CWNk/Q3SR9K+mnafs6S9KakbZLaJM2S1CPtcS9pUIbzWSppfOrf\n35f0B0n3pvb1nqShqfrG1PmNS9t2pKQ/pc5vo6Sf77Pvzs6vm6RbJL2bevwZSYdl8fxvkvSNSo8D\ni1+M7cqO7dS2p0tqr/RYiH1ZvUIfKulgSb9LL3rvP5H0oqTz08qXKhn4tZKeTG/vnDtJ0gOSviup\nTsnVUH0Xx/66pBMlDZM0zTn3xVR9t6QfSTpC0tdSj0/M8bz2Gizpz5IOV/Iyx1OSzpQ0SMkV8izn\nXO9U2+1KBnatkh+ACc65b2V5fpMlfUvSOZL6S9oqaXaefUZxMLYrP7bPlvR2fqdXWlYD/QhJH3rv\nd0Uea0s9vteb3vvnvPd7vPc79mk7WtIL3vtl3vudkqYpudrozO3e+x3e+xZJLZK+Ikne+5Xe+z96\n73d579dL+rWSwZSPdd7733jvd0t6WtIASdO99x3e+2ZJO5X8AMh7v9R7vyp1fn+W9Nu043Z1fj9Q\nciW2yXvfoeSqbTR/vlcUY7uCY9s59+XUvqbmeX4lZfUH80NJRzjnukcGfl3q8b02drKf/umPe+8/\ndc5t6eLY6a8ff6rkT2E5506QNFPJpGGNkud+ZRf7yuSfaf/ekerbvrW9xx0s6ZeSviSph6Sekual\n2nV1fsdKWuCc25NW2y3pKEmb8+w7CsPYrtDYTq3UeVHSD733r+d8ZmVg9Qr9TUkdkr6dXnTO9ZI0\nQtIraeXOrkraJB2dtv0hSv4UzMeDklYrme3vK+knklye+8rFXEnPSxrgvf+8pIfSjtvV+W2UNMJ7\nX5v2dbD3njCvHMb2/5RtbDvnjpW0RNIvvPdzSnAuRWEy0L33H0u6XdL9zrkLnXOfc841KPntvUlS\ntt+Q+ZJGOeeGplYP3K78B2ofJZM3nzjnviBpQp77yee4H3nv/+2cO0vSd9Ie6+r8HpJ0Z2owyzl3\npHPu0kwHcs71dM4dnPpvD+fcwc65cvxgVw3GdnDcko9t51y9pN9Lmu29f6gUJ1IsJgNdkrz3v1Jy\npTBDyWB7S8lv5WGp18yy2cfbkm5QMjHTpmS50vtKrpBy9WMlA65d0iNKXh8sh4mSpjvn2pW89vfM\n3geyOL/7lFwBNae2/6OSSatM/qLkT+J6SS+l/n1sMU8GjO005Rrb4yUNlPQz59wne79KcD4FM/3G\nomJLza5vU/Kn5bpK96fYrJ8fMrP+vbd+fnuZvUIvFufcKOdcTeo1yhmSVilZz2qC9fNDZta/99bP\nL4ZA79qlkv6e+jpe0pXe1p811s8PmVn/3ls/vwAvuQCAEVyhA4ARBDoAGFHWd4qmbqMJlIz3viLr\n3hnbKLVsxjZX6ABgBIEOAEYQ6ABgBIEOAEYQ6ABgBIEOAEYQ6ABgBIEOAEYQ6ABgBIEOAEYQ6ABg\nBIEOAEYQ6ABgBIEOAEYQ6ABgBIEOAEYQ6ABgBIEOAEYQ6ABgBIEOAEYQ6ABgBIEOAEYQ6ABgBIEO\nAEYQ6ABgRPdKdwCAbSeffHJQ6949++hpaWkpZndM4wodAIwg0AHACAIdAIwg0AHACAIdAIwwscql\nrq4uqDU2Nma9/QUXXBCtX3LJJXn3qRi6dYv/vl27dm1QmzlzZrTtpk2bgtr69esL6hdwyCGHROvX\nXHNNULvnnnuCWi6rXFatWhWte++z3kfMG2+8Ea3Pnz8/qK1YsSLatr29vaA+FBtX6ABgBIEOAEYQ\n6ABgBIEOAEa4QicWcjqYcyU52PLly4PaGWecUYpDlZVzLlrP5XvW2toa1EaMGBFtG5tAPdB47+NP\nWomVamzvD2IToAsWLIi2HT58eFArNGOK8XNQ6H5Xr14dbXveeecFtba2toL6lUk2Y5srdAAwgkAH\nACMIdAAwgkAHACMIdAAwwsQqlz179gS1cp5XqZRqdn/RokXReqVvdVAMrHLJ3+DBg6P1WbNmBbVc\nVpG99dZbQW3x4sVZb//KK69E68cdd1xQ27FjR7Tttm3bgtrll18ebRu7FUhDQ0O07Zw5c4LauHHj\nom0LxSoXAKgiBDoAGEGgA4ARBDoAGGHifuhNTU1BbezYsWXtw8KFC4Nac3NzQfs8++yzo/XZs2cH\ntT59+hR0LCDTJOHpp58e1DJNzMduw3HxxRcHtS1btuTYu9CyZcsK2n7JkiXR+iOPPBLUrr766mjb\n2HNTSVyhA4ARBDoAGEGgA4ARBDoAGEGgA4ARJla5PPnkk1nVDjTdupXm9+1HH31Ukv0CF110UVDb\nunVrBXrStUwf9HLFFVeUuSfFwxU6ABhBoAOAEQQ6ABhBoAOAESYmRQ80/fr1C2qTJ08Oarfcckt0\n+0Lvh37bbbcVtD1sWrduXcH7GDNmTFB7+OGHC95voQYOHBjUHn300Wjb3r17Z73flStX5t2nUuAK\nHQCMINABwAgCHQCMINABwAgCHQCMcIWumMjpYAfYJ6PX19cHtUwfOhFz7733Rus1NTVBrVevXkHN\nufiHfBf6Pbvxxhuj9Y0bNwa15557rqBjlVs2n4xeCgfa2M5F7ANVJkyYEG3b1tYW1IYNGxbUVq9e\nXXC/TjjhhKA2ZcqUaNtrr722oGMtWrQoWr/uuuuCWuw5KIZsxjZX6ABgBIEOAEYQ6ABgBIEOAEYw\nKSqpqakpWr/55puD2kknnVTq7vxXqSZFM+no6Ahq06ZNi7Ztbm4Oau+++2607fbt2wvrWA6YFC2+\nww8/PKi98MIL0bZDhgwJai0tLUFt6NCh0e3r6uqCWmzyU5LmzJkT1A477LBo25jYIgBJmjdvXlCb\nPn16tG17e3vWxysUk6IAUEUIdAAwgkAHACMIdAAwgklRZb7v+J133lnmnvy/ck+KFuqBBx6I1m+4\n4Yay9YFJ0fI49NBDo/XXXnstqMUWEqxZsya6fWxStG/fvtG2sZ+DLVu2RNvGxuZ9990Xbbu/fqg1\nk6IAUEUIdAAwgkAHACMIdAAwgkAHACO6V7oD+4NMq0ky1bP12GOPResbNmwIanfccUdBx8rkwgsv\nDGqjRo2Kth05cmRQO+aYY6JtY8/NpEmTom1bW1uD2oMPPhhtiwPDmDFjovVs33qf6e38MZs3b47W\nJ06cGNSWLl0abVvOt+hXElfoAGAEgQ4ARhDoAGAEgQ4ARvDWf0lHHnlktH7UUUcVtN933nknWt+9\ne3dB+y2VU089NagtXLgw2vboo4/Oer8zZ84MalOnTs2+Yzngrf/ZGT58eFAbP358tO3o0aNL3Z1O\n3XrrrdH63XffXeaeVBZv/QeAKkKgA4ARBDoAGEGgA4ARBDoAGMEqF3QqtvJFkl5//fWgVlNTk/V+\nDzrooLz71JlqWOXSv3//aP36668PaplWrvTr1y+o5ZIFS5Ysidabm5uD2sqVK4Pa/fffH90+9mEY\nn332WbTtgAEDgtoHH3wQbWsBq1wAoIoQ6ABgBIEOAEYQ6ABgBPdDz1Hs/uKStHjx4jL3pDxqa2uj\n9VwmNV966aVidafqzJ07N6ide+650baZbmER09HREdTmzZsXbTtjxoygtm7dumjbnTt3BrXYrQMG\nDhzYVRf/q0ePHtF6Y2NjULM8KZoNrtABwAgCHQCMINABwAgCHQCMINABwAhWuXRi2LBhQe3pp5+O\ntn388ceD2pQpU6Jtd+3aVVC/SiW2guepp56Ktu3Zs2fW+33++efz7lO1W7NmTVAbO3Zswftdu3Zt\nUHv55ZejbS+77LKg1tDQEG17yimnBLXTTjstt87tY/PmzdF6pg+QqWZcoQOAEQQ6ABhBoAOAEQQ6\nABjB/dA7MXLkyKCWywRf7N7QUvaTosuWLYvWly9fHtRuuummrPuVyZlnnhnUMr2dfPv27UGtpaUl\n2vaqq64Kahs2bMixd9mphvuh33XXXdH6pEmTglqfPn2ibZ0Ln6ZyZkEmsQnQc845J9r2vffeK3V3\n9ivcDx0AqgiBDgBGEOgAYASBDgBGEOgAYASrXDoxaNCgoJZplcuJJ55Y9OPHViJI+8dqhNgHLzQ1\nNVWgJ/+vGla5ZFJfXx/Urrzyymjburq6oJbpgzNyke3qmWeffTa6/axZs4Laxx9/XHC/LGCVCwBU\nEQIdAIwg0AHACAIdAIxgUjRHNTU10Xrsk83PP//8aNva2tqgFnuL/eDBg6Pb79mzp7MudinT/aVf\nffXVrPcxefLkoLZ169a8+1Qs1TwpCtuYFAWAKkKgA4ARBDoAGEGgA4ARBDoAGMEqlwro1atXUIt9\nEEFjY2N0+9gtCWIrZyRpxYoVQS3TapTW1tZo/UDCKhdYxSoXAKgiBDoAGEGgA4ARBDoAGMGkKExh\nUhRWMSkKAFWEQAcAIwh0ADCCQAcAIwh0ADCCQAcAIwh0ADCCQAcAIwh0ADCCQAcAIwh0ADCCQAcA\nIwh0ADCCQAcAIwh0ADCCQAcAIwh0ADCCQAcAIwh0ADCCQAcAIwh0ADDCec+HlQOABVyhA4ARBDoA\nGEGgA4ARBDoAGEGgA4ARBDoAGEGgA4ARBDoAGEGgA4ARBDoAGEGgA4ARBDoAGEGgA4ARBDoAGEGg\nA4ARBDoAGEGgA4ARBDoAGEGgA4ARBDoAGEGgA4ARBDoAGEGgA4AR/wHNiAlz7Wsi5wAAAABJRU5E\nrkJggg==\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAADHCAYAAAAJSqg8AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4wLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvpW3flQAAE65JREFUeJzt3X2M1VV+x/HPF1AYYHiSooAwPEZE\nUCtQAyotFVnaum6bxqwPUaNN6za23cY04rqxm+omGoPdTUh23VhTDVUrqenKxjWWal236DLKg8LS\nysMMDqMIMwzPDiB4+sfvTnOX8z3MjMzjmfcrMRm/9/u787t3zv3ym/mec34WQhAAoPfr190nAADo\nGBR0AMgEBR0AMkFBB4BMUNABIBMUdADIBAW9E5jZ75lZfWcca2ZHzWzKVz874KtjbPdsWRV0M9tl\nZs1mdsTMDprZO2b2LTPL5nWGEIaGEGq6+zzOxsxmmdnrZtZoZix06ACM7Z7BzO4ys/VmdtjM6s3s\nCTMb0N3n1SKbwVDm6yGESklVkh6XtEzSM917Sn3OF5JWSfqz7j6RzDC2u99gSX8rabSkqyVdL+nv\nuvWMyuRY0CVJIYRDIYTVkr4p6S4zmyVJZjbQzJabWZ2Z7TWzp8ysouU4M/uGmW0q/Qu808yWluLj\nzGy1mTWZ2Q4z+/OyYyrM7FkzO2BmWyXNKz+X0rEvm1mDmdWa2d+09dgzmVkws2mlr581sx+Z2Wul\nX1fXmtlFZvbD0vP9r5n9dtmxD5Ze0xEz22pmf1L2WH8ze7J0VV1rZn9V+l4DSo8PN7NnzGyPmX1i\nZt83s/6J9/6jEMIzkn7d6g8K7cbY7tax/eMQwi9DCCdDCJ9Iel7SNa39zLpMCCGb/yTtkrTYiddJ\n+svS1z+UtFrSKEmVkn4m6bHSY78j6ZCkG1T8Yzde0ozSY7+Q9CNJgyRdKalB0vWlxx6X9MvSc06Q\ntEVSfemxfpLWS/p7SedLmiKpRtLXWjs28RqDpGmlr5+V1ChpTum83pRUK+lOSf0lfV/Sf5Ude7Ok\ncaVz+qakY5LGlh77lqStki6WNFLSf5a+14DS4z+V9BNJQySNkVQt6d5Wfh7TiiHW/WOjt//H2O5Z\nY7vs+/5U0uPdPT7+/3y6+wS6aND/StJ3JVnpBz217LH5kmpLX/9E0g+c4ydIOi2psiz2mKRnS1/X\nSFpa9thflA36qyXVnfF835H0z60dm3iNZw76p8se+2tJ/1P2/7MlHTzLc22S9I3S12+WD2JJi1sG\nvaQLJZ2QVFH2+K3lH6jE81PQO+g/xnbPGtulvLsl1Usa3d3jo+W/HvPH/E42XlKTpN9S8Tew9WbW\n8pip+BdfKgb3z53jx0lqCiEcKYt9LGlu2eO7z3isRZWkcWZ2sCzWX8WVS2vHtsXesq+bnf8f2vI/\nZnanpPslTSqFhqr4W6B3HuVfV0k6T9Kesvet3xk56B6MbXX92DazP1bxG8jiEEJj6y+la2Rf0M1s\nnopB/98qfoVrlnRZKP7+dabdkqY68U8ljTKzyrKBP1FSy3PsUfGB+XXZY+XPWRtCmJ44xbMd22HM\nrErS0yqaOO+GEE6b2SYVH/qW87i47JAJZV/vVnEVMzqEcKozzg/tx9gudPXYLvUenpb0RyGEzed6\n/h0p26aomQ0zsxsl/aukfwkhbA4hfKniB/EDMxtTyhtvZl8rHfaMpLvN7Hoz61d6bEYIYbekdyQ9\nZmaDzOxyFTM4ni8dt0rSd8xspJldrOLXwxbVkg6b2bJSk6i/FdP65rXh2I40RMWvmQ2l1323pFll\nj6+S9O3Sax6hYgaFJCmEsEfSf0h6svS+9jOzqWb2u943ssIgFX9XVek9G9gpr6oPYmxHunJs/76K\n9+ZPQwjVnfNyvrocC/rPzOyIin95vyvpH1X8ravFMkk7JP3KzA6raJBcIkmlH9Ddkn6gooH0CxW/\nkknF39Umqbii+XdJ3wshrCk99g8qfp2sVTE4VrZ8sxDCaUlfV9FsqlVxJfVPkoa3dmxHCiFslfSk\npHdV/Oo6W9LaspSnS9//Q0kbVfx6fkrF31elohl1vorm0gFJ/yZpbOLbVam4Wmy5MmuW9FEHvZS+\njLHt6OKx/bCK1/fz0uybo2b2Woe+oHNgpT/uA7/BzP5A0lMhhKpWk4FeJOexneMVOr6C0q/Mf2hm\nA8xsvKTvqbhaA3q1vjS2uUKHJMnMBqv4NXyGij+RvCrp2yGEw916YsA56ktjm4IOAJngTy4AkAkK\nOgBkoksXFhlbqaKThRCs9ayOx9hGZ2vL2OYKHQAyQUEHgExQ0AEgExR0AMgEBR0AMkFBB4BMUNAB\nIBMUdADIBAUdADJBQQeATFDQASATFHQAyAQFHQAyQUEHgExQ0AEgExR0AMgEBR0AMkFBB4BMUNAB\nIBMUdADIRJfeJLor9evn/1s1YED8kr2YJA0cOLDNzxtCfI/g48ePu7knTpxo0/EpX375ZZtzgd7E\nLL4Pcurz2b9//yh26tQpN9f7fKU+c73588UVOgBkgoIOAJmgoANAJijoAJCJLJqiXqPSa2hK0tCh\nQ6PYeeed5+aOGjUqii1evNjNrampiWL79+93c48dOxbFBg0aFMUOHjzoHr979+4o1tzc7Ob25gYP\neh+vqVlRURHFJk2a5B4/e/bsKHbJJZe4uRs3boxie/fudXO3bdsWxbxzlfzG6smTJ91cr7Ga+sx5\n8Y7+fHKFDgCZoKADQCYo6ACQCQo6AGSCgg4AmehVs1y8pb6SNGzYsCh21VVXubkjRoyIYldffbWb\nO3bs2Cg2fvx4N9frrqc6495zjB49Ooqltg44cuRIFHvzzTfd3BdffDGK1dfXu7mpZdPAmVIzw7xZ\nYA899FAUS30+vWX+R48edXM3bNgQxVauXOnmVlZWRrGJEye6uXV1dVHM265DkhobG6PYvn373NzP\nP/88iqU+418VV+gAkAkKOgBkgoIOAJmgoANAJnpVU9RbQixJl156aRS7//772/y8gwcPduNTpkyJ\nYrt27XJzvS0FhgwZ4uZ6TdwxY8ZEseHDh7vHe9saLFq0yM1dtmxZFFuxYoWb+8gjj0SxL774ws1F\n3+GNt1tuucXNfeKJJ6KY1/Bvz30FUrnvvPNOFGtoaHBzvYkE3nlJ0oIFC6JYqnlZXV0dxd5++203\n12uKdjSu0AEgExR0AMgEBR0AMkFBB4BMUNABIBO9apZLqtt97bXXtvk5Ro4cGcUOHz7s5npLgNes\nWePmejMBUjeoWLhwYRS78MILo1hqeXVqCwSPN/tmyZIlbu7y5cuj2KFDh9r8vdC7pT5fEyZMiGJ3\n3XWXm+t9vrznTd3YYevWrVEsNaNm586dbfpekjRu3Lgolprlcs0110Sx888/3831PqNvvPGGm5va\nCqQjcYUOAJmgoANAJijoAJAJCjoAZKJXNUVT+yKvXr06inlNEMlfyu7tkS5JmzZtimLr1q1zc71l\nvd4Sf0navHlzFLvnnnvafPxll10WxQYNGuTmes+Rer1eLk3RPHl3vE9tNXHllVdGsVTD3rsvgDc5\n4NFHH3WPf/nll6PY6dOn3VyP97okvyE5efJkN9erHamJE97n3tu+4Gzn1pG4QgeATFDQASATFHQA\nyAQFHQAyQUEHgEz0qlkuqe6xd5ftp556ys31Ovapmzh89tlnbYpJ/gb6qU3xvfjzzz8fxZqamtzj\np02bFsUeeOABN9dbsvzJJ5+4ud5y7NRS6tTSbXSf9syi8GZFjRo1ys31Zni89dZbbu7HH38cxV55\n5ZUo1tjY6B6f+oy3VWpbjPnz50exOXPmtPk5mpub3dxPP/00inXnZ4YrdADIBAUdADJBQQeATFDQ\nASATWTRFU/uOezZu3BjFUs0kr1maWo4/Y8aMKObtcS5JY8aMiWLeEuKqqir3+Ntuuy2KpbY6GDx4\ncBSbMmWKm+sthT5w4ICb6zV2aZR2jdR4HTAg/jinGnTe/v0VFRVu7v79+6OYt0Rfkurr66OY19w/\n1+an5L8Ps2bNcnMffPDBKJba49ybZPHSSy+5uTt27IhiqckQHfGaW8MVOgBkgoIOAJmgoANAJijo\nAJAJCjoAZKJXzXJJ8WZXeLNGJKmhoSGKXXTRRW6uF7/88svd3Llz50axhQsXurl1dXVRzHsNqQ34\nvdkv3gwHSTp16lQUGzJkiJu7ZMmSKJba2L+2tjaKpbr73jl0Rcc/V+2Z5eLNZpH8pf+p2WLe86aW\n2Hs3o/DON/UavHGRGtvt2QLDm+3jLduXpHfffTeKrV271s398MMPo5i3DUhX4QodADJBQQeATFDQ\nASATFHQAyIR1ZXPKzLrsm6WWPHvNmBEjRri5S5cujWI333yzm3vDDTdEsdSd0b2GrdekSr2G1PN6\nvCaVt5Rb8rdF8Pa3lqTXXnstiq1fv97N9ZpP7bmTe3uEEDr/1uqOnjC2vaX7qeal14RP5U6aNCmK\npba18Mbxtm3bopjXKJf8Rqf32ZL8CQpes17yx/ymTZvc3A8++CCKpe4h4NWT1IQM7z1vz3YZbRnb\nXKEDQCYo6ACQCQo6AGSCgg4Amchipagn1WzwGj8nTpxo83PMnDnTzfUalakmk9eI9po2qcah11BK\nNcoOHTrU5tyJEydGsenTp7u5ixYtimIrVqxwc1etWhXFvBW7kv/esKr0N6Wa4pWVlVEstWrR2ws8\nNV69vfZTkwPmzZsXxbz7CowcOdI93mvspvYt37VrVxRbuXKlm7tmzZoolmpeeud78uRJN9f7LKU+\nX56OvocAV+gAkAkKOgBkgoIOAJmgoANAJijoAJCJbGe5pHgzJlLLkBsbG6PYgQMH3Fxv3+mamho3\n94UXXohi1dXVbf5eXtc/1Vn3uujeUm5Jevjhh6PYFVdc4eZ6syduv/12N9d7bantB1I/C7TOW4bu\nLcVPSc0mGT58eBTzZr5I0oQJE9r0vO3ZvqKpqcmNe1tNvP76626u91lqz/YT7dm/PTUrq7O2uyjH\nFToAZIKCDgCZoKADQCYo6ACQiT7XFPWahKllve+//34Uu/fee91cr5mX2kv82LFjUawrl7ynzmvs\n2LFRbPny5W6ut4d8qtG1YMGCKLZlyxY3l6Zo61INcK/pltrr32vypd77HTt2RLF9+/a5uV6zfNSo\nUVEs1SD04qkl+t4Yas82Hu2Raop6z5v63HbFFhZcoQNAJijoAJAJCjoAZIKCDgCZoKADQCb63CwX\nT6qD3dzcHMXq6+vd3HNdWtyVvA38JWn79u1RLDVrwLtBQWr2xXXXXRfFXn31VTfXu0N8X+aNzdSM\nDW82yZgxY9xc7zm84yV/dsZHH33k5npL/2fPnh3FhgwZ4h7vzZRqz5YEAwb4Jc0bm6nPp/eep2Zw\nec+R+nx1Ba7QASATFHQAyAQFHQAyQUEHgEzQFD0Lr3GTaki1p+nS3VLNy8mTJ0exQ4cOubmp/bA9\nXrMttZwbrRs8eLAbnz9/fhSbMmWKm1tRURHFUuO1rq4uiqUaf972Ed7WGkOHDnWPb8+e7lVVVVEs\n9d4cPXo0irVnO4DUe+M1bFO557r9QFtwhQ4AmaCgA0AmKOgAkAkKOgBkgoIOAJlglouk/v37u3Gv\n4+4tbZb8DnZDQ4Ob683w8GYCpG444MVT2xd4S6FTNz248847o1jq9aZmynhqa2uj2PHjx91c73V0\nxY0Beqr2LEOfPn16FPO2XZCkCy64IIqlZpN4Yyg1y8XbEmLgwIFurscb2wcPHnRzvRvQNDU1ubne\n+abGlTe2U+Pd+9x25+w2rtABIBMUdADIBAUdADJBQQeATNAUPQtvWa+33FiS5s6dG8VuuukmN9dr\nmnh7r9fU1LjHr127Noqlluh7r2Hq1KlubmqZuMdram7YsMHNfe6556JYqtHVlxugHu/9SDXxvcZf\nqgHu7ZOeal56DcFUE97L9c7LG++StGfPnii2cuVKN3fVqlVRLLWlxLkuu/eanz0RV+gAkAkKOgBk\ngoIOAJmgoANAJijoAJAJ68pZBWbWI6cwpGYNDBs2LIotXbrUzX300Uej2MSJE91cbym193NIdea9\nGQIHDhxwc73naGxsdHO99yE182Hz5s1RbMWKFW5udXV1FOusO6OHEPzpF52sK8d26mcyZ86cKHbH\nHXe4uQsXLoxi3nYAkr8lQGp5+969e6PYjh07otiaNWvc4997770otmXLFjf32LFjUawj6llP3X6i\nLWObK3QAyAQFHQAyQUEHgExQ0AEgEzRFlV7G7DWDUsvmb7311ih23333ubmVlZVtOodU49BrBh0+\nfNjNXbduXRRbvXq1m7tz584olmq2evtOp3K9Blpnjbu+0BRNjVdvn/TRo0e7ubNmzYpiM2fOdHO9\n8eptKSH5DUyv0VlfX+8e3559y/samqIA0IdQ0AEgExR0AMgEBR0AMkFBB4BMMMvlLLzZBKm7f3td\n/3Hjxrm5N954YxTztgmora11j9++fXsUS90wwnuOo0ePurntuYN5T5150BdmuZzlHKJYaluL9tzZ\n3pMaF6dOnYpiPXWs9DbMcgGAPoSCDgCZoKADQCYo6ACQCZqivZDX/EotBz/Xu533Nn25KYq80RQF\ngD6Egg4AmaCgA0AmKOgAkAkKOgBkIr79PHo8b2YSy6sBcIUOAJmgoANAJijoAJAJCjoAZIKCDgCZ\noKADQCYo6ACQCQo6AGSCgg4AmaCgA0AmKOgAkAkKOgBkgoIOAJmgoANAJijoAJAJYx9tAMgDV+gA\nkAkKOgBkgoIOAJmgoANAJijoAJAJCjoAZIKCDgCZoKADQCYo6ACQCQo6AGSCgg4AmaCgA0AmKOgA\nkAkKOgBkgoIOAJmgoANAJijoAJAJCjoAZIKCDgCZoKADQCYo6ACQCQo6AGSCgg4Amfg/3qGZnurJ\npRcAAAAASUVORK5CYII=\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "digitA = 3\n", "digitB = 8\n", "\n", "digitA_index = label_dict[digitA]\n", "digitB_index = label_dict[digitB]\n", "\n", "imgA = img_data[digitA_index[0],:,:][0] \n", "imgB = img_data[digitB_index[0],:,:][0]\n", "\n", "# Print distance between original image\n", "imgA_B_dist = image_pair_cosine_distance(imgA, imgB)\n", "print(\"Distance between two original image: {0:.3f}\".format(imgA_B_dist))\n", " \n", "# Plot the two images\n", "img1 = imgA.reshape(28,28)\n", "text1 = 'Original image 1'\n", "\n", "img2 = imgB.reshape(28,28)\n", "text2 = 'Original image 2'\n", "\n", "plot_image_pair(img1, text1, img2, text2)\n", " \n", "# Decode the encoded stream \n", "imgA_decoded = model.eval([imgA])[0]\n", "imgB_decoded = model.eval([imgB])[0] \n", "imgA_B_decoded_dist = image_pair_cosine_distance(imgA_decoded, imgB_decoded)\n", "\n", "#Print distance between original image\n", "print(\"Distance between two decoded image: {0:.3f}\".format(imgA_B_decoded_dist))\n", "\n", "# Plot the original and the decoded image\n", "img1 = imgA_decoded.reshape(28,28)\n", "text1 = 'Decoded image 1'\n", "\n", "img2 = imgB_decoded.reshape(28,28)\n", "text2 = 'Decoded image 2'\n", "\n", "plot_image_pair(img1, text1, img2, text2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Print the results of the deep encoder test error for regression testing" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.97620738737\n" ] } ], "source": [ "# Simple autoencoder test error\n", "print(simple_ae_test_error)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.87243351221\n" ] } ], "source": [ "# Deep autoencoder test error\n", "print(deep_ae_test_error)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Suggested tasks**\n", "\n", "- Try different activation functions.\n", "- Find which images are more similar to one another (a) using original image and (b) decoded image.\n", "- Try using mean square error as the loss function. Does it improve the performance of the encoder in terms of reduced errors.\n", "- Can you try different network structure to reduce the error further. Explain your observations.\n", "- Can you use a different distance metric to compute similarity between the MNIST images.\n", "- Try a deep encoder with [1000, 500, 250, 128, 64, 32]. What is the training error for same number of iterations? " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "anaconda-cloud": {}, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.4" } }, "nbformat": 4, "nbformat_minor": 1 }