— Deep Learning, PyTorch, Machine Learning, Computer Vision, Object Detection, Face Detection, Python — 5 min read
Share
TL;DR Learn how to prepare a custom Face Detection dataset for Detectron2 and PyTorch. Fine-tune a pre-trained model to find face boundaries in images.
Face detection is the task of finding (boundaries of) faces in images. This is useful for
Historically, this was a really tough problem to solve. Tons of manual feature engineering, novel algorithms and methods were developed to improve the state-of-the-art.
These days, face detection models are included in almost every computer vision package/framework. Some of the best-performing ones use Deep Learning methods. OpenCV, for example, provides a variety of tools like the Cascade Classifier.
In this guide, you’ll learn how to:
Here’s an example of what you’ll get at the end of this guide:
Detectron2 is a framework for building state-of-the-art object detection and image segmentation models. It is developed by the Facebook Research team. Detectron2 is a complete rewrite of the first version.
Under the hood, Detectron2 uses PyTorch (compatible with the latest version(s)) and allows for blazing fast training. You can learn more at introductory blog post by Facebook Research.
The real power of Detectron2 lies in the HUGE amount of pre-trained models available at the Model Zoo. But what good that would it be if you can’t fine-tune those on your own datasets? Fortunately, that’s super easy! We’ll see how it is done in this guide.
At the time of this writing, Detectron2 is still in an alpha stage. While there is an official release, we’ll clone and compile from the master branch. This should equal version 0.1.
Let’s start by installing some requirements:
1!pip install -q cython pyyaml==5.12!pip install -q -U 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'
And download, compile, and install the Detectron2 package:
1!git clone https://github.com/facebookresearch/detectron2 detectron2_repo2!pip install -q -e detectron2_repo
At this point, you’ll need to restart the notebook runtime to continue!
1!pip install -q -U watermark
1%reload_ext watermark2%watermark -v -p numpy,pandas,pycocotools,torch,torchvision,detectron2
1CPython 3.6.92IPython 5.5.034numpy 1.17.55pandas 0.25.36pycocotools 2.07torch 1.4.08torchvision 0.5.09detectron2 0.1
1import torch, torchvision2import detectron23from detectron2.utils.logger import setup_logger4setup_logger()56import glob78import os9import ntpath10import numpy as np11import cv212import random13import itertools14import pandas as pd15from tqdm import tqdm16import urllib17import json18import PIL.Image as Image1920from detectron2 import model_zoo21from detectron2.engine import DefaultPredictor, DefaultTrainer22from detectron2.config import get_cfg23from detectron2.utils.visualizer import Visualizer, ColorMode24from detectron2.data import DatasetCatalog, MetadataCatalog, build_detection_test_loader25from detectron2.evaluation import COCOEvaluator, inference_on_dataset26from detectron2.structures import BoxMode2728import seaborn as sns29from pylab import rcParams30import matplotlib.pyplot as plt31from matplotlib import rc3233%matplotlib inline34%config InlineBackend.figure_format='retina'3536sns.set(style='whitegrid', palette='muted', font_scale=1.2)3738HAPPY_COLORS_PALETTE = ["#01BEFE", "#FFDD00", "#FF7D00", "#FF006D", "#ADFF02", "#8F00FF"]3940sns.set_palette(sns.color_palette(HAPPY_COLORS_PALETTE))4142rcParams['figure.figsize'] = 12, 84344RANDOM_SEED = 4245np.random.seed(RANDOM_SEED)46torch.manual_seed(RANDOM_SEED)
Our dataset is provided by Dataturks, and it is hosted on Kaggle. Here’s an excerpt from the description:
Faces in images marked with bounding boxes. Have around 500 images with around 1100 faces manually tagged via bounding box.
I’ve downloaded the JSON file containing the annotations and uploaded it to Google Drive. Let’s get it:
1!gdown --id 1K79wJgmPTWamqb04Op2GxW0SW9oxw8KS
Let’s load the file into a Pandas dataframe:
1faces_df = pd.read_json('face_detection.json', lines=True)
Each line contains a single face annotation. Note that multiple lines might point to a single image (e.g. multiple faces per image).
The dataset contains only image URLs and annotations. We’ll have to download the images. We’ll also normalize the annotations, so it’s easier to use them with Detectron2 later on:
1os.makedirs("faces", exist_ok=True)23dataset = []45for index, row in tqdm(faces_df.iterrows(), total=faces_df.shape[0]):6 img = urllib.request.urlopen(row["content"])7 img = Image.open(img)8 img = img.convert('RGB')910 image_name = f'face_{index}.jpeg'1112 img.save(f'faces/{image_name}', "JPEG")1314 annotations = row['annotation']15 for an in annotations:1617 data = {}1819 width = an['imageWidth']20 height = an['imageHeight']21 points = an['points']2223 data['file_name'] = image_name24 data['width'] = width25 data['height'] = height2627 data["x_min"] = int(round(points[0]["x"] * width))28 data["y_min"] = int(round(points[0]["y"] * height))29 data["x_max"] = int(round(points[1]["x"] * width))30 data["y_max"] = int(round(points[1]["y"] * height))3132 data['class_name'] = 'face'3334 dataset.append(data)
Let’s put the data into a dataframe so we can have a better look:
1df = pd.DataFrame(dataset)
1print(df.file_name.unique().shape[0], df.shape[0])
1409 1132
We have a total of 409 images (a lot less than the promised 500) and 1132 annotations. Let’s save them to the disk (so you might reuse them):
1df.to_csv('annotations.csv', header=True, index=None)
Let’s see some sample annotated data. We’ll use OpenCV to load an image, add the bounding boxes, and resize it. We’ll define a helper function to do it all:
1def annotate_image(annotations, resize=True):2 file_name = annotations.file_name.to_numpy()[0]3 img = cv2.cvtColor(cv2.imread(f'faces/{file_name}'), cv2.COLOR_BGR2RGB)45 for i, a in annotations.iterrows():6 cv2.rectangle(img, (a.x_min, a.y_min), (a.x_max, a.y_max), (0, 255, 0), 2)78 if not resize:9 return img1011 return cv2.resize(img, (384, 384), interpolation = cv2.INTER_AREA)
Let’s start by showing some annotated images:
1img_df = df[df.file_name == df.file_name.unique()[0]]2img = annotate_image(img_df, resize=False)34plt.imshow(img)5plt.axis('off');
1img_df = df[df.file_name == df.file_name.unique()[1]]2img = annotate_image(img_df, resize=False)34plt.imshow(img)5plt.axis('off');
Those are good ones, the annotations are clearly visible. We can use torchvision to create a grid of images. Note that the images are in various sizes, so we’ll resize them:
1sample_images = [annotate_image(df[df.file_name == f]) for f in df.file_name.unique()[:10]]2sample_images = torch.as_tensor(sample_images)
1sample_images.shape
1torch.Size([10, 384, 384, 3])
1sample_images = sample_images.permute(0, 3, 1, 2)
1sample_images.shape
1torch.Size([10, 3, 384, 384])
1plt.figure(figsize=(24, 12))2grid_img = torchvision.utils.make_grid(sample_images, nrow=5)34plt.imshow(grid_img.permute(1, 2, 0))5plt.axis('off');
You can clearly see that some annotations are missing (column 4). That’s real life data for you, sometimes you have to deal with it in some way.
It is time to go through the steps of fine-tuning a model using a custom dataset. But first, let’s save 5% of the data for testing:
1df = pd.read_csv('annotations.csv')23IMAGES_PATH = f'faces'45unique_files = df.file_name.unique()67train_files = set(np.random.choice(unique_files, int(len(unique_files) * 0.95), replace=False))8train_df = df[df.file_name.isin(train_files)]9test_df = df[~df.file_name.isin(train_files)]
The classical train_test_split won’t work here, cause we want a split amongst the file names.
The next parts are written in a bit more generic way. Obviously, we have a single class - face. But adding more should be as simple as adding more annotations to the dataframe:
1classes = df.class_name.unique().tolist()
Next, we’ll write a function that converts our dataset into a format that is used by Detectron2:
1def create_dataset_dicts(df, classes):2 dataset_dicts = []3 for image_id, img_name in enumerate(df.file_name.unique()):45 record = {}67 image_df = df[df.file_name == img_name]89 file_path = f'{IMAGES_PATH}/{img_name}'10 record["file_name"] = file_path11 record["image_id"] = image_id12 record["height"] = int(image_df.iloc[0].height)13 record["width"] = int(image_df.iloc[0].width)1415 objs = []16 for _, row in image_df.iterrows():1718 xmin = int(row.x_min)19 ymin = int(row.y_min)20 xmax = int(row.x_max)21 ymax = int(row.y_max)2223 poly = [24 (xmin, ymin), (xmax, ymin),25 (xmax, ymax), (xmin, ymax)26 ]27 poly = list(itertools.chain.from_iterable(poly))2829 obj = {30 "bbox": [xmin, ymin, xmax, ymax],31 "bbox_mode": BoxMode.XYXY_ABS,32 "segmentation": [poly],33 "category_id": classes.index(row.class_name),34 "iscrowd": 035 }36 objs.append(obj)3738 record["annotations"] = objs39 dataset_dicts.append(record)40 return dataset_dicts
We convert every annotation row to a single record with a list of annotations. You might also notice that we’re building a polygon that is of the exact same shape as the bounding box. This is required for the image segmentation models in Detectron2.
You’ll have to register your dataset into the dataset and metadata catalogues:
1for d in ["train", "val"]:2 DatasetCatalog.register("faces_" + d, lambda d=d: create_dataset_dicts(train_df if d == "train" else test_df, classes))3 MetadataCatalog.get("faces_" + d).set(thing_classes=classes)45statement_metadata = MetadataCatalog.get("faces_train")
Unfortunately, evaluator for the test set is not included by default. We can easily fix that by writing our own trainer:
1class CocoTrainer(DefaultTrainer):23 @classmethod4 def build_evaluator(cls, cfg, dataset_name, output_folder=None):56 if output_folder is None:7 os.makedirs("coco_eval", exist_ok=True)8 output_folder = "coco_eval"910 return COCOEvaluator(dataset_name, cfg, False, output_folder)
The evaluation results will be stored in the coco_eval
folder if no folder is provided.
Fine-tuning a Detectron2 model is nothing like writing PyTorch code. We’ll load a configuration file, change a few values, and start the training process. But hey, it really helps if you know what you’re doing 😂
For this tutorial, we’ll use the Mask R-CNN X101-FPN model. It is pre-trained on the COCO dataset and achieves very good performance. The downside is that it is slow to train.
Let’s load the config file and the pre-trained model weights:
1cfg = get_cfg()23cfg.merge_from_file(4 model_zoo.get_config_file(5 "COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml"6 )7)89cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(10 "COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml"11)
Specify the datasets (we registered those) we’ll use for training and evaluation:
1cfg.DATASETS.TRAIN = ("faces_train",)2cfg.DATASETS.TEST = ("faces_val",)3cfg.DATALOADER.NUM_WORKERS = 4
And for the optimizer, we’ll do a bit of magic to converge to something nice:
1cfg.SOLVER.IMS_PER_BATCH = 42cfg.SOLVER.BASE_LR = 0.0013cfg.SOLVER.WARMUP_ITERS = 10004cfg.SOLVER.MAX_ITER = 15005cfg.SOLVER.STEPS = (1000, 1500)6cfg.SOLVER.GAMMA = 0.05
Except for the standard stuff (batch size, max number of iterations, and learning rate) we have a couple of interesting params:
WARMUP_ITERS
- the learning rate starts from 0 and goes to the preset one for this number of iterationsSTEPS
- the checkpoints (number of iterations) at which the learning rate will be reduced by GAMMA
Finally, we’ll specify the number of classes and the period at which we’ll evaluate on the test set:
1cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 642cfg.MODEL.ROI_HEADS.NUM_CLASSES = len(classes)34cfg.TEST.EVAL_PERIOD = 500
Time to train, using our custom trainer:
1os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)23trainer = CocoTrainer(cfg)4trainer.resume_or_load(resume=False)5trainer.train()
Evaluating object detection models is a bit different when compared to evaluating standard classification or regression models.
The main metric you need to know about is IoU (intersection over union). It measures the overlap between two boundaries - the predicted and ground truth one. It can get values between 0 and 1.
IoU=area of unionarea of overlapUsing IoU, one can define a threshold (e.g. >0.5) to classify whether a prediction is a true positive (TP) or a false positive (FP).
Now you can calculate average precision (AP) by taking the area under the precision-recall curve.
Now AP@X (e.g. AP50) is just AP at some IoU threshold. This should give you a working understanding of how object detection models are evaluated.
I suggest you read the mAP (mean Average Precision) for Object Detection tutorial by Jonathan Hui if you want to learn more on the topic.
I’ve prepared a pre-trained model for you, so you don’t have to wait for the training to complete. Let’s download it:
1!gdown --id 18Ev2bpdKsBaDufhVKf0cT6RmM3FjW3nL2!mv face_detector.pth output/model_final.pth
We can start making predictions by loading the model and setting a minimum threshold of 85% certainty at which we’ll consider the predictions as correct:
1cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth")2cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.853predictor = DefaultPredictor(cfg)
Let’s run the evaluator with the trained model:
1evaluator = COCOEvaluator("faces_val", cfg, False, output_dir="./output/")2val_loader = build_detection_test_loader(cfg, "faces_val")3inference_on_dataset(trainer.model, val_loader, evaluator)
Next, let’s create a folder and save all images with predicted annotations in the test set:
1os.makedirs("annotated_results", exist_ok=True)23test_image_paths = test_df.file_name.unique()
1for clothing_image in test_image_paths:2 file_path = f'{IMAGES_PATH}/{clothing_image}'3 im = cv2.imread(file_path)4 outputs = predictor(im)5 v = Visualizer(6 im[:, :, ::-1],7 metadata=statement_metadata,8 scale=1.,9 instance_mode=ColorMode.IMAGE10 )11 instances = outputs["instances"].to("cpu")12 instances.remove('pred_masks')13 v = v.draw_instance_predictions(instances)14 result = v.get_image()[:, :, ::-1]15 file_name = ntpath.basename(clothing_image)16 write_res = cv2.imwrite(f'annotated_results/{file_name}', result)
Let’s have a look:
1annotated_images = [f'annotated_results/{f}' for f in test_df.file_name.unique()]
1img = cv2.cvtColor(cv2.imread(annotated_images[0]), cv2.COLOR_BGR2RGB)23plt.imshow(img)4plt.axis('off');
1img = cv2.cvtColor(cv2.imread(annotated_images[1]), cv2.COLOR_BGR2RGB)23plt.imshow(img)4plt.axis('off');
1img = cv2.cvtColor(cv2.imread(annotated_images[3]), cv2.COLOR_BGR2RGB)23plt.imshow(img)4plt.axis('off');
1img = cv2.cvtColor(cv2.imread(annotated_images[4]), cv2.COLOR_BGR2RGB)23plt.imshow(img)4plt.axis('off');
Not bad. Not bad at all. I suggest you explore more images on your own, too!
Note that some faces have multiple bounding boxes (on the second image) with different degrees of certainty. Maybe training the model longer will help? How about adding more or augmenting the existing data?
Congratulations! You now know the basics of Detectron2 for object detection! You might be surprised by the results, given the small dataset we have. That’s the power of large pre-trained models for you 😍
You learned how to:
Share
You'll never get spam from me