Skip to content

Commit 20827c0

Browse files
Re-enabling CurriedFunctionBenchmarks (#14241)
Fixes #8321 by adding `@Tail_Call` annotation and avoiding `StackOverflowError`.
1 parent 437be0d commit 20827c0

File tree

1 file changed

+112
-0
lines changed

1 file changed

+112
-0
lines changed
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
package org.enso.interpreter.bench.benchmarks.semantic;
2+
3+
import java.nio.file.Paths;
4+
import java.util.concurrent.TimeUnit;
5+
import java.util.function.Function;
6+
import java.util.logging.Level;
7+
import org.enso.common.RuntimeOptions;
8+
import org.graalvm.polyglot.Context;
9+
import org.graalvm.polyglot.Value;
10+
import org.graalvm.polyglot.io.IOAccess;
11+
import org.openjdk.jmh.annotations.Benchmark;
12+
import org.openjdk.jmh.annotations.BenchmarkMode;
13+
import org.openjdk.jmh.annotations.Fork;
14+
import org.openjdk.jmh.annotations.Measurement;
15+
import org.openjdk.jmh.annotations.Mode;
16+
import org.openjdk.jmh.annotations.OutputTimeUnit;
17+
import org.openjdk.jmh.annotations.Scope;
18+
import org.openjdk.jmh.annotations.Setup;
19+
import org.openjdk.jmh.annotations.State;
20+
import org.openjdk.jmh.annotations.Warmup;
21+
import org.openjdk.jmh.infra.BenchmarkParams;
22+
import org.openjdk.jmh.infra.Blackhole;
23+
24+
@BenchmarkMode(Mode.AverageTime)
25+
@Fork(1)
26+
@Warmup(iterations = 3)
27+
@Measurement(iterations = 3)
28+
@OutputTimeUnit(TimeUnit.MILLISECONDS)
29+
@State(Scope.Benchmark)
30+
public class CurriedFunctionBenchmarks {
31+
private Value fn;
32+
private Value avg;
33+
34+
@Setup
35+
public void initializeBenchmark(BenchmarkParams params) throws Exception {
36+
var ctx =
37+
Context.newBuilder()
38+
.allowExperimentalOptions(true)
39+
.option(RuntimeOptions.LOG_LEVEL, Level.WARNING.getName())
40+
.logHandler(System.err)
41+
.allowIO(IOAccess.ALL)
42+
.allowAllAccess(true)
43+
.option(
44+
"enso.languageHomeOverride",
45+
Paths.get("../../distribution/component").toFile().getAbsolutePath())
46+
.build();
47+
48+
var benchmarkName = SrcUtil.findName(params);
49+
var code =
50+
"""
51+
from Standard.Base import all
52+
53+
avg fn len =
54+
sum acc i = if i == len then acc else
55+
@Tail_Call sum (acc + fn i) i+1
56+
(sum 0 0) / len
57+
58+
type Holder
59+
three_times x = 3 * x
60+
61+
callback_curried =
62+
h = Holder
63+
h.three_times
64+
65+
callback_lambda =
66+
h = Holder
67+
(x -> h.three_times x)
68+
""";
69+
70+
var module = ctx.eval(SrcUtil.source(benchmarkName, code));
71+
72+
Function<String, Value> getMethod = (name) -> module.invokeMember("eval_expression", name);
73+
switch (benchmarkName) {
74+
case "averageLambda":
75+
{
76+
this.fn = getMethod.apply("callback_lambda");
77+
break;
78+
}
79+
case "averageCurried":
80+
{
81+
this.fn = getMethod.apply("callback_curried");
82+
break;
83+
}
84+
default:
85+
throw new IllegalStateException("Unexpected benchmark: " + params.getBenchmark());
86+
}
87+
this.avg = getMethod.apply("avg");
88+
}
89+
90+
@Benchmark
91+
public void averageLambda(Blackhole matter) {
92+
performBenchmark(matter);
93+
}
94+
95+
@Benchmark
96+
public void averageCurried(Blackhole matter) {
97+
performBenchmark(matter);
98+
}
99+
100+
private void performBenchmark(Blackhole hole) throws AssertionError {
101+
var average = avg.execute(fn, 10000);
102+
if (!average.fitsInDouble()) {
103+
throw new AssertionError("Shall be a double: " + average);
104+
}
105+
var result = (long) average.asDouble();
106+
boolean isResultCorrect = result == 14998;
107+
if (!isResultCorrect) {
108+
throw new AssertionError("Expecting reasonable average but was " + result + "\n" + fn);
109+
}
110+
hole.consume(result);
111+
}
112+
}

0 commit comments

Comments
 (0)