Churn prediction with Spark

Churn prediction with Spark


Churn analysis is one of the applications of predictive analysis. It is mostly applied in businesses with contractual customer-supplier relationship where the initial goal is to keep the customers and attract new ones. In churn analysis we measure the unwanted behavior in this type of businesses, which means that we try to identify which customers are the most likely to leave their supplier (to churn) and to identify which factors are possible indicators of such customer decision. In this case, supplier can react according to this information, supplier can contact customers who are more likely to churn offering them special benefits or prices. Knowing which factors caused this behavior supplier can avoid this happening in the future. Suggested measure of performance of a targeting model at classifying whether a customer will churn is lift. Lift measures response within target group (identified by model as customers most likely to churn) against the average response for the population as a whole (target response divided by average response). The model performance is satisfying if the target response is much better than an average response.

With more and more Data Science computing being shifted towards Big Data ecosystem, and Apache Spark being a natural choice as a prime engine for large-scale data processing, we’ve decided to explore the option of building and scoring a CHURN model with Spark Java. Such solution provides scalability, direction towards real time CDR CHURN classification, as well as holds as a proof of capabilities of fast and functional consolidation of CHURN solutions onto the Spark platform.


The model has been developed in Java, using Apache Spark libraries to perform binary classification using logistic regression. As a measure of performance of our model, we have chosen lift.

Installation and configuration for Windows

1.      Compatible Hadoop and Spark versions need to be installed (Example: hadoop-2.7.1 and spark-2.3.0-bin-hadoop2.7)

2.      Environmental variables need to be set: SCALA_HOME, HADOOP_HOME, PATH=%SCALA_HOME%\bin;%PATH%,  set PATH=%HADOOP_HOME%\bin;%PATH%

Java dependency

groupId: org.apache.spark

artifactId: spark-core_2.11

version: 2.3.0


The dataset we are using is in csv format, it shows information about international calls, special number calls, CC calls, complaints by customers… for a given month. The last column shows if a customer decided to churn in next month period, 1 if yes, 0 if no.



For this example, we’ve formatted a set of variables that notably include average prices and duration of international calls, special number calls over two-month period, duration of free number calls, number of mobile phone calls. Other important variables are penalties amount, number of CC calls regarding complaints and CC calls regarding information about services and offers. Some of those values can directly indicate whether a customer is interested in offers and services or if customer is not satisfied with them at all.

Our goal was to train a model which can predict if a customer will churn in next month period, by resulting in high probability. 

Java source code

At the beginning we load data from CSV format into JavaRDD<String> format, and we remove the first line which contains the column names.

String trainingDataFilePath = “data.dsv”;

SparkConf sparkConfig = new SparkConf();



JavaSparkContext sparkContext = new JavaSparkContext(sparkConfig);


JavaRDD<String> data = sparkContext.textFile(trainingDataFilePath);

final String headerRow = data.first();

data = data.filter(item -> !item.equals(headerRow));

Then we convert it into JavaRDD<LabeledPoint> format specifying feature and label columns. 

JavaRDD<LabeledPoint> formattedData =;

    private static Function<String, LabeledPoint> getFunctionToConvertLineToLabelledPoint() {

        return new Function<String, LabeledPoint>() {

            public LabeledPoint call(String line) throws Exception {

                String[] parts = line.split(“;”);

                double BLL_CHG_INTERNATIONAL_AVG2 = Double.parseDouble(parts[0]);

                double BLL_CHG_SPECIAL_AVG2 = Double.parseDouble(parts[1]);

                double BLL_DUR_FREE_MIN = Double.parseDouble(parts[2]);

                double BLL_DUR_SPECIAL_MIN_AVG2 = Double.parseDouble(parts[3]);

                double BLL_NUM_FREE_AVG6 = Double.parseDouble(parts[4]);

                double BLL_NUM_MOBILE = Double.parseDouble(parts[5]);

                double BLL_DUR_INT_MIN_AVG2 = Double.parseDouble(parts[6]);

                double BLL_DUR_MOBILE_MIN = Double.parseDouble(parts[7]);

                double CC_COMPLAINT_AVG36 = Double.parseDouble(parts[8]);

                double CC_INFO_BILL_AVG6 = Double.parseDouble(parts[9]);

                double CC_INFO_OTHER_CNT2 = Double.parseDouble(parts[10]);

                double CC_INFO_SERVICE = Double.parseDouble(parts[11]);

                double CC_OUT_CON = Double.parseDouble(parts[12]);

                double CC_WCOMPLAINT_0_CNT2 = Double.parseDouble(parts[13]);

                double CC_WCOMPLAINT_2ST_AVG36 = Double.parseDouble(parts[14]);

                double CRM_DOWNSELL_CNT = Double.parseDouble(parts[15]);

                double CRM_PENALTIES_AMOUNT = Double.parseDouble(parts[16]);

                double CRM_ACCESS_TYPE_FO = Double.parseDouble(parts[17]);

                double CRM_ACCESS_TYPE_CPS_WLR = Double.parseDouble(parts[18]);

                double CRM_ACCESS_TYPE_LL = Double.parseDouble(parts[19]);

                double CRM_ACCESS_TYPE_CPS = Double.parseDouble(parts[20]);

                double CRM_ACCESS_TYPE_NA = Double.parseDouble(parts[21]);

                double CRM_ACCESS_TYPE_BSA = Double.parseDouble(parts[22]);

                double CRM_ACCESS_TYPE_0 = Double.parseDouble(parts[23]);

                double CRM_CHANNEL_O = Double.parseDouble(parts[24]);

                double CRM_CHANNEL_P = Double.parseDouble(parts[25]);

                double CRM_CHANNEL_PO = Double.parseDouble(parts[26]);

                double CRM_CHANNEL_OT = Double.parseDouble(parts[27]);

                double CRM_CHANNEL_0 = Double.parseDouble(parts[28]);

                double CRM_CHANNEL_DP = Double.parseDouble(parts[29]);

                double CRM_CHANNEL_DS = Double.parseDouble(parts[30]);

                double label = Double.parseDouble(parts[31]);

                Vector featureVector = Vectors.dense(new double[]{BLL_CHG_INTERNATIONAL_AVG2, BLL_CHG_SPECIAL_AVG2, BLL_DUR_FREE_MIN, BLL_DUR_SPECIAL_MIN_AVG2,







                return new LabeledPoint(label, featureVector);





We split the data into train and cross validation data and setup the logistic regression classifier.

JavaRDD<LabeledPoint>[] splits = formattedData.randomSplit(new double[]{0.7, 0.3}, SPLIT_SEED);

JavaRDD<LabeledPoint> trainingData = splits[0];


JavaRDD<LabeledPoint> crossValidationData = splits[1];

final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() .setNumClasses(2) .run(trainingData.rdd());

Then we remove default threshold in order to get raw probabilities.


We calculate probabilities on the cross validation set. Result format is JavaRDD<Tuple2<Object, Object>>, each Tuple2<Object, Object> containing estimated probability and CHURNUM column value.

JavaRDD<Tuple2<Object, Object>> predictionAndLabels =

        new Function<LabeledPoint, Tuple2<Object, Object>>() {


    public Tuple2<Object, Object> call(LabeledPoint p) {

        Double prediction = model.predict(p.features());

        return new Tuple2<Object, Object>(prediction, p.label());




We sort the result by probability values in descending order.

JavaRDD<Tuple2<Object, Object>> predictionAndLabelsSorted = predictionAndLabels.sortBy(new Function<Tuple2<Object, Object>, Double>() {

            private static final long serialVersionUID = 1L;


            public Double call(Tuple2<Object, Object> value) throws Exception {

                return (Double) value._1;


        }, false, 1);

We split the data into 100 sections (percentiles) and calculate lift for each section.

CalculateCumulativeLift_100Quantiles(predictionAndLabelsSorted, true, “cv”);

static private void CalculateCumulativeLift_100Quantiles(JavaRDD<Tuple2<Object, Object>> testPredictionAndLabelsSorted, boolean csvW, String dataName) throws IOException {

        double lift = 0.0, averageChurn, modelAverageChurn;

        int size, selSize, liftNum = 0, tempSize;

        ArrayList<Double> cLift = new ArrayList<Double>();

        size = testPredictionAndLabelsSorted.collect().size();

        averageChurn = (double) (testPredictionAndLabelsSorted.filter(pAndL -> pAndL._2$mcD$sp() == 1.0).count()) / (double) (testPredictionAndLabelsSorted.collect().size());

        selSize = (int) (0.01 * size);

        for (int i = 0; i < 99; i++) {

            ArrayList<Tuple2<Object, Object>> tempTestPredictionAndLabelsSorted = new ArrayList<Tuple2<Object, Object>>(testPredictionAndLabelsSorted.collect().subList(0, (i + 1) * selSize));

            tempTestPredictionAndLabelsSorted.removeIf(pAndL -> !(pAndL._2$mcD$sp() == 1.0));

            modelAverageChurn = (double) tempTestPredictionAndLabelsSorted.size() / (double) ((i + 1) * selSize);

            lift = modelAverageChurn / averageChurn;




        ArrayList<Tuple2<Object, Object>> tempTestPredictionAndLabelsSorted = new ArrayList<Tuple2<Object, Object>>(testPredictionAndLabelsSorted.collect().subList(0, size));

        tempTestPredictionAndLabelsSorted.removeIf(pAndL -> !(pAndL._2$mcD$sp() == 1.0));

        modelAverageChurn = (double) tempTestPredictionAndLabelsSorted.size() / (double) (size);

        lift = modelAverageChurn / averageChurn;


        System.out.println(“Cumulative lift: ” + cLift.toString());

        if (csvW) {

            String fileName = dataName + “_100Quantiles.csv”;




We split the data into 100 sections (percentiles) and calculate gain for each section.

CalculateGain_100Quantiles(predictionAndLabelsSorted, true, “cv”);

static private void CalculateGain_100Quantiles(JavaRDD<Tuple2<Object, Object>> testPredictionAndLabelsSorted, boolean csvW, String dataName) throws IOException {

        double  modelChurn;

        int size, selSize, liftNum = 0, tempSize;

        long churnNumber;

        ArrayList<Double> cGain = new ArrayList<Double>();

        size = testPredictionAndLabelsSorted.collect().size();

        churnNumber = testPredictionAndLabelsSorted.filter(pAndL -> pAndL._2$mcD$sp() == 1.0).count();


        selSize = (int) (0.01 * size);

        for (int i = 0; i < 99; i++) {

            ArrayList<Tuple2<Object, Object>> tempTestPredictionAndLabelsSorted = new ArrayList<Tuple2<Object, Object>>(testPredictionAndLabelsSorted.collect().subList(0, (i + 1) * selSize));

            tempTestPredictionAndLabelsSorted.removeIf(pAndL -> !(pAndL._2$mcD$sp() == 1.0));

            modelChurn = ((double) tempTestPredictionAndLabelsSorted.size() / (double) (churnNumber))*100;



        ArrayList<Tuple2<Object, Object>> tempTestPredictionAndLabelsSorted = new ArrayList<Tuple2<Object, Object>>(testPredictionAndLabelsSorted.collect().subList(0, size));

        tempTestPredictionAndLabelsSorted.removeIf(pAndL -> !(pAndL._2$mcD$sp() == 1.0));

        modelChurn = ((double) tempTestPredictionAndLabelsSorted.size() / (double) (churnNumber))*100;


        System.out.println(“Gain: ” + cGain.toString());

        if (csvW) {

            String fileName = dataName + “_100Quantiles_gain.csv”;





Cross validation set



The results of this model are satisfying, lift for the first percentile is 6.1, which means that by taking 1% (first percentile) of customers based on the model it can be expected 6.1 times the total number of targets taken by randomly selecting 1% of customers. Gain of 32% for top 10 percentiles says that 32% of churners are detected within 10% target customers.


Lift of 5.5 for top 3 percentiles shows that the telecom company can expect 5.5 times number of churners taken by randomly selecting 3% of customers, which can be helpful when deciding which customers should be contacted and offered better prices.



Gain at each percentile is cumulative number of targets up to that percentile divided by the total number of targets.

 Churn roc build

Area under ROC = 0.7356327255193738

Test set

Results on an independent test are also satisfying. Lift for the top percentile is 9.5, lift for the top 3 percentiles is 5.  Gain of 35% for top 10 percentiles says that 35% of churners are detected within 10% target customers.

 Churn_3 Churn_5 Churn_4

Area under ROC = 0.7261240728177258


The resulting predictive model is as expected on par with the similar models built trough R, Python or data mining suites such as Oracle Data Miner.  Upon building such model, we’re free to store it, as well as build, test and out of time test data in HDFS, schedule our scoring runs through Oozzie workflow and monitor them through Spark web UI.


Writen by:

Jasmina Redžić