Are you familiar with this error? do you remember the nightmare when writing recursive methods? In this post, we will talk about a technique to eliminate StackOverflowError
based on HotSpot JVM.
What is Java Virtual Machine Stack?
To understand this error, we need to know the Java Virtual Machine Stack first.
Java Virtual Machine Stack is also called Java Stack. It is used to store stack frames, which record the current state of the current method in the current thread. When we create a thread, a Java Virtual Machine Stack will be created at the same time and only the corresponding thread can operate it.
A stack frame consists of three parts: local variables, operand stacks, and dynamic links. When the current method A invokes another method B, a new stack frame of B will be pushed into the Java Virtual Machine Stack. After method B is complete normally or abruptly, B’s stack frame will be popped out, and A’s stack frame is on the top again, then it will be used to restore the state of A to continue the process.
Say we have a function factorial
public Long factorial(Long n) {
if (n == 1) {
return 1l;
}
return n * factorial(n - 1);
}
The Java Virtual Machine Stack of factorial(4)
looks like
The Java Virtual Machine Stack size of each thread is limited, we can check the default size by
java -XX:+PrintFlagsFinal -version|grep ThreadStackSize
We can also specify the maximum stack size
java -Xss2M //set the maximum thread stack size to be 2M
The minimum stack size of different OS are
OS | Default Stack Size |
---|---|
Windows |
40KB |
Linux AArch64 | 72KB |
Linux RISC-V | 72KB |
Linux s390 | 32KB |
Linux ARM | 32KB |
Linux x86 | 40KB |
What is StackOverflowError?
According to the Java Virtual Machine Specification, If the computation in a thread requires a larger Java Virtual Machine Stack than is permitted, the Java Virtual Machine throws a StackOverflowError
.
For example, if we call factorial(10000l)
, there will be a StackOverflowError
, because it exhausts the Java Virtual Machine Stack before calling factorial(1l)
which starts to pop the stack frame.
factorial(10000l);
java.lang.StackOverflowError
at io.github.sjmyuan.trampoline.StackOverflowTest.factorial(StackOverflowTest.java:12)
at io.github.sjmyuan.trampoline.StackOverflowTest.factorial(StackOverflowTest.java:12)
at io.github.sjmyuan.trampoline.StackOverflowTest.factorial(StackOverflowTest.java:12)
at io.github.sjmyuan.trampoline.StackOverflowTest.factorial(StackOverflowTest.java:12)
at io.github.sjmyuan.trampoline.StackOverflowTest.factorial(StackOverflowTest.java:12)
at io.github.sjmyuan.trampoline.StackOverflowTest.factorial(StackOverflowTest.java:12)
.....
Tail call and Tail recursion
A tail call is a method invocation which is the final action of the current method. For example
public Integer add(Integer x, Integer y) {
return x + y;
}
public Ineger substract(Integer x, Integer y) {
return add(x, -1*y); // tail call
}
In the above code, the invocation of add(x, <result of -1*y>)
is a tail call.
But in the factorial
example, factorial(n-1)
is not a tail call, because the final action is *
public Long factorial(Long n) {
if (n == 1) {
return 1l;
}
Long nextFacorial = factorial(n-1); // not tail call
return n * nextFactorial; // final action of method
}
If the tail call is the same as the current method, we say the current method is tail recursion. For example
public Long fibonacci(Long n, Long a, Long b) { // tail recursion
if (n == 0) {
return a;
}
if (n == 1) {
return b;
}
return fibonacci(n - 1, b, a + b); // tail call
}
How to eliminate tail recursion?
Tail recursion can be eliminated by a while
loop. For example
public Long fibonacci(Long n, Long a, Long b) {
Long nParam = n;
Long aParam = a;
Long bParam = b;
while (true) {
if (nParam == 0) {
return aParam;
}
if (nParam == 1) {
return bParam;
}
nParam = nParam - 1;
Long aCurrent = aParam;
aParam = bParam;
bParam = aCurrent + bParam;
}
}
We can use the following steps to eliminate any tail recursion
- Create a local variable for each parameter, like
nParam
,aParam
, andbParam
. - Wrap the body of the method with a
while(true)
loop, and replace the reference of the parameter with the reference of the corresponding local variable, like lines 5 - 15. - Replace the tail call of the current method with local variables assignment, which assigns the parameter of the tail call to the corresponding local variable, like lines 12 - 15.
It’s possible to eliminate the tail recursion by the compiler automatically, Scala already implemented it, but Java does not support it yet. One of the reasons is polymorphic, the compiler won’t know if the method has been overridden by a child, so it can’t use the current method body to eliminate the tail recursion. Even in Scala, it requires the tail recursion method to be private or final, which can’t be overridden.
How to eliminate StackOverflowError?
A method without method invocation won’t throw StackOverflowError
.
Let’s focus on the method with method invocation. We know
StackOverflowError
is caused by the limited size of the Java Virtual Machine Stack- We can eliminate the tail recursion by a
while
loop that requires fewer stack frames. - The heap size is larger than the Java Virtual Machine Stack size
Could we utilize tail recursion and heap to eliminate StackOverflowError
? The answer is trampoline.
What is trampoline?
Trampoline is a technique to trade stack for heap. The root cause of StackOverflowError
is the current method need to wait for the method invocation to complete to release the stack frame. Trampoline allows the current method to release the stack frame immediately after the method invocation because it stores the invocation state in the heap, not in the stack.
The trade-off is we need to pick up the invocation state from the heap to execute, so there is a scheduler in trampoline to do this work and we can implement it by tail recursion.
Not like normal method invocation, the stack of the trampoline will increase and decrease just like a trampoline
CPS
To apply trampoline easily, we need to convert the method from direct style to CPS first.
CPS means continuation-passing style, it is a style of programming in which control is passed explicitly in the form of a continuation.
To write a CPS method, we need to pass an extra argument, its type is function and we call it continuation. When the current method completes the computation, it will pass the result to the continuation instead of returning it.
For example, we can rewrite factorial
to CPS
public void factorial(Long n, Consumer<Long> continuation) {
if (n == 1) {
continuation.accept(1l);
return;
}
factorial(n - 1, (Long result) -> continuation.accept(n * result));
}
It’s a tail recursion now, we can convert it to a while
loop
public void factorial(Long n, Consumer<Long> continuation) {
Long nParam = n;
Consumer<Long> continuationParam = continuation;
while (true) {
if (nParam == 1) {
continuationParam.accept(1l);
return;
}
Long nCurrent = nParam;
nParam = nParam - 1;
final Consumer<Long> currentContinuation = continuationParam;
continuationParam = (Long result) -> currentContinuation.accept(nCurrent * result);
}
}
But it still throws StackOverflowError
factorial(10000l, (x) -> {});
java.lang.StackOverflowError
at java.base/java.lang.Long.longValue(Long.java:1353)
at io.github.sjmyuan.trampoline.CPSTest.lambda$1(CPSTest.java:28)
at io.github.sjmyuan.trampoline.CPSTest.lambda$1(CPSTest.java:28)
at io.github.sjmyuan.trampoline.CPSTest.lambda$1(CPSTest.java:28)
at io.github.sjmyuan.trampoline.CPSTest.lambda$1(CPSTest.java:28)
at io.github.sjmyuan.trampoline.CPSTest.lambda$1(CPSTest.java:28)
at io.github.sjmyuan.trampoline.CPSTest.lambda$1(CPSTest.java:28)
at io.github.sjmyuan.trampoline.CPSTest.lambda$1(CPSTest.java:28)
at io.github.sjmyuan.trampoline.CPSTest.lambda$1(CPSTest.java:28)
....
continuationParam.accept(1l)
in line 6 thrown the error, because a new currentConinuation
call will be made in each loop in line 13.
Not all the CPS methods are tail recursion, for example
public boolean isEven(Long n) {
if (n == 0)
return true;
return isOdd(n - 1);
}
public boolean isOdd(Long n) {
if (n == 0)
return false;
return isEven(n - 1);
}
We can rewrite these two methods to CPS
public void isEven(Long n, Consumer<Boolean> continuation) {
if (n == 0) {
continuation.accept(true);
return;
}
isOdd(n - 1, (result) -> continuation.accept(result));
}
public void isOdd(Long n, Consumer<Boolean> continuation) {
if (n == 0) {
continuation.accept(false);
return;
}
isEven(n - 1, (result) -> continuation.accept(result));
}
The above method invocations are tail calls, but isEven
and isOdd
are not tail recursion.
We can’t eliminateStackOverflowError
by CPS, and it’s easy to make mistakes using CPS, but all the method invocations are tail calls and the execution orders are explicit, which are easier to apply trampoline.
How to implement trampoline?
A trampoline is a loop that iteratively invokes thunk-returning functions. A thunk is a function without argument, like Supplier
in Java or Lazy
in Scala
Supplier<Long> thunk = () -> 1l
To eliminate the StackOverflowError
, we need to take over the control of method invocation from JVM. We can wrap the method invocation in a thunk, then the method won’t be invoked until we run the thunk
Supplier<Long> thunk = () -> factorial(4l); // create a Supplier instance, won't run factorial(4l)
thunk.get(); // run factorial(4)
If we use direct style, we have to run the thunk in the same method, which makes no sense to use it
private Long factorial(Long n) {
if (n == 1) {
return 1l;
}
Supplier<Long> thunk = () -> factorial(n-1);
return n * thunk.get();
}
If we use CPS, we also need to run thunk in the same method
public void factorial(Long n, Consumer<Long> continuation) {
if (n == 1) {
Supplier<Void> thunk = () -> {
continuation.accept(1l);
return null;
};
thunk.get();
return;
}
Supplier<Void> thunk = () -> {
factorial(n - 1, (Long result) -> continuation.accept(n * result));
return null;
};
thunk.get();
}
We don’t want the current method to run the thunk, and we can see the last action of the CPS method is thunk.get()
, then we can just return the thunk and decide when to run it by ourself
public Supplier<Void> factorial(Long n, Consumer<Long> continuation) {
if (n == 1) {
Supplier<Void> thunk = () -> {
continuation.accept(1l);
return null;
};
return thunk;
}
Supplier<Void> thunk = () -> {
Supplier<Void> thunkContinuation =
factorial(n - 1, (Long result) -> continuation.accept(n * result));
thunkContinuation.get();
return null;
};
return thunk;
}
We can’t do this in direct style, because other actions depend on the thunk result.
But this method can still throw StackOverflowError
, because thunk.get()
will call thunkContinuation.get()
. We return the tail call of factorial
by thunk, but miss the tail call in thunk.
To get the tail call in thunk, we need to change the thunk signature from Supplier<Void>
to Supplier<Supplier<Supplier<....>>>
. We can’t define the recursive type only by Supplier
, so we involve another type
public class More {
private Supplier<More> thunk;
public More(Supplier<More> thunk){
this.thunk = thunk;
}
}
Then we can return More
as thunk.
public More factorial(Long n, Function<Long, More> continuation) {
if (n == 1) {
return new More(() -> continuation.apply(1l));
}
return new More(() -> factorial(n - 1, (Long result) -> new More(() -> continuation.apply(n * result))));
}
We also change the signature of continuation
, there are two reasons
continuation
may also throwStackOverflowError
, which needs to apply trampoline- We can’t construct
More
bySupplier<Void>
whenn == 1
Then we can iterate thunk to calculate the result
public static void run(More trampoline) {
run(trampoline.getThunk().get());
}
But you may find we can’t call factorial
, because we can’t construct More
in continuation
, if we want to construct More
, we need another More
instance, which is a dead loop.
More trampoline = factorial(4l, x -> new More(y -> new More(z -> ....)))
So we need a new class Done
to signal the end of the computation, it should have the same parent as More
public interface Trampoline {
}
public class Done implements Trampoline {
public Done() {
}
}
public class More implements Trampoline {
private Supplier<Trampoline> thunk;
public More(Supplier<Trampoline> thunk) {
this.thunk = thunk;
}
}
The run
method also need to detect the Done
signal to end the computation
public static void run(Trampoline trampoline) {
if (trampoline instanceof Done) {
return;
}
run(((More) trampoline).getThunk().get());
}
Now the factorial
become
public Trampoline factorial(Long n, Function<Long, Trampoline> continuation) {
if (n == 1) {
return new More(() -> continuation.apply(1l));
}
return new More(() -> factorial(n - 1, (Long result) -> new More(() -> continuation.apply(n * result))));
}
And we can call it
Trampoline trampoline = factorial(4l, x -> new Done())
run(trampoline);
Remember that tail recursion can be eliminated by while
loop, there is no StackOverflowError
even we call factorial(10000l, x -> Done())
public static void run(Trampoline trampoline) {
Trampoline trampolineParam = trampoline;
while (true) {
if (trampolineParam instanceof Done) {
return;
}
trampolineParam = ((More) trampolineParam).getThunk().get();
}
}
Let’s summarize how to apply trampoline to the CPS method
- Change return type from void to
Trampoline
- Change continuation from
Consumer<Long>
toFunction<Long, Trampoline>
- Replace all the tail calls with
More
, including the tail call in continuation. - Iterate
Trampoline
byrun
method
The stack state of run(facorial(4))
is
We can see that each method invocation return immediately and there are no accumulated stack frames.
But If we don’t replace the tail call with More
in continuation, the implementation will be
public Trampoline factorial(Long n, Function<Long, Trampoline> continuation) {
if (n == 1) {
return new More(() -> continuation.apply(1l));
}
return new More(() -> factorial(n - 1,
result -> continuation.apply(n * result))); // call continuation.apply directly
}
It will throw StackOverflowError
, the stack state of run(facorial(4))
is
We can see that the stack frames are accumulated when calling continuation3(1)
because it doesn’t return thunk, JVM invokes nested continuation
and store method state in stack automatically, we can’t control it by trampoline.
How to make it easy to use?
There are some pain points for the usage of trampoline
- CPS is not customary in our daily work
- Easy to make mistake when replacing tail call with
More
CPS uses the stack to store the continuation function, which passes the continuation as a method argument. Considering all the methods that will apply trampoline, we can use the heap to store the relationship between result and continuation.
Let’s add one more class FlatMap
like this
public class FlatMap<A, B> implements Trampoline<B> {
private Trampoline<A> lastResult;
private Function<A, Trampoline<B>> continuation;
public FlatMap(Trampoline<A> lastResult, Function<A, Trampoline<B>> continuation) {
this.lastResult = lastResult;
this.continuation = continuation;
}
}
We store both results of the last expression and the following continuation in FlatMap
, and we allow the continuation to convert the expression result to any type.
Because we use trampoline to store the expression result, Do
and More
need to be refactored
public interface Trampoline<T> {
}
public class Done<T> implements Trampoline<T> {
private T result;
public Done(T result) {
this.result = result;
}
}
public class More<T> implements Trampoline<T> {
private Supplier<Trampoline<T>> thunk;
public More(Supplier<Trampoline<T>> thunk) {
this.thunk = thunk;
}
}
The run
method also needs to handle FlatMap
- If the last result is
Done
, just apply continuation to the result - If the last result is
More
, invoke the thunk and create a newFlatMap
to apply continuation to the thunk result - If the last result is
FlatMap
, changeFlatMap(FlatMap(trampoline, g), f)
toFlatMap(trampoline, x -> FlatMap(g(x), f))
public static <S> S run(Trampoline<S> trampoline) {
if (trampoline instanceof Done) {
return ((Done<S>) trampoline).getResult();
} else if (trampoline instanceof More) {
return run(((More<S>) trampoline).getThunk().get());
} else {
FlatMap<Object, S> continuation = (FlatMap<Object, S>) trampoline;
Trampoline<Object> lastResult = continuation.getLastResult();
Function<Object, Trampoline<S>> continuationFunc = continuation.getContinuation();
if (lastResult instanceof FlatMap) {
FlatMap<Object, Object> lastResultContinuation =
(Continuation<Object, Object>) lastResult;
return run(new FlatMap<Object, S>(lastResultContinuation.getLastResult(),
x -> new FlatMap<Object, S>(
lastResultContinuation.getContinuation().apply(x),
continuationFunc)));
} else if (lastResult instanceof More) {
return run(new FlatMap<Object, S>(((More<Object>) lastResult).getThunk().get(),
continuationFunc));
} else {
return run(continuationFunc.apply(((Done<Object>) lastResult).getResult()));
}
}
}
Now we don’t need to convert the method to CPS, we can apply trampoline directly like this
private Trampoline<Long> factorial(Long n) {
if (n == 1) {
return new Done<Long>(1l);
}
return new FlatMap<Long, Long>(factorial(n - 1), x -> new Done<Long>(n * x));
}
Oops! it still throws StackOverflowError
, why?
Trampoline<Long> trampoline = factorial(10000l);
Trampoline.run(trampoline);
java.lang.StackOverflowError
at io.github.sjmyuan.trampoline.v3.TrampolineTest.factorialTrampoline(TrampolineTest.java:11)
at io.github.sjmyuan.trampoline.v3.TrampolineTest.factorialTrampoline(TrampolineTest.java:11)
at io.github.sjmyuan.trampoline.v3.TrampolineTest.factorialTrampoline(TrampolineTest.java:11)
at io.github.sjmyuan.trampoline.v3.TrampolineTest.factorialTrampoline(TrampolineTest.java:11)
at io.github.sjmyuan.trampoline.v3.TrampolineTest.factorialTrampoline(TrampolineTest.java:11)
at io.github.sjmyuan.trampoline.v3.TrampolineTest.factorialTrampoline(TrampolineTest.java:11)
at io.github.sjmyuan.trampoline.v3.TrampolineTest.factorialTrampoline(TrampolineTest.java:11)
at io.github.sjmyuan.trampoline.v3.TrampolineTest.factorialTrampoline(TrampolineTest.java:11)
at io.github.sjmyuan.trampoline.v3.TrampolineTest.factorialTrampoline(TrampolineTest.java:11)
at io.github.sjmyuan.trampoline.v3.TrampolineTest.factorialTrampoline(TrampolineTest.java:11)
....
The reason is we call factorial(n - 1)
directly
new FlatMap<Long, Long>(factorial(n - 1), x -> new Done<Long>(n * x))
Before we create the instance of FlatMap
, we need to wait for the return of factorial(n - 1)
, which will create another FlatMap
and wait for the return of factorial(n - 2)
. The stack frames are accumulated.
We can use More
to avoid this scenario.
new FlatMap<Long, Long>(new More<Long>(() -> factorial(n - 1)),
x -> new Done<Long>(n * x));
The most important thing here is to wrap all the method invocation with More
.
But it’s still not convenient to construct the trampoline class, we can add some utility methods
public static <A> Trampoline<A> of(A v) {
return new Done<A>(v);
}
public static <A> Trampoline<A> suspend(Supplier<Trampoline<A>> thunk) {
return new More<A>(thunk);
}
<B> Trampoline<B> flatMap(Function<T, Trampoline<B>> continuation);
flatMap
is a method implemented by each class
// Done and More
public <B> Trampoline<B> flatMap(Function<T, Trampoline<B>> continuation) {
return new FlatMap<T, B>(this, continuation);
}
//FlatMap
public <C> Trampoline<C> flatMap(Function<B, Trampoline<C>> nextContinuation) {
return new FlatMap<A, C>(lastResult,
x -> Trampoline.suspend(() -> continuation.apply(x)).flatMap(nextContinuation));
}
We will make the constructor of FlatMap
to be protected, then we can apply the strict rule to the usage of FlatMap
to avoid the direct call when constructing it.
With these methods, factorial
can be cleaner
public Trampoline<Long> factorial(Long n) {
if (n == 1) {
return Trampoline.of(1l);
}
return Trampoline.suspend(() -> factorial(n - 1))
.flatMap(x -> Trampoline.of(n * x));
}
Summary
Trampoline is an important technique to eliminate StackOverflowError
in functional programming. It’s also essential knowledge to understand the implementation details of functional programming libraries.
According to the implementation, trampoline is a Free Monad, the missed part is themap
public <B> Trampoline<B> map(Function<T, B> continuation) {
return new FlatMap<T, B>(this, x -> Trampoline.suspend(() -> Trampoline.of(continuation.apply(x))));
}
Most of the above contents are explanations of Stackless Scala With Free Monads in Java, but there are some points I don’t understand
-
The
StackOverflowError
caused by the left-leaning tower of FlatMaps in section 4.3 An easy thing to get wrongI tried to reproduce this scenario by
@Test public void tooManyLeftAssociateContinuationWillNotThrowError() { Trampoline<Long> trampoline = new Done<Long>(1l); for (int i = 1; i < 50000; i++) { trampoline = new FlatMap<Long, Long>(trampoline, x -> new Done<Long>(x)); } assertThat(Trampoline.run(trampoline)).isEqualTo(1); }
But it didn’t throw
StackOverflowError
, and I also draw the stack state, it won’t blow up the stack unless some continuation blows up it. It’s not the problem of left-associate FlatMaps, it’s the problem of continuation which is out of our control. -
The flatMap implementation in 4.3 An easy thing to get wrong
The flatMap implementation in the article is
def flatMap [B ]( f: A => Trampoline [B ]): Trampoline [ B] = this match { case FlatMap (a , g) => FlatMap (a , (x: Any ) = > g (x) flatMap f) case x => FlatMap (x , f ) }
It throws
StackOverflow
in the following scenario@Test public void tooManyFlatMapWillThrowError() { Trampoline<Long> trampoline = Trampoline.of(1l); for (int i = 1; i < 50000; i++) { trampoline = trampoline.flatMap(x -> Trampoline.of(x)); } assertThat(Trampoline.run(trampoline)).isEqualTo(1); }
The stack state is
We can fix it by suspend
continuation.apply(x)
($g(x)$ in the article implementation)//FlatMap public <C> Trampoline<C> flatMap(Function<B, Trampoline<C>> nextContinuation) { return new FlatMap<A, C>(lastResult, x -> Trampoline.suspend(() -> continuation.apply(x)).flatMap(nextContinuation)); }
The stack state is
I already put all the code in trampoline-example, welcome to review and let me know if there is any mistake, hope this post can help you to understand Trampoline.
Comments