Pattern matching in Scala

Pattern matching is a feature that is not unfamiliar in a lot of functional languages and Scala is no exception.

It matches a value against several patterns. Each pattern points to an expression. The expression that is associated with the the first matching pattern, will be executed.

The syntax of pattern matching in Scala is defined as follows:

e match { case p1 => e1 ... case pn => en }  

A Simple example

Let's now create a simple example where we want to print the cons cells of a list. To do this, we create a recursive function that uses pattern matching.

def pCons(list: List[Int]): String = list match {  
    case Nil => "nil"
    case x::xs => "(" + x + "." + pCons(xs) + ")"
} 

If we execute this function with a list, pCons(List(1,2,3,4)), we'll get the following result:

(1.(2.(3.(4.nil))))

Case classes

The real beauty of pattern matching in Scala is how it can decompose object hierarchies.

Scala has a special class called the case class that differ a bit from a normal class. You can find the list of differences here, but the important features of the case class for pattern matching is that we can access the constructor parameters and that it provides a recursive decomposition mechanism through the pattern matching.

Now, let's take a look at another example, where we use the case classes. I've created a small language that contains 3 different types of expressions.

abstract class Exp

case class Fun(e: Exp) extends Exp  
case class Number(n: Int) extends Exp  
case class Sum(exp1: Exp, exp2: Exp) extends Exp  
case class Product(exp1: Exp, exp2: Exp) extends Exp  

Let's now use pattern matching to work with this class hierarchy.

First let's create a recursive function that takes an expression and then returns it as a String in a clojure-like way.

def print(e: Exp): String = e match {  
    case Number(x) => x.toString
    case Sum(e1, e2) => "(+ " + print(e1) + " " + print(e2) + ")"
    case Product(e1, e2) => "(* " + print(e1) + " " + print(e2) + ")"
    case Fun(e) => "(fn [] " + print(e) + ")"
}

Ok, let's test the print function with an expression.

print(Fun(Product(Sum(Number(4), Number(2)), Number(3))))  

This will print the following.

(fn [] (* (+ 4 2) 3))

Let's now create a function, calculate, that will calculate an expression.

def calculate(e: Exp): Int = e match {  
    case Number(x) => x
    case Sum(e1, e2) => calculate(e1) + calculate(e2)
    case Product(e1, e2) => calculate(e1) * calculate(e2)
    case Fun(e) => calculate(e)

If we now send the expression we created earlier as a parameter to our calculate function, we'll get 18 as expected.

Guards

In this last example, we want to add another case to our print method. We want to catch all Products that contains two Numbers that are equal, and then associate it with an expression that prints a squared function istead of the normal * function. To achieve this, we're going to use a guard. A Guard allows us to do additional checks, by adding an if followed by a condition after the pattern.

case p1 if c1 => e1  

Let's take a look at how the print method would look after adding the squared check.

def print(e: Exp): String = e match {  
    case Number(x) => x.toString
    case Sum(e1, e2) => "(+ " + print(e1) + " " + print(e2) + ")"
    case Product(Number(x), Number(y)) if x == y => "(squared " + x + ")"
    case Product(e1, e2) => "(* " + print(e1) + " " + print(e2) + ")"
    case Fun(e) => "(fn [] " + print(e) + ")"
}

If we then execute print(Fun(Product(Number(4), Product(Number(3), Number(3))))), we'll get the following String back:

(fn [] (* 4 (squared 3)))

As you can see, pattern matching in Scala is offering a concise way to match all sorts of data, letting you work with structures instead of just an instance of an object.

Enjoyed the post?

If you don't want to miss future posts, make sure to subscribe