4 minute read

In this blog we will talk about the problem of list concatenation and give a rough introduction to Trampoline.1

Before we discuss more about this topic, we need to anwser two questions

  • What is Continuation Passing Style?
  • What is StackOverflow?

Let’s try to understand them in the following two sections.

What is Continuation Passing Style?

In functional programming, continuation-passing style (CPS) is a style of programming in which control is passed explicitly in the form of a continuation.2

Let’s see a traditional function to calculate the Pythagorean theorem.3

def pythageorean(x:Double,y:Double):Double = {
    Math.sqrt(x*x+y*y)
}
$ pythageorean(3,4)
res1: Double = 5.0

Let’s see the CPS of this function

def multiply(x:Double,y:Double,continuation:Double=>Double):Double = {
    continuation(x*y)
}

def sqrt(x:Double,continuation:Double=>Double):Double = {
    continuation(Math.sqrt(x))
}

def add(x:Double,y:Double,continuation:Double=>Double):Double = {
    continuation(x+y)
}

def pythageorean(x:Double,y:Double,continuation:Double=>Double):Double = {
    multiply(x,x,vx=>{
        multiply(y,y,vy=>{
            add(vx,vy,vxy=>{
                sqrt(vxy,continuation)
            })
        })
    })
}
$ pythageorean(3,4,x=>x)
res6: Double = 5.0

To write a CPS function, we only add one extra argument called continuation and ensure this argument will process the return value of function. In this way, we make something explicit

  • Procedure returns

    We need to call continuation explicitly

  • Intermediate values

    We will give them name when call continuation

  • Order of argument evaluation

    We can control the order of argument evaluation by continuation

  • Tail calls4

    This is the main benefit we can get

    For non-recursion function, it will always be tail calls, because we need to use continuation to process return value.

    def add(x:Double,y:Double,continuation:Double=>Double):Double = {
        continuation(x+y)
    }
    

    For recursive function, we can make a tail-recursion easily

    def sum(n:BigDecimal):BigDecimal = {
        if(n == 0) 0
        else n+sum(n-1)
    }
    
    def sum(n:BigDecimal, continuation:BigDecimal=>BigDecimal):BigDecimal = {
        if(n==0) continuation(0)
        else sum(n-1, x=> continuation(x+n))
    }
    

What is StackOverflow?

Usually, we may encounter stack overflow in two scenarios

  • Non Tail-Recursion

    object Test {
        def isOdd(v:Int):Boolean = {
            if( v==0 ) 
                false
            else 
                isEven(v-1)
        }
        def isEven(v:Int):Boolean = {
            if(v ==0) 
                true
            else 
                isOdd(v-1)
        }
    }
      
    
    $ Test.isOdd(10000000)
    java.lang.StackOverflowError
      ammonite.$sess.cmd0$Test$.isEven(cmd0.sc:6)
      ammonite.$sess.cmd0$Test$.isOdd(cmd0.sc:3)
      ammonite.$sess.cmd0$Test$.isEven(cmd0.sc:6)
      ammonite.$sess.cmd0$Test$.isOdd(cmd0.sc:3)
    
  • Deeply Nested Function

    object Test {
        def nestedFunctions(n:Int):Int => Int ={
            Range(0,n).foldLeft[Int => Int](identity)((acc,e)=>{
                x:Int => e+acc(x)
            })
        }
    }
    
    $ val f = Test.nestedFunctions(10000)
    $ f(1)
    java.lang.StackOverflowError
      ammonite.$sess.cmd3$Test$.$anonfun$nestedFunctions$3(cmd3.sc:4)
      ammonite.$sess.cmd3$Test$.$anonfun$nestedFunctions$3(cmd3.sc:4)
      ammonite.$sess.cmd3$Test$.$anonfun$nestedFunctions$3(cmd3.sc:4)
      ammonite.$sess.cmd3$Test$.$anonfun$nestedFunctions$3(cmd3.sc:4)
      ammonite.$sess.cmd3$Test$.$anonfun$nestedFunctions$3(cmd3.sc:4)
    

In theory, if the compiler do nothing for the recursive function, we will always get a StackOverflowError, because the program always need to use the stack to store the parameter and local variable, but the size of stack is fixed.

In this diagram we can understand it clearly.5

We know Scala compiler already did some optimization for tail-recursive function. To avoid the StackOverflowError, we should make our recursive function to be tail-recursive and avoid deeply nested function.

The Problem of List Concatenation

We know List is a recursive data type in Scala, we can give its definition roughly like this

sealed trait List[+A]
case object Nil extends List[Nothing]
case class ::[A](head:A, tail:List[A]) extends List[A]

If we want to concatenate two list, we can do it like this

def concat[A](left:List[A],right:List[A]):List[A] = {
    left match {
        case Nil => right
        case head::tail => head::concat(tail,right)
    }
}

But this implementation has a problem, it will throw StackOverflowError when the left list is too long

$ concat[Int](List.fill(100000)(0),Nil)
java.lang.StackOverflowError
  scala.collection.immutable.Nil$.equals(List.scala:433)
  ammonite.$sess.cmd31$.concat(cmd31.sc:3)
  ammonite.$sess.cmd31$.concat(cmd31.sc:4)
  ammonite.$sess.cmd31$.concat(cmd31.sc:4)

The call stack looks like this

According to the previous section, we can use CPS to convert this function to tail-recursion like this

def concat[A](left:List[A],right:List[A],continuation:List[A]=>List[A]):List[A] = {
    left match {
        case Nil => continuation(right)
        case head::tail => concat(tail,right, x=> continuation(head::x))
    }
}

But this tail-recusion still throw a StackOverflowError

$ concat[Int](List.fill(100000)(0),Nil,identity)
java.lang.StackOverflowError
  ammonite.$sess.cmd34$.$anonfun$concat$1(cmd34.sc:4)
  ammonite.$sess.cmd34$.$anonfun$concat$1(cmd34.sc:4)
  ammonite.$sess.cmd34$.$anonfun$concat$1(cmd34.sc:4)
  ammonite.$sess.cmd34$.$anonfun$concat$1(cmd34.sc:4)

Why this happen? let’s see what’s the call stack of this function

On the left side, we optimize the tail-recursion and it won’t throw the StackOverflowError, but we found the tail-recursive function compose the continuation function in every call and we got a deeply nested function in the final call.

When we begin to evaluation the continuation function, we actually invoke a deeply nested function. According to the previous section, it will throw the StackOverflowError.

How to avoid the deeply nested function? or How to convert the nested function to tail-recursion?

Nested function to Tail-recursion

Let’s say we have three functions

def f1(x:Int):Int = x+1
def f2(x:Int):Int = f1(x)+2
def f3(x:Int):Int = f2(x)+3

How do we change f3 to a tail-recursive function? To do that, we can use a data type to store the function information and loop the data type using tail-recursion

sealed trait Action[A]
case class Done[A](v:A) extends Action[A]
case class Doing[A](x:Action[A], f:A=>Action[A]) extends Action[A]

def evaluate[A](x:Action[A]):A = {
    x match {
        case Done(v) => v
        case Doing(v,f) => v match {
            case Done(v1) => evaluate(f(v1))
            case Doing(v1,f1) => evaluate(Doing[A](v1,y=>Doing(f1(y),f)))
        }
    }
}
def f1M(x:Int):Action[Int] = Doing[Int](Done(x),v=>Done(v+1))
def f2M(x:Int):Action[Int] = Doing[Int](f1M(x),v=>Done(v+2))
def f3M(x:Int):Action[Int] = Doing[Int](f2M(x),v=>Done(v+3))
$ evaluate(f3M(100))
res46: Int = 106

The basic idea here is we use Action to store the call stack and it will break the call chains, then re-evaluate them one by one in an outer loop.

The knowledge of Action is more than what we can see in above code, it use the technique called Trampoline,61 we will talk about it more detail in another blog.

Reference

  1. Steven E. Ganz and Daniel P. Friedman and Mitchell Wand, Trampolined Style, in International Conference on Functional Programming, ACM Press, 1999, pp. 18–27.  2

  2. Continuation passing style 

  3. Pythagorean theorem 

  4. Tail Calls 

  5. What is a StackOverflow 

  6. Stackless Scala with Free Monads 

Tags:

Updated:

Comments