— Deep Learning, NLP, REST, Machine Learning, Deployment, Sentiment Analysis, Python — 3 min read
Share
TL;DR Learn how to create a REST API for Sentiment Analysis using a pre-trained BERT model
Getting Things Done with Pytorch
on GitHubIn this tutorial, you’ll learn how to deploy a pre-trained BERT model as a REST API using FastAPI. Here are the steps:
We’ll manage our dependencies using Pipenv. Here’s the complete Pipfile:
1[[source]]2name = "pypi"3url = "https://pypi.org/simple"4verify_ssl = true56[dev-packages]7black = "==19.10b0"8isort = "*"9flake8 = "*"10gdown = "*"1112[packages]13fastapi = "*"14uvicorn = "*"15pydantic = "*"16torch = "*"17transformers = "*"1819[requires]20python_version = "3.8"2122[pipenv]23allow_prereleases = true
The backbone of our REST API will be:
Some tools will help us write some better code (thanks to Momchil Hardalov for the configs):
Let’s start by creating a skeleton structure for our project. Your directory should look like this:
1.2├── Pipfile3├── Pipfile.lock4└── sentiment_analyzer5 ├── api.py
We’ll start by creating a dummy/stubbed response to test that everything is working end-to-end. Here are the contents of api.py
:
1from typing import Dict23from fastapi import Depends, FastAPI4from pydantic import BaseModel56app = FastAPI()789class SentimentRequest(BaseModel):10 text: str111213class SentimentResponse(BaseModel):1415 probabilities: Dict[str, float]16 sentiment: str17 confidence: float181920@app.post("/predict", response_model=SentimentResponse)21def predict(request: SentimentRequest):22 return SentimentResponse(23 sentiment="positive",24 confidence=0.98,25 probabilities=dict(negative=0.005, neutral=0.015, positive=0.98)26 )
Our API expects a text - the review for sentiment analysis. The response contains the sentiment, confidence (softmax output for the sentiment) and all probabilities for each sentiment.
Here’s the file structure of the complete project:
1.2├── assets3│ └── model_state_dict.bin4├── bin5│ └── download_model6├── config.json7├── Pipfile8├── Pipfile.lock9└── sentiment_analyzer10 ├── api.py11 ├── classifier12 │ ├── model.py13 │ └── sentiment_classifier.py
We’ll need the pre-trained model. We’ll write the download_model
script for that:
1#!/usr/bin/env python2import gdown34gdown.download(5 "https://drive.google.com/uc?id=1V8itWtowCYnb2Bc9KlK9SxGff9WwmogA",6 "assets/model_state_dict.bin",7)
The model can be downloaded from my Google Drive. Let’s get it:
1python bin/download_model
Our pre-trained model is stored as a PyTorch state dict. We need to load it and use it to predict the text sentiment.
Let’s start with the config file config.json
:
1{2 "BERT_MODEL": "bert-base-cased",3 "PRE_TRAINED_MODEL": "assets/model_state_dict.bin",4 "CLASS_NAMES": [5 "negative",6 "neutral",7 "positive"8 ],9 "MAX_SEQUENCE_LEN": 16010}
Next, we’ll define the sentiment_classifier.py
:
1import json23from torch import nn4from transformers import BertModel56with open("config.json") as json_file:7 config = json.load(json_file)8910class SentimentClassifier(nn.Module):11 def __init__(self, n_classes):12 super(SentimentClassifier, self).__init__()13 self.bert = BertModel.from_pretrained(config["BERT_MODEL"])14 self.drop = nn.Dropout(p=0.3)15 self.out = nn.Linear(self.bert.config.hidden_size, n_classes)1617 def forward(self, input_ids, attention_mask):18 _, pooled_output = self.bert(input_ids=input_ids, attention_mask=attention_mask)19 output = self.drop(pooled_output)20 return self.out(output)
This is the same model we’ve used for training. It just uses the config file.
Recall that BERT requires some special text preprocessing. We need a place to use the tokenizer from Hugging Face. We also need to do some massaging of the model outputs to convert them to our API response format.
The Model
provides a nice abstraction (a Facade) to our classifier. It exposes a single predict()
method and should be pretty generalizable if you want to use the same project structure as a template for your next deployment. The model.py
file:
1import json23import torch4import torch.nn.functional as F5from transformers import BertTokenizer67from .sentiment_classifier import SentimentClassifier89with open("config.json") as json_file:10 config = json.load(json_file)111213class Model:14 def __init__(self):1516 self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")1718 self.tokenizer = BertTokenizer.from_pretrained(config["BERT_MODEL"])1920 classifier = SentimentClassifier(len(config["CLASS_NAMES"]))21 classifier.load_state_dict(22 torch.load(config["PRE_TRAINED_MODEL"], map_location=self.device)23 )24 classifier = classifier.eval()25 self.classifier = classifier.to(self.device)2627 def predict(self, text):28 encoded_text = self.tokenizer.encode_plus(29 text,30 max_length=config["MAX_SEQUENCE_LEN"],31 add_special_tokens=True,32 return_token_type_ids=False,33 pad_to_max_length=True,34 return_attention_mask=True,35 return_tensors="pt",36 )37 input_ids = encoded_text["input_ids"].to(self.device)38 attention_mask = encoded_text["attention_mask"].to(self.device)3940 with torch.no_grad():41 probabilities = F.softmax(self.classifier(input_ids, attention_mask), dim=1)42 confidence, predicted_class = torch.max(probabilities, dim=1)43 predicted_class = predicted_class.cpu().item()44 probabilities = probabilities.flatten().cpu().numpy().tolist()45 return (46 config["CLASS_NAMES"][predicted_class],47 confidence,48 dict(zip(config["CLASS_NAMES"], probabilities)),49 )505152model = Model()535455def get_model():56 return model
We’ll do the inference on the GPU, if one is available. We return the name of the predicted sentiment, the confidence, and the probabilities for each sentiment.
But why don’t we define all that logic in our request handler function? For this tutorial, this is an example of overengeneering. But in the real world, when you start testing your implementation, this will be such a nice bonus.
You see, mixing everything in the request handler logic will result in countless sleepless nights. When shit hits the fan (and it will) you’ll wonder if your REST or model code is wrong. This way allows you to test them, separately.
The get_model()
function ensures that we have a single instance of our Model (Singleton). We’ll use it in our API handler.
Our request handler needs access to the model to return a prediction. We’ll use the Dependency Injection framework provided by FastAPI to inject our model. Here’s the new predict
function:
1@app.post("/predict", response_model=SentimentResponse)2def predict(request: SentimentRequest, model: Model = Depends(get_model)):3 sentiment, confidence, probabilities = model.predict(request.text)4 return SentimentResponse(5 sentiment=sentiment, confidence=confidence, probabilities=probabilities6 )
The model gets injected by Depends
and our Singleton function get_model
. You can really appreciate the power of abstraction by looking at this!
But does it work?
Let’s fire up the server:
1uvicorn sentiment_analyzer.api:app
This should take a couple of seconds to load everything and start the HTTP server.
1http POST http://localhost:8000/predict text="This app is a total waste of time!"
Here’s the response:
1{2 "confidence": 0.999885082244873,3 "probabilities": {4 "negative": 0.999885082244873,5 "neutral": 8.876612992025912e-05,6 "positive": 2.614063305372838e-057 },8 "sentiment": "negative"9}
Let’s try with a positive one:
1http POST http://localhost:8000/predict text="OMG. I love how easy it is to stick to my schedule. Would recommend to everyone!"
1{2 "confidence": 0.999932050704956,3 "probabilities": {4 "negative": 1.834999602579046e-05,5 "neutral": 4.956663542543538e-05,6 "positive": 0.9999320507049567 },8 "sentiment": "positive"9}
Both results are on point. Feel free to tryout with some real reviews from the Play Store.
You should now be a proud owner of ready to deploy (kind of) Sentiment Analysis REST API using BERT. Of course, you’re missing lots of stuff to be production-ready - logging, monitoring, alerting, containerization, and much more. But hey, you did good!
Getting Things Done with Pytorch
on GitHubYou learned how to:
Go on then, deploy and make your users happy!
Share
You'll never get spam from me