Unit testing Spark Scala code

Published May 16, 2019

Unit tests. They’re a good thing. I use them even in single-person projects, because I like being able to double-check my own logic, and because it’s less effort to run a couple tests than to remember the way my code interacts with my other code every time I make a change.

I actually test pipeline code, too. I know this is a little uncommon – some people believe that if there is a mistake in their code, they will know because their pipeline will fail. I don’t follow this logic: one, not all mistakes cause pipeline failures. What if you have a subtle math error in your pipeline that causes you to miscalculate, say, your customers’ clickthrough rates? The pipeline infrastructure does not care that your code does not reflect your intentions correctly; it will do what you told it to do.

Two, pipeline failures can be pretty expensive themselves. An hour of one relatively wimpy instance on Amazon EMR (say c3.2xlarge, the smallest type currently available in Frankfurt) only comes to about $0.63, but most pipelines will need more and bigger instances and run for several hours. And then you’ll have to debug your pipeline, probably under time pressure because it needs to run because something else needs the data. Surely it’s better all around to test before you deploy.

Recently I’ve been working on my Scala skills after years of mostly C++ and Python, and thought I’d write up what I found out about testing Spark pipeline code written in Scala. Where I’ve found good writeups, I’ve linked to them, and tried what they recommended for myself. If you notice that my Scala style is less than idiomatic, this is because I’m still figuring out what works best — if you can spare the time, please drop me a note with your improvements!

Testing styles

Scala has a unit testing framework called scalatest. It gives you a choice of testing styles ranging from FunSuite, which looks reassuringly familiar if you come from xUnit-style testing, to various types of Spec-based tests that (I think) are intended to make it easier to express the behaviour your test should be verifying.

Background for the example

For the examples below, I have used code from a pipeline evaluating citations of computer science publications based on a data dump obtained from dblp. DBLP make their data available under an Open Data Commons ODC-By 1.0 License.

The data I have is in json format. The fields I use are:

Other fields are available, e.g. venue (name of conference or journal), authors (list of author names) and of course title, abstract etc. The data also contains a field called n_citation that should be the number of times the paper was cited; however, a lot of records have a value of 50 there, so I suspect the value is capped. There are a few interesting “impurities” in the data, for example some publications that seem to have been cited before they were published; this usually appears to be a merge artifact, where a famous paper was re-issued in a collection and DBLP merged two publication records into one (which may be correct) and used the later publication date (which is not correct in all cases).

I wanted to see how the number of citations a paper attracts develops over time, so originally I made a Jupyter notebook. And then I wanted to have a simple-but-not-trivial example for trying some Scala features, so I decided to use the DBLP data.

First, here’s the first few pieces of the Scala code:

package com.nephometrics.dblp

import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.SparkSession


case class BasePublication(id: String, references: Array[String], year: Int)

class Citations(@transient val spark: SparkSession) {

  def this() = this(SparkSession.builder.getOrCreate())

  import spark.implicits._

  def idRefYearDs(baseData: DataFrame): Dataset[BasePublication] = {
    baseData.select("id", "references", "year").filter(
      "references is not NULL").as[BasePublication]
  }

  // Take a dataset of BasePublications, invert on the reference ids,
  // group by reference id + yearCited, sum the citation counts.
  def countCitationsByYear(publications: Dataset[BasePublication]): Dataset[(String, Int)] = {
     val citedPublications = publications.flatMap(row =>
         for(b <- row.references) yield (b + "." + row.year, 1))
     citedPublications.groupByKey(_._1).reduceGroups(
       (a, b) => (a._1, a._2 + b._2)).map(_._2)
  }
}

This code uses Datasets wherever possible because it seems worth trying. I might write another blog post about how the code differs when I allow myself to use Dataframes or even RDDs in a few places, and what implications that might have for performance.

One thing to note is that transforming baseData into a Dataset[BasePublication] requires an import of spark_implicits._ or a custom Encoder. I found this a little awkward because I like to write my base logic without depending on a SparkSession (cue the “what is a true Unit Test” discussion), but creating a custom Encoder just to keep my Unit Tests pure did not seem like a good idea either.

I wrote the code so that it can either getOrCreate() a SparkSession using the default builder or take one as a parameter to the constructor of the Citations class. This is so I can “inject” a pre-constructed SparkSession, e.g. one that is set up for testing. More on why that matters below. I have marked the SparkSession field @transient to exclude it from serialization.

A basic Unit Test

Here is a very basic test, “FunSuite” style.

package com.nephometrics.dblp

import org.apache.spark.sql.types._
import org.apache.spark.sql.SparkSession
import org.scalatest.FunSuite

// This just makes it easier to create test data.
case class Publication(id: String, year: Int, venue: String,
  references: Option[Array[String]])

class DblpCitationsTest extends FunSuite {

  // Do not copy this code! You probably want to put
  // it into a beforeAll() or beforeEach() method, and
  // add some useful options (explained below).
  val spark: SparkSession = SparkSession.builder().appName(
       "citations").master("local").getOrCreate()

  import spark.implicits._

  // pub1 published year 10, no citations
  // pub2 published year 9, cited by pub1 and pub5 in year 10 (at age 1)
  // pub3 published year 8, cited by pub2 in year 9 and pub1 in year 10
  // pub5 published in year 10, no citations
  // pub4 does not have an entry
  val sourceDF = Seq(
    Publication("pub1", 10, "here", Some(Array("pub2", "pub3"))),
    Publication("pub2", 9, "there", Some(Array("pub3", "pub4"))),
    Publication("pub3", 8, "there", None),
    Publication("pub5", 10, "there", Some(Array("pub2", "pub4"))),
  ).toDF()

  val stats = new Citations(spark)

  test("citations are counted by year") {
    val ds = stats.idRefYearDs(sourceDF)
    val counted = stats.countCitationsByYear(ds)
    val results = counted.collect()
    assert(5 == counted.count())

    for (r <- results) {
      if (r._1 == "pub2.10") {
        assert(2 == r._2)
      } else {
        assert(Array("pub3.9", "pub3.10", "pub4.10", "pub4.9").contains(r._1))
        assert(1 == r._2)
      }
    }
  }

and, for completeness, here is a basic build.sbt file:

name := "DBLP-Scala"
version := "0.9"
scalaVersion := "2.12.8"
val sparkVersion = "2.4.2"
scalacOptions += "-target:jvm-1.8"
javaOptions ++= Seq("-Xms512M", "-Xmx2048M", "-XX:+CMSClassUnloadingEnabled")
fork in Test := true
parallelExecution in Test := false
testOptions in Test += Tests.Argument(TestFrameworks.ScalaTest, "-oD")
libraryDependencies ++= Seq(
  "org.apache.spark" %% "spark-core" % sparkVersion,
  "org.apache.spark" %% "spark-sql" % sparkVersion,

  "org.scala-lang" % "scala-reflect" % "2.12.8",
  "org.scalatest" %% "scalatest" % "3.2.0-SNAP9" % Test

I specify a target JVM because I also have Java 10 installed and some spark features don’t work with it.

The Java memory options are there to avoid running out of memory when I run tests repeatedly from an sbt shell.

The -oD test option enables timing information for tests.

The line fork in Test := true tells sbt to fork the JVM. This avoids an sbt restart if a test causes the JVM to shut down.

The line parallelExection in Test := false makes sure tests are executed serially. I’ve seen this recommended especially when you create a separate SparkSession for every test though I have not been able to get my tests to fail when I leave this setting out. I enabled it here because I wanted to measure test running times independently, and it seemed best to run them serially for that purpose.

Note that while this build.sbt file works fine for tests, I have had some grief packaging my code and running spark-submit on the resulting .jar file where the pipeline fails with a NoClassDefFoundError for scala/runtime/LambdaDeserialize. As best I can tell so far, this tends to be caused by version mismatches or incompatible versions e.g. of scala and spark. If I figure out a generic way to diagnose and fix this kind of issue, I’ll write it up as a separate blog post; right now I can basically poke at this kind of issue until it works, but I don’t think I can explain why in a helpful way.

Spark Sessions in Unit Tests

As mentioned above, you will need a SparkSession to run most of the code you want to test. A unit testing purist will tell you that if you need a Spark Context or Session, your test has strayed from Unit Test into Integration Test territory. If you want to follow that line, you could break out “just the logic” of your code into a scala library that you test with “real” unit tests and then write integration tests for the other parts. It’s a clean way to structure your code and probably makes sense especially for large code bases.

For my own purposes, I refer to fast, frequently run tests as Unit Tests even when they use a SparkSession; what matters to me is how I use them.

Creating a Spark Session

So let’s look at how to create a SparkSession for Tests. I’ve shown one option above where I create a custom SparkSession in the Test Suite class and pass it to the constructor of the class I’m testing.

This is ok, but there are cleaner ways to structure this code. The Scalatest library has before and after fixtures that you can use for setup and teardown around your tests. These come in Each and All variants (executed before and after every test vs. before and after the entire test suite).

You can use beforeAll to set up a spark session that can then get used by every test, like so:

import org.scalatest.BeforeAndAfterAll

class DblpCitationsTest extends FunSuite with BeforeAndAfterAll {
  var stats: Citations = _

  override def beforeAll(): Unit = {
    val spark: SparkSession = SparkSession
    .builder()
    .appName("citations")
    .master("local")
    .getOrCreate()
  }

  stats = new Citations(spark)
  super.beforeAll()
}

You can declare this using beforeEach in pretty much the same way, and it will have a very similar effect because getOrCreate() will return the current Spark Session if one already exists. I wrote “similar” and not “the same” because may you run your tests in parallel in separate threads and use thread-local SparkSessions. Or maybe you implement an afterEach() method that calls spark.stop() – that will stop your spark session after every test, so you will get a new one the next time you call getOrCreate().

The scalatest docs explain more about the options here, e.g. how to run tests with fixtures, and run the same test with different fixtures.

More examples and utilities:

  1. Holden Karau’s spark-testing-base package will create a spark context for you if you mix a relevant trait into your test class
  2. spark-fast-tests shows how to use a SparkSessionTestWrapper trait that gives you some control over the spark session and only instantiates one per test suite.

Just to experiment, I added a second test to my suite, ran test from the sbt shell and looked at how long it took. Of course the raw number of seconds is meaningless by itself because a lot depends on the specs of the machine running the test and what else is running on it. I ran this on a machine with four CPUs (well, two cores with two hardware threads each) and 8 gigs of RAM that wasn’t doing much else at the time besides a couple of vim sessions. I ran the test suite with a beforeAll() setup a few times and results were between 15 and 19 seconds. That seems like a lot for two simple tests.

Does it get worse if I create a new Spark Session for every test? Actually, it gets faster (takes about 14 seconds per run). That seems counterintuitive at first, but more investigation showed that creating and stopping Spark Sessions was actually not the main thing slowing my tests down – my test suite only contained two tests, and moreover, I had not yet used some other good optimization options.

Configuring a Spark Session for testing

I ran my tests using sbt test 2>/tmp/sparktest.log and went over the logs in more detail.

The suite now has two tests, the first one is called BaseDataConversion and it just creates a DataFrame with the base data, converts it into a Dataset, and counts the number of entries. The second one is the citations are counted by year test for which I included the code above.

SBT reported:

[info] DblpCitationsTest:
[info] - BaseDataConversion (2 seconds, 79 milliseconds)
[info] - citations are counted by year (4 seconds, 723 milliseconds)
[info] Run completed in 13 seconds, 511 milliseconds.
[info] Total number of tests run: 2
[info] Suites: completed 1, aborted 0
[info] Tests: succeeded 2, failed 0, canceled 0, ignored 0, pending 0
[info] All tests passed.
[success] Total time: 16 s, completed May 20, 2019, 11:27:57 AM

SBT reported 16s total time for this test, of which 13s 511 milliseconds were spent on the test run, 2s 79 ms on the first test, and 4s 723 ms on the second. The first 4s or so of the total time are not recorded in the log file so I’ll assume those were spent starting up sbt, loading the project definition and settings.

I don’t know which operations went into the 2s 79 ms that sbt reported for the first test; from the log file, it seems like 6s passed between “start the context” and “start job for action”, the count action took about 620ms, and then teardown took less than 1s.

For the second test, it’s 2s to start up services, 1.6s to run the job, less than 1s for teardown. It’s consistently faster to start services for the second test even with a spark.stop() call in afterEach() I suspect this is because the first session start will bring up a couple of servers (like the block manager) that (as far as I can tell) are left running even when you call spark.stop().

Overall this is a pretty dismally long running time for a couple of very simple tests. I went over the steps outlined here to optimize. The most effective step is to run tests from the sbt shell instead of making separate sbt calls from the commandline, since this saves a lot of startup time. For the following tests, I ran from the commandline anyway because I wanted to be able to redirect stderr to different log files for different versions of the code.

Another recommendation is to set the number of shuffle partitions to 1. The log files for the second test show 200 tasks launched for shuffling, so this looked like a promising thing to try.

Add .config("spark.sql.shuffle.partitions", "1") to the configuration of the spark session builder, and we get:

[info] DblpCitationsTest:
[info] - BaseDataConversion (1 second, 803 milliseconds)
[info] - citations are counted by year (1 second, 404 milliseconds)
[info] Run completed in 10 seconds, 630 milliseconds.
[info] Total number of tests run: 2
[info] Suites: completed 1, aborted 0
[info] Tests: succeeded 2, failed 0, canceled 0, ignored 0, pending 0
[info] All tests passed.
[success] Total time: 20 s, completed May 20, 2019, 12:11:34 PM

The time spent on running the tests has come down nicely. The total time is up though – from the log file, it looks like it took 11s to start sbt and compile the code this time. This would get better if I used the sbt shell, and it will also amortize for large test suites.

Inside the tests, the count action from the first unit test took 520ms this time (it does not need to shuffle so the number of shuffle partitions should have no effect on it), while the collect call in the second test is down to 469ms now. As mentioned above, the raw numbers are almost meaningless, but the reduction in running time from 1.6s to 469ms is nice.

The medium post I referenced above also talks about using a single SparkSession and SparkContext instead of creating a new one for every test. I experimented with this, but did not see significant impact on running times; but then, I only have two tests.

Next, I don’t need spark to run a UI server during tests, so remove that as well:

.config("spark.ui.enabled", "false").

This has no noticeable impact on performance but still seems worth doing.

Once I felt like I would not want to read the detailed log output any more, I added this line:

spark.sparkContext.setLogLevel(org.apache.log4j.Level.ERROR.toString()

It had little to no impact on running times, but when I redirected stderr to /dev/null, I got the total wall time down to 12s, which is nice. Now I run my tests like this:

sbt test 2> /dev/null

so I still get the test results on stdout and then if I see a failure, I can re-run without the redirect to get the logs.

To be sure, 12 seconds is still longer than I want for unit tests though. Based on the timing above, most of the time is spent in starting sbt, compiling code, and bringing up the spark context and servers.

Code structure and testability

Spark Scala code is often written in the ‘fluent’ style, with chains of method invocations. This can make it a bit more difficult to diagnose bugs that are buried in the middle of a chain. Debugging is doable in the scala shell of course, but I like to write my unit tests with a focus on small pieces of logic. Maybe that means I test too much of the implementation as opposed to the end to end behaviour; I just find it works for me.

I usually break the whole pipeline down into methods of 15-30 lines that represent what I think of as “chunk”s of logic. I test those methods individually, and then I assemble them into a pipeline and test the pipeline end to end. Performance-wise, this makes no difference so long as the methods do not introduce extra actions.

Fortunately, my approach does not even make a performance difference during testing; at least based on my observations, it looks as though Spark is clever enough to cache and re-use intermediate results. I was pretty happy when I noticed that.

Here is an example:

  test("citations are counted by year") {
    val counted = stats.countCitationsByYear(baseDataSet)
    val results = counted.collect()
    assert(5 == counted.count())

    for (r <- results) {
      if (r._1 == "pub2.10") {
        assert(2 == r._2)
      } else {
        assert(Array("pub3.9", "pub3.10", "pub4.10", "pub4.9").contains(r._1))
        assert(1 == r._2)
      }
    }
  }

   test("citation age is computed from year published and year cited") {
    val counted = stats.countCitationsByYear(baseDataSet)
    val cited = stats.countCitationsByAge(counted, sourceDF)
    val results = cited.collect()
    assert(5 == cited.count())

    // More assertions come here
   }

During testing, I noticed that the “citation age” test is consistently slower when I remove the “citations are counted” test. Test timing and logs comparison shows Spark does not (or at least not always) re-compute counted for the second test if it’s already been computed for the first. I specify serial execution of my tests to make sure I can take advantage of this.

My pipeline takes a sourceDF and does something like this:

val baseDataSet = idRefYearDs(sourceDF)
val counted = countCitationsByYear(baseDataSet)
val cited = countCitationsByAge(counted, sourceDF)
// ... more transformations after this ...

For the optimized plan that Spark creates when I call an action, it makes no difference whether I split things up into methods and assign intermediate results to vals. Maybe I’ll adopt more of the fluent style over time and change the way I test, but for now, it is reassuring to know I can test in small chunks.

There is a situation where breaking out a separate method does make running in Spark more difficult: trying to pass a locally defined method into something like map(). More background here, but the upshot is that if you’re asking for code to be executed on worker nodes, Spark needs to send that code to those nodes. It does this by serializing the object that defines the code; and of course it needs to have access to the values of the variables to be passed to the code. You need to make sure the object in question is actually Serializable, and you (probably) want to make sure to keep that object small.

One more thing to note in this context is that a method declared with def will contain a reference to this (Background), so if this is not Serializable, or if you just don’t want it to be serialized, use val not def to declare the function that you want the worker nodes to execute.