[Scala] Tail recursion

在學習Functional programming過程中,學到遞迴可以分為兩類:

  • Tail recursion
1
If a function calls itself as its last action, the function’s stack frame can be reused. This is called tail recursion.
  • Tail call
1
If the last action of a function consists of calling a function (which may be the same), one stack frame would be sufficient for both functions. Such calls are called tail-calls.

分別以計算最大公因數(gcd)和階層(factorial)為例:

  • gcd

    1
    2
    3
    4
    5
    6
    def gcd(a: Int, b: Int): Int = {
    if (b == 0)
    a
    else
    gcd(b, a % b)
    }
  • factorial

    1
    2
    3
    4
    5
    6
    def factorial(n: Int): Int = {
    if (n == 0)
    1
    else
    n * factorial(n - 1)
    }

例如:

1
gcd(21, 14) -> gcd(14, 7) -> gcd(7, 0) -> 7
1
2
3
4
5
6
factorial(5) -> 
5 * factorial(4) -> 5 * 4 * factorial(3) ->
5 * 4 * 3 * factorial(2) ->
5 * 4 * 3 * 2 * factorial(1) ->
5 * 4 * 3 * 2 * 1 * factorial(0) ->
5 * 4 * 3 * 2 * 1 * 1

可以發現:

  1. 在gcd範例中,每一步不會依賴上一步的結果,上一步的結果是以參數方式傳入到函式裡面。
  2. 在factorial範例中,每一步會依賴上一步的結果,所以需要 stack 來記錄每一步的狀態。
    等走到盡頭後,再取出 stack 內的元素並且計算之,逐一合併結果。

假如執行 factorial(100000) 可以預期會發生 stack overflow,我們可以將原本的程式改成Tail recursion版本,就不會有 stack overflow 發生。
在Scala中會對Tail recursion做最佳化,或是可以透過 @tailrec 標註此函數是Tail recursion。

  • 測試範例:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import java.util.concurrent.TimeUnit

import com.google.common.base.Stopwatch

def factorial(n: Int): Int = {
if (n == 0)
1
else
n * factorial(n - 1)
}

def factorialTailrec(n: Int, result: Int): Int = {
if (n == 0)
result
else
factorialTailrec(n - 1, n * result)
}

val sw = Stopwatch.createUnstarted()
sw.elapsed(TimeUnit.MILLISECONDS)
sw.start()
factorial(15)
sw.stop()
println(sw.toString)
sw.reset()

sw.start()
factorialTailrec(15, 1)
sw.stop()
println(sw.toString)
  • 測試結果:
    1
    2
    1.611 ms
    677.4 μs

可以看出來改成Tail recursion版本效率會大幅改善。

參考:

  1. What is tail recursion?
  2. Tail recursion