Skip to content

Commit d57e5e0

Browse files
committed
chore: TCO for tailstrict and Materializer.
1 parent 7d5f5b3 commit d57e5e0

7 files changed

Lines changed: 908 additions & 94 deletions

File tree

sjsonnet/src/sjsonnet/Error.scala

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,20 @@ class Error(msg: String, stack: List[Error.Frame] = Nil, underlying: Option[Thro
1717
def addFrame(pos: Position, expr: Expr = null)(implicit ev: EvalErrorScope): Error = {
1818
if (stack.isEmpty || alwaysAddPos(expr)) {
1919
val exprErrorString = if (expr == null) null else expr.exprErrorString
20-
val newFrame = new Error.Frame(pos, exprErrorString)
21-
stack match {
22-
case s :: ss if s.pos == pos =>
23-
if (s.exprErrorString == null && exprErrorString != null) copy(stack = newFrame :: ss)
24-
else this
25-
case _ => copy(stack = newFrame :: stack)
26-
}
20+
addFrameString(pos, exprErrorString)
2721
} else this
2822
}
2923

24+
def addFrameString(pos: Position, exprErrorString: String)(implicit ev: EvalErrorScope): Error = {
25+
val newFrame = new Error.Frame(pos, exprErrorString)
26+
stack match {
27+
case s :: ss if s.pos == pos =>
28+
if (s.exprErrorString == null && exprErrorString != null) copy(stack = newFrame :: ss)
29+
else this
30+
case _ => copy(stack = newFrame :: stack)
31+
}
32+
}
33+
3034
def asSeenFrom(ev: EvalErrorScope): Error =
3135
copy(stack = stack.map(_.asSeenFrom(ev)))
3236

sjsonnet/src/sjsonnet/Evaluator.scala

Lines changed: 137 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -202,13 +202,22 @@ class Evaluator(
202202
}
203203
}
204204

205+
/**
206+
* Function application entry points (visitApply/visitApply0-3 for user functions,
207+
* visitApplyBuiltin/visitApplyBuiltin0-4 for built-in functions).
208+
*
209+
* When `e.tailstrict` is true, the result is wrapped in `TailCall.resolve()` which iteratively
210+
* resolves any [[TailCall]] chain. When false, arguments are wrapped as lazy thunks to preserve
211+
* Jsonnet's default lazy evaluation semantics, and `Val.Func.apply` resolves any TailCall
212+
* internally via `TailCall.resolve` before returning.
213+
*/
205214
protected def visitApply(e: Apply)(implicit scope: ValScope): Val = {
206215
val lhs = visitExpr(e.value)
207216
implicit val tailstrictMode: TailstrictMode =
208217
if (e.tailstrict) TailstrictModeEnabled else TailstrictModeDisabled
209218

210219
if (e.tailstrict) {
211-
lhs.cast[Val.Func].apply(e.args.map(visitExpr(_)), e.namedNames, e.pos)
220+
TailCall.resolve(lhs.cast[Val.Func].apply(e.args.map(visitExpr(_)), e.namedNames, e.pos))
212221
} else {
213222
lhs.cast[Val.Func].apply(e.args.map(visitAsLazy(_)), e.namedNames, e.pos)
214223
}
@@ -218,15 +227,19 @@ class Evaluator(
218227
val lhs = visitExpr(e.value)
219228
implicit val tailstrictMode: TailstrictMode =
220229
if (e.tailstrict) TailstrictModeEnabled else TailstrictModeDisabled
221-
lhs.cast[Val.Func].apply0(e.pos)
230+
if (e.tailstrict) {
231+
TailCall.resolve(lhs.cast[Val.Func].apply0(e.pos))
232+
} else {
233+
lhs.cast[Val.Func].apply0(e.pos)
234+
}
222235
}
223236

224237
protected def visitApply1(e: Apply1)(implicit scope: ValScope): Val = {
225238
val lhs = visitExpr(e.value)
226239
implicit val tailstrictMode: TailstrictMode =
227240
if (e.tailstrict) TailstrictModeEnabled else TailstrictModeDisabled
228241
if (e.tailstrict) {
229-
lhs.cast[Val.Func].apply1(visitExpr(e.a1), e.pos)
242+
TailCall.resolve(lhs.cast[Val.Func].apply1(visitExpr(e.a1), e.pos))
230243
} else {
231244
val l1 = visitAsLazy(e.a1)
232245
lhs.cast[Val.Func].apply1(l1, e.pos)
@@ -239,7 +252,7 @@ class Evaluator(
239252
if (e.tailstrict) TailstrictModeEnabled else TailstrictModeDisabled
240253

241254
if (e.tailstrict) {
242-
lhs.cast[Val.Func].apply2(visitExpr(e.a1), visitExpr(e.a2), e.pos)
255+
TailCall.resolve(lhs.cast[Val.Func].apply2(visitExpr(e.a1), visitExpr(e.a2), e.pos))
243256
} else {
244257
val l1 = visitAsLazy(e.a1)
245258
val l2 = visitAsLazy(e.a2)
@@ -253,7 +266,9 @@ class Evaluator(
253266
if (e.tailstrict) TailstrictModeEnabled else TailstrictModeDisabled
254267

255268
if (e.tailstrict) {
256-
lhs.cast[Val.Func].apply3(visitExpr(e.a1), visitExpr(e.a2), visitExpr(e.a3), e.pos)
269+
TailCall.resolve(
270+
lhs.cast[Val.Func].apply3(visitExpr(e.a1), visitExpr(e.a2), visitExpr(e.a3), e.pos)
271+
)
257272
} else {
258273
val l1 = visitAsLazy(e.a1)
259274
val l2 = visitAsLazy(e.a2)
@@ -262,41 +277,48 @@ class Evaluator(
262277
}
263278
}
264279

265-
protected def visitApplyBuiltin0(e: ApplyBuiltin0): Val = e.func.evalRhs(this, e.pos)
280+
protected def visitApplyBuiltin0(e: ApplyBuiltin0): Val = {
281+
val result = e.func.evalRhs(this, e.pos)
282+
if (e.tailstrict) TailCall.resolve(result) else result
283+
}
266284

267285
protected def visitApplyBuiltin1(e: ApplyBuiltin1)(implicit scope: ValScope): Val = {
268286
if (e.tailstrict) {
269-
e.func.evalRhs(visitExpr(e.a1), this, e.pos)
287+
TailCall.resolve(e.func.evalRhs(visitExpr(e.a1), this, e.pos))
270288
} else {
271289
e.func.evalRhs(visitAsLazy(e.a1), this, e.pos)
272290
}
273291
}
274292

275293
protected def visitApplyBuiltin2(e: ApplyBuiltin2)(implicit scope: ValScope): Val = {
276294
if (e.tailstrict) {
277-
e.func.evalRhs(visitExpr(e.a1), visitExpr(e.a2), this, e.pos)
295+
TailCall.resolve(e.func.evalRhs(visitExpr(e.a1), visitExpr(e.a2), this, e.pos))
278296
} else {
279297
e.func.evalRhs(visitAsLazy(e.a1), visitAsLazy(e.a2), this, e.pos)
280298
}
281299
}
282300

283301
protected def visitApplyBuiltin3(e: ApplyBuiltin3)(implicit scope: ValScope): Val = {
284302
if (e.tailstrict) {
285-
e.func.evalRhs(visitExpr(e.a1), visitExpr(e.a2), visitExpr(e.a3), this, e.pos)
303+
TailCall.resolve(
304+
e.func.evalRhs(visitExpr(e.a1), visitExpr(e.a2), visitExpr(e.a3), this, e.pos)
305+
)
286306
} else {
287307
e.func.evalRhs(visitAsLazy(e.a1), visitAsLazy(e.a2), visitAsLazy(e.a3), this, e.pos)
288308
}
289309
}
290310

291311
protected def visitApplyBuiltin4(e: ApplyBuiltin4)(implicit scope: ValScope): Val = {
292312
if (e.tailstrict) {
293-
e.func.evalRhs(
294-
visitExpr(e.a1),
295-
visitExpr(e.a2),
296-
visitExpr(e.a3),
297-
visitExpr(e.a4),
298-
this,
299-
e.pos
313+
TailCall.resolve(
314+
e.func.evalRhs(
315+
visitExpr(e.a1),
316+
visitExpr(e.a2),
317+
visitExpr(e.a3),
318+
visitExpr(e.a4),
319+
this,
320+
e.pos
321+
)
300322
)
301323
} else {
302324
e.func.evalRhs(
@@ -319,7 +341,7 @@ class Evaluator(
319341
arr(idx) = visitExpr(e.argExprs(idx))
320342
idx += 1
321343
}
322-
e.func.evalRhs(arr, this, e.pos)
344+
TailCall.resolve(e.func.evalRhs(arr, this, e.pos))
323345
} else {
324346
while (idx < e.argExprs.length) {
325347
val boundIdx = idx
@@ -638,10 +660,107 @@ class Evaluator(
638660
scope: ValScope): Val.Func =
639661
new Val.Func(outerPos, scope, params) {
640662
def evalRhs(vs: ValScope, es: EvalScope, fs: FileScope, pos: Position): Val =
641-
visitExpr(rhs)(vs)
663+
visitExprWithTailCallSupport(rhs)(vs)
642664
override def evalDefault(expr: Expr, vs: ValScope, es: EvalScope): Val = visitExpr(expr)(vs)
643665
}
644666

667+
/**
668+
* Evaluate an expression with tail-call support. When a `tailstrict` call is encountered at a
669+
* potential tail position, returns a [[TailCall]] sentinel instead of recursing, enabling
670+
* `TailCall.resolve` in `visitApply*` to iterate rather than grow the JVM stack.
671+
*
672+
* Potential tail positions are propagated through: IfElse (both branches), LocalExpr (returned),
673+
* and AssertExpr (returned). All other expression types delegate to normal `visitExpr`.
674+
*/
675+
@tailrec
676+
private def visitExprWithTailCallSupport(e: Expr)(implicit scope: ValScope): Val = e match {
677+
case e: IfElse =>
678+
visitExpr(e.cond) match {
679+
case Val.True(_) => visitExprWithTailCallSupport(e.`then`)
680+
case Val.False(_) =>
681+
e.`else` match {
682+
case null => Val.Null(e.pos)
683+
case v => visitExprWithTailCallSupport(v)
684+
}
685+
case v => Error.fail("Need boolean, found " + v.prettyName, e.pos)
686+
}
687+
case e: LocalExpr =>
688+
val bindings = e.bindings
689+
val s =
690+
if (bindings == null) scope
691+
else {
692+
val base = scope.length
693+
val newScope = scope.extendBy(bindings.length)
694+
var i = 0
695+
while (i < bindings.length) {
696+
val b = bindings(i)
697+
newScope.bindings(base + i) = b.args match {
698+
case null => visitAsLazy(b.rhs)(newScope)
699+
case argSpec =>
700+
new Lazy(() => visitMethod(b.rhs, argSpec, b.pos)(newScope))
701+
}
702+
i += 1
703+
}
704+
newScope
705+
}
706+
visitExprWithTailCallSupport(e.returned)(s)
707+
case e: AssertExpr =>
708+
if (!visitExpr(e.asserted.value).isInstanceOf[Val.True]) {
709+
e.asserted.msg match {
710+
case null => Error.fail("Assertion failed", e)
711+
case msg =>
712+
Error.fail("Assertion failed: " + materializeError(visitExpr(msg)), e)
713+
}
714+
}
715+
visitExprWithTailCallSupport(e.returned)
716+
// Tail-position tailstrict calls: match TailstrictableExpr to unify the tailstrict guard,
717+
// then dispatch by concrete type.
718+
//
719+
// - Apply* (user function calls): construct a TailCall sentinel that the caller's
720+
// TailCall.resolve loop will resolve iteratively, avoiding JVM stack growth for
721+
// tail-recursive calls.
722+
// - ApplyBuiltin* (built-in function calls): fall through to visitExpr, which dispatches to
723+
// visitApplyBuiltin*. Those methods already wrap their result in TailCall.resolve() when
724+
// tailstrict=true, resolving any TailCall that a user-defined callback (e.g. the function
725+
// argument to std.makeArray or std.sort) may have returned.
726+
case e: TailstrictableExpr if e.tailstrict =>
727+
e match {
728+
case e: Apply =>
729+
try {
730+
val func = visitExpr(e.value).cast[Val.Func]
731+
new TailCall(func, e.args.map(visitExpr(_)).asInstanceOf[Array[Eval]], e.namedNames, e)
732+
} catch Error.withStackFrame(e)
733+
case e: Apply0 =>
734+
try {
735+
val func = visitExpr(e.value).cast[Val.Func]
736+
new TailCall(func, Evaluator.emptyLazyArray, null, e)
737+
} catch Error.withStackFrame(e)
738+
case e: Apply1 =>
739+
try {
740+
val func = visitExpr(e.value).cast[Val.Func]
741+
new TailCall(func, Array[Eval](visitExpr(e.a1)), null, e)
742+
} catch Error.withStackFrame(e)
743+
case e: Apply2 =>
744+
try {
745+
val func = visitExpr(e.value).cast[Val.Func]
746+
new TailCall(func, Array[Eval](visitExpr(e.a1), visitExpr(e.a2)), null, e)
747+
} catch Error.withStackFrame(e)
748+
case e: Apply3 =>
749+
try {
750+
val func = visitExpr(e.value).cast[Val.Func]
751+
new TailCall(
752+
func,
753+
Array[Eval](visitExpr(e.a1), visitExpr(e.a2), visitExpr(e.a3)),
754+
null,
755+
e
756+
)
757+
} catch Error.withStackFrame(e)
758+
case _ => visitExpr(e)
759+
}
760+
case _ =>
761+
visitExpr(e)
762+
}
763+
645764
def visitBindings(bindings: Array[Bind], scope: => ValScope): Array[Eval] = {
646765
val arrF = new Array[Eval](bindings.length)
647766
var i = 0

sjsonnet/src/sjsonnet/Expr.scala

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,27 @@ trait Expr {
2323

2424
override def toString: String = s"$exprErrorString@$pos"
2525
}
26+
27+
/**
28+
* Marker trait for [[Expr]] nodes that represent function calls eligible for tail-call
29+
* optimization. All Apply* (user function calls) and ApplyBuiltin* (built-in function calls) mix in
30+
* this trait, providing a uniform `tailstrict` flag. The evaluator handles the two families
31+
* differently when `tailstrict` is true:
32+
*
33+
* - '''User function calls''' (Apply*) in tail position: the evaluator constructs a [[TailCall]]
34+
* sentinel and returns it to the caller's [[TailCall.resolve]] trampoline loop, avoiding JVM
35+
* stack growth for tail-recursive calls.
36+
* - '''Built-in function calls''' (ApplyBuiltin*): the evaluator wraps the result in
37+
* [[TailCall.resolve]] at the call site, resolving any [[TailCall]] that a user-defined
38+
* callback (e.g. the function argument to `std.makeArray` or `std.sort`) may have returned.
39+
*
40+
* @see
41+
* [[TailCall]] for the sentinel value used in the TCO protocol
42+
*/
43+
trait TailstrictableExpr extends Expr {
44+
def tailstrict: Boolean
45+
}
46+
2647
object Expr {
2748
private final def arrStr(a: Array[?]): String = {
2849
if (a == null) "null" else a.mkString("[", ", ", "]")
@@ -189,17 +210,19 @@ object Expr {
189210
args: Array[Expr],
190211
namedNames: Array[String],
191212
tailstrict: Boolean)
192-
extends Expr {
213+
extends TailstrictableExpr {
193214
final override private[sjsonnet] def tag = ExprTags.Apply
194215
}
195-
final case class Apply0(pos: Position, value: Expr, tailstrict: Boolean) extends Expr {
216+
final case class Apply0(pos: Position, value: Expr, tailstrict: Boolean)
217+
extends TailstrictableExpr {
196218
final override private[sjsonnet] def tag = ExprTags.Apply0
197219
}
198-
final case class Apply1(pos: Position, value: Expr, a1: Expr, tailstrict: Boolean) extends Expr {
220+
final case class Apply1(pos: Position, value: Expr, a1: Expr, tailstrict: Boolean)
221+
extends TailstrictableExpr {
199222
final override private[sjsonnet] def tag = ExprTags.Apply1
200223
}
201224
final case class Apply2(pos: Position, value: Expr, a1: Expr, a2: Expr, tailstrict: Boolean)
202-
extends Expr {
225+
extends TailstrictableExpr {
203226
final override private[sjsonnet] def tag = ExprTags.Apply2
204227
}
205228
final case class Apply3(
@@ -209,25 +232,25 @@ object Expr {
209232
a2: Expr,
210233
a3: Expr,
211234
tailstrict: Boolean)
212-
extends Expr {
235+
extends TailstrictableExpr {
213236
final override private[sjsonnet] def tag = ExprTags.Apply3
214237
}
215238
final case class ApplyBuiltin(
216239
pos: Position,
217240
func: Val.Builtin,
218241
argExprs: Array[Expr],
219242
tailstrict: Boolean)
220-
extends Expr {
243+
extends TailstrictableExpr {
221244
final override private[sjsonnet] def tag = ExprTags.ApplyBuiltin
222245
override def exprErrorString: String = s"std.${func.functionName}"
223246
}
224247
final case class ApplyBuiltin0(pos: Position, func: Val.Builtin0, tailstrict: Boolean)
225-
extends Expr {
248+
extends TailstrictableExpr {
226249
final override private[sjsonnet] def tag = ExprTags.ApplyBuiltin0
227250
override def exprErrorString: String = s"std.${func.functionName}"
228251
}
229252
final case class ApplyBuiltin1(pos: Position, func: Val.Builtin1, a1: Expr, tailstrict: Boolean)
230-
extends Expr {
253+
extends TailstrictableExpr {
231254
final override private[sjsonnet] def tag = ExprTags.ApplyBuiltin1
232255
override def exprErrorString: String = s"std.${func.functionName}"
233256
}
@@ -237,7 +260,7 @@ object Expr {
237260
a1: Expr,
238261
a2: Expr,
239262
tailstrict: Boolean)
240-
extends Expr {
263+
extends TailstrictableExpr {
241264
final override private[sjsonnet] def tag = ExprTags.ApplyBuiltin2
242265
override def exprErrorString: String = s"std.${func.functionName}"
243266
}
@@ -248,7 +271,7 @@ object Expr {
248271
a2: Expr,
249272
a3: Expr,
250273
tailstrict: Boolean)
251-
extends Expr {
274+
extends TailstrictableExpr {
252275
final override private[sjsonnet] def tag = ExprTags.ApplyBuiltin3
253276
override def exprErrorString: String = s"std.${func.functionName}"
254277
}
@@ -260,7 +283,7 @@ object Expr {
260283
a3: Expr,
261284
a4: Expr,
262285
tailstrict: Boolean)
263-
extends Expr {
286+
extends TailstrictableExpr {
264287
override private[sjsonnet] def tag = ExprTags.ApplyBuiltin4
265288
override def exprErrorString: String = s"std.${func.functionName}"
266289
}

0 commit comments

Comments
 (0)