Predicting churn with the SPSS random tree algorithm

This shows you how to create a predictive model of churn rate by using IBM SPSS Algorithm on Apache Spark version 2.0. You'll learn how to create an SPSS random tree model by using the IBM SPSS Machine Learning API, and how to view the model with IBM SPSS Model Viewer.

Because it consists of multiple classification and regression trees (CART), you can use random tree algorithms to generate accurate predictive models and solve complex classification and regression problems. Each tree develops from a bootstrap sample that is produced by resampling the original data points with replacement data. During the resampling phase, the best split variable is selected for each node from a specified smaller number of variables that are drawn randomly from the full set of variables. Each tree grows without pruning and then, during the scoring phase, the random tree algorithm aggregates tree scores by majority voting (for classification) or average (for regression).

In this notebook, you'll create a model with telecommunications data to predict when its customers will leave for a competitor, so that you can take some action to retain the customer.

To get the most out of this notebook, you should have some familiarity with the Scala programming language.

This notebooks runs on Scala 2.11 with Spark 2.0. Some familiarity with Scala is recommended.


This notebook contains the following main sections:

  1. Load the Telco Churn data to the cloud data repository.
  2. Prepare the data.
  3. Configure the RandomTrees model.
  4. View the model.
  5. Summary and next steps.

1. Load the Telco Churn data to the cloud data repository

Telco Churn is a hypothetical data file that concerns a telecommunications company's efforts to reduce turnover in its customer base. Each case corresponds to a separate customer and it records various demographic and service usage information. Before you can work with the data, you must use the URL to get the telco.csv and telco_Feb.csv files from the GitHub repository.

In [1]:
val link_telco = ""

import sys.process._
new URL(link_telco) #> new File("telco.csv") !!

val link_telco_Feb = ""

import sys.process._
new URL(link_telco_Feb) #> new File("telco_Feb.csv") !!

2. Prepare the data

After uploading the CSV files that contain the data, you must create a SQLContext, put the data from the telco.scv file into a Spark DataFrame, and show the first row in the DataFrame.

In [2]:
val sqlContext = new org.apache.spark.sql.SQLContext(sc)

val dfTelco = sqlContext.
    option("header", "true").
    option("inferschema", "true").
|region|tenure|age|marital|address|income| ed|employ|retire|gender|reside|tollfree|equip|callcard|wireless|longmon|tollmon|equipmon|cardmon|wiremon|longten|tollten|equipten|cardten|wireten|multline|voice|pager|internet|callid|callwait|forward|confer|ebill|         loglong|logtoll|logequi|         logcard|logwire|           lninc|custcat|churn|
|     2|    13| 44|      1|      9|    64|  4|     5|     0|     0|     2|       0|    0|       1|       0|    3.7|    0.0|     0.0|    7.5|    0.0|  37.45|    0.0|     0.0|  110.0|    0.0|       0|    0|    0|       0|     0|       0|      1|     0|    0|1.30833281965018|       |       |2.01490302054226|       |4.15888308335967|      1|    1|
only showing top 1 row

Review the data. Print the schema of the DataFrame to look at what kind of data you have.

In [3]:
 |-- region: integer (nullable = true)
 |-- tenure: integer (nullable = true)
 |-- age: integer (nullable = true)
 |-- marital: integer (nullable = true)
 |-- address: integer (nullable = true)
 |-- income: integer (nullable = true)
 |-- ed: integer (nullable = true)
 |-- employ: integer (nullable = true)
 |-- retire: integer (nullable = true)
 |-- gender: integer (nullable = true)
 |-- reside: integer (nullable = true)
 |-- tollfree: integer (nullable = true)
 |-- equip: integer (nullable = true)
 |-- callcard: integer (nullable = true)
 |-- wireless: integer (nullable = true)
 |-- longmon: double (nullable = true)
 |-- tollmon: double (nullable = true)
 |-- equipmon: double (nullable = true)
 |-- cardmon: double (nullable = true)
 |-- wiremon: double (nullable = true)
 |-- longten: double (nullable = true)
 |-- tollten: double (nullable = true)
 |-- equipten: double (nullable = true)
 |-- cardten: double (nullable = true)
 |-- wireten: double (nullable = true)
 |-- multline: integer (nullable = true)
 |-- voice: integer (nullable = true)
 |-- pager: integer (nullable = true)
 |-- internet: integer (nullable = true)
 |-- callid: integer (nullable = true)
 |-- callwait: integer (nullable = true)
 |-- forward: integer (nullable = true)
 |-- confer: integer (nullable = true)
 |-- ebill: integer (nullable = true)
 |-- loglong: double (nullable = true)
 |-- logtoll: string (nullable = true)
 |-- logequi: string (nullable = true)
 |-- logcard: string (nullable = true)
 |-- logwire: string (nullable = true)
 |-- lninc: double (nullable = true)
 |-- custcat: integer (nullable = true)
 |-- churn: integer (nullable = true)

Create a DataFrame for the telco_Feb.csv data. You'll use this year's data to build the model, and use the February data for accuracy value.

In [4]:
val dfTelcoFeb = sqlContext.
    option("header", "true").
    option("inferschema", "true").

3. Configure the RandomTrees model

By running this portion of the code, you create the random trees estimator, import the libraries, and set the ordinal and nominal variables. Because no inputFieldList value is set, all fields except the target, frequency, and analysis weight fields are treated as input fields. To make the random tree model build faster, set the number of trees to 10 instead of the default value, which is 100. Finally, you must specify the churn target field.

In [5]:

val ordinal = Array("ed")
val nominal = Array("region",
val srf = RandomTrees().setTargetField("churn").setNumTrees(10)
val srfModel =,true).setOrdinalMeasure(ordinal,true))

Do the prediction and get your results.

In [6]:
val predResult = srfModel.transform(dfTelcoFeb)
val predResultNew = predResult.withColumn("prediction", predResult("prediction").cast("double")).
    withColumn("churn", predResult("churn").cast("double"))

To get the accuracy result, use the Apache Spark MulticlassClassificationEvaluator function. Notice that the accuracy is above 90%.

In [7]:
val evaluator = new MulticlassClassificationEvaluator().setLabelCol("churn").setMetricName("accuracy")
val acc_result = evaluator.evaluate(predResultNew)

4. View the model

View the model with the SPSS Model Viewer. The visualization for the generalized linear model includes a confusion matrix, a table with top decision rules, and a table and chart of predictor importance.

4.1 Generate a project token

Before you can run the model viewer, you need to generate a project token

  1. In the My Projects banner, click the More icon and then click Insert project token. The project token is inserted into the first cell of the notebook, before the title.
  2. Copy the text, which appears at the beginning of the notebook, into the following cell and run it.
In [8]:

4.2 Start the model viewer

Run the code in the following cell to start SPSS Model Viewer, where you can see a visualization and see model statistics and other characteristics.

In [9]:
Model Visualization
Random Trees
The Random Trees algorithm is a sophisticated modern approach to supervised learning for categorical or continuous targets. The algorithm uses groups of classification or regression trees and randomness to make predictions that are particularly robust when applied to new observations. The IBM SPSS Spark Machine Learning Library implementation features a table of top decision rules for classification models without imbalance handling and measures of relative predictor importance for all models.

For more information, visit the Random trees page on the Data Science Experience web site.


Target Field churn
Model Building Method Random Trees Classification
Number of Predictors Input 36
Model Accuracy 0.695
Misclassification Rate 0.305

4.3. Export the XML files (PMML, StatXML) for other detail statistics

By exporting your results to different formats, such as Predictive Model Markup Language (PMML) or statXML format you can share your statistical analyses outside of IBM Data Science Experience.

In [10]:
import{File, PrintWriter}

val statXML = srfModel.statXML()
new PrintWriter("StatXML.xml") {

Summary and next steps

You have created a predictive model of churn rate by using IBM SPSS Algorithm on Apache Spark. Now you can create a different model to compare model evaluations, such as the test of model effects, residuals, and so on. See SPSS documentation.


Wang Zhiyuan and Yu Wenpei are SPSS Algorithm Engineers at IBM.

Copyright © 2017 IBM. This notebook and its source code are released under the terms of the MIT License.