{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "01.making-predictions.ipynb", "provenance": [], "collapsed_sections": [], "authorship_tag": "ABX9TyOt128pPbuoPXd4T+vPAjzZ" }, "kernelspec": { "name": "python3", "display_name": "Python 3" } }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "gguoUjRZK1_u" }, "source": [ "# How Neural Networks Make Predictions\n", "\n", "> TL;DR Learn how models make their predictions and learn to evaluate them\n", "\n", "How models (and Neural Networks in particular) make their predictions? How can they be improved?\n", "\n", "* [Run the notebook in your browser (Google Colab)](https://colab.research.google.com/drive/1LH-V1kYOl6icJ3O2-lWA4IK4i_EEReOP?usp=sharing)\n", "\n", "You’ll learn how to:\n", "\n", "- Build a baseline model with a single weight value\n", "- Evaluate your model performance by measuring it's error\n", "- Find better weight values with the guessing method\n", "\n", "## The almighty function\n", "\n", "Most of your work (as ML person) is related to how well this one single function performs:\n", "\n", "```python\n", "def predict(data):\n", " ...\n", "```\n", "\n", "To the untrained eye, this looks deceptively simple. Isn't this just one function? I can guarantee you you're about to develop a love-hate relationship with it!\n", "\n", "In reality, when dealing with Deep Learning the `predict()` function can have a simple implementation:\n", "\n", "```python\n", "def predict(data):\n", " return data * weight\n", "```\n", "\n", "But what does `weight` contain? This is the \"brains\" of your model. Without good weight - you don't get to drink Corona on your favorite beach, washboard abs, and hotties around you.\n", "\n", "## How can you get some good weights for yourself?\n", "\n", "There are some ways to find good weight values, but let's start with something really dumb (a good baseline if you will) and improve on that (hopefully, in a non-linear fashion).\n", "\n", "Let's say you need to predict the Tesla stock price. The current price is 420 USD. What will be the price in a minute?" ] }, { "cell_type": "code", "metadata": { "id": "2VF8HOqSOGME" }, "source": [ "def predict(data, weight):\n", " return data * weight" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "inabMRKbU7l6" }, "source": [ "Actually, this is your first model. Congratulations! Let's use it:" ] }, { "cell_type": "code", "metadata": { "id": "8z8_bBHaUmmJ", "colab": { "base_uri": "https://localhost:8080/", "height": 35 }, "outputId": "32a68887-a94a-4cb2-f11f-a0ebe532de65" }, "source": [ "current_stock_price = 420\n", "\n", "prediction = predict(current_stock_price, weight=1.01)\n", "\n", "f\"Predicted Tesla stock price is {prediction} USD\"" ], "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "application/vnd.google.colaboratory.intrinsic+json": { "type": "string" }, "text/plain": [ "'Predicted Tesla stock price is 424.2 USD'" ] }, "metadata": { "tags": [] }, "execution_count": 2 } ] }, { "cell_type": "markdown", "metadata": { "id": "0-SZZ61CWUOp" }, "source": [ "Yes, it works! Unfortunately, the real price is 425 USD. How can you use this information to make your model better?\n", "\n", "One popular way is to start by measuring the error of the prediction:" ] }, { "cell_type": "code", "metadata": { "id": "l6U1xrjwVHUW", "colab": { "base_uri": "https://localhost:8080/", "height": 35 }, "outputId": "cc0d8ff8-0c8b-4168-d60a-c9aaf0c38bec" }, "source": [ "import torch\n", "import torch.nn.functional as F\n", "\n", "real_next_price = 425\n", "\n", "error = F.mse_loss(\n", " input=torch.tensor(prediction), \n", " target=torch.tensor(real_next_price)\n", ")\n", "\n", "error" ], "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "tensor(0.6400)" ] }, "metadata": { "tags": [] }, "execution_count": 3 } ] }, { "cell_type": "markdown", "metadata": { "id": "OQg4ZXy8anws" }, "source": [ "Great, now you have a concrete goal - reduce the error to 0. We can continue with our dumb strategy of guessing weight values:" ] }, { "cell_type": "code", "metadata": { "id": "k69QHbNFXE1B" }, "source": [ "weight = 1.01\n", "\n", "change = 0.0001\n", "\n", "target = torch.tensor(real_next_price)\n", "\n", "for _ in range(1000):\n", " guess_up = torch.tensor(predict(current_stock_price, weight + change))\n", " guess_down = torch.tensor(predict(current_stock_price, weight - change))\n", "\n", " error_up = F.mse_loss(guess_up, target)\n", " error_down = F.mse_loss(guess_down, target)\n", "\n", " if error_up < error_down:\n", " weight += change\n", " else:\n", " weight -= change" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "tpDG-bj79DcG" }, "source": [ "We start with the current weight and try to reduce or increase it by a bit. We take whichever change has a smaller error and repeat.\n", "\n", "What is the predicted stock price with the new model?" ] }, { "cell_type": "code", "metadata": { "id": "lWLg7cg97Oby", "colab": { "base_uri": "https://localhost:8080/", "height": 35 }, "outputId": "ba254686-fbfd-4ee7-910a-cb32d0b07e11" }, "source": [ "predict(current_stock_price, weight)" ], "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "425.0399999999999" ] }, "metadata": { "tags": [] }, "execution_count": 5 } ] }, { "cell_type": "markdown", "metadata": { "id": "tmWfYXYE9_zV" }, "source": [ "This one looks much better. What about the error?" ] }, { "cell_type": "code", "metadata": { "id": "sYoyF0G67RK4", "colab": { "base_uri": "https://localhost:8080/", "height": 35 }, "outputId": "aa864dba-5d42-4de5-cf74-e0ee4fc38684" }, "source": [ "F.mse_loss(\n", " input=torch.tensor(predict(current_stock_price, weight)),\n", " target=target\n", ")" ], "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "tensor(0.0016)" ] }, "metadata": { "tags": [] }, "execution_count": 6 } ] }, { "cell_type": "markdown", "metadata": { "id": "S2QCedBuAq1l" }, "source": [ "A drastic improvement compared to the first model. Can you really use this guessing method in practice? No!\n", "\n", "## Summary\n", "\n", "Making good models is all about finding good weight values. The guessing method is a simple way to find weights for your model. But it is slow and might not give good results.\n", "\n", "* [Run the notebook in your browser (Google Colab)](https://colab.research.google.com/drive/1LH-V1kYOl6icJ3O2-lWA4IK4i_EEReOP?usp=sharing)\n", "\n", "You learned how to:\n", "\n", "- Build a baseline model with a single weight value\n", "- Evaluate your model performance by measuring it's error\n", "- Find better weight values with the guessing method\n", "\n", "Next, we'll get back to the real world and build an end-to-end project containing our baseline model. From there, we'll be free to experiment and look for improvements." ] } ] }