Tail Call/Recursion Optimization by Example
This is a tail call example,
def foo(n): if n > 0: return foo(n-1) else: return bar(n) def bar(n): return n-1
Here, both calls to
bar are tail calls/recursions because both of them are right before
return. First, it can be transformed into loop for efficiency. Second, it may cause problems in some cases.
Suppose we have a very limited stack space, and we call
foo(100000000000), then 100000000000 stack frames are going to be allocated. It is highly likely that the program will crash with a stack overflow error before it reachs 100000000000. This is especially important in functional languages because most of them don't have language constructs for loop. Instead, they use tail call as loop.
Examples and Solutions
For a non-optimized compiler, it will generate the following code (pseudo-code) for
n is on the stack, and we use
loc() to denote the location of something.
mov reg loc(n) ; load the value n from stack into register cmp reg 0 ; compare n and 0 jg 1 ; jump to label 1 if n > 0 push reg ; push n onto stack to pass parameter call bar ; call bar, who uses the parameter pop ; pop n from stack to clear it
- 1: desc reg ; n = n-1
push reg ; push the new n onto stack to pass parameter call foo ; call foo, who uses the parameter pop ; pop n from stack to clear it ret ; return
A tail call optimized code would look like this:
mov reg loc(n) ; load the value n from stack into register cmp reg 0 ; compare n and 0 jg 1 ; jump to label 1 if n > 0 jmp bar ; jump to bar
- 1: desc reg ; n = n-1
mov loc(n) reg ; store the new n back to its location jmp foo ; jump to foo again
Here, we are not pushing parameters onto stack and call functions. Instead, we are adjusting the current parameters and loop. Or we can say, we are not buding new stack frames, we are instead reusing the same stack after we adjust the parameters needed by the code.
You can correspond the optimized machine code to the following python code.
def foo(n): while n > 0: n = n - 1 return n - 1
And you can examine that this is a semanticlly equivalent program of the original one, which doesn't have tail calls.
Constant folding is the process of simplifying constant expressions at compile time. It is useful mostly because it reduces the amount of computation needed in runtime.
int x = 17 int y = (x + 3) / 2 return y + 5
Will be reduced to
int x = 17 int y = 10 return 15
by constant folding optimization. However, the definition of
y are not removed at this time. They will probably be removed by the dead code optimization process.
How to do it though? Recall the core idea of constant folding, is to reduce constant expression to an equivalent expression that can't be further reduced. That implies two things we need to consider. First, how to identify a constant expression. Second, how to reduce it.
There are more than one solutions to this. And here are some tips.
- Expressions without variables are easy to do.
- Expressions with variables need further checking. To check the variable definition and assignment within the scope.