— Deep Learning, Computer Vision, Machine Learning, Neural Network, Transfer Learning, Python — 4 min read
Share
TL;DR Learn how to use Transfer Learning to classify traffic sign images. You’ll build a dataset of images in a format suitable for working with Torchvision. Get predictions on images from the wild (downloaded from the Internet).
In this tutorial, you’ll learn how to fine-tune a pre-trained model for classifying raw pixels of traffic signs.
Here’s what we’ll go over:
Will this model be ready for the real world?
1import torch, torchvision23from pathlib import Path4import numpy as np5import cv26import pandas as pd7from tqdm import tqdm8import PIL.Image as Image9import seaborn as sns10from pylab import rcParams11import matplotlib.pyplot as plt12from matplotlib import rc13from matplotlib.ticker import MaxNLocator14from torch.optim import lr_scheduler15from sklearn.model_selection import train_test_split16from sklearn.metrics import confusion_matrix, classification_report17from glob import glob18import shutil19from collections import defaultdict2021from torch import nn, optim2223import torch.nn.functional as F24import torchvision.transforms as T25from torchvision.datasets import ImageFolder26from torch.utils.data import DataLoader27from torchvision import models2829%matplotlib inline30%config InlineBackend.figure_format='retina'3132sns.set(style='whitegrid', palette='muted', font_scale=1.2)3334HAPPY_COLORS_PALETTE = ["#01BEFE", "#FFDD00", "#FF7D00", "#FF006D", "#ADFF02", "#8F00FF"]3536sns.set_palette(sns.color_palette(HAPPY_COLORS_PALETTE))3738rcParams['figure.figsize'] = 12, 83940RANDOM_SEED = 4241np.random.seed(RANDOM_SEED)42torch.manual_seed(RANDOM_SEED)4344device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
German Traffic Sign Recognition Benchmark (GTSRB) contains more than 50,000 annotated images of 40+ traffic signs. Given an image, you’ll have to recognize the traffic sign on it.
1!wget https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/GTSRB_Final_Training_Images.zip2!unzip -qq GTSRB_Final_Training_Images.zip
Let’s start by getting a feel of the data. The images for each traffic sign are stored in a separate directory. How many do we have?
1train_folders = sorted(glob('GTSRB/Final_Training/Images/*'))2len(train_folders)
143
We’ll create 3 helper functions that use OpenCV and Torchvision to load and show images:
1def load_image(img_path, resize=True):2 img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)34 if resize:5 img = cv2.resize(img, (64, 64), interpolation = cv2.INTER_AREA)67 return img89def show_image(img_path):10 img = load_image(img_path)11 plt.imshow(img)12 plt.axis('off')1314def show_sign_grid(image_paths):15 images = [load_image(img) for img in image_paths]16 images = torch.as_tensor(images)17 images = images.permute(0, 3, 1, 2)18 grid_img = torchvision.utils.make_grid(images, nrow=11)19 plt.figure(figsize=(24, 12))20 plt.imshow(grid_img.permute(1, 2, 0))21 plt.axis('off');
Let’s have a look at some examples for each traffic sign:
1sample_images = [np.random.choice(glob(f'{tf}/*ppm')) for tf in train_folders]2show_sign_grid(sample_images)
And here is a single sign:
1img_path = glob(f'{train_folders[16]}/*ppm')[1]23show_image(img_path)
To keep things simple, we’ll focus on classifying some of the most used traffic signs:
1class_names = ['priority_road', 'give_way', 'stop', 'no_entry']23class_indices = [12, 13, 14, 17]
We’ll copy the images files to a new directory, so it’s easier to use the Torchvision’s dataset helpers. Let’s start with the directories for each class:
1!rm -rf data23DATA_DIR = Path('data')45DATASETS = ['train', 'val', 'test']67for ds in DATASETS:8 for cls in class_names:9 (DATA_DIR / ds / cls).mkdir(parents=True, exist_ok=True)
We’ll reserve 80% of the images for training, 10% for validation, and 10% test for each class. We’ll copy each image to the correct dataset directory:
1for i, cls_index in enumerate(class_indices):2 image_paths = np.array(glob(f'{train_folders[cls_index]}/*.ppm'))3 class_name = class_names[i]4 print(f'{class_name}: {len(image_paths)}')5 np.random.shuffle(image_paths)67 ds_split = np.split(8 image_paths,9 indices_or_sections=[int(.8*len(image_paths)), int(.9*len(image_paths))]10 )1112 dataset_data = zip(DATASETS, ds_split)1314 for ds, images in dataset_data:15 for img_path in images:16 shutil.copy(img_path, f'{DATA_DIR}/{ds}/{class_name}/')
1priority_road: 21002 give_way: 21603 stop: 7804 no_entry: 1110
We have some class imbalance, but it is not that bad. We’ll ignore it.
We’ll apply some image augmentation techniques to artificially increase the size of our training dataset:
1mean_nums = [0.485, 0.456, 0.406]2std_nums = [0.229, 0.224, 0.225]34transforms = {'train': T.Compose([5 T.RandomResizedCrop(size=256),6 T.RandomRotation(degrees=15),7 T.RandomHorizontalFlip(),8 T.ToTensor(),9 T.Normalize(mean_nums, std_nums)10]), 'val': T.Compose([11 T.Resize(size=256),12 T.CenterCrop(size=224),13 T.ToTensor(),14 T.Normalize(mean_nums, std_nums)15]), 'test': T.Compose([16 T.Resize(size=256),17 T.CenterCrop(size=224),18 T.ToTensor(),19 T.Normalize(mean_nums, std_nums)20]),21}
We apply some random resizing, rotation, and horizontal flips. Finally, we normalize the tensors using preset values for each channel. This is a requirement of the pre-trained models in Torchvision.
We’ll create a PyTorch dataset for each image dataset folder and data loaders for easier training:
1image_datasets = {2 d: ImageFolder(f'{DATA_DIR}/{d}', transforms[d]) for d in DATASETS3}45data_loaders = {6 d: DataLoader(image_datasets[d], batch_size=4, shuffle=True, num_workers=4)7 for d in DATASETS8}
We’ll also store the number of examples in each dataset and class names for later:
1dataset_sizes = {d: len(image_datasets[d]) for d in DATASETS}2class_names = image_datasets['train'].classes34dataset_sizes
1{'test': 615, 'train': 4920, 'val': 615}
Let’s have a look at some example images with applied transformations. We also need to reverse the normalization and reorder the color channels to get correct image data:
1def imshow(inp, title=None):2 inp = inp.numpy().transpose((1, 2, 0))3 mean = np.array([mean_nums])4 std = np.array([std_nums])5 inp = std * inp + mean6 inp = np.clip(inp, 0, 1)7 plt.imshow(inp)8 if title is not None:9 plt.title(title)10 plt.axis('off')1112inputs, classes = next(iter(data_loaders['train']))13out = torchvision.utils.make_grid(inputs)1415imshow(out, title=[class_names[x] for x in classes])
Our model will receive raw image pixels and try to classify them into one of four traffic signs. How hard can it be? Try to build a model from scratch.
Here, we’ll use Transfer Learning to copy the architecture of the very popular ResNet model. On top of that, we’ll use the learned weights of the model from training on the ImageNet dataset . All of this is made easy to use by Torchvision:
1def create_model(n_classes):2 model = models.resnet34(pretrained=True)34 n_features = model.fc.in_features5 model.fc = nn.Linear(n_features, n_classes)67 return model.to(device)
We reuse almost everything except the change of the output layer. This is needed because the number of classes in our dataset is different than ImageNet.
Let’s create an instance of our model:
1base_model = create_model(len(class_names))
We’ll write 3 helper functions to encapsulate the training and evaluation logic. Let’s start with train_epoch
:
1def train_epoch(2 model,3 data_loader,4 loss_fn,5 optimizer,6 device,7 scheduler,8 n_examples9):10 model = model.train()1112 losses = []13 correct_predictions = 01415 for inputs, labels in data_loader:16 inputs = inputs.to(device)17 labels = labels.to(device)1819 outputs = model(inputs)2021 _, preds = torch.max(outputs, dim=1)22 loss = loss_fn(outputs, labels)2324 correct_predictions += torch.sum(preds == labels)25 losses.append(loss.item())2627 loss.backward()28 optimizer.step()29 optimizer.zero_grad()3031 scheduler.step()3233 return correct_predictions.double() / n_examples, np.mean(losses)
We start by turning our model into train mode and go over the data. After getting the predictions, we get the class with maximum probability along with the loss, so we can calculate the epoch loss and accuracy.
Note that we’re also using a learning rate scheduler (more on that later).
1def eval_model(model, data_loader, loss_fn, device, n_examples):2 model = model.eval()34 losses = []5 correct_predictions = 067 with torch.no_grad():8 for inputs, labels in data_loader:9 inputs = inputs.to(device)10 labels = labels.to(device)1112 outputs = model(inputs)1314 _, preds = torch.max(outputs, dim=1)1516 loss = loss_fn(outputs, labels)1718 correct_predictions += torch.sum(preds == labels)19 losses.append(loss.item())2021 return correct_predictions.double() / n_examples, np.mean(losses)
The evaluation of the model is pretty similar, except that we don’t do any gradient calculations.
Let’s put everything together:
1def train_model(model, data_loaders, dataset_sizes, device, n_epochs=3):2 optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)3 scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)4 loss_fn = nn.CrossEntropyLoss().to(device)56 history = defaultdict(list)7 best_accuracy = 089 for epoch in range(n_epochs):1011 print(f'Epoch {epoch + 1}/{n_epochs}')12 print('-' * 10)1314 train_acc, train_loss = train_epoch(15 model,16 data_loaders['train'],17 loss_fn,18 optimizer,19 device,20 scheduler,21 dataset_sizes['train']22 )2324 print(f'Train loss {train_loss} accuracy {train_acc}')2526 val_acc, val_loss = eval_model(27 model,28 data_loaders['val'],29 loss_fn,30 device,31 dataset_sizes['val']32 )3334 print(f'Val loss {val_loss} accuracy {val_acc}')35 print()3637 history['train_acc'].append(train_acc)38 history['train_loss'].append(train_loss)39 history['val_acc'].append(val_acc)40 history['val_loss'].append(val_loss)4142 if val_acc > best_accuracy:43 torch.save(model.state_dict(), 'best_model_state.bin')44 best_accuracy = val_acc4546 print(f'Best val accuracy: {best_accuracy}')4748 model.load_state_dict(torch.load('best_model_state.bin'))4950 return model, history
We do a lot of string formatting and recording of the training history. The hard stuff gets delegated to the previous helper functions. We also want the best model, so the weights of the most accurate model(s) get stored during the training.
Let’s train our first model:
1%%time23base_model, history = train_model(base_model, data_loaders, dataset_sizes, device)
1Epoch 1/32 ----------3 Train loss 0.31827690804876935 accuracy 0.88597560975609764 Val loss 0.0012465072916699694 accuracy 1.056 Epoch 2/37 ----------8 Train loss 0.12230596961529275 accuracy 0.96158536585365859 Val loss 0.0007955377752130681 accuracy 1.01011 Epoch 3/312 ----------13 Train loss 0.07771141678094864 accuracy 0.974593495934959414 Val loss 0.0025791768387877366 accuracy 0.99837398373983741516 Best val accuracy: 1.017 CPU times: user 2min 24s, sys: 48.2 s, total: 3min 12s18 Wall time: 3min 21s
Here’s a little helper function that visualizes the training history for us:
1def plot_training_history(history):2 fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 6))34 ax1.plot(history['train_loss'], label='train loss')5 ax1.plot(history['val_loss'], label='validation loss')67 ax1.xaxis.set_major_locator(MaxNLocator(integer=True))8 ax1.set_ylim([-0.05, 1.05])9 ax1.legend()10 ax1.set_ylabel('Loss')11 ax1.set_xlabel('Epoch')1213 ax2.plot(history['train_acc'], label='train accuracy')14 ax2.plot(history['val_acc'], label='validation accuracy')1516 ax2.xaxis.set_major_locator(MaxNLocator(integer=True))17 ax2.set_ylim([-0.05, 1.05])18 ax2.legend()1920 ax2.set_ylabel('Accuracy')21 ax2.set_xlabel('Epoch')2223 fig.suptitle('Training history')2425plot_training_history(history)
The pre-trained model is so good that we get very high accuracy and low loss after 3 epochs. Unfortunately, our validation set is too small to get some meaningful metrics from it.
Let’s see some predictions on traffic signs from the test set:
1def show_predictions(model, class_names, n_images=6):2 model = model.eval()3 images_handeled = 04 plt.figure()56 with torch.no_grad():7 for i, (inputs, labels) in enumerate(data_loaders['test']):8 inputs = inputs.to(device)9 labels = labels.to(device)1011 outputs = model(inputs)12 _, preds = torch.max(outputs, 1)1314 for j in range(inputs.shape[0]):15 images_handeled += 116 ax = plt.subplot(2, n_images//2, images_handeled)17 ax.set_title(f'predicted: {class_names[preds[j]]}')18 imshow(inputs.cpu().data[j])19 ax.axis('off')2021 if images_handeled == n_images:22 return
1show_predictions(base_model, class_names, n_images=8)
Very good! Even the almost not visible priority road sign is classified correctly. Let’s dive a bit deeper.
We’ll start by getting the predictions from our model:
1def get_predictions(model, data_loader):2 model = model.eval()3 predictions = []4 real_values = []5 with torch.no_grad():6 for inputs, labels in data_loader:7 inputs = inputs.to(device)8 labels = labels.to(device)910 outputs = model(inputs)11 _, preds = torch.max(outputs, 1)12 predictions.extend(preds)13 real_values.extend(labels)14 predictions = torch.as_tensor(predictions).cpu()15 real_values = torch.as_tensor(real_values).cpu()16 return predictions, real_values
1y_pred, y_test = get_predictions(base_model, data_loaders['test'])
1print(classification_report(y_test, y_pred, target_names=class_names))
1precision recall f1-score support23 give_way 1.00 1.00 1.00 2164 no_entry 1.00 1.00 1.00 1115 priority_road 1.00 1.00 1.00 2106 stop 1.00 1.00 1.00 7878 accuracy 1.00 6159 macro avg 1.00 1.00 1.00 61510 weighted avg 1.00 1.00 1.00 615
The classification report shows us that our model is perfect, not something you see every day! Does this thing make any mistakes?
1def show_confusion_matrix(confusion_matrix, class_names):23 cm = confusion_matrix.copy()45 cell_counts = cm.flatten()67 cm_row_norm = cm / cm.sum(axis=1)[:, np.newaxis]89 row_percentages = ["{0:.2f}".format(value) for value in cm_row_norm.flatten()]1011 cell_labels = [f"{cnt}\n{per}" for cnt, per in zip(cell_counts, row_percentages)]12 cell_labels = np.asarray(cell_labels).reshape(cm.shape[0], cm.shape[1])1314 df_cm = pd.DataFrame(cm_row_norm, index=class_names, columns=class_names)1516 hmap = sns.heatmap(df_cm, annot=cell_labels, fmt="", cmap="Blues")17 hmap.yaxis.set_ticklabels(hmap.yaxis.get_ticklabels(), rotation=0, ha='right')18 hmap.xaxis.set_ticklabels(hmap.xaxis.get_ticklabels(), rotation=30, ha='right')19 plt.ylabel('True Sign')20 plt.xlabel('Predicted Sign');
1cm = confusion_matrix(y_test, y_pred)2show_confusion_matrix(cm, class_names)
No, no mistakes here!
Ok, but how good our model will be when confronted with a real-world image? Let’s check it out:
1!gdown --id 19Qz3a61Ou_QSHsLeTznx8LtDBu4tbqHr
1show_image('stop-sign.jpg')
For this, we’ll have a look at the confidence for each class. Let’s get this from our model:
1def predict_proba(model, image_path):2 img = Image.open(image_path)3 img = img.convert('RGB')4 img = transforms['test'](img).unsqueeze(0)56 pred = model(img.to(device))7 pred = F.softmax(pred, dim=1)8 return pred.detach().cpu().numpy().flatten()
1pred = predict_proba(base_model, 'stop-sign.jpg')2pred
1array([1.1296713e-03, 1.9811286e-04, 3.4486805e-04, 9.9832731e-01],2 dtype=float32)
This is a bit hard to understand. Let’s plot it:
1def show_prediction_confidence(prediction, class_names):2 pred_df = pd.DataFrame({3 'class_names': class_names,4 'values': prediction5 })6 sns.barplot(x='values', y='class_names', data=pred_df, orient='h')7 plt.xlim([0, 1]);
1show_prediction_confidence(pred, class_names)
Again, our model is performing very well! Really confident in the correct traffic sign!
The last challenge for our model is a traffic sign that it hasn’t seen before:
1!gdown --id 1F61-iNhlJk-yKZRGcu6S9P29HxDFxF0u
1show_image('unknown-sign.jpg')
Let’s get the predictions:
1pred = predict_proba(base_model, 'unknown-sign.jpg')2pred
1array([9.9413127e-01, 1.1861280e-06, 3.9936006e-03, 1.8739274e-03],2 dtype=float32)
1show_prediction_confidence(pred, class_names)
Our model is very certain (more than 95% confidence) that this is a give way sign. This is obviously wrong. How can you make your model see this?
While there are a variety of ways to handle this situation (one described in this paper: A Baseline for Detecting Misclassified and Out-of-Distribution Examples in Neural Networks), we’ll do something simpler.
We’ll get the indices of all traffic signs that weren’t included in our original dataset:
1unknown_indices = [2 i for i, f in enumerate(train_folders) \3 if i not in class_indices4]56len(unknown_indices)
139
We’ll create a new folder for the unknown class and copy some of the images there:
1for ds in DATASETS:2 (DATA_DIR / ds / 'unknown').mkdir(parents=True, exist_ok=True)34for ui in unknown_indices:5 image_paths = np.array(glob(f'{train_folders[ui]}/*.ppm'))6 image_paths = np.random.choice(image_paths, 50)78 ds_split = np.split(9 image_paths,10 indices_or_sections=[int(.8*len(image_paths)), int(.9*len(image_paths))]11 )1213 dataset_data = zip(DATASETS, ds_split)1415 for ds, images in dataset_data:16 for img_path in images:17 shutil.copy(img_path, f'{DATA_DIR}/{ds}/unknown/')
The next steps are identical to what we’ve already done:
1image_datasets = {2 d: ImageFolder(f'{DATA_DIR}/{d}', transforms[d]) for d in DATASETS3}45data_loaders = {6 d: DataLoader(image_datasets[d], batch_size=4, shuffle=True, num_workers=4)7 for d in DATASETS8}910dataset_sizes = {d: len(image_datasets[d]) for d in DATASETS}11class_names = image_datasets['train'].classes1213dataset_sizes
1{'test': 784, 'train': 5704, 'val': 794}
1%%time23enchanced_model = create_model(len(class_names))4enchanced_model, history = train_model(enchanced_model, data_loaders, dataset_sizes, device)
1Epoch 1/32 ----------3 Train loss 0.39523224640235327 accuracy 0.86500701262272084 Val loss 0.002290595416447625 accuracy 1.056 Epoch 2/37 ----------8 Train loss 0.173455789528505 accuracy 0.94460028050490899 Val loss 0.030148923471944415 accuracy 0.98866498740554151011 Epoch 3/312 ----------13 Train loss 0.11575758963990512 accuracy 0.964060308555399714 Val loss 0.0014996432778823317 accuracy 1.01516 Best val accuracy: 1.017 CPU times: user 2min 47s, sys: 56.2 s, total: 3min 44s18 Wall time: 3min 53s
1plot_training_history(history)
Again, our model is learning very quickly. Let’s have a look at the sample image again:
1show_image('unknown-sign.jpg')
1pred = predict_proba(enchanced_model, 'unknown-sign.jpg')2show_prediction_confidence(pred, class_names)
Great, the model doesn’t give much weight to any of the known classes. It doesn’t magically know that this is a two-way sign, but recognizes is as unknown.
Let’s have a look at some examples of our new dataset:
1show_predictions(enchanced_model, class_names, n_images=8)
Let’s get an overview of the new model’s performance:
1y_pred, y_test = get_predictions(enchanced_model, data_loaders['test'])
1print(classification_report(y_test, y_pred, target_names=class_names))
1precision recall f1-score support23 give_way 1.00 1.00 1.00 2164 no_entry 1.00 1.00 1.00 1115 priority_road 1.00 1.00 1.00 2106 stop 1.00 1.00 1.00 787 unknown 1.00 1.00 1.00 16989 accuracy 1.00 78410 macro avg 1.00 1.00 1.00 78411 weighted avg 1.00 1.00 1.00 784
1cm = confusion_matrix(y_test, y_pred)2show_confusion_matrix(cm, class_names)
Our model is still perfect. Go ahead, try it on more images!
Good job! You trained two different models for classifying traffic signs from raw pixels. You also built a dataset that is compatible with Torchvision.
Here’s what you’ve learned:
Can you use transfer learning for other tasks? How do you do it? Let me know in the comments below.
Share
You'll never get spam from me