Skip to content


Customer churn prediction using Neural Networks with TensorFlow.js | Deep Learning for JavaScript Hackers (Part IV)

Neural Networks, Deep Learning, TensorFlow, Machine Learning, JavaScript6 min read


TL;DR Learn about Deep Learning and create Deep Neural Network model to predict customer churn using TensorFlow.js. Learn how to preprocess string categorical data.

First day! You’ve landed this Data Scientist intern job at a large telecom company. You can’t stop dreaming about the Lambos and designer clothes you’re going to get once you’re a Senior Data Scientist.

Even your mom is calling to remind you to put your Ph.D. in Statistics diploma on the wall. This is the life, who cares about that you’re in your mid-30s and this is your first job ever.

Your team lead comes around, asking how do you enjoy the job and saying that he might have a task for you! You start imagining implementing complex statistical models from scratch, doing research, and adding cutting-edge methods but… Well, the reality is slightly different. He sent you a link to a CSV fail and asks you to predict customer churn. He suggests that you might to try to apply Deep Learning to the problem.

Your dream is starting now. Time do to some work!

Run the complete source code for this tutorial right in your browser:

Source code on GitHub

Customer churn data

Our dataset Telco Customer Churn comes from Kaggle.

“Predict behavior to retain customers. You can analyze all relevant customer data and develop focused customer retention programs.” [IBM Sample Data Sets]

The data set includes information about:

  • Customers who left within the last month – the column is called Churn
  • Services that each customer has signed up for – phone, multiple lines, internet, online security, online backup, device protection, tech support, and streaming TV and movies
  • Customer account information – how long they’ve been a customer, contract, payment method, paperless billing, monthly charges, and total charges
  • Demographic info about customers – gender, age range, and if they have partners and dependents

It has 7,044 examples and 21 variables:

  • customerID: Customer ID
  • gender: Whether the customer is a male or a female
  • SeniorCitizen: Whether the customer is a senior citizen or not (1, 0)
  • Partner: Whether the customer has a partner or not (Yes, No)
  • Dependents: Whether the customer has dependents or not (Yes, No)
  • tenure: Number of months the customer has stayed with the company
  • PhoneService: Whether the customer has a phone service or not (Yes, No)
  • MultipleLines: Whether the customer has multiple lines or not (Yes, No, No phone service)
  • InternetService: Customer’s internet service provider (DSL, Fiber optic, No)
  • OnlineSecurity: Whether the customer has online security or not (Yes, No, No internet service)
  • OnlineBackup: Whether the customer has online backup or not (Yes, No, No internet service)
  • DeviceProtection: Whether the customer has device protection or not (Yes, No, No internet service)
  • TechSupport: Whether the customer has tech support or not (Yes, No, No internet service)
  • StreamingTV: Whether the customer has streaming TV or not (Yes, No, No internet service)
  • StreamingMovies: Whether the customer has streaming movies or not (Yes, No, No internet service)
  • Contract: The contract term of the customer (Month-to-month, One year, Two year)
  • PaperlessBilling: Whether the customer has paperless billing or not (Yes, No)
  • PaymentMethod: The customer’s payment method (Electronic check, Mailed check, Bank transfer (automatic), Credit card (automatic))
  • MonthlyCharges: The amount charged to the customer monthly
  • TotalCharges: The total amount charged to the customer
  • Churn: Whether the customer churned or not (Yes or No)

We’ll use Papa Parse to load the data:

1const prepareData = async () => {
2 const csv = await Papa.parsePromise(
3 ""
4 )
6 const data =
7 return data.slice(0, data.length - 1)

Note that we ignore the last row since it is empty.


Let’s get a feeling of our dataset. How many of the customers churned?


About 74% of the customers are still using the company services. We have a very unbalanced dataset.

Does gender play a role in losing customers?

sex churn

Seems like it doesn’t. We have about the same amount of female and male customers. How about seniority?

senior churn

About 20% of the customers are senior, and they are much more likely to churn, compared to the nonseniors.

For how long customers stay with the company?


Seems like the more you stay, the more likely you’re to stay with the company.

How do monthly charges affect churn?

monthly charges

A customer with low monthly charges (< $30) produce is much more likely to be retained.

How about the total amount charged per customer?

total charges

The higher the total amount charged by the company, the more likely it is for this customer to be retained.

Our dataset has a total of 21 features, and we didn’t look through all of those. However, we found some interesting stuff.

We’ve learned that SeniorCitizen, tenure, MonthlyCharges, and TotalCharges are somewhat correlated with the churn status. We’ll use them for our model!

Deep Learning

Deep Learning is a subset of Machine Learning, using Deep Artificial Neural Networks as a primary model to solve a variety of tasks.

To obtain a Deep Neural Network, take a Neural Network with one hidden layer (shallow Neural Network) and add more layers. That’s the definition of a Deep Neural Network - Neural Network with more than one hidden layer!

In Deep Neural Networks, each layer of neurons is trained on the features/outputs of the previous layer. Thus, you can create a feature hierarchy of increasing abstraction and learn complex concepts.

These networks are very good at discovering patterns within raw data (images, texts, video, and audio recordings), which is the most amounts of data we have. For example, Deep Learning can take millions of images and categorize them into photos of your grandma, funny cats, and delicious cakes.

Deep Neural Nets are holding state-of-the-art scores on a variety of important problems. Examples are image recognition, image segmentation, sound recognition, recommender systems, natural language processing, etc.

So basically, Deep Learning is Large Neural Networks. Why now? Why Deep Learning wasn’t practical before?

  • Most real-world applications of Deep Learning require large amounts of labeled data: developing a driverless car might require thousands of hours of video.

  • Training models with large amounts of parameters (weights) requires substantial computing power: special purpose hardware in the form of GPUs and TPUs offers massively parallel computations, suitable for Deep Learning.

  • Big companies have been storing your data for a while now: they want to monetize it.

  • We learned (kinda) how to initialize the weights of the neurons in the Neural Network models: mostly using small random values

  • We have better regularization techniques (e.g. Dropout)

Last but not least, we have software that is performant and (sometimes) easy to use. Libraries like TensorFlow, PyTorch, MXNet and Chainer allows practitioners to develop, analyze, test and deploy models of varying complexity and reuse work done by other practitioners and researchers.

Predicting customer churn

Let’s use the “all-powerful” Deep Learning machinery to predict which customers are going to churn. First, we need to do some data preprocessing since a lot of the features are categorical.

Data preprocessing

We’ll use all numerical (except customerID) and the following categorical features:

1const categoricalFeatures = new Set([
2 "TechSupport",
3 "Contract",
4 "PaymentMethod",
5 "gender",
6 "Partner",
7 "InternetService",
8 "Dependents",
9 "PhoneService",
10 "TechSupport",
11 "StreamingTV",
12 "PaperlessBilling",

Let’s create training and testing datasets from our data:

1const [xTrain, xTest, yTrain, yTest] = toTensors(data, categoricalFeatures, 0.1)

Here’s how we create our Tensors:

1const toTensors = (data, categoricalFeatures, testSize) => {
2 const categoricalData = {}
3 categoricalFeatures.forEach(f => {
4 categoricalData[f] = toCategorical(data, f)
5 })
7 const features = [
8 "SeniorCitizen",
9 "tenure",
10 "MonthlyCharges",
11 "TotalCharges",
12 ].concat(Array.from(categoricalFeatures))
14 const X =, i) =>
15 features.flatMap(f => {
16 if (categoricalFeatures.has(f)) {
17 return categoricalData[f][i]
18 }
20 return r[f]
21 })
22 )
24 const X_t = normalize(tf.tensor2d(X))
26 const y = tf.tensor(toCategorical(data, "Churn"))
28 const splitIdx = parseInt((1 - testSize) * data.length, 10)
30 const [xTrain, xTest] = tf.split(X_t, [splitIdx, data.length - splitIdx])
31 const [yTrain, yTest] = tf.split(y, [splitIdx, data.length - splitIdx])
33 return [xTrain, xTest, yTrain, yTest]

First, we use the function toCategorical() to convert categorical features into one-hot encoded vectors. We do that by converting the string values into numbers and use tf.oneHot() to create the vectors.

We create a 2-dimensional Tensor from our features (categorical and numerical) and normalize it. Another, one-hot encoded, Tensor is made from the Churn column.

Finally, we split the data into training and testing datasets and return the results. How do we encode categorical variables?

1const toCategorical = (data, column) => {
2 const values = => r[column])
3 const uniqueValues = new Set(values)
5 const mapping = {}
7 Array.from(uniqueValues).forEach((i, v) => {
8 mapping[i] = v
9 })
11 const encoded = values
12 .map(v => {
13 if (!v) {
14 return 0
15 }
16 return mapping[v]
17 })
18 .map(v => oneHot(v, uniqueValues.size))
20 return encoded

First, we extract a vector of all values for the feature. Next, we obtain the unique values and create a string to int mapping from it.

Note that we check for missing values and encode those as 0. Finally, we one-hot encode each value.

Here are the remaining utility functions:

1// normalized = (value − min_value) / (max_value − min_value)
2const normalize = tensor =>
3 tf.div(tf.sub(tensor, tf.min(tensor)), tf.sub(tf.max(tensor), tf.min(tensor)))
5const oneHot = (val, categoryCount) =>
6 Array.from(tf.oneHot(val, categoryCount).dataSync())

Building a Deep Neural Network

We’ll wrap the building and training of our model into a function called trainModel():

1const trainModel = async (xTrain, yTrain) => {
2 ...

Let’s create a Deep Neural Network using the sequential model API in TensorFlow:

1const model = tf.sequential()
3 tf.layers.dense({
4 units: 32,
5 activation: "relu",
6 inputShape: [xTrain.shape[1]],
7 })
11 tf.layers.dense({
12 units: 64,
13 activation: "relu",
14 })
17model.add(tf.layers.dense({ units: 2, activation: "softmax" }))

Our Deep Neural Network has two hidden layers with 32 and 64 neurons, respectively. Each layer has a ReLU activation function.

Time to compile our model:

2 optimizer: tf.train.adam(0.001),
3 loss: "binaryCrossentropy",
4 metrics: ["accuracy"],

We’ll train our model using the Adam optimizer and measure our error using Binary Crossentropy.


Finally, we’ll pass the training data to the fit method of our model and train for 100 epochs, shuffle the data, and use 10% of it for validation. We’ll visualize the training progress using tfjs-vis:

1const lossContainer = document.getElementById("loss-cont")
3await, yTrain, {
4 batchSize: 32,
5 epochs: 100,
6 shuffle: true,
7 validationSplit: 0.1,
8 callbacks:
9 lossContainer,
10 ["loss", "val_loss", "acc", "val_acc"],
11 {
12 callbacks: ["onEpochEnd"],
13 }
14 ),

Let’s train our model:

1const model = await trainModel(xTrain, yTrain)



It seems like our model is learning during the first ten epochs and plateaus after that.


Let’s evaluate our model on the test data:

1const result = model.evaluate(xTest, yTest, {
2 batchSize: 32,
5// loss
8// accuracy
2 0.44808024168014526
4 0.7929078340530396

The model has an accuracy of 79.2% on the test data. Let’s have a look at what kind of mistakes it makes using the confusion matrix:

1const preds = model.predict(xTest).argMax(-1)
2const labels = yTest.argMax(-1)
3const confusionMatrix = await tfvis.metrics.confusionMatrix(labels, preds)
4const container = document.getElementById("confusion-matrix")
5tfvis.render.confusionMatrix(container, {
6 values: confusionMatrix,
7 tickLabels: ["Retained", "Churned"],

confusion matrix

It seems like our model is overconfident in predicting retained customers. Depending on your needs, you might try to tune the model, and predict retained customers better.


Great job! You just built a Deep Neural Network that predicts customer churn with ~80% accuracy. Here’s what you’ve learned:

  • What is Deep Learning
  • What is the difference between shallow Neural Networks and Deep Neural Networks
  • Preprocess string categorical data
  • Build and evaluate a Deep Neural Network in TensorFlow.js

But can it be that Deep Learning is even more powerful? So powerful that it can understand images?

Run the complete source code for this tutorial right in your browser:

Source code on GitHub



Want to be a Machine Learning expert?

Join the weekly newsletter on Data Science, Deep Learning and Machine Learning in your inbox, curated by me! Chosen by 10,000+ Machine Learning practitioners. (There might be some exclusive content, too!)

You'll never get spam from me