- Concepts of Programming Languages

Tail Recursion

Instructor:

Fibonacci Numbers

  • In Scala 3 new syntax, implement a recursive function to compute the Fibonacci numbers
    1. def fibonacci(n: Int): Int = n match
    2. case 0 => 0
    3. case 1 => 1
    4. case _ => fibonacci(n - 1) + fibonacci(n - 2)
  • Try fibonacci(30) and fibonacci(45)

Fibonacci Numbers

  • Make it fast
    1. def fibonacci(n: Int, memo: mutable.Map[Int, Int] = mutable.Map.empty): Int =
    2. if n <= 0 then return 0
    3. if n == 1 then return 1
    4. if memo.contains(n) then return memo(n)
    5. // Calculate Fibonacci and store in the map
    6. val result = fibonacci(n - 1, memo) + fibonacci(n - 2, memo)
    7. memo(n) = result
    8. result
  • Applies dynamic programming
  • Do we need to keep all the numbers ever computed?
  • Does it truly solve the problem?

Learning Objectives

How to get recursive iteration without stack penalty and without recomputing intermediate results?

  • Identify and express tail recursion

Recursion and Stack Limitations

  1. def sum (xs:List[Int]) : Int = xs match
  2. case Nil => 0
  3. case x::rest => x + sum (rest)
  4. val xs = List(11,21,31)
  5. sum (xs)
    1. sum(11::21::31::Nil)
    1. --> sum(11::21::31::Nil)
    1. --> 11 + sum(21::31::Nil)
    1. --> 11 + (21 + sum(31::Nil))
    1. --> 11 + (21 + (31 + sum(Nil)))
    1. --> 11 + (21 + (31 + 0))
    1. --> 11 + (21 + 31)
    1. --> 11 + 52
    1. --> 63 = (11 + (21 + (31 + 0)))
  • Summing up left to after the last recursive call returns
  • How does the stack look like?

Call Stack

  • Contains activation records (AR) for active calls, also known as stack frames
  • Changes to call stack
    • AR pushed when a function/method call is made
    • AR popped when a function/method returns
  • Runtime environments limit size of call stacks?
  • Can cause problems with deep recursion
    • Java, Scala: StackOverflowError
    • C: stack limits set by operating system

Tail Recursive Calls

Sum of elements in a list computing forward

  1. def sum (xs:List[Int], z:Int = 0) : Int = xs match
  2. case Nil => z
  3. case x::rest => sum (rest, z + x)
    1. sum(11::21::31::Nil)
    1. --> sum(11::21::31::Nil, 0)
    1. --> sum(21::31::Nil, 11)
    1. --> sum(31::Nil, 32)
    1. --> sum(Nil, 63)
    1. -->
    1. -->
    1. -->
    1. --> 63 = (((0 + 11) + 21) + 31)
  • All recursive calls are in tail-position
  • Result sum computed before recursive call is made, no work left
  • How is the stack now different?

Tail Call Optimization

  • Many compilers implement tail-call optimization
    • overwrite existing activation record instead of creating new
  • Recursive calls must be tail-recursive
  • Includes mutual recursion
    • f calls to g, which calls back to f

Exercise: Recursive vs. Tail-Recursive Fibonacci

  1. def fib(n:Int) : Long =
  2. if n <= 1 then n
  3. else fib(n-1) + fib(n-2)
  • Time complexity
  • How to improve?
  • fib(0) fib(1) fib(2) fib(3) fib(4) fib(5) fib(6) fib(7) fib(8)
    0 1 1 2 3 5 8 13 21
  • Represent sliding window in result
    (not tail-recursive!)
    1. def fib(n:Int) : (Long, Long) =
    2. if n <= 1 then (0, n)
    3. else
    4. val (a, b) = fib(n-1)
    5. (b, a+b)
  • Represent sliding window in arguments
    (tail-recursive)
    1. def fib(n:Int, a:Long=0, b:Long=1) : Long =
    2. if n == 0 then a
    3. else if n == 1 then b
    4. else fib(n-1, b, a+b)

Tail-recursive Fibonacci Numbers

  • In Scala 3, implement a tail-recursive function to compute Fibonacci numbers
    1. def fibonacci(n: Int): Int =
    2. @tailrec
    3. def fibHelper(n: Int, a: Int, b: Int): Int = n match
    4. case 0 => a
    5. case _ => fibHelper(n - 1, b, a + b)
    6. end fibHelper
    7. fibHelper(n, 0, 1)
    8. end fibonacci
  • tailrec annotation: compiler error if not tail-recursive
  • Specific instructions help generate good code

Translate Tail-Recursion to Loop

Tail-recursive

  1. def factorial (n:Int) : Int =
  2. @tailrec
  3. def loop (m:Int, result:Int) : Int =
  4. if m > 1 then loop(m-1, m*result)
  5. else result
  6. loop(n,1)

Recursive (mutable)

  1. def factorial (n:Int) : Int =
  2. var result = 1
  3. def loop (m:Int) : Unit =
  4. if m > 1 then
  5. result = result*m
  6. loop(m-1)
  7. loop(n)
  8. result
  1. def factorial (n:Int) : Int =
  2. var result = 1
  3. var m = n
  4. def loop () : Unit =
  5. if m > 1 then
  6. result = result*m
  7. m = m-1
  8. loop()
  9. loop()
  10. result

Loop (mutable data)

  1. def factorial (n:Int) : Int =
  2. val result = 1
  3. var m = n
  4. while m > 1 do
  5. result = result * m
  6. m = m - 1
  7. result

Summary

  • Tail-call optimization
    • avoids the performance penalty of creating activation records
    • overwrites an existing activation record
    • all recursive calls must be in tail position (last operation)

* Time complexity $O(n)$ (additional penalty for activation records) * Space complexity $O(n)$

* Time complexity $O(n)$ (no penalty for creating activation records) * Space complexity $O(1)$