Apache Spark Scala: round() doesn't always round()
BLOT:
Apache Spark Scala round() doesn’t always round() correctly so check your data types.
Long story:
I recently had to move a piece of code from SAS to Apache Spark. The SAS code was pretty simple:
- Create a column with number based on value on another column e.g.
if col_1 < 5 then col_1_val 0.123
- Do this for multiple columns
- Sum all the “_val” columns to get final score
Note: SAS was taken as the standard so any differences/changes needed to be done in Spark and not in SAS code.
Thankfully Apache Spark has many ways to replicate 1 and 2 with commands like withColumn
and when-otherwise
logic.
Part 3 should have been the easiest as I could just say:
val final_df = df.withColumn(“score”, round($”col_1_val” +$”col_2_val” + $”col_3_val” + $”col_4_val”))
However, when I compared my results I noticed that a few values were not the same between SAS and Spark e.g.
+-----------+---------------+-------------+
| Total Sum | SAS round | Spark round |
+-----------+---------------+-------------+
| 165.5 | 166 (correct) | 165 (wrong) |
+-----------+---------------+-------------+
Understanding the problem
After some research (here and here) it seems that the difference has to do with the way a decimal is actually represented by the machine and the data types I was using (DoubleType). Fancy words like “mantissa” and “IEEE floating point 754 specification” are used but the bottom line is that the computer may actually see the data like this:
+-----------+---------------------+-------------+
| Total Sum | Spark actual value | Spark round |
+-----------+---------------------+-------------+
| 165.5 | 165.498986895866996 | 165 |
+-----------+---------------------+-------------+
Given the command to round to the nearest integer the Spark round function has done is correctly but it’s not what we would achieve had we done it by hand.
Correcting the problem
To rectify this it is advised to convert our DoubleType to a FloatType before using round().
import org.apache.spark.sql.types.FloatTypeval final_df = df.withColumn(“score”, round(($”col_1_val” +$”col_2_val” + $”col_3_val” + $”col_4_val”).cast(FloatType))
Output is now as expected for all values where differences occurred.
+-----------+-----------+-------------+
| Total Sum | SAS round | Spark round |
+-----------+-----------+-------------+
| 165.5 | 166 | 166 |
+-----------+-----------+-------------+
Phew! :)