Scala map function: Examples using Lists and DataFrames

Louise Forbes
5 min readMar 23, 2023

--

Photo by Adolfo Félix on Unsplash

The map function in Scala can be really useful when you need to do the same thing for each element of a list, sequence or array. As Data Scientist we often don’t have just a simple list but rather DataFrames and we are often dealing with the same data but different attributes (e.g. monthly rain values, monthly account statements). The question that must then be asked is “Can I use the map function on a series of DataFrame?”. Well, yes, let’s see how

Recap: Map function with list

First let recap how the map function works with a list.

The map function takes each element of a list and applies some function/transformation to each one.

val my_int_list = List(1, 2, 3, 4, 5)
val doubled_list = my_int_list.map(elem => elem * 2)

doubled_list: List[Int] = List(2, 4, 6, 8, 10)


val my_string_list = List("tree", "Cat", "DOG", "aPPle")
val all_lower_list = my_string_list.map(st => st.toLowerCase)

all_lower_list: List[String] = List(tree, cat, dog, apple)

Using DataFrames

The dataset

The dataset we are going to use is the rainfall everyday for a number of years. (Dataset provided by an avid geographer who collects rainfall data at his own house). An example of the DataFrame (randomly ordered to show different possibilities) is below:

allRainfall.orderBy(rand).show(10)

+----------+--------+
| Date|Amt (mm)|
+----------+--------+
|2010-08-07| 0.0|
|2015-08-31| 0.0|
|2012-12-31| 12.0|
|2015-06-25| 0.0|
|1997-04-26| 0.0|
|2006-12-20| 0.0|
|1998-09-27| 2.0|
|2008-10-10| 0.0|
|2000-03-24| 0.0|
|1999-03-14| 6.5|
+----------+--------+

Let’s consider a situation where someone has already extracted the summer rainfall, by month, for a part of the 2014/2015 rainy season (I’m in the southern hemisphere and it rains in summer #Jozigirl). The raw data has the date and amount of rainfall on that date.

nov2014.show(5)
+-------------------+--------+
| Date|Amt (mm)|
+-------------------+--------+
|2014-11-01 00:00:00| 0|
|2014-11-02 00:00:00| 0|
|2014-11-03 00:00:00| 3.5|
|2014-11-04 00:00:00| 4.5|
|2014-11-05 00:00:00| 4|
+-------------------+--------+

dec2014.show(5)
+-------------------+--------+
| Date|Amt (mm)|
+-------------------+--------+
|2014-12-01 00:00:00| 5|
|2014-12-02 00:00:00| 0|
|2014-12-03 00:00:00| 0|
|2014-12-04 00:00:00| 0|
|2014-12-05 00:00:00| 0|
+-------------------+--------+

jan2015.show(5)
+-------------------+--------+
| Date|Amt (mm)|
+-------------------+--------+
|2015-01-01 00:00:00| 0|
|2015-01-02 00:00:00| 1.5|
|2015-01-03 00:00:00| 0|
|2015-01-04 00:00:00| 34|
|2015-01-05 00:00:00| 2|
+-------------------+--------+

feb2015.show(5)
+-------------------+--------+
| Date|Amt (mm)|
+-------------------+--------+
|2015-02-01 00:00:00| 16|
|2015-02-02 00:00:00| 0|
|2015-02-03 00:00:00| 1.7|
|2015-02-04 00:00:00| 1.7|
|2015-02-05 00:00:00| 1.7|
+-------------------+--------+

Now consider we want to create a summary table for each of these DataFrames with only select a few statistics. It is possible to write the summary code for each each DataFrame individually but what happens when there are 12 DataFrames for the whole year and then you want to add a statistic (the code gets long and difficult to maintain). Instead, create a sequence of DataFrames and use the map function to find the summary.

val allMonths = Seq(nov2014, dec2014, jan2015, feb2015)

val allSumDfs = allMonths
.map(df => df.select("Amt (mm)").summary("count", "min", "mean", "max" ).show())

+-------+-----------------+
|summary| Amt (mm)|
+-------+-----------------+
| count| 30|
| min| 0|
| mean|5.466666666666667|
| max| 6|
+-------+-----------------+

+-------+-----------------+
|summary| Amt (mm)|
+-------+-----------------+
| count| 31|
| min| 0|
| mean|4.403225806451613|
| max| 6|
+-------+-----------------+

+-------+-----------------+
|summary| Amt (mm)|
+-------+-----------------+
| count| 31|
| min| 0|
| mean|6.064516129032258|
| max| 9.5|
+-------+-----------------+

+-------+-----------------+
|summary| Amt (mm)|
+-------+-----------------+
| count| 28|
| min| 0|
| mean|3.267857142857143|
| max| 8.5|
+-------+-----------------+

More complicated map function

What about a more complicated function? Let’s say we want a function that does the following:

  • Extract the rainfall data for a given month
  • Show a summary of that month
  • Save that month’s rainfall data
  • Return the rainfall DataFrame

Note: I know the Single Responsibility Principle says that functions are only supposed to do one thing so this is not a good function but go with me for this example.

We can see that the table doesn’t have the year and month as a separate column. We can include this in the function but we will be calculating the same thing on every function call so let’s rather add it before.

val allRainfallYearMonth = allRainfall
.withColumn("Year", year(to_date($"Date")))
.withColumn("Month", date_format(to_date($"Date"), "MMM"))

allRainfallYearMonth.orderBy(rand).show(10)
+----------+--------+----+-----+
| Date|Amt (mm)|Year|Month|
+----------+--------+----+-----+
|2008-02-05| 16.0|2008| Feb|
|2002-04-30| 0.0|2002| Apr|
|2009-02-12| 4.0|2009| Feb|
|2005-10-25| 3.0|2005| Oct|
|2000-03-30| 7.5|2000| Mar|
|2001-09-28| 0.0|2001| Sep|
|2013-10-31| 1.0|2013| Oct|
|2003-06-06| 6.0|2003| Jun|
|2004-08-23| 0.0|2004| Aug|
|2014-06-13| 0.0|2014| Jun|
+----------+--------+----+-----+

The function we want to use is:

import org.apache.spark.sql.DataFrame

def extractRainfallMonth(yearMonthList: List[Any], rainfall: DataFrame, basePath: String ): DataFrame = {

val year = yearMonthList(0)
val month = yearMonthList(1)

//Extract specific year and month rainfall
val rainfallMonth = rainfall
.filter($"Year" === year and $"Month" === month)

//Print summary
rainfallMonth.select("Amt (mm)")
.summary("count", "min", "mean", "max" ).show()

//Save file
rainfallMonth.write.format("com.crealytics.spark.excel")
.option("header", "true").mode("overwrite")
.save(basePath + month + year +".xlsx")

return rainfallMonth
}

First we need to create a sequence/list/array where our variables can be stored. It’s important to note that a map function can only accept one argument so you need to combine the inputs into a List and hence our function above uses a List[Any] as an input. We’re going to use arbitrary months and years to show how this works.

val yearMonthSeq = Seq(List(2002, "Apr"),
List(1999, "Nov"),
List(2004, "Jan"),
List(2011, "Sep"))

Then we use the map function to loop through these, calling our function

val folderPath = <path to output folder eg D:/Documents/Rainfall/>
val allRainDfs = yearMonthSeq
.map(ym => extractRainfallMonth(ym, allRainfallYearMonth, folderPath))

+-------+------------------+
|summary| Amt (mm)|
+-------+------------------+
| count| 30|
| min| 0.0|
| mean|0.9833333333333333|
| max| 9.0|
+-------+------------------+

+-------+------------------+
|summary| Amt (mm)|
+-------+------------------+
| count| 30|
| min| 0.0|
| mean|2.6666666666666665|
| max| 6.0|
+-------+------------------+

+-------+------------------+
|summary| Amt (mm)|
+-------+------------------+
| count| 31|
| min| 0.0|
| mean|3.4516129032258065|
| max| 30.0|
+-------+------------------+

+-------+-------------------+
|summary| Amt (mm)|
+-------+-------------------+
| count| 30|
| min| 0.0|
| mean|0.06666666666666667|
| max| 2.0|
+-------+-------------------+

Summaries are printed, results saved, and our output is a sequence of DataFrames.

allRainDfs: Seq[org.apache.spark.sql.DataFrame] = 
List([Date: string, Amt (mm): string ... 2 more fields],
[Date: string, Amt (mm): string ... 2 more fields],
[Date: string, Amt (mm): string ... 2 more fields],
[Date: string, Amt (mm): string ... 2 more fields])

Now we could take this sequence and extract the individuals DataFrames or, for the sake of this example, combine them into one DataFrame.

val Apr2002 = allRainDfs(0)
val allCombined = allRainDfs.reduce(_ union _)

Apr2002: org.apache.spark.sql.DataFrame = [Date: string, Amt (mm): string ... 2 more fields]
allCombined: org.apache.spark.sql.DataFrame = [Date: string, Amt (mm): string ... 2 more fields]

I hope this has helped you using a map function for simple and complex concepts.

--

--