From 8e3b761d9ece3336f70265c9b2f4465336b4ef4c Mon Sep 17 00:00:00 2001 From: Islon Scherer Date: Tue, 1 Jul 2025 16:32:31 +0200 Subject: [PATCH 1/4] Add `iterate` and `iterateIndexed` methods for `Collections` in the stdlib. --- .../org/pkl/core/stdlib/base/ListNodes.java | 76 ++++++++++++++++++ .../org/pkl/core/stdlib/base/SetNodes.java | 79 +++++++++++++++++++ .../LanguageSnippetTests/input/api/list.pkl | 14 ++++ .../LanguageSnippetTests/input/api/set.pkl | 12 +++ .../LanguageSnippetTests/output/api/list.pcf | 12 +++ .../LanguageSnippetTests/output/api/set.pcf | 10 +++ stdlib/base.pkl | 18 +++++ 7 files changed, 221 insertions(+) diff --git a/pkl-core/src/main/java/org/pkl/core/stdlib/base/ListNodes.java b/pkl-core/src/main/java/org/pkl/core/stdlib/base/ListNodes.java index 42352823f..4555881ad 100644 --- a/pkl-core/src/main/java/org/pkl/core/stdlib/base/ListNodes.java +++ b/pkl-core/src/main/java/org/pkl/core/stdlib/base/ListNodes.java @@ -25,7 +25,10 @@ import org.pkl.core.ast.internal.IsInstanceOfNodeGen; import org.pkl.core.ast.lambda.*; import org.pkl.core.ast.type.TypeNode; +import org.pkl.core.ast.type.TypeNode.BooleanTypeNode; +import org.pkl.core.ast.type.TypeNode.PairTypeNode; import org.pkl.core.ast.type.TypeNode.UInt8TypeAliasTypeNode; +import org.pkl.core.ast.type.TypeNode.UnknownTypeNode; import org.pkl.core.ast.type.VmTypeMismatchException; import org.pkl.core.runtime.*; import org.pkl.core.stdlib.*; @@ -785,6 +788,79 @@ protected Object eval(VmList self, Object initial, VmFunction function) { } } + public abstract static class iterate extends ExternalMethod2Node { + @Child private ApplyVmFunction2Node applyLambdaNode = ApplyVmFunction2NodeGen.create(); + @Child @LateInit private TypeNode typeNode; + + private TypeNode getTypeNode() { + if (typeNode == null) { + CompilerDirectives.transferToInterpreterAndInvalidate(); + var section = VmUtils.unavailableSourceSection(); + var booleanTypeNode = new BooleanTypeNode(section); + var unknownTypeNode = new UnknownTypeNode(section); + typeNode = new PairTypeNode(section, booleanTypeNode, unknownTypeNode); + } + return typeNode; + } + + @Specialization + protected Object eval(VirtualFrame frame, VmList self, Object initial, VmFunction function) { + var iter = self.iterator(); + var typeNode = getTypeNode(); + var result = initial; + var loop = 0; + + while (iter.hasNext()) { + var elem = iter.next(); + var pairUnchecked = applyLambdaNode.execute(function, result, elem); + var pair = (VmPair) typeNode.executeEagerly(frame, pairUnchecked); + result = pair.getSecond(); + loop++; + if (!(Boolean) pair.getFirst()) break; + } + + LoopNode.reportLoopCount(this, loop); + return result; + } + } + + public abstract static class iterateIndexed extends ExternalMethod2Node { + @Child private ApplyVmFunction3Node applyLambdaNode = ApplyVmFunction3NodeGen.create(); + @Child @LateInit private TypeNode typeNode; + + private TypeNode getTypeNode() { + if (typeNode == null) { + CompilerDirectives.transferToInterpreterAndInvalidate(); + var section = VmUtils.unavailableSourceSection(); + var booleanTypeNode = new BooleanTypeNode(section); + var unknownTypeNode = new UnknownTypeNode(section); + typeNode = new PairTypeNode(section, booleanTypeNode, unknownTypeNode); + } + return typeNode; + } + + @Specialization + protected Object eval(VirtualFrame frame, VmList self, Object initial, VmFunction function) { + var iter = self.iterator(); + var typeNode = getTypeNode(); + var result = initial; + long index = 0; + var loop = 0; + + while (iter.hasNext()) { + var elem = iter.next(); + var pairUnchecked = applyLambdaNode.execute(function, index++, result, elem); + var pair = (VmPair) typeNode.executeEagerly(frame, pairUnchecked); + result = pair.getSecond(); + loop++; + if (!(Boolean) pair.getFirst()) break; + } + + LoopNode.reportLoopCount(this, loop); + return result; + } + } + public abstract static class reduce extends ExternalMethod1Node { @Child private ApplyVmFunction2Node applyLambdaNode = ApplyVmFunction2NodeGen.create(); diff --git a/pkl-core/src/main/java/org/pkl/core/stdlib/base/SetNodes.java b/pkl-core/src/main/java/org/pkl/core/stdlib/base/SetNodes.java index 285e45599..fe1bd9881 100644 --- a/pkl-core/src/main/java/org/pkl/core/stdlib/base/SetNodes.java +++ b/pkl-core/src/main/java/org/pkl/core/stdlib/base/SetNodes.java @@ -18,6 +18,7 @@ import com.oracle.truffle.api.CompilerDirectives; import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary; import com.oracle.truffle.api.dsl.Specialization; +import com.oracle.truffle.api.frame.VirtualFrame; import com.oracle.truffle.api.nodes.LoopNode; import org.pkl.core.ast.expression.binary.GreaterThanNode; import org.pkl.core.ast.expression.binary.GreaterThanNodeGen; @@ -26,11 +27,16 @@ import org.pkl.core.ast.internal.IsInstanceOfNode; import org.pkl.core.ast.internal.IsInstanceOfNodeGen; import org.pkl.core.ast.lambda.*; +import org.pkl.core.ast.type.TypeNode; +import org.pkl.core.ast.type.TypeNode.BooleanTypeNode; +import org.pkl.core.ast.type.TypeNode.PairTypeNode; +import org.pkl.core.ast.type.TypeNode.UnknownTypeNode; import org.pkl.core.runtime.*; import org.pkl.core.stdlib.*; import org.pkl.core.stdlib.base.CollectionNodes.CompareByNode; import org.pkl.core.stdlib.base.CollectionNodes.CompareNode; import org.pkl.core.stdlib.base.CollectionNodes.CompareWithNode; +import org.pkl.core.util.LateInit; // duplication between ListNodes and SetNodes is "intentional" // (sharing nodes between VmCollection subtypes results in @@ -561,6 +567,79 @@ protected Object eval(VmSet self, Object initial, VmFunction function) { } } + public abstract static class iterate extends ExternalMethod2Node { + @Child private ApplyVmFunction2Node applyLambdaNode = ApplyVmFunction2NodeGen.create(); + @Child @LateInit private TypeNode typeNode; + + private TypeNode getTypeNode() { + if (typeNode == null) { + CompilerDirectives.transferToInterpreterAndInvalidate(); + var section = VmUtils.unavailableSourceSection(); + var booleanTypeNode = new BooleanTypeNode(section); + var unknownTypeNode = new UnknownTypeNode(section); + typeNode = new PairTypeNode(section, booleanTypeNode, unknownTypeNode); + } + return typeNode; + } + + @Specialization + protected Object eval(VirtualFrame frame, VmSet self, Object initial, VmFunction function) { + var iter = self.iterator(); + var typeNode = getTypeNode(); + var result = initial; + var loop = 0; + + while (iter.hasNext()) { + var elem = iter.next(); + var pairUnchecked = applyLambdaNode.execute(function, result, elem); + var pair = (VmPair) typeNode.executeEagerly(frame, pairUnchecked); + result = pair.getSecond(); + loop++; + if (!(Boolean) pair.getFirst()) break; + } + + LoopNode.reportLoopCount(this, loop); + return result; + } + } + + public abstract static class iterateIndexed extends ExternalMethod2Node { + @Child private ApplyVmFunction3Node applyLambdaNode = ApplyVmFunction3NodeGen.create(); + @Child @LateInit private TypeNode typeNode; + + private TypeNode getTypeNode() { + if (typeNode == null) { + CompilerDirectives.transferToInterpreterAndInvalidate(); + var section = VmUtils.unavailableSourceSection(); + var booleanTypeNode = new BooleanTypeNode(section); + var unknownTypeNode = new UnknownTypeNode(section); + typeNode = new PairTypeNode(section, booleanTypeNode, unknownTypeNode); + } + return typeNode; + } + + @Specialization + protected Object eval(VirtualFrame frame, VmSet self, Object initial, VmFunction function) { + var iter = self.iterator(); + var typeNode = getTypeNode(); + var result = initial; + long index = 0; + var loop = 0; + + while (iter.hasNext()) { + var elem = iter.next(); + var pairUnchecked = applyLambdaNode.execute(function, index++, result, elem); + var pair = (VmPair) typeNode.executeEagerly(frame, pairUnchecked); + result = pair.getSecond(); + loop++; + if (!(Boolean) pair.getFirst()) break; + } + + LoopNode.reportLoopCount(this, loop); + return result; + } + } + public abstract static class reduce extends ExternalMethod1Node { @Child private ApplyVmFunction2Node applyLambdaNode = ApplyVmFunction2NodeGen.create(); diff --git a/pkl-core/src/test/files/LanguageSnippetTests/input/api/list.pkl b/pkl-core/src/test/files/LanguageSnippetTests/input/api/list.pkl index 47c9c2db7..99e5d9374 100644 --- a/pkl-core/src/test/files/LanguageSnippetTests/input/api/list.pkl +++ b/pkl-core/src/test/files/LanguageSnippetTests/input/api/list.pkl @@ -216,6 +216,20 @@ examples { List().foldBack(0, (x, acc) -> x + acc) } + ["iterate()"] { + list1.iterate(0, (acc, x) -> Pair(x < 2, x + acc)) + list1.iterate(0, (acc, x) -> Pair(true, x + acc)) + module.catch(() -> list1.iterate(0, (acc, x) -> x + acc)) + module.catch(() -> list1.iterate(0, (acc, x) -> Pair(x, x + acc))) + } + + ["iterateIndexed()"] { + list1.iterateIndexed(0, (i, acc, x) -> Pair(i < 1, x + acc)) + list1.iterateIndexed(0, (i, acc, x) -> Pair(true, x + acc)) + module.catch(() -> list1.iterateIndexed(0, (i, acc, x) -> x + acc)) + module.catch(() -> list1.iterateIndexed(0, (i, acc, x) -> Pair(i, x + acc))) + } + ["reduce()"] { list1.reduce((x, y) -> x + y) List(1).reduce((x, y) -> x + y) diff --git a/pkl-core/src/test/files/LanguageSnippetTests/input/api/set.pkl b/pkl-core/src/test/files/LanguageSnippetTests/input/api/set.pkl index 0e73308bb..43fe336ac 100644 --- a/pkl-core/src/test/files/LanguageSnippetTests/input/api/set.pkl +++ b/pkl-core/src/test/files/LanguageSnippetTests/input/api/set.pkl @@ -178,6 +178,18 @@ examples { Set().foldBack(0, (x, acc) -> x + acc) } + ["iterate()"] { + set1.iterate(0, (acc, x) -> Pair(x < 2, x + acc)) + set1.iterate(0, (acc, x) -> Pair(true, x + acc)) + module.catch(() -> set1.iterate(0, (acc, x) -> Pair(1, x + acc))) + } + + ["iterateIndexed()"] { + set1.iterateIndexed(0, (i, acc, x) -> Pair(i < 1, x + acc)) + set1.iterateIndexed(0, (i, acc, x) -> Pair(true, x + acc)) + module.catch(() -> set1.iterateIndexed(0, (i, acc, x) -> Pair(i, x + acc))) + } + ["reduce()"] { set1.reduce((x, y) -> x + y) Set(1).reduce((x, y) -> x + y) diff --git a/pkl-core/src/test/files/LanguageSnippetTests/output/api/list.pcf b/pkl-core/src/test/files/LanguageSnippetTests/output/api/list.pcf index abd53a9f9..060eae599 100644 --- a/pkl-core/src/test/files/LanguageSnippetTests/output/api/list.pcf +++ b/pkl-core/src/test/files/LanguageSnippetTests/output/api/list.pcf @@ -172,6 +172,18 @@ examples { 1 0 } + ["iterate()"] { + 3 + 6 + "Expected value of type `Pair`, but got type `Int`. Value: 1" + "Expected value of type `Boolean`, but got type `Int`. Value: 1" + } + ["iterateIndexed()"] { + 3 + 6 + "Expected value of type `Pair`, but got type `Int`. Value: 1" + "Expected value of type `Boolean`, but got type `Int`. Value: 0" + } ["reduce()"] { 6 1 diff --git a/pkl-core/src/test/files/LanguageSnippetTests/output/api/set.pcf b/pkl-core/src/test/files/LanguageSnippetTests/output/api/set.pcf index 833577cc3..e9fe5ae71 100644 --- a/pkl-core/src/test/files/LanguageSnippetTests/output/api/set.pcf +++ b/pkl-core/src/test/files/LanguageSnippetTests/output/api/set.pcf @@ -140,6 +140,16 @@ examples { 1 0 } + ["iterate()"] { + 3 + 6 + "Expected value of type `Boolean`, but got type `Int`. Value: 1" + } + ["iterateIndexed()"] { + 3 + 6 + "Expected value of type `Boolean`, but got type `Int`. Value: 0" + } ["reduce()"] { 6 1 diff --git a/stdlib/base.pkl b/stdlib/base.pkl index f5d00101f..0bd75c418 100644 --- a/stdlib/base.pkl +++ b/stdlib/base.pkl @@ -2571,6 +2571,18 @@ abstract external class Collection extends Any { /// The first parameter of [operator] is the zero-based index of the current element. abstract function foldIndexed(initial: Result, operator: (Int, Result, Element) -> Result): Result + /// Iterate over elements of this collection in order, accumulating a result. + /// The looping function returns a [Pair] where the first value is a [Boolean] indicating if the iteration + /// should keep going ([true]) or stop, and the second value is the accumulated result. + /// This function is similar to [fold], but it can short-circuit. + abstract function iterate(initial: Result, loopFunction: (Result, Element) -> Pair): Result + + /// Iterate over the indexes and elements of this collection in order, accumulating a result. + /// The looping function returns a [Pair] where the first value is a [Boolean] indicating if the iteration + /// should keep going ([true]) or stop, and the second value is the accumulated result. + /// This function is similar to [foldIndexed], but it can short-circuit. + abstract function iterateIndexed(initial: Result, loopFunction: (Int, Result, Element) -> Pair): Result + /// Folds this collection in iteration order using [operator], starting with the first element. /// /// Throws if this collection is empty. @@ -3050,6 +3062,9 @@ external class List extends Collection { external function foldBack(initial: Result, operator: (Element, Result) -> Result): Result external function foldIndexed(initial: Result, operator: (Int, Result, Element) -> Result): Result + external function iterate(initial: Result, loopFunction: (Element, Result) -> Pair): Result + external function iterateIndexed(initial: Result, loopFunction: (Int, Result, Element) -> Pair): Result + external function reduce(operator: (Element|Result, Element) -> Result): Result external function reduceOrNull(operator: (Element|Result, Element) -> Result): Result? @@ -3191,6 +3206,9 @@ external class Set extends Collection { external function foldBack(initial: Result, operator: (Element, Result) -> Result): Result external function foldIndexed(initial: Result, operator: (Int, Result, Element) -> Result): Result + external function iterate(initial: Result, loopFunction: (Element, Result) -> Pair): Result + external function iterateIndexed(initial: Result, loopFunction: (Int, Result, Element) -> Pair): Result + external function reduce(operator: (Element|Result, Element) -> Result): Result external function reduceOrNull(operator: (Element|Result, Element) -> Result): Result? From e3d5187918f129c4e4af3150b92530a1bd7f0ad6 Mon Sep 17 00:00:00 2001 From: Islon Scherer Date: Tue, 1 Jul 2025 18:34:33 +0200 Subject: [PATCH 2/4] Address review comments. --- pkl-core/src/main/java/org/pkl/core/stdlib/base/ListNodes.java | 2 +- pkl-core/src/main/java/org/pkl/core/stdlib/base/SetNodes.java | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pkl-core/src/main/java/org/pkl/core/stdlib/base/ListNodes.java b/pkl-core/src/main/java/org/pkl/core/stdlib/base/ListNodes.java index 4555881ad..18f6ca11f 100644 --- a/pkl-core/src/main/java/org/pkl/core/stdlib/base/ListNodes.java +++ b/pkl-core/src/main/java/org/pkl/core/stdlib/base/ListNodes.java @@ -844,7 +844,7 @@ protected Object eval(VirtualFrame frame, VmList self, Object initial, VmFunctio var iter = self.iterator(); var typeNode = getTypeNode(); var result = initial; - long index = 0; + var index = 0L; var loop = 0; while (iter.hasNext()) { diff --git a/pkl-core/src/main/java/org/pkl/core/stdlib/base/SetNodes.java b/pkl-core/src/main/java/org/pkl/core/stdlib/base/SetNodes.java index fe1bd9881..d6226ab4a 100644 --- a/pkl-core/src/main/java/org/pkl/core/stdlib/base/SetNodes.java +++ b/pkl-core/src/main/java/org/pkl/core/stdlib/base/SetNodes.java @@ -623,7 +623,7 @@ protected Object eval(VirtualFrame frame, VmSet self, Object initial, VmFunction var iter = self.iterator(); var typeNode = getTypeNode(); var result = initial; - long index = 0; + var index = 0L; var loop = 0; while (iter.hasNext()) { From d353b855082a9c3e0c09ae779f6b3b7cfb05903c Mon Sep 17 00:00:00 2001 From: Islon Scherer Date: Mon, 14 Jul 2025 16:35:30 +0200 Subject: [PATCH 3/4] Remove iterate functions and add `folded` function to short-circuit the other folding functions --- .../org/pkl/core/stdlib/base/BaseNodes.java | 9 + .../pkl/core/stdlib/base/FoldedException.java | 24 +++ .../org/pkl/core/stdlib/base/ListNodes.java | 192 +++++++---------- .../org/pkl/core/stdlib/base/SetNodes.java | 197 +++++++----------- .../org/pkl/core/errorMessages.properties | 3 + .../LanguageSnippetTests/input/api/list.pkl | 21 +- .../LanguageSnippetTests/input/api/set.pkl | 16 +- .../LanguageSnippetTests/output/api/list.pcf | 17 +- .../LanguageSnippetTests/output/api/set.pcf | 14 +- stdlib/base.pkl | 33 ++- 10 files changed, 227 insertions(+), 299 deletions(-) create mode 100644 pkl-core/src/main/java/org/pkl/core/stdlib/base/FoldedException.java diff --git a/pkl-core/src/main/java/org/pkl/core/stdlib/base/BaseNodes.java b/pkl-core/src/main/java/org/pkl/core/stdlib/base/BaseNodes.java index 5fef5731c..a084e515a 100644 --- a/pkl-core/src/main/java/org/pkl/core/stdlib/base/BaseNodes.java +++ b/pkl-core/src/main/java/org/pkl/core/stdlib/base/BaseNodes.java @@ -151,4 +151,13 @@ protected VmList eval(VirtualFrame frame, VmTyped self, Object args) { throw exceptionBuilder().bug("Node `BaseNodes.Bytes` should never be executed.").build(); } } + + public abstract static class folded extends ExternalMethod1Node { + @Specialization + protected void eval(VirtualFrame frame, VmTyped self, Object value) { + CompilerDirectives.transferToInterpreter(); + var foldedException = new FoldedException(value); + throw exceptionBuilder().evalError("invalidFoldedCall").withCause(foldedException).build(); + } + } } diff --git a/pkl-core/src/main/java/org/pkl/core/stdlib/base/FoldedException.java b/pkl-core/src/main/java/org/pkl/core/stdlib/base/FoldedException.java new file mode 100644 index 000000000..9ab48ae59 --- /dev/null +++ b/pkl-core/src/main/java/org/pkl/core/stdlib/base/FoldedException.java @@ -0,0 +1,24 @@ +/* + * Copyright © 2025 Apple Inc. and the Pkl project authors. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.pkl.core.stdlib.base; + +public class FoldedException extends Exception { + public final Object value; + + public FoldedException(Object value) { + this.value = value; + } +} diff --git a/pkl-core/src/main/java/org/pkl/core/stdlib/base/ListNodes.java b/pkl-core/src/main/java/org/pkl/core/stdlib/base/ListNodes.java index 18f6ca11f..1524bf23d 100644 --- a/pkl-core/src/main/java/org/pkl/core/stdlib/base/ListNodes.java +++ b/pkl-core/src/main/java/org/pkl/core/stdlib/base/ListNodes.java @@ -25,10 +25,7 @@ import org.pkl.core.ast.internal.IsInstanceOfNodeGen; import org.pkl.core.ast.lambda.*; import org.pkl.core.ast.type.TypeNode; -import org.pkl.core.ast.type.TypeNode.BooleanTypeNode; -import org.pkl.core.ast.type.TypeNode.PairTypeNode; import org.pkl.core.ast.type.TypeNode.UInt8TypeAliasTypeNode; -import org.pkl.core.ast.type.TypeNode.UnknownTypeNode; import org.pkl.core.ast.type.VmTypeMismatchException; import org.pkl.core.runtime.*; import org.pkl.core.stdlib.*; @@ -742,14 +739,22 @@ public abstract static class fold extends ExternalMethod2Node { @Specialization protected Object eval(VmList self, Object initial, VmFunction function) { - var iter = self.iterator(); - var result = initial; - while (iter.hasNext()) { - var elem = iter.next(); - result = applyLambdaNode.execute(function, result, elem); + try { + var iter = self.iterator(); + var result = initial; + while (iter.hasNext()) { + var elem = iter.next(); + result = applyLambdaNode.execute(function, result, elem); + } + LoopNode.reportLoopCount(this, self.getLength()); + return result; + } catch (VmException vme) { + CompilerDirectives.transferToInterpreter(); + if (vme.getCause() instanceof FoldedException fe) { + return fe.value; + } + throw vme; } - LoopNode.reportLoopCount(this, self.getLength()); - return result; } } @@ -758,14 +763,22 @@ public abstract static class foldBack extends ExternalMethod2Node { @Specialization protected Object eval(VmList self, Object initial, VmFunction function) { - var iter = self.reverseIterator(); - var result = initial; - while (iter.hasNext()) { - var elem = iter.next(); - result = applyLambdaNode.execute(function, elem, result); + try { + var iter = self.reverseIterator(); + var result = initial; + while (iter.hasNext()) { + var elem = iter.next(); + result = applyLambdaNode.execute(function, elem, result); + } + LoopNode.reportLoopCount(this, self.getLength()); + return result; + } catch (VmException vme) { + CompilerDirectives.transferToInterpreter(); + if (vme.getCause() instanceof FoldedException fe) { + return fe.value; + } + throw vme; } - LoopNode.reportLoopCount(this, self.getLength()); - return result; } } @@ -774,90 +787,25 @@ public abstract static class foldIndexed extends ExternalMethod2Node { @Specialization protected Object eval(VmList self, Object initial, VmFunction function) { - var iter = self.iterator(); - var result = initial; - long index = 0; - - while (iter.hasNext()) { - var elem = iter.next(); - result = applyLambdaNode.execute(function, index++, result, elem); - } - - LoopNode.reportLoopCount(this, self.getLength()); - return result; - } - } - - public abstract static class iterate extends ExternalMethod2Node { - @Child private ApplyVmFunction2Node applyLambdaNode = ApplyVmFunction2NodeGen.create(); - @Child @LateInit private TypeNode typeNode; - - private TypeNode getTypeNode() { - if (typeNode == null) { - CompilerDirectives.transferToInterpreterAndInvalidate(); - var section = VmUtils.unavailableSourceSection(); - var booleanTypeNode = new BooleanTypeNode(section); - var unknownTypeNode = new UnknownTypeNode(section); - typeNode = new PairTypeNode(section, booleanTypeNode, unknownTypeNode); - } - return typeNode; - } - - @Specialization - protected Object eval(VirtualFrame frame, VmList self, Object initial, VmFunction function) { - var iter = self.iterator(); - var typeNode = getTypeNode(); - var result = initial; - var loop = 0; - - while (iter.hasNext()) { - var elem = iter.next(); - var pairUnchecked = applyLambdaNode.execute(function, result, elem); - var pair = (VmPair) typeNode.executeEagerly(frame, pairUnchecked); - result = pair.getSecond(); - loop++; - if (!(Boolean) pair.getFirst()) break; - } - - LoopNode.reportLoopCount(this, loop); - return result; - } - } - - public abstract static class iterateIndexed extends ExternalMethod2Node { - @Child private ApplyVmFunction3Node applyLambdaNode = ApplyVmFunction3NodeGen.create(); - @Child @LateInit private TypeNode typeNode; - - private TypeNode getTypeNode() { - if (typeNode == null) { - CompilerDirectives.transferToInterpreterAndInvalidate(); - var section = VmUtils.unavailableSourceSection(); - var booleanTypeNode = new BooleanTypeNode(section); - var unknownTypeNode = new UnknownTypeNode(section); - typeNode = new PairTypeNode(section, booleanTypeNode, unknownTypeNode); - } - return typeNode; - } + try { + var iter = self.iterator(); + var result = initial; + long index = 0; - @Specialization - protected Object eval(VirtualFrame frame, VmList self, Object initial, VmFunction function) { - var iter = self.iterator(); - var typeNode = getTypeNode(); - var result = initial; - var index = 0L; - var loop = 0; + while (iter.hasNext()) { + var elem = iter.next(); + result = applyLambdaNode.execute(function, index++, result, elem); + } - while (iter.hasNext()) { - var elem = iter.next(); - var pairUnchecked = applyLambdaNode.execute(function, index++, result, elem); - var pair = (VmPair) typeNode.executeEagerly(frame, pairUnchecked); - result = pair.getSecond(); - loop++; - if (!(Boolean) pair.getFirst()) break; + LoopNode.reportLoopCount(this, self.getLength()); + return result; + } catch (VmException vme) { + CompilerDirectives.transferToInterpreter(); + if (vme.getCause() instanceof FoldedException fe) { + return fe.value; + } + throw vme; } - - LoopNode.reportLoopCount(this, loop); - return result; } } @@ -868,15 +816,23 @@ public abstract static class reduce extends ExternalMethod1Node { protected Object eval(VmList self, VmFunction function) { self.checkNonEmpty(); - var iterator = self.iterator(); - var result = iterator.next(); - while (iterator.hasNext()) { - var elem = iterator.next(); - result = applyLambdaNode.execute(function, result, elem); - } + try { + var iterator = self.iterator(); + var result = iterator.next(); + while (iterator.hasNext()) { + var elem = iterator.next(); + result = applyLambdaNode.execute(function, result, elem); + } - LoopNode.reportLoopCount(this, self.getLength()); - return result; + LoopNode.reportLoopCount(this, self.getLength()); + return result; + } catch (VmException vme) { + CompilerDirectives.transferToInterpreter(); + if (vme.getCause() instanceof FoldedException fe) { + return fe.value; + } + throw vme; + } } } @@ -887,15 +843,23 @@ public abstract static class reduceOrNull extends ExternalMethod1Node { protected Object eval(VmList self, VmFunction function) { if (self.isEmpty()) return VmNull.withoutDefault(); - var iterator = self.iterator(); - var result = iterator.next(); - while (iterator.hasNext()) { - var elem = iterator.next(); - result = applyLambdaNode.execute(function, result, elem); - } + try { + var iterator = self.iterator(); + var result = iterator.next(); + while (iterator.hasNext()) { + var elem = iterator.next(); + result = applyLambdaNode.execute(function, result, elem); + } - LoopNode.reportLoopCount(this, self.getLength()); - return result; + LoopNode.reportLoopCount(this, self.getLength()); + return result; + } catch (VmException vme) { + CompilerDirectives.transferToInterpreter(); + if (vme.getCause() instanceof FoldedException fe) { + return fe.value; + } + throw vme; + } } } diff --git a/pkl-core/src/main/java/org/pkl/core/stdlib/base/SetNodes.java b/pkl-core/src/main/java/org/pkl/core/stdlib/base/SetNodes.java index d6226ab4a..53e8f6bbe 100644 --- a/pkl-core/src/main/java/org/pkl/core/stdlib/base/SetNodes.java +++ b/pkl-core/src/main/java/org/pkl/core/stdlib/base/SetNodes.java @@ -18,7 +18,6 @@ import com.oracle.truffle.api.CompilerDirectives; import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary; import com.oracle.truffle.api.dsl.Specialization; -import com.oracle.truffle.api.frame.VirtualFrame; import com.oracle.truffle.api.nodes.LoopNode; import org.pkl.core.ast.expression.binary.GreaterThanNode; import org.pkl.core.ast.expression.binary.GreaterThanNodeGen; @@ -27,16 +26,11 @@ import org.pkl.core.ast.internal.IsInstanceOfNode; import org.pkl.core.ast.internal.IsInstanceOfNodeGen; import org.pkl.core.ast.lambda.*; -import org.pkl.core.ast.type.TypeNode; -import org.pkl.core.ast.type.TypeNode.BooleanTypeNode; -import org.pkl.core.ast.type.TypeNode.PairTypeNode; -import org.pkl.core.ast.type.TypeNode.UnknownTypeNode; import org.pkl.core.runtime.*; import org.pkl.core.stdlib.*; import org.pkl.core.stdlib.base.CollectionNodes.CompareByNode; import org.pkl.core.stdlib.base.CollectionNodes.CompareNode; import org.pkl.core.stdlib.base.CollectionNodes.CompareWithNode; -import org.pkl.core.util.LateInit; // duplication between ListNodes and SetNodes is "intentional" // (sharing nodes between VmCollection subtypes results in @@ -521,14 +515,22 @@ public abstract static class fold extends ExternalMethod2Node { @Specialization protected Object eval(VmSet self, Object initial, VmFunction function) { - var iter = self.iterator(); - var result = initial; - while (iter.hasNext()) { - var elem = iter.next(); - result = applyLambdaNode.execute(function, result, elem); + try { + var iter = self.iterator(); + var result = initial; + while (iter.hasNext()) { + var elem = iter.next(); + result = applyLambdaNode.execute(function, result, elem); + } + LoopNode.reportLoopCount(this, self.getLength()); + return result; + } catch (VmException vme) { + CompilerDirectives.transferToInterpreter(); + if (vme.getCause() instanceof FoldedException fe) { + return fe.value; + } + throw vme; } - LoopNode.reportLoopCount(this, self.getLength()); - return result; } } @@ -537,14 +539,22 @@ public abstract static class foldBack extends ExternalMethod2Node { @Specialization protected Object eval(VmSet self, Object initial, VmFunction function) { - var iter = self.reverseIterator(); - var result = initial; - while (iter.hasNext()) { - var elem = iter.next(); - result = applyLambdaNode.execute(function, elem, result); + try { + var iter = self.reverseIterator(); + var result = initial; + while (iter.hasNext()) { + var elem = iter.next(); + result = applyLambdaNode.execute(function, elem, result); + } + LoopNode.reportLoopCount(this, self.getLength()); + return result; + } catch (VmException vme) { + CompilerDirectives.transferToInterpreter(); + if (vme.getCause() instanceof FoldedException fe) { + return fe.value; + } + throw vme; } - LoopNode.reportLoopCount(this, self.getLength()); - return result; } } @@ -553,90 +563,25 @@ public abstract static class foldIndexed extends ExternalMethod2Node { @Specialization protected Object eval(VmSet self, Object initial, VmFunction function) { - var iter = self.iterator(); - var result = initial; - long index = 0; - - while (iter.hasNext()) { - var elem = iter.next(); - result = applyLambdaNode.execute(function, index++, result, elem); - } - - LoopNode.reportLoopCount(this, self.getLength()); - return result; - } - } - - public abstract static class iterate extends ExternalMethod2Node { - @Child private ApplyVmFunction2Node applyLambdaNode = ApplyVmFunction2NodeGen.create(); - @Child @LateInit private TypeNode typeNode; - - private TypeNode getTypeNode() { - if (typeNode == null) { - CompilerDirectives.transferToInterpreterAndInvalidate(); - var section = VmUtils.unavailableSourceSection(); - var booleanTypeNode = new BooleanTypeNode(section); - var unknownTypeNode = new UnknownTypeNode(section); - typeNode = new PairTypeNode(section, booleanTypeNode, unknownTypeNode); - } - return typeNode; - } - - @Specialization - protected Object eval(VirtualFrame frame, VmSet self, Object initial, VmFunction function) { - var iter = self.iterator(); - var typeNode = getTypeNode(); - var result = initial; - var loop = 0; - - while (iter.hasNext()) { - var elem = iter.next(); - var pairUnchecked = applyLambdaNode.execute(function, result, elem); - var pair = (VmPair) typeNode.executeEagerly(frame, pairUnchecked); - result = pair.getSecond(); - loop++; - if (!(Boolean) pair.getFirst()) break; - } - - LoopNode.reportLoopCount(this, loop); - return result; - } - } - - public abstract static class iterateIndexed extends ExternalMethod2Node { - @Child private ApplyVmFunction3Node applyLambdaNode = ApplyVmFunction3NodeGen.create(); - @Child @LateInit private TypeNode typeNode; - - private TypeNode getTypeNode() { - if (typeNode == null) { - CompilerDirectives.transferToInterpreterAndInvalidate(); - var section = VmUtils.unavailableSourceSection(); - var booleanTypeNode = new BooleanTypeNode(section); - var unknownTypeNode = new UnknownTypeNode(section); - typeNode = new PairTypeNode(section, booleanTypeNode, unknownTypeNode); - } - return typeNode; - } - - @Specialization - protected Object eval(VirtualFrame frame, VmSet self, Object initial, VmFunction function) { - var iter = self.iterator(); - var typeNode = getTypeNode(); - var result = initial; - var index = 0L; - var loop = 0; + try { + var iter = self.iterator(); + var result = initial; + long index = 0; + + while (iter.hasNext()) { + var elem = iter.next(); + result = applyLambdaNode.execute(function, index++, result, elem); + } - while (iter.hasNext()) { - var elem = iter.next(); - var pairUnchecked = applyLambdaNode.execute(function, index++, result, elem); - var pair = (VmPair) typeNode.executeEagerly(frame, pairUnchecked); - result = pair.getSecond(); - loop++; - if (!(Boolean) pair.getFirst()) break; + LoopNode.reportLoopCount(this, self.getLength()); + return result; + } catch (VmException vme) { + CompilerDirectives.transferToInterpreter(); + if (vme.getCause() instanceof FoldedException fe) { + return fe.value; + } + throw vme; } - - LoopNode.reportLoopCount(this, loop); - return result; } } @@ -647,15 +592,23 @@ public abstract static class reduce extends ExternalMethod1Node { protected Object eval(VmSet self, VmFunction function) { self.checkNonEmpty(); - var iterator = self.iterator(); - var result = iterator.next(); - while (iterator.hasNext()) { - var elem = iterator.next(); - result = applyLambdaNode.execute(function, result, elem); - } + try { + var iterator = self.iterator(); + var result = iterator.next(); + while (iterator.hasNext()) { + var elem = iterator.next(); + result = applyLambdaNode.execute(function, result, elem); + } - LoopNode.reportLoopCount(this, self.getLength()); - return result; + LoopNode.reportLoopCount(this, self.getLength()); + return result; + } catch (VmException vme) { + CompilerDirectives.transferToInterpreter(); + if (vme.getCause() instanceof FoldedException fe) { + return fe.value; + } + throw vme; + } } } @@ -666,15 +619,23 @@ public abstract static class reduceOrNull extends ExternalMethod1Node { protected Object eval(VmSet self, VmFunction function) { if (self.isEmpty()) return VmNull.withoutDefault(); - var iterator = self.iterator(); - var result = iterator.next(); - while (iterator.hasNext()) { - var elem = iterator.next(); - result = applyLambdaNode.execute(function, result, elem); - } + try { + var iterator = self.iterator(); + var result = iterator.next(); + while (iterator.hasNext()) { + var elem = iterator.next(); + result = applyLambdaNode.execute(function, result, elem); + } - LoopNode.reportLoopCount(this, self.getLength()); - return result; + LoopNode.reportLoopCount(this, self.getLength()); + return result; + } catch (VmException vme) { + CompilerDirectives.transferToInterpreter(); + if (vme.getCause() instanceof FoldedException fe) { + return fe.value; + } + throw vme; + } } } diff --git a/pkl-core/src/main/resources/org/pkl/core/errorMessages.properties b/pkl-core/src/main/resources/org/pkl/core/errorMessages.properties index efe3de0be..cea80dc4c 100644 --- a/pkl-core/src/main/resources/org/pkl/core/errorMessages.properties +++ b/pkl-core/src/main/resources/org/pkl/core/errorMessages.properties @@ -1070,3 +1070,6 @@ invalidStringBase64=\ characterCodingException=\ Invalid bytes for charset "{0}". + +invalidFoldedCall=\ +Function `folded` should only be called from inside a fold function. diff --git a/pkl-core/src/test/files/LanguageSnippetTests/input/api/list.pkl b/pkl-core/src/test/files/LanguageSnippetTests/input/api/list.pkl index 99e5d9374..4c51a9a8f 100644 --- a/pkl-core/src/test/files/LanguageSnippetTests/input/api/list.pkl +++ b/pkl-core/src/test/files/LanguageSnippetTests/input/api/list.pkl @@ -199,7 +199,7 @@ examples { list1.dropLastWhile((x) -> true) list1.dropLastWhile((x) -> false) list1.dropLastWhile((x) -> x > 1) - module.catch(() -> list1.dropLastWhile((x) -> "wrong type")) + module.catch(() -> list1.dropLastWhile((_) -> "wrong type")) } ["fold()"] { @@ -207,6 +207,8 @@ examples { list1.fold(List(), (acc, x) -> acc.add(x)) List(1).fold(0, (acc, x) -> acc + x) List().fold(0, (acc, x) -> acc + x) + list2.fold(1, (acc, x) -> if (acc > 10) folded(acc) else acc * x) + module.catch(() -> folded(10)) } ["foldBack()"] { @@ -214,26 +216,14 @@ examples { list1.foldBack(List(), (x, acc) -> acc.add(x)) List(1).foldBack(0, (x, acc) -> x + acc) List().foldBack(0, (x, acc) -> x + acc) - } - - ["iterate()"] { - list1.iterate(0, (acc, x) -> Pair(x < 2, x + acc)) - list1.iterate(0, (acc, x) -> Pair(true, x + acc)) - module.catch(() -> list1.iterate(0, (acc, x) -> x + acc)) - module.catch(() -> list1.iterate(0, (acc, x) -> Pair(x, x + acc))) - } - - ["iterateIndexed()"] { - list1.iterateIndexed(0, (i, acc, x) -> Pair(i < 1, x + acc)) - list1.iterateIndexed(0, (i, acc, x) -> Pair(true, x + acc)) - module.catch(() -> list1.iterateIndexed(0, (i, acc, x) -> x + acc)) - module.catch(() -> list1.iterateIndexed(0, (i, acc, x) -> Pair(i, x + acc))) + list2.foldBack(1, (x, acc) -> if (acc > 10) folded(acc) else acc * x) } ["reduce()"] { list1.reduce((x, y) -> x + y) List(1).reduce((x, y) -> x + y) module.catch(() -> List().reduce((x, y) -> x + y)) + list2.reduce((acc, x) -> if (acc > 10) folded(acc) else acc * x) } ["groupBy()"] { @@ -360,6 +350,7 @@ examples { list1.foldIndexed(0, (idx, x, y) -> idx + x + y) List(1).foldIndexed(0, (idx, x, y) -> idx + x + y) List().foldIndexed(0, (idx, x, y) -> idx + x + y) + list2.foldIndexed(0, (idx, acc, x) -> if (acc > 10) folded(acc) else idx + acc + x) } ["toListing()"] { diff --git a/pkl-core/src/test/files/LanguageSnippetTests/input/api/set.pkl b/pkl-core/src/test/files/LanguageSnippetTests/input/api/set.pkl index 43fe336ac..92de02011 100644 --- a/pkl-core/src/test/files/LanguageSnippetTests/input/api/set.pkl +++ b/pkl-core/src/test/files/LanguageSnippetTests/input/api/set.pkl @@ -169,6 +169,7 @@ examples { set1.fold(List(), (acc, x) -> acc.add(x)) Set(1).fold(0, (acc, x) -> acc + x) Set().fold(0, (acc, x) -> acc + x) + set2.fold(1, (acc, x) -> if (acc > 10) folded(acc) else acc * x) } ["foldBack()"] { @@ -176,24 +177,14 @@ examples { set1.foldBack(List(), (x, acc) -> acc.add(x)) Set(1).foldBack(0, (x, acc) -> x + acc) Set().foldBack(0, (x, acc) -> x + acc) - } - - ["iterate()"] { - set1.iterate(0, (acc, x) -> Pair(x < 2, x + acc)) - set1.iterate(0, (acc, x) -> Pair(true, x + acc)) - module.catch(() -> set1.iterate(0, (acc, x) -> Pair(1, x + acc))) - } - - ["iterateIndexed()"] { - set1.iterateIndexed(0, (i, acc, x) -> Pair(i < 1, x + acc)) - set1.iterateIndexed(0, (i, acc, x) -> Pair(true, x + acc)) - module.catch(() -> set1.iterateIndexed(0, (i, acc, x) -> Pair(i, x + acc))) + set2.foldBack(1, (x, acc) -> if (acc > 10) folded(acc) else acc * x) } ["reduce()"] { set1.reduce((x, y) -> x + y) Set(1).reduce((x, y) -> x + y) module.catch(() -> Set().reduce((x, y) -> x + y)) + set2.reduce((acc, x) -> if (acc > 10) folded(acc) else acc * x) } ["groupBy()"] { @@ -306,6 +297,7 @@ examples { set1.foldIndexed(0, (idx, x, y) -> idx + x + y) Set(1).foldIndexed(0, (idx, x, y) -> idx + x + y) Set().foldIndexed(0, (idx, x, y) -> idx + x + y) + set2.foldIndexed(0, (idx, acc, x) -> if (acc > 10) folded(acc) else idx + acc + x) } ["toListing()"] { diff --git a/pkl-core/src/test/files/LanguageSnippetTests/output/api/list.pcf b/pkl-core/src/test/files/LanguageSnippetTests/output/api/list.pcf index 060eae599..61dac679a 100644 --- a/pkl-core/src/test/files/LanguageSnippetTests/output/api/list.pcf +++ b/pkl-core/src/test/files/LanguageSnippetTests/output/api/list.pcf @@ -165,29 +165,21 @@ examples { List(1, 2, 3) 1 0 + 24 + "Function `folded` should only be called from inside a fold function." } ["foldBack()"] { 6 List(3, 2, 1) 1 0 - } - ["iterate()"] { - 3 - 6 - "Expected value of type `Pair`, but got type `Int`. Value: 1" - "Expected value of type `Boolean`, but got type `Int`. Value: 1" - } - ["iterateIndexed()"] { - 3 - 6 - "Expected value of type `Pair`, but got type `Int`. Value: 1" - "Expected value of type `Boolean`, but got type `Int`. Value: 0" + 20 } ["reduce()"] { 6 1 "Expected a non-empty collection. Collection: List()" + 24 } ["groupBy()"] { Map(1, List(1), 2, List(2), 3, List(3), 4, List(4), 5, List(5)) @@ -298,6 +290,7 @@ examples { 9 1 0 + 16 } ["toListing()"] { new { diff --git a/pkl-core/src/test/files/LanguageSnippetTests/output/api/set.pcf b/pkl-core/src/test/files/LanguageSnippetTests/output/api/set.pcf index e9fe5ae71..32351f1b7 100644 --- a/pkl-core/src/test/files/LanguageSnippetTests/output/api/set.pcf +++ b/pkl-core/src/test/files/LanguageSnippetTests/output/api/set.pcf @@ -133,27 +133,20 @@ examples { List(1, 2, 3) 1 0 + 24 } ["foldBack()"] { 6 List(3, 2, 1) 1 0 - } - ["iterate()"] { - 3 - 6 - "Expected value of type `Boolean`, but got type `Int`. Value: 1" - } - ["iterateIndexed()"] { - 3 - 6 - "Expected value of type `Boolean`, but got type `Int`. Value: 0" + 20 } ["reduce()"] { 6 1 "Expected a non-empty collection. Collection: Set()" + 24 } ["groupBy()"] { Map(1, Set(1), 2, Set(2), 3, Set(3), 4, Set(4), 5, Set(5)) @@ -249,6 +242,7 @@ examples { 9 1 0 + 16 } ["toListing()"] { new { diff --git a/stdlib/base.pkl b/stdlib/base.pkl index 0bd75c418..276ae4223 100644 --- a/stdlib/base.pkl +++ b/stdlib/base.pkl @@ -2160,6 +2160,11 @@ external class Pair extends Any { external value: Second } +/// Wraps [value], terminating the fold with [value] as result. +/// Should only be called from inside a fold function like +/// `fold`, `foldIndexed`, `reduce`, etc. +external const function folded(value: T): T + /// Common base class for [List] and [Set]. /// /// The following operators are supported for all collections: @@ -2560,35 +2565,33 @@ abstract external class Collection extends Any { abstract function dropLastWhile(predicate: (Element) -> Boolean): Collection /// Folds this collection in iteration order using [operator], starting with [initial]. + /// + /// [folded] can be used to short-circuit the computation. abstract function fold(initial: Result, operator: (Result, Element) -> Result): Result /// Folds this collection in reverse iteration order using [operator], starting with [initial]. + /// + /// [folded] can be used to short-circuit the computation. @AlsoKnownAs { names { "foldRight" } } abstract function foldBack(initial: Result, operator: (Element, Result) -> Result): Result /// Folds this collection in iteration order using [operator], starting with [initial]. /// /// The first parameter of [operator] is the zero-based index of the current element. + /// + /// [folded] can be used to short-circuit the computation. abstract function foldIndexed(initial: Result, operator: (Int, Result, Element) -> Result): Result - /// Iterate over elements of this collection in order, accumulating a result. - /// The looping function returns a [Pair] where the first value is a [Boolean] indicating if the iteration - /// should keep going ([true]) or stop, and the second value is the accumulated result. - /// This function is similar to [fold], but it can short-circuit. - abstract function iterate(initial: Result, loopFunction: (Result, Element) -> Pair): Result - - /// Iterate over the indexes and elements of this collection in order, accumulating a result. - /// The looping function returns a [Pair] where the first value is a [Boolean] indicating if the iteration - /// should keep going ([true]) or stop, and the second value is the accumulated result. - /// This function is similar to [foldIndexed], but it can short-circuit. - abstract function iterateIndexed(initial: Result, loopFunction: (Int, Result, Element) -> Pair): Result - /// Folds this collection in iteration order using [operator], starting with the first element. /// /// Throws if this collection is empty. + /// + /// [folded] can be used to short-circuit the computation. abstract function reduce(operator: (Element|Result, Element) -> Result): Result /// Same as [reduce()] but returns [null] if this collection is empty. + /// + /// [folded] can be used to short-circuit the computation. abstract function reduceOrNull(operator: (Element|Result, Element) -> Result): Result? /// Groups the elements in this collection according to keys returned by [selector]. @@ -3062,9 +3065,6 @@ external class List extends Collection { external function foldBack(initial: Result, operator: (Element, Result) -> Result): Result external function foldIndexed(initial: Result, operator: (Int, Result, Element) -> Result): Result - external function iterate(initial: Result, loopFunction: (Element, Result) -> Pair): Result - external function iterateIndexed(initial: Result, loopFunction: (Int, Result, Element) -> Pair): Result - external function reduce(operator: (Element|Result, Element) -> Result): Result external function reduceOrNull(operator: (Element|Result, Element) -> Result): Result? @@ -3206,9 +3206,6 @@ external class Set extends Collection { external function foldBack(initial: Result, operator: (Element, Result) -> Result): Result external function foldIndexed(initial: Result, operator: (Int, Result, Element) -> Result): Result - external function iterate(initial: Result, loopFunction: (Element, Result) -> Pair): Result - external function iterateIndexed(initial: Result, loopFunction: (Int, Result, Element) -> Pair): Result - external function reduce(operator: (Element|Result, Element) -> Result): Result external function reduceOrNull(operator: (Element|Result, Element) -> Result): Result? From a3ec1941cf7ee628c3bc759ad47b40c3fbfedc5a Mon Sep 17 00:00:00 2001 From: Islon Scherer Date: Wed, 16 Jul 2025 11:19:41 +0200 Subject: [PATCH 4/4] Change folded function to not throw an exception --- .../java/org/pkl/core/runtime/BaseModule.java | 8 ++ .../org/pkl/core/stdlib/base/BaseNodes.java | 9 -- .../org/pkl/core/stdlib/base/ListNodes.java | 128 ++++++++---------- .../org/pkl/core/stdlib/base/SetNodes.java | 128 ++++++++---------- .../LanguageSnippetTests/input/api/list.pkl | 25 ++-- .../LanguageSnippetTests/input/api/set.pkl | 22 +-- .../LanguageSnippetTests/output/api/list.pcf | 1 - stdlib/base.pkl | 11 +- 8 files changed, 148 insertions(+), 184 deletions(-) diff --git a/pkl-core/src/main/java/org/pkl/core/runtime/BaseModule.java b/pkl-core/src/main/java/org/pkl/core/runtime/BaseModule.java index f0497dd57..5b0f93e49 100644 --- a/pkl-core/src/main/java/org/pkl/core/runtime/BaseModule.java +++ b/pkl-core/src/main/java/org/pkl/core/runtime/BaseModule.java @@ -179,6 +179,10 @@ public static VmClass getFunction5Class() { return Function5Class.instance; } + public static VmClass getFoldedClass() { + return FoldedClass.instance; + } + public static VmClass getPairClass() { return PairClass.instance; } @@ -379,6 +383,10 @@ private static final class Function5Class { static final VmClass instance = loadClass("Function5"); } + private static final class FoldedClass { + static final VmClass instance = loadClass("Folded"); + } + private static final class NonNullTypeAlias { static final VmTypeAlias instance = loadTypeAlias("NonNull"); } diff --git a/pkl-core/src/main/java/org/pkl/core/stdlib/base/BaseNodes.java b/pkl-core/src/main/java/org/pkl/core/stdlib/base/BaseNodes.java index a084e515a..5fef5731c 100644 --- a/pkl-core/src/main/java/org/pkl/core/stdlib/base/BaseNodes.java +++ b/pkl-core/src/main/java/org/pkl/core/stdlib/base/BaseNodes.java @@ -151,13 +151,4 @@ protected VmList eval(VirtualFrame frame, VmTyped self, Object args) { throw exceptionBuilder().bug("Node `BaseNodes.Bytes` should never be executed.").build(); } } - - public abstract static class folded extends ExternalMethod1Node { - @Specialization - protected void eval(VirtualFrame frame, VmTyped self, Object value) { - CompilerDirectives.transferToInterpreter(); - var foldedException = new FoldedException(value); - throw exceptionBuilder().evalError("invalidFoldedCall").withCause(foldedException).build(); - } - } } diff --git a/pkl-core/src/main/java/org/pkl/core/stdlib/base/ListNodes.java b/pkl-core/src/main/java/org/pkl/core/stdlib/base/ListNodes.java index 1524bf23d..7af7bcee0 100644 --- a/pkl-core/src/main/java/org/pkl/core/stdlib/base/ListNodes.java +++ b/pkl-core/src/main/java/org/pkl/core/stdlib/base/ListNodes.java @@ -739,22 +739,18 @@ public abstract static class fold extends ExternalMethod2Node { @Specialization protected Object eval(VmList self, Object initial, VmFunction function) { - try { - var iter = self.iterator(); - var result = initial; - while (iter.hasNext()) { - var elem = iter.next(); - result = applyLambdaNode.execute(function, result, elem); - } - LoopNode.reportLoopCount(this, self.getLength()); - return result; - } catch (VmException vme) { - CompilerDirectives.transferToInterpreter(); - if (vme.getCause() instanceof FoldedException fe) { - return fe.value; + var iter = self.iterator(); + var result = initial; + while (iter.hasNext()) { + var elem = iter.next(); + result = applyLambdaNode.execute(function, result, elem); + if (result instanceof VmTyped vmTyped + && vmTyped.getVmClass() == BaseModule.getFoldedClass()) { + return VmUtils.readMember(vmTyped, Identifier.VALUE); } - throw vme; } + LoopNode.reportLoopCount(this, self.getLength()); + return result; } } @@ -763,22 +759,18 @@ public abstract static class foldBack extends ExternalMethod2Node { @Specialization protected Object eval(VmList self, Object initial, VmFunction function) { - try { - var iter = self.reverseIterator(); - var result = initial; - while (iter.hasNext()) { - var elem = iter.next(); - result = applyLambdaNode.execute(function, elem, result); - } - LoopNode.reportLoopCount(this, self.getLength()); - return result; - } catch (VmException vme) { - CompilerDirectives.transferToInterpreter(); - if (vme.getCause() instanceof FoldedException fe) { - return fe.value; + var iter = self.reverseIterator(); + var result = initial; + while (iter.hasNext()) { + var elem = iter.next(); + result = applyLambdaNode.execute(function, elem, result); + if (result instanceof VmTyped vmTyped + && vmTyped.getVmClass() == BaseModule.getFoldedClass()) { + return VmUtils.readMember(vmTyped, Identifier.VALUE); } - throw vme; } + LoopNode.reportLoopCount(this, self.getLength()); + return result; } } @@ -787,25 +779,21 @@ public abstract static class foldIndexed extends ExternalMethod2Node { @Specialization protected Object eval(VmList self, Object initial, VmFunction function) { - try { - var iter = self.iterator(); - var result = initial; - long index = 0; - - while (iter.hasNext()) { - var elem = iter.next(); - result = applyLambdaNode.execute(function, index++, result, elem); - } + var iter = self.iterator(); + var result = initial; + long index = 0; - LoopNode.reportLoopCount(this, self.getLength()); - return result; - } catch (VmException vme) { - CompilerDirectives.transferToInterpreter(); - if (vme.getCause() instanceof FoldedException fe) { - return fe.value; + while (iter.hasNext()) { + var elem = iter.next(); + result = applyLambdaNode.execute(function, index++, result, elem); + if (result instanceof VmTyped vmTyped + && vmTyped.getVmClass() == BaseModule.getFoldedClass()) { + return VmUtils.readMember(vmTyped, Identifier.VALUE); } - throw vme; } + + LoopNode.reportLoopCount(this, self.getLength()); + return result; } } @@ -816,23 +804,19 @@ public abstract static class reduce extends ExternalMethod1Node { protected Object eval(VmList self, VmFunction function) { self.checkNonEmpty(); - try { - var iterator = self.iterator(); - var result = iterator.next(); - while (iterator.hasNext()) { - var elem = iterator.next(); - result = applyLambdaNode.execute(function, result, elem); - } - - LoopNode.reportLoopCount(this, self.getLength()); - return result; - } catch (VmException vme) { - CompilerDirectives.transferToInterpreter(); - if (vme.getCause() instanceof FoldedException fe) { - return fe.value; + var iterator = self.iterator(); + var result = iterator.next(); + while (iterator.hasNext()) { + var elem = iterator.next(); + result = applyLambdaNode.execute(function, result, elem); + if (result instanceof VmTyped vmTyped + && vmTyped.getVmClass() == BaseModule.getFoldedClass()) { + return VmUtils.readMember(vmTyped, Identifier.VALUE); } - throw vme; } + + LoopNode.reportLoopCount(this, self.getLength()); + return result; } } @@ -843,23 +827,19 @@ public abstract static class reduceOrNull extends ExternalMethod1Node { protected Object eval(VmList self, VmFunction function) { if (self.isEmpty()) return VmNull.withoutDefault(); - try { - var iterator = self.iterator(); - var result = iterator.next(); - while (iterator.hasNext()) { - var elem = iterator.next(); - result = applyLambdaNode.execute(function, result, elem); - } - - LoopNode.reportLoopCount(this, self.getLength()); - return result; - } catch (VmException vme) { - CompilerDirectives.transferToInterpreter(); - if (vme.getCause() instanceof FoldedException fe) { - return fe.value; + var iterator = self.iterator(); + var result = iterator.next(); + while (iterator.hasNext()) { + var elem = iterator.next(); + result = applyLambdaNode.execute(function, result, elem); + if (result instanceof VmTyped vmTyped + && vmTyped.getVmClass() == BaseModule.getFoldedClass()) { + return VmUtils.readMember(vmTyped, Identifier.VALUE); } - throw vme; } + + LoopNode.reportLoopCount(this, self.getLength()); + return result; } } diff --git a/pkl-core/src/main/java/org/pkl/core/stdlib/base/SetNodes.java b/pkl-core/src/main/java/org/pkl/core/stdlib/base/SetNodes.java index 53e8f6bbe..c40f5c9cf 100644 --- a/pkl-core/src/main/java/org/pkl/core/stdlib/base/SetNodes.java +++ b/pkl-core/src/main/java/org/pkl/core/stdlib/base/SetNodes.java @@ -515,22 +515,18 @@ public abstract static class fold extends ExternalMethod2Node { @Specialization protected Object eval(VmSet self, Object initial, VmFunction function) { - try { - var iter = self.iterator(); - var result = initial; - while (iter.hasNext()) { - var elem = iter.next(); - result = applyLambdaNode.execute(function, result, elem); - } - LoopNode.reportLoopCount(this, self.getLength()); - return result; - } catch (VmException vme) { - CompilerDirectives.transferToInterpreter(); - if (vme.getCause() instanceof FoldedException fe) { - return fe.value; + var iter = self.iterator(); + var result = initial; + while (iter.hasNext()) { + var elem = iter.next(); + result = applyLambdaNode.execute(function, result, elem); + if (result instanceof VmTyped vmTyped + && vmTyped.getVmClass() == BaseModule.getFoldedClass()) { + return VmUtils.readMember(vmTyped, Identifier.VALUE); } - throw vme; } + LoopNode.reportLoopCount(this, self.getLength()); + return result; } } @@ -539,22 +535,18 @@ public abstract static class foldBack extends ExternalMethod2Node { @Specialization protected Object eval(VmSet self, Object initial, VmFunction function) { - try { - var iter = self.reverseIterator(); - var result = initial; - while (iter.hasNext()) { - var elem = iter.next(); - result = applyLambdaNode.execute(function, elem, result); - } - LoopNode.reportLoopCount(this, self.getLength()); - return result; - } catch (VmException vme) { - CompilerDirectives.transferToInterpreter(); - if (vme.getCause() instanceof FoldedException fe) { - return fe.value; + var iter = self.reverseIterator(); + var result = initial; + while (iter.hasNext()) { + var elem = iter.next(); + result = applyLambdaNode.execute(function, elem, result); + if (result instanceof VmTyped vmTyped + && vmTyped.getVmClass() == BaseModule.getFoldedClass()) { + return VmUtils.readMember(vmTyped, Identifier.VALUE); } - throw vme; } + LoopNode.reportLoopCount(this, self.getLength()); + return result; } } @@ -563,25 +555,21 @@ public abstract static class foldIndexed extends ExternalMethod2Node { @Specialization protected Object eval(VmSet self, Object initial, VmFunction function) { - try { - var iter = self.iterator(); - var result = initial; - long index = 0; - - while (iter.hasNext()) { - var elem = iter.next(); - result = applyLambdaNode.execute(function, index++, result, elem); - } + var iter = self.iterator(); + var result = initial; + long index = 0; - LoopNode.reportLoopCount(this, self.getLength()); - return result; - } catch (VmException vme) { - CompilerDirectives.transferToInterpreter(); - if (vme.getCause() instanceof FoldedException fe) { - return fe.value; + while (iter.hasNext()) { + var elem = iter.next(); + result = applyLambdaNode.execute(function, index++, result, elem); + if (result instanceof VmTyped vmTyped + && vmTyped.getVmClass() == BaseModule.getFoldedClass()) { + return VmUtils.readMember(vmTyped, Identifier.VALUE); } - throw vme; } + + LoopNode.reportLoopCount(this, self.getLength()); + return result; } } @@ -592,23 +580,19 @@ public abstract static class reduce extends ExternalMethod1Node { protected Object eval(VmSet self, VmFunction function) { self.checkNonEmpty(); - try { - var iterator = self.iterator(); - var result = iterator.next(); - while (iterator.hasNext()) { - var elem = iterator.next(); - result = applyLambdaNode.execute(function, result, elem); - } - - LoopNode.reportLoopCount(this, self.getLength()); - return result; - } catch (VmException vme) { - CompilerDirectives.transferToInterpreter(); - if (vme.getCause() instanceof FoldedException fe) { - return fe.value; + var iterator = self.iterator(); + var result = iterator.next(); + while (iterator.hasNext()) { + var elem = iterator.next(); + result = applyLambdaNode.execute(function, result, elem); + if (result instanceof VmTyped vmTyped + && vmTyped.getVmClass() == BaseModule.getFoldedClass()) { + return VmUtils.readMember(vmTyped, Identifier.VALUE); } - throw vme; } + + LoopNode.reportLoopCount(this, self.getLength()); + return result; } } @@ -619,23 +603,19 @@ public abstract static class reduceOrNull extends ExternalMethod1Node { protected Object eval(VmSet self, VmFunction function) { if (self.isEmpty()) return VmNull.withoutDefault(); - try { - var iterator = self.iterator(); - var result = iterator.next(); - while (iterator.hasNext()) { - var elem = iterator.next(); - result = applyLambdaNode.execute(function, result, elem); - } - - LoopNode.reportLoopCount(this, self.getLength()); - return result; - } catch (VmException vme) { - CompilerDirectives.transferToInterpreter(); - if (vme.getCause() instanceof FoldedException fe) { - return fe.value; + var iterator = self.iterator(); + var result = iterator.next(); + while (iterator.hasNext()) { + var elem = iterator.next(); + result = applyLambdaNode.execute(function, result, elem); + if (result instanceof VmTyped vmTyped + && vmTyped.getVmClass() == BaseModule.getFoldedClass()) { + return VmUtils.readMember(vmTyped, Identifier.VALUE); } - throw vme; } + + LoopNode.reportLoopCount(this, self.getLength()); + return result; } } diff --git a/pkl-core/src/test/files/LanguageSnippetTests/input/api/list.pkl b/pkl-core/src/test/files/LanguageSnippetTests/input/api/list.pkl index 4c51a9a8f..2f3b4a737 100644 --- a/pkl-core/src/test/files/LanguageSnippetTests/input/api/list.pkl +++ b/pkl-core/src/test/files/LanguageSnippetTests/input/api/list.pkl @@ -58,11 +58,11 @@ facts { examples { ["every"] { - module.catch(() -> list1.every((x) -> "wrong type")) + module.catch(() -> list1.every((_) -> "wrong type")) } ["any"] { - module.catch(() -> list1.any((x) -> "wrong type")) + module.catch(() -> list1.any((_) -> "wrong type")) } ["length"] { @@ -72,7 +72,7 @@ examples { ["filter()"] { list1.filter((x) -> x > 1) list1.filter((x) -> x > 3) - module.catch(() -> list1. filter((x) -> "wrong type")) + module.catch(() -> list1. filter((_) -> "wrong type")) } ["map()"] { @@ -82,7 +82,7 @@ examples { ["flatMap()"] { list1.flatMap((x) -> List(x, x + 1)) - module.catch(() -> list1.flatMap((x) -> "wrong type")) + module.catch(() -> list1.flatMap((_) -> "wrong type")) } ["flatten()"] { @@ -113,19 +113,19 @@ examples { list1.count((x) -> x >= 2) list1.count((x) -> true) list1.count((x) -> false) - module.catch(() -> list1.count((x) -> "wrong type")) + module.catch(() -> list1.count((_) -> "wrong type")) } ["find()"] { list1.find((x) -> x >= 2) list1.find((x) -> true) - module.catch(() -> list1.find((x) -> "wrong type")) + module.catch(() -> list1.find((_) -> "wrong type")) } ["findLast()"] { list1.findLast((x) -> x >= 2) list1.findLast((x) -> true) - module.catch(() -> list1.findLast((x) -> "wrong type")) + module.catch(() -> list1.findLast((_) -> "wrong type")) } ["indexOf()"] { @@ -141,13 +141,13 @@ examples { ["findIndex()"] { list1.findIndex((x) -> x >= 2) List(1,2,3,2,1).findIndex((x) -> x >= 2) - module.catch(() -> list1.findIndex((x) -> "wrong type")) + module.catch(() -> list1.findIndex((_) -> "wrong type")) } ["findLastIndex()"] { list1.findLastIndex((x) -> x >= 2) List(1,2,3,2,1).findLastIndex((x) -> x >= 2) - module.catch(() -> list1.findLastIndex((x) -> "wrong type")) + module.catch(() -> list1.findLastIndex((_) -> "wrong type")) } ["take()"] { @@ -160,7 +160,7 @@ examples { list1.takeWhile((x) -> true) list1.takeWhile((x) -> false) list1.takeWhile((x) -> x < 3) - module.catch(() -> list1.takeWhile((x) -> "wrong type")) + module.catch(() -> list1.takeWhile((_) -> "wrong type")) } ["takeLast()"] { @@ -173,7 +173,7 @@ examples { list1.takeLastWhile((x) -> true) list1.takeLastWhile((x) -> false) list1.takeLastWhile((x) -> x > 1) - module.catch(() -> list1.takeLastWhile((x) -> "wrong type")) + module.catch(() -> list1.takeLastWhile((_) -> "wrong type")) } ["drop()"] { @@ -186,7 +186,7 @@ examples { list1.dropWhile((x) -> true) list1.dropWhile((x) -> false) list1.dropWhile((x) -> x < 3) - module.catch(() -> list1.dropWhile((x) -> "wrong type")) + module.catch(() -> list1.dropWhile((_) -> "wrong type")) } ["dropLast()"] { @@ -208,7 +208,6 @@ examples { List(1).fold(0, (acc, x) -> acc + x) List().fold(0, (acc, x) -> acc + x) list2.fold(1, (acc, x) -> if (acc > 10) folded(acc) else acc * x) - module.catch(() -> folded(10)) } ["foldBack()"] { diff --git a/pkl-core/src/test/files/LanguageSnippetTests/input/api/set.pkl b/pkl-core/src/test/files/LanguageSnippetTests/input/api/set.pkl index 92de02011..80221ac5e 100644 --- a/pkl-core/src/test/files/LanguageSnippetTests/input/api/set.pkl +++ b/pkl-core/src/test/files/LanguageSnippetTests/input/api/set.pkl @@ -42,11 +42,11 @@ facts { examples { ["every"] { - module.catch(() -> set1.every((x) -> "wrong type")) + module.catch(() -> set1.every((_) -> "wrong type")) } ["any"] { - module.catch(() -> set1.any((x) -> "wrong type")) + module.catch(() -> set1.any((_) -> "wrong type")) } ["length"] { @@ -56,7 +56,7 @@ examples { ["filter()"] { set1.filter((x) -> x > 1) set1. filter((x) -> x > 3) - module.catch(() -> set1. filter((x) -> "wrong type")) + module.catch(() -> set1. filter((_) -> "wrong type")) } ["map()"] { @@ -66,7 +66,7 @@ examples { ["flatMap()"] { set1.flatMap((x) -> List(x, x + 1)) - module.catch(() -> set1.flatMap((x) -> "wrong type")) + module.catch(() -> set1.flatMap((_) -> "wrong type")) } ["flatten()"] { @@ -97,19 +97,19 @@ examples { set1.count((x) -> x >= 2) set1.count((x) -> true) set1.count((x) -> false) - module.catch(() -> set1.count((x) -> "wrong type")) + module.catch(() -> set1.count((_) -> "wrong type")) } ["find()"] { set1.find((x) -> x >= 2) set1.find((x) -> true) - module.catch(() -> set1.find((x) -> "wrong type")) + module.catch(() -> set1.find((_) -> "wrong type")) } ["findLast()"] { set1.findLast((x) -> x >= 2) set1.findLast((x) -> true) - module.catch(() -> set1.findLast((x) -> "wrong type")) + module.catch(() -> set1.findLast((_) -> "wrong type")) } ["take()"] { @@ -122,7 +122,7 @@ examples { set1.takeWhile((x) -> true) set1.takeWhile((x) -> false) set1.takeWhile((x) -> x < 3) - module.catch(() -> set1.takeWhile((x) -> "wrong type")) + module.catch(() -> set1.takeWhile((_) -> "wrong type")) } ["takeLast()"] { @@ -135,7 +135,7 @@ examples { set1.takeLastWhile((x) -> true) set1.takeLastWhile((x) -> false) set1.takeLastWhile((x) -> x > 1) - module.catch(() -> set1.takeLastWhile((x) -> "wrong type")) + module.catch(() -> set1.takeLastWhile((_) -> "wrong type")) } ["drop()"] { @@ -148,7 +148,7 @@ examples { set1.dropWhile((x) -> true) set1.dropWhile((x) -> false) set1.dropWhile((x) -> x < 3) - module.catch(() -> set1.dropWhile((x) -> "wrong type")) + module.catch(() -> set1.dropWhile((_) -> "wrong type")) } ["dropLast()"] { @@ -161,7 +161,7 @@ examples { set1.dropLastWhile((x) -> true) set1.dropLastWhile((x) -> false) set1.dropLastWhile((x) -> x > 1) - module.catch(() -> set1.dropLastWhile((x) -> "wrong type")) + module.catch(() -> set1.dropLastWhile((_) -> "wrong type")) } ["fold()"] { diff --git a/pkl-core/src/test/files/LanguageSnippetTests/output/api/list.pcf b/pkl-core/src/test/files/LanguageSnippetTests/output/api/list.pcf index 61dac679a..7b6156403 100644 --- a/pkl-core/src/test/files/LanguageSnippetTests/output/api/list.pcf +++ b/pkl-core/src/test/files/LanguageSnippetTests/output/api/list.pcf @@ -166,7 +166,6 @@ examples { 1 0 24 - "Function `folded` should only be called from inside a fold function." } ["foldBack()"] { 6 diff --git a/stdlib/base.pkl b/stdlib/base.pkl index 276ae4223..93a351512 100644 --- a/stdlib/base.pkl +++ b/stdlib/base.pkl @@ -2160,10 +2160,17 @@ external class Pair extends Any { external value: Second } -/// Wraps [value], terminating the fold with [value] as result. +/// Wraps [v], terminating the fold with [v] as result. +/// /// Should only be called from inside a fold function like /// `fold`, `foldIndexed`, `reduce`, etc. -external const function folded(value: T): T +const function folded(v: Any): Any = new Folded { + value = v +} + +class Folded { + value: Any +} /// Common base class for [List] and [Set]. ///