Skip to content

Curiousily

How Neural Networks make predictions

Deep Learning, Object Detection, Neural Network, PyTorch, Python2 min read

Share

TL;DR Learn how models make their predictions and learn to evaluate them

How models (and Neural Networks in particular) make their predictions? How can they be improved?

  • Run the notebook in your browser (Google Colab)

You’ll learn how to:

  • Build a baseline model with a single weight value
  • Evaluate your model performance by measuring it’s error
  • Find better weight values with the guessing method

The almighty function

Most of your work (as ML person) is related to how well this one single function performs:

1def predict(data):
2 ...

To the untrained eye, this looks deceptively simple. Isn’t this just one function? I can guarantee you you’re about to develop a love-hate relationship with it!

In reality, when dealing with Deep Learning the predict() function can have a simple implementation:

1def predict(data):
2 return data * weight

But what does weight contain? This is the “brains” of your model. Without good weight - you don’t get to drink Corona on your favorite beach, washboard abs, and hotties around you.

How can you get some good weights for yourself?

There are some ways to find good weight values, but let’s start with something really dumb (a good baseline if you will) and improve on that (hopefully, in a non-linear fashion).

Let’s say you need to predict the Tesla stock price. The current price is 420 USD. What will be the price in a minute?

1def predict(data, weight):
2 return data * weight

Actually, this is your first model. Congratulations! Let’s use it:

1current_stock_price = 420
2
3prediction = predict(current_stock_price, weight=1.01)
4
5f"Predicted Tesla stock price is {prediction} USD"
1'Predicted Tesla stock price is 424.2 USD'

Yes, it works! Unfortunately, the real price is 425 USD. How can you use this information to make your model better?

One popular way is to start by measuring the error of the prediction:

1import torch
2import torch.nn.functional as F
3
4real_next_price = 425
5
6error = F.mse_loss(
7 input=torch.tensor(prediction),
8 target=torch.tensor(real_next_price)
9)
10
11error
1tensor(0.6400)

Great, now you have a concrete goal - reduce the error to 0. We can continue with our dumb strategy of guessing weight values:

1weight = 1.01
2
3change = 0.0001
4
5target = torch.tensor(real_next_price)
6
7for _ in range(1000):
8 guess_up = torch.tensor(predict(current_stock_price, weight + change))
9 guess_down = torch.tensor(predict(current_stock_price, weight - change))
10
11 error_up = F.mse_loss(guess_up, target)
12 error_down = F.mse_loss(guess_down, target)
13
14 if error_up < error_down:
15 weight += change
16 else:
17 weight -= change

We start with the current weight and try to reduce or increase it by a bit. We take whichever change has a smaller error and repeat.

What is the predicted stock price with the new model?

1predict(current_stock_price, weight)
1425.0399999999999

This one looks much better. What about the error?

1F.mse_loss(
2 input=torch.tensor(predict(current_stock_price, weight)),
3 target=target
4)
1tensor(0.0016)

A drastic improvement compared to the first model. Can you really use this guessing method in practice? No!

Summary

Making good models is all about finding good weight values. The guessing method is a simple way to find weights for your model. But it is slow and might not give good results.

You learned how to:

  • Build a baseline model with a single weight value
  • Evaluate your model performance by measuring it’s error
  • Find better weight values with the guessing method

Next, we’ll get back to the real world and build an end-to-end project containing our baseline model. From there, we’ll be free to experiment and look for improvements.

Share

Want to be a Machine Learning expert?

Join the weekly newsletter on Data Science, Deep Learning and Machine Learning in your inbox, curated by me! Chosen by 10,000+ Machine Learning practitioners. (There might be some exclusive content, too!)

You'll never get spam from me