diff --git a/docs/simfil-language.md b/docs/simfil-language.md index ace9fa90..fc0941af 100644 --- a/docs/simfil-language.md +++ b/docs/simfil-language.md @@ -34,14 +34,66 @@ Example: `*.["field name with spaces"]`. ### Symbols -Simfil parses identifiers containing only uppercase letters and underscores -as strings, but only if not on either side of a path operator `.`. -This means, that expressions like `**.field = ABC` get parsed as -`**.field = "ABC"`. Note that this is not the case if a symbol appears on -either side of `.`, such as `ABC.field`! +Simfil parses unquoted identifiers as field names. String values should be +written as quoted literals, for example `field = "ABC"`. -To force parsing a symbol as a field, you can put it in a path expression: -`_.FIELD` or use the subscript operator `[FIELD]`. +When schema metadata is available, the compiler may reinterpret an unquoted +standalone token as a schema symbol. This is a schema rewrite, not a parser +rule: without schema metadata, `ABC` is the field `ABC`. + +### Schema-Aware Field and Enum Resolution + +When the caller supplies a schema for the current model, simfil can use that schema while compiling, completing, and evaluating path expressions. This keeps short queries practical without changing the core path syntax. + +Schema-aware behavior is conservative: + +- A standalone scalar field name can resolve to the concrete schema path that owns that field. +- A recursive wildcard such as `**.speedLimitKmh` can skip schema branches that cannot contain `speedLimitKmh`, which avoids scanning arbitrary object branches when the schema is precise. +- An unquoted operand can resolve to a string constant when the schema proves it belongs to an enum domain. +- A standalone enum-like string literal can resolve to an equality comparison against the schema path that owns that enum value. +- If the same token can mean both a field and an enum-like value, field access wins. This keeps schema shorthand aligned with normal unquoted identifiers. + +Examples: + +```simfil +speedLimitKmh > 80 +``` + +can be compiled against a schema as the concrete path that owns `speedLimitKmh`, while: + +```simfil +SPEED_LIMIT_END +``` + +can be compiled as a comparison against schema paths whose enum domain contains `SPEED_LIMIT_END`. + +Other common schema-aware patterns: + +```simfil +**.speedLimitKmh > 80 +``` + +uses the recursive wildcard syntax while still allowing schema-guided pruning, and: + +```simfil +"SPEED_LIMIT_END" +``` + +uses the quoted enum value inserted by schema-aware completion. Schema mode can +also resolve unquoted enum tokens when the schema proves they are enum values, +but quoted strings are the explicit representation. + +Use explicit path syntax when you want to force field access: + +```simfil +_.WARNING_SIGN +``` + +or: + +```simfil +["WARNING_SIGN"] +``` ### Sub-Queries @@ -224,6 +276,7 @@ of `expr` are stored for debugging purposes; see `limit`. *Example* ``` trace(a.**.b{trace("sub", c == "test")}) +trace(**.speedLimitKmh, name="speed limits") ``` Arguments: diff --git a/include/simfil/environment.h b/include/simfil/environment.h index 03838728..4f7088db 100644 --- a/include/simfil/environment.h +++ b/include/simfil/environment.h @@ -22,6 +22,7 @@ namespace simfil class Expr; class Function; class Diagnostics; +class Schema; struct ResultFn; struct Debug; @@ -61,6 +62,8 @@ struct Trace struct Environment { public: + using QuerySchemaCallback = std::function; + /** * Construct a SIMFIL execution environment with a string cache, * which is used to map field names to short integer IDs. @@ -116,6 +119,12 @@ struct Environment [[nodiscard]] auto strings() const -> std::shared_ptr; + /** + * Query an object schema by its schema id. + * Returns nullptr if no callback is configured or the schema is unknown. + */ + auto querySchema(SchemaId schemaId) const -> const Schema*; + public: std::unique_ptr warnMtx; std::vector> warnings; @@ -129,6 +138,16 @@ struct Environment /* constant ident -> value */ std::map constants; + QuerySchemaCallback querySchemaCallback; + + /** + * Enable cached schema-guided wildcard field traversal plans. + * + * Disabling this keeps the older behavior where wildcard field lookups only + * ask each node schema whether the requested field can appear below it. + */ + bool enableWildcardFieldPlans = true; + Debug* debug = nullptr; std::shared_ptr stringPool; }; diff --git a/include/simfil/expression-visitor.h b/include/simfil/expression-visitor.h index d01ccdfd..350c32dd 100644 --- a/include/simfil/expression-visitor.h +++ b/include/simfil/expression-visitor.h @@ -21,7 +21,9 @@ class UnpackExpr; class UnaryWordOpExpr; class BinaryWordOpExpr; class FieldExpr; +class WildcardFieldExpr; class PathExpr; +class PathAlternativesExpr; class AndExpr; class OrExpr; struct OperatorEq; @@ -53,7 +55,9 @@ class ExprVisitor virtual void visit(const EachExpr& expr); virtual void visit(const CallExpression& expr); virtual void visit(const PathExpr& expr); + virtual void visit(const PathAlternativesExpr& expr); virtual void visit(const FieldExpr& expr); + virtual void visit(const WildcardFieldExpr& expr); virtual void visit(const UnpackExpr& expr); virtual void visit(const UnaryWordOpExpr& expr); virtual void visit(const BinaryWordOpExpr& expr); diff --git a/include/simfil/expression.h b/include/simfil/expression.h index 39692f9d..d798789d 100644 --- a/include/simfil/expression.h +++ b/include/simfil/expression.h @@ -8,12 +8,16 @@ #include "simfil/result.h" #include +#include namespace simfil { +class Expr; class ExprVisitor; +using ExprPtr = std::unique_ptr; + class Expr { friend class AST; @@ -31,17 +35,16 @@ class Expr VALUE, }; - Expr() = delete; - explicit Expr(ExprId id) - : id_(id) - {} - explicit Expr(ExprId id, const Token& token) - : id_(id) + Expr() = default; + explicit Expr(const Token& token) { assert(token.end >= token.begin); sourceLocation_.offset = token.begin; sourceLocation_.size = token.end - token.begin; } + explicit Expr(SourceLocation location) + : sourceLocation_(location) + {} virtual ~Expr() = default; @@ -56,6 +59,26 @@ class Expr return false; } + /* Accept expression visitor */ + virtual auto accept(ExprVisitor& v) const -> void = 0; + + /* Get the number of child expressions */ + virtual auto numChildren() const -> std::size_t + { + return 0; + } + + /* Get the n-th child expression */ + virtual auto childAt(std::size_t) -> ExprPtr& + { + throw std::out_of_range("AST child index out of range"); + } + + virtual auto childAt(std::size_t) const -> const ExprPtr& + { + throw std::out_of_range("AST child index out of range"); + } + /* Debug */ virtual auto toString() const -> std::string = 0; @@ -90,11 +113,7 @@ class Expr return ieval(ctx, std::move(val), res); } - /* Accept expression visitor */ - virtual auto accept(ExprVisitor& v) const -> void = 0; - /* Source location the expression got parsed from */ - [[nodiscard]] auto sourceLocation() const -> SourceLocation { return sourceLocation_; @@ -110,12 +129,10 @@ class Expr return ieval(ctx, value, result); } - ExprId id_; + ExprId id_ = 0; SourceLocation sourceLocation_; }; -using ExprPtr = std::unique_ptr; - class AST { public: @@ -126,6 +143,8 @@ class AST ~AST(); + auto reenumerate() -> void; + auto expr() const -> const Expr& { return *expr_; @@ -137,6 +156,8 @@ class AST } private: + static auto reenumerate(Expr& expr, Expr::ExprId& nextId) -> void; + /* The original query string of the AST */ std::string queryString_; diff --git a/include/simfil/model/model.h b/include/simfil/model/model.h index 31951867..0cc4ca39 100644 --- a/include/simfil/model/model.h +++ b/include/simfil/model/model.h @@ -2,6 +2,7 @@ #pragma once #include "simfil/model/string-pool.h" +#include "simfil/model/schema.h" #include "simfil/byte-array.h" #include "tl/expected.hpp" #if defined(SIMFIL_WITH_MODEL_JSON) @@ -135,6 +136,16 @@ class Model : public std::enable_shared_from_this } else { using ModelType = detail::ModelTypeOf; + + // Merged views may expose nodes owned by another model, e.g. + // overlay feature references returned through a base tile. The + // node address must be interpreted by the model that created it. + if (auto owner = node.owningModel(); owner && owner.get() != this) { + if (auto typedOwner = dynamic_cast(owner.get())) { + return resolveInternal(res::tag{}, *typedOwner, node); + } + } + #if !defined(NDEBUG) // In debug builds, validate the model type to catch misuse early. auto typedModel = dynamic_cast(this); @@ -274,7 +285,9 @@ class ModelPool : public Model size_t stringDataBytes = 0; size_t stringRangeBytes = 0; size_t objectMemberBytes = 0; + size_t objectSchemaBytes = 0; size_t arrayMemberBytes = 0; + size_t arraySchemaBytes = 0; [[nodiscard]] size_t totalBytes() const { @@ -284,7 +297,9 @@ class ModelPool : public Model + stringDataBytes + stringRangeBytes + objectMemberBytes - + arrayMemberBytes; + + objectSchemaBytes + + arrayMemberBytes + + arraySchemaBytes; } }; @@ -299,12 +314,18 @@ class ModelPool : public Model struct Impl; std::unique_ptr impl_; + [[nodiscard]] SchemaId objectSchemaId(ArrayIndex members) const; + auto setObjectSchemaId(ArrayIndex members, SchemaId schemaId) -> tl::expected; + [[nodiscard]] SchemaId arraySchemaId(ArrayIndex members) const; + auto setArraySchemaId(ArrayIndex members, SchemaId schemaId) -> tl::expected; + /** * Protected object/array member storage access, * so derived ModelPools can create Object/Array-derived nodes. */ Object::Storage& objectMemberStorage(); [[nodiscard]] Object::Storage const& objectMemberStorage() const; + Array::Storage& arrayMemberStorage(); [[nodiscard]] Array::Storage const& arrayMemberStorage() const; }; diff --git a/include/simfil/model/nodes.h b/include/simfil/model/nodes.h index b9309c45..ff97b642 100644 --- a/include/simfil/model/nodes.h +++ b/include/simfil/model/nodes.h @@ -8,6 +8,7 @@ #include #include "arena.h" +#include "schema.h" #include "string-pool.h" #include "simfil/byte-array.h" #include "simfil/error.h" @@ -56,8 +57,9 @@ enum class ValueType Bytes, TransientObject, Object, - Array - // If you add types, update TypeFlags::flags bit size! + Array, + // End + LAST_ }; using ScalarValueType = std::variant< @@ -276,6 +278,9 @@ struct ModelNode /// Get an Object model's field names [[nodiscard]] virtual StringId keyAt(int64_t i) const; + /// Get the schema id for schema-aware container nodes, or NoSchemaId otherwise. + [[nodiscard]] virtual SchemaId schema() const; + /// Get the number of children [[nodiscard]] virtual uint32_t size() const; @@ -288,6 +293,9 @@ struct ModelNode /// True if the node points at a valid model and address. [[nodiscard]] inline bool isResolved() const {return model_ && addr_;} + /// Return the model that owns this node address. + [[nodiscard]] inline ModelConstPtr owningModel() const {return model_;} + /// Virtual destructor to allow polymorphism virtual ~ModelNode() = default; @@ -431,6 +439,7 @@ struct ModelNodeBase : public ModelNode [[nodiscard]] ModelNode::Ptr get(const StringId&) const override; [[nodiscard]] ModelNode::Ptr at(int64_t) const override; [[nodiscard]] StringId keyAt(int64_t) const override; + [[nodiscard]] SchemaId schema() const override; [[nodiscard]] uint32_t size() const override; bool iterate(IterCallback const&) const override {return true;} // NOLINT (allow discard) @@ -547,6 +556,9 @@ struct BaseArray : public MandatoryDerivedModelNodeBase bool forEach(std::function const& callback) const; + [[nodiscard]] SchemaId schema() const override; + auto setSchema(SchemaId schemaId) -> tl::expected; + [[nodiscard]] ValueType type() const override; [[nodiscard]] ModelNode::Ptr at(int64_t) const override; [[nodiscard]] uint32_t size() const override; @@ -610,6 +622,9 @@ struct BaseObject : public MandatoryDerivedModelNodeBase return addFieldInternal(name, static_cast(value)); } + [[nodiscard]] SchemaId schema() const override; + auto setSchema(SchemaId schemaId) -> tl::expected; + [[nodiscard]] ValueType type() const override; [[nodiscard]] ModelNode::Ptr at(int64_t) const override; [[nodiscard]] uint32_t size() const override; diff --git a/include/simfil/model/nodes.impl.h b/include/simfil/model/nodes.impl.h index d8a08e7e..09adbb8d 100644 --- a/include/simfil/model/nodes.impl.h +++ b/include/simfil/model/nodes.impl.h @@ -21,6 +21,18 @@ ValueType BaseArray::type() const return ValueType::Array; } +template +SchemaId BaseArray::schema() const +{ + return model().arraySchemaId(members_); +} + +template +auto BaseArray::setSchema(SchemaId schemaId) -> tl::expected +{ + return model().setArraySchemaId(members_, schemaId); +} + template ModelNode::Ptr BaseArray::at(int64_t i) const { @@ -96,6 +108,18 @@ ValueType BaseObject::type() const return ValueType::Object; } +template +SchemaId BaseObject::schema() const +{ + return model().objectSchemaId(members_); +} + +template +auto BaseObject::setSchema(SchemaId schemaId) -> tl::expected +{ + return model().setObjectSchemaId(members_, schemaId); +} + template ModelNode::Ptr BaseObject::at(int64_t i) const { diff --git a/include/simfil/model/schema.h b/include/simfil/model/schema.h new file mode 100644 index 00000000..c131c3c2 --- /dev/null +++ b/include/simfil/model/schema.h @@ -0,0 +1,842 @@ +#pragma once + +#include "simfil/model/string-pool.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace simfil +{ + +class Schema; + +using SchemaId = std::uint16_t; +constexpr SchemaId NoSchemaId = SchemaId{0}; +constexpr SchemaId MaxSchemaId = SchemaId{std::numeric_limits::max()}; + +/** + * One segment in a schema-derived query path. + * + * Field segments address object members. Array-element segments represent the + * non-recursive `*` operator needed to traverse array elements precisely. + */ +struct SchemaPathSegment +{ + enum class Kind { + Field, + ArrayElement, + }; + + Kind kind = Kind::Field; + StringId field = 0; + + auto operator<=>(const SchemaPathSegment&) const = default; +}; + +/** Sequence of schema path segments from a root schema to a reachable value. */ +using SchemaPath = std::vector; + +/** + * Concept defining a callback to query a Schema* by SchemaId. + */ +template +concept QuerySchemaFn = requires(const Fn& fn) { + { fn(SchemaId{}) } -> std::convertible_to; +}; +template +concept QueryMutableSchemaFn = requires(const Fn& fn) { + { fn(SchemaId{}) } -> std::convertible_to; +}; + +/** + * + */ +class Schema +{ +public: + /** Schema kind */ + enum class Kind { + Object, + Array, + Value, + }; + + /** Finalization state */ + enum class State { + Dirty, + Finalizing, + Clean, + }; + + using SchemaIdStack = sfl::small_vector; + + virtual ~Schema() = default; + + /** + * Return this schemas kind. + */ + virtual auto kind() const -> Kind = 0; + + /** + * Returns true if this schema or any of the schemas it refers to + * can possibly contain the given field. + */ + virtual auto canHaveField(StringId fieldId) const -> bool = 0; + + /** + * Returns true if this schema or any of the schemas it refers to + * can possibly contain the given enum-like string symbol. + */ + virtual auto canHaveEnumSymbol(StringId symbolId) const -> bool + { + return false; + } + + /** + * Finalize this schema and all schemas it refers to. + */ + virtual auto finalize(const std::function& queryFn) -> State + { + return State::Clean; + } + + /** + * @return All nested field names. + */ + virtual auto nestedFields() const & -> std::span = 0; + + /** + * Return field names directly available on this schema node. + * + * Completion uses this to suggest fields valid at the current node without + * also suggesting fields that only occur deeper in the schema graph. + */ + virtual auto directFields() const & -> std::span + { + return nestedFields(); + } + + /** + * @return All nested enum-like string symbols. + */ + virtual auto nestedEnumSymbols() const & -> std::span + { + return {}; + } + + /** + * Return enum-like string symbols accepted directly by this schema node. + * + * Unlike nestedEnumSymbols(), this does not include descendants and is used + * to derive precise schema paths for schema-backed rewrites. + */ + virtual auto directEnumSymbols() const & -> std::span + { + return {}; + } + + /** + * Enumerate precise paths to all fields with the requested name. + */ + static auto fieldPaths(SchemaId root, + const std::function& queryFn, + StringId field) -> std::vector + { + std::vector paths; + SchemaIdStack visited; + SchemaPath current; + collectFieldPaths(root, queryFn, field, visited, current, paths); + sortUniquePaths(paths); + return paths; + } + + /** + * Enumerate precise paths to all values that can hold the enum-like symbol. + */ + static auto enumSymbolPaths(SchemaId root, + const std::function& queryFn, + StringId symbol) -> std::vector + { + std::vector paths; + SchemaIdStack visited; + SchemaPath current; + collectEnumSymbolPaths(root, queryFn, symbol, visited, current, paths); + sortUniquePaths(paths); + return paths; + } + + /** + * Return paths that should be compared with the supplied string symbol when + * the symbol appears as a full standalone query. Embedders can use this for + * schema-native aliases such as attribute type-code predicates. + */ + virtual auto symbolEqualityPaths( + StringId, + const std::function&) const -> std::vector + { + return {}; + } + + /** + * Return scalar field paths that should replace the supplied symbol when it + * appears as a standalone operand inside a larger expression. + */ + virtual auto scalarFieldPathsForSymbol( + StringId, + const std::function&) const -> std::vector + { + return {}; + } + + /** + * Return the first reachable scalar path below the supplied schema. + */ + static auto firstScalarFieldPath( + SchemaId root, + const std::function& queryFn) -> std::optional + { + SchemaIdStack visited; + SchemaPath current; + return firstScalarFieldPath(root, queryFn, visited, current); + } + + /** + * Return true once `canHaveField` is backed by finalized field caches. + */ + virtual auto finalized() const -> bool + { + return true; + } + + /** + * Monotonic counter for cache invalidation after schema mutations. + */ + virtual auto revision() const -> std::uint64_t + { + return 0; + } + +protected: + /** + * Append all fields reachable from this schema without relying on cached + * finalization state. This lets cyclic schema graphs still produce an exact + * field set by cutting recursion at already visited schema ids. + */ + virtual auto collectNestedFields(const std::function& queryFn, + SchemaIdStack& visited, + std::vector& fields) const -> void = 0; + + /** + * Append all enum-like string symbols reachable from this schema without + * relying on cached finalization state. + */ + virtual auto collectNestedEnumSymbols(const std::function& queryFn, + SchemaIdStack& visited, + std::vector& symbols) const -> void + { + } + + /** + * Visit fields declared directly by this schema and their possible child + * schemas. The default is empty for scalar schemas. + */ + virtual auto forEachDirectField( + const std::function)>&) const -> void + { + } + + /** + * Visit possible array element schemas. The default is empty for non-arrays. + */ + virtual auto forEachElementSchema(const std::function&) const -> void + { + } + + /** + * Recursively collect schema paths to matching fields. + */ + static auto collectFieldPaths(SchemaId schemaId, + const std::function& queryFn, + StringId field, + SchemaIdStack& visited, + SchemaPath& current, + std::vector& paths) -> void + { + if (schemaId == NoSchemaId || std::ranges::find(visited, schemaId) != visited.end()) + return; + + auto const* schema = queryFn(schemaId); + if (!schema) + return; + + visited.push_back(schemaId); + + schema->forEachDirectField([&](StringId directField, std::span childSchemas) { + current.push_back({SchemaPathSegment::Kind::Field, directField}); + if (directField == field) + paths.push_back(current); + for (auto childSchemaId : childSchemas) + collectFieldPaths(childSchemaId, queryFn, field, visited, current, paths); + current.pop_back(); + }); + + schema->forEachElementSchema([&](SchemaId elementSchemaId) { + current.push_back({SchemaPathSegment::Kind::ArrayElement, 0}); + collectFieldPaths(elementSchemaId, queryFn, field, visited, current, paths); + current.pop_back(); + }); + + visited.pop_back(); + } + + /** + * Recursively collect schema paths to values accepting a matching enum-like + * string symbol. + */ + static auto collectEnumSymbolPaths(SchemaId schemaId, + const std::function& queryFn, + StringId symbol, + SchemaIdStack& visited, + SchemaPath& current, + std::vector& paths) -> void + { + if (schemaId == NoSchemaId || std::ranges::find(visited, schemaId) != visited.end()) + return; + + auto const* schema = queryFn(schemaId); + if (!schema) + return; + + visited.push_back(schemaId); + + for (auto directSymbol : schema->directEnumSymbols()) { + if (directSymbol == symbol) + paths.push_back(current); + } + + schema->forEachDirectField([&](StringId directField, std::span childSchemas) { + current.push_back({SchemaPathSegment::Kind::Field, directField}); + for (auto childSchemaId : childSchemas) + collectEnumSymbolPaths(childSchemaId, queryFn, symbol, visited, current, paths); + current.pop_back(); + }); + + schema->forEachElementSchema([&](SchemaId elementSchemaId) { + current.push_back({SchemaPathSegment::Kind::ArrayElement, 0}); + collectEnumSymbolPaths(elementSchemaId, queryFn, symbol, visited, current, paths); + current.pop_back(); + }); + + visited.pop_back(); + } + + /** + * Recursively find the first scalar field path in schema declaration order. + */ + static auto firstScalarFieldPath(SchemaId schemaId, + const std::function& queryFn, + SchemaIdStack& visited, + SchemaPath& current) -> std::optional + { + if (schemaId == NoSchemaId || std::ranges::find(visited, schemaId) != visited.end()) + return std::nullopt; + + auto const* schema = queryFn(schemaId); + if (!schema) + return std::nullopt; + + if (schema->kind() == Kind::Value) + return current; + + visited.push_back(schemaId); + + if (schema->kind() == Kind::Object) { + std::optional result; + schema->forEachDirectField([&](StringId directField, std::span childSchemas) { + if (result) + return; + + current.push_back({SchemaPathSegment::Kind::Field, directField}); + if (childSchemas.empty()) { + result = current; + } + else { + for (auto childSchemaId : childSchemas) { + result = firstScalarFieldPath(childSchemaId, queryFn, visited, current); + if (result) + break; + } + } + current.pop_back(); + }); + visited.pop_back(); + return result; + } + + std::optional result; + schema->forEachElementSchema([&](SchemaId elementSchemaId) { + if (result) + return; + + current.push_back({SchemaPathSegment::Kind::ArrayElement, 0}); + result = firstScalarFieldPath(elementSchemaId, queryFn, visited, current); + current.pop_back(); + }); + + visited.pop_back(); + return result; + } + + /** + * Keep path rewrites deterministic and avoid duplicate paths from combined + * schemas or shared references. + */ + static auto sortUniquePaths(std::vector& paths) -> void + { + std::ranges::sort(paths); + auto duplicates = std::ranges::unique(paths); + paths.erase(duplicates.begin(), duplicates.end()); + } + + /** + * Append reachable values through a schema id, using finalized child + * caches when possible and falling back to raw graph traversal for cycles. + */ + template + static auto appendSchemaValues(SchemaId schemaId, + const std::function& queryFn, + SchemaIdStack& visited, + std::vector& values, + CachedValuesFn&& cachedValues, + CollectValuesFn&& collectValues) -> void + { + if (schemaId == NoSchemaId || std::ranges::find(visited, schemaId) != visited.end()) + return; + + auto* schema = queryFn(schemaId); + if (!schema) + return; + + visited.push_back(schemaId); + + if (schema->finalize(queryFn) == State::Clean) { + auto childValues = std::invoke(cachedValues, *schema); + values.insert(values.end(), childValues.begin(), childValues.end()); + return; + } + + std::invoke(collectValues, *schema, queryFn, visited, values); + } + + /** + * Append fields reachable through a schema id. + */ + static auto appendSchemaFields(SchemaId schemaId, + const std::function& queryFn, + SchemaIdStack& visited, + std::vector& fields) -> void + { + appendSchemaValues( + schemaId, + queryFn, + visited, + fields, + [](const Schema& schema) { return schema.nestedFields(); }, + [](const Schema& schema, const auto& query, auto& visitedSchemas, auto& values) { + schema.collectNestedFields(query, visitedSchemas, values); + }); + } + + /** + * Append enum-like string symbols reachable through a schema id. + */ + static auto appendSchemaEnumSymbols(SchemaId schemaId, + const std::function& queryFn, + SchemaIdStack& visited, + std::vector& symbols) -> void + { + appendSchemaValues( + schemaId, + queryFn, + visited, + symbols, + [](const Schema& schema) { return schema.nestedEnumSymbols(); }, + [](const Schema& schema, const auto& query, auto& visitedSchemas, auto& values) { + schema.collectNestedEnumSymbols(query, visitedSchemas, values); + }); + } + + /** + * Sort ids and remove duplicates. + */ + static auto sortUnique(std::vector& values) -> void + { + std::ranges::sort(values); + auto duplicates = std::ranges::unique(values); + values.erase(duplicates.begin(), duplicates.end()); + } + + /** + * Shared finalization implementation for schemas that cache descendant + * fields and enum-like string symbols. + */ + static auto finalizeReachableMetadata(State& state, + std::vector& flatFields, + std::vector& flatEnumSymbols, + const std::function& queryFn, + const Schema& schema) -> State + { + if (state == State::Clean || state == State::Finalizing) + return state; + + state = State::Finalizing; + flatFields.clear(); + flatEnumSymbols.clear(); + + SchemaIdStack visitedFields; + schema.collectNestedFields(queryFn, visitedFields, flatFields); + sortUnique(flatFields); + + SchemaIdStack visitedEnumSymbols; + schema.collectNestedEnumSymbols(queryFn, visitedEnumSymbols, flatEnumSymbols); + sortUnique(flatEnumSymbols); + + state = State::Clean; + return State::Clean; + } + + /** + * Shared membership test; dirty schemas remain conservative. + */ + static auto containsField(State state, const std::vector& flatFields, StringId field) -> bool + { + if (state != State::Clean) + return true; + + auto iter = std::ranges::lower_bound(flatFields, field); + return iter != flatFields.end() && *iter == field; + } +}; + +/** + * Schema for object nodes. + * + * Stores direct fields and optional child schema ids per field. After + * `finalize()` it also caches all reachable child fields. + */ +class ObjectSchema : public Schema +{ +public: + struct FieldSummary { + StringId field = 0; + sfl::small_vector schemas; + + auto operator<=>(const FieldSummary& other) const + { + return field <=> other.field; + } + }; + + auto kind() const -> Kind override + { + return Kind::Object; + } + + auto canHaveField(StringId field) const -> bool override + { + return containsField(state_, flatFields_, field); + } + + auto canHaveEnumSymbol(StringId symbol) const -> bool override + { + return containsField(state_, flatEnumSymbols_, symbol); + } + + /** + * Add a direct field and optional child schemas reachable through it. + */ + auto addField(StringId field, std::initializer_list schemas = {}) -> void + { + FieldSummary summary; + summary.field = field; + summary.schemas.insert(summary.schemas.end(), schemas.begin(), schemas.end()); + fields_.push_back(std::move(summary)); + directFields_.push_back(field); + state_ = State::Dirty; + ++revision_; + } + + /** + * Recompute the cached descendant field set from this schema and all + * reachable child schemas. + */ + auto finalize(const std::function& lookup) -> State override + { + return finalizeReachableMetadata(state_, flatFields_, flatEnumSymbols_, lookup, *this); + } + + auto fields() const & -> std::span + { + return {fields_.begin(), fields_.end()}; + } + + auto forEachDirectField( + const std::function)>& fn) const -> void override + { + for (auto const& field : fields_) + fn(field.field, {field.schemas.begin(), field.schemas.end()}); + } + + auto nestedFields() const & -> std::span override + { + return {flatFields_.cbegin(), flatFields_.cend()}; + } + + auto directFields() const & -> std::span override + { + return {directFields_.cbegin(), directFields_.cend()}; + } + + auto nestedEnumSymbols() const & -> std::span override + { + return {flatEnumSymbols_.cbegin(), flatEnumSymbols_.cend()}; + } + + auto finalized() const -> bool override + { + return state_ == State::Clean; + } + + auto revision() const -> std::uint64_t override + { + return revision_; + } + +private: + auto collectNestedFields(const std::function& lookup, + SchemaIdStack& visited, + std::vector& fields) const -> void override + { + for (const auto& field : fields_) { + fields.push_back(field.field); + for (const auto& fieldSchemaId : field.schemas) + appendSchemaFields(fieldSchemaId, lookup, visited, fields); + } + } + + auto collectNestedEnumSymbols(const std::function& lookup, + SchemaIdStack& visited, + std::vector& symbols) const -> void override + { + for (const auto& field : fields_) { + for (const auto& fieldSchemaId : field.schemas) + appendSchemaEnumSymbols(fieldSchemaId, lookup, visited, symbols); + } + } + + sfl::small_vector fields_; + + std::vector directFields_; + std::vector flatFields_; // Ordered! + std::vector flatEnumSymbols_; // Ordered! + std::uint64_t revision_ = 0; + State state_ = State::Dirty; +}; + +/** + * Schema for scalar value nodes. + * + * Stores optional enum-like string symbols for schema-aware completion and + * parsing. Value schemas never contribute nested fields. + */ +class ValueSchema : public Schema +{ +public: + auto kind() const -> Kind override + { + return Kind::Value; + } + + auto canHaveField(StringId) const -> bool override + { + return false; + } + + auto canHaveEnumSymbol(StringId symbol) const -> bool override + { + return containsField(state_, enumSymbols_, symbol); + } + + /** + * Add an enum-like string symbol accepted by this value schema. + */ + auto addEnumSymbol(StringId symbol) -> void + { + enumSymbols_.push_back(symbol); + state_ = State::Dirty; + ++revision_; + } + + auto finalize(const std::function&) -> State override + { + if (state_ == State::Clean || state_ == State::Finalizing) + return state_; + + state_ = State::Finalizing; + sortUnique(enumSymbols_); + state_ = State::Clean; + return State::Clean; + } + + auto nestedFields() const & -> std::span override + { + return {}; + } + + auto nestedEnumSymbols() const & -> std::span override + { + return {enumSymbols_.cbegin(), enumSymbols_.cend()}; + } + + auto directEnumSymbols() const & -> std::span override + { + return {enumSymbols_.cbegin(), enumSymbols_.cend()}; + } + + auto finalized() const -> bool override + { + return state_ == State::Clean; + } + + auto revision() const -> std::uint64_t override + { + return revision_; + } + +private: + auto collectNestedFields(const std::function&, + SchemaIdStack&, + std::vector&) const -> void override + { + } + + auto collectNestedEnumSymbols(const std::function&, + SchemaIdStack&, + std::vector& symbols) const -> void override + { + symbols.insert(symbols.end(), enumSymbols_.begin(), enumSymbols_.end()); + } + + std::vector enumSymbols_; // Ordered after finalize(). + std::uint64_t revision_ = 0; + State state_ = State::Dirty; +}; + +/** + * Schema for array nodes. + * + * Stores the set of possible element schemas. After `finalize()` it caches + * all fields reachable through any element schema. + */ +class ArraySchema : public Schema +{ +public: + auto kind() const -> Kind override + { + return Kind::Array; + } + + auto canHaveField(StringId field) const -> bool override + { + return containsField(state_, flatFields_, field); + } + + auto canHaveEnumSymbol(StringId symbol) const -> bool override + { + return containsField(state_, flatEnumSymbols_, symbol); + } + + /** + * Add possible schemas for elements contained in the array. + */ + auto addElementSchemas(std::initializer_list schemas) -> void + { + schemas_.insert(schemas_.end(), schemas.begin(), schemas.end()); + state_ = State::Dirty; + ++revision_; + } + + /** + * Recompute the cached descendant field set from all possible element + * schemas. + */ + auto finalize(const std::function& lookup) -> State override + { + return finalizeReachableMetadata(state_, flatFields_, flatEnumSymbols_, lookup, *this); + } + + auto nestedFields() const & -> std::span override + { + return {flatFields_.cbegin(), flatFields_.cend()}; + } + + auto nestedEnumSymbols() const & -> std::span override + { + return {flatEnumSymbols_.cbegin(), flatEnumSymbols_.cend()}; + } + + auto finalized() const -> bool override + { + return state_ == State::Clean; + } + + auto revision() const -> std::uint64_t override + { + return revision_; + } + + auto elementSchemas() const & -> std::span + { + return {schemas_.begin(), schemas_.end()}; + } + + auto forEachElementSchema(const std::function& fn) const -> void override + { + for (auto schemaId : schemas_) + fn(schemaId); + } + +private: + auto collectNestedFields(const std::function& lookup, + SchemaIdStack& visited, + std::vector& fields) const -> void override + { + for (const auto& schemaId : schemas_) + appendSchemaFields(schemaId, lookup, visited, fields); + } + + auto collectNestedEnumSymbols(const std::function& lookup, + SchemaIdStack& visited, + std::vector& symbols) const -> void override + { + for (const auto& schemaId : schemas_) + appendSchemaEnumSymbols(schemaId, lookup, visited, symbols); + } + + sfl::small_vector schemas_; + std::vector flatFields_; // Ordered! + std::vector flatEnumSymbols_; // Ordered! + std::uint64_t revision_ = 0; + State state_ = State::Dirty; +}; + +} diff --git a/include/simfil/parser.h b/include/simfil/parser.h index 0999e500..43fe4c2b 100644 --- a/include/simfil/parser.h +++ b/include/simfil/parser.h @@ -59,7 +59,6 @@ class Parser }; struct Context { - Expr::ExprId id = 0; bool inPath = false; }; @@ -104,11 +103,6 @@ class Parser auto mode() const -> Mode; auto relaxed() const -> bool; - /** - * Get the next expression id. - */ - auto nextId() -> Expr::ExprId; - Context ctx; Environment* const env; std::unordered_map prefixParsers; diff --git a/include/simfil/simfil.h b/include/simfil/simfil.h index 2b98016d..629b15fd 100644 --- a/include/simfil/simfil.h +++ b/include/simfil/simfil.h @@ -2,8 +2,11 @@ #pragma once +#include +#include #include #include +#include #include #include "simfil/expression.h" @@ -11,12 +14,87 @@ #include "simfil/diagnostics.h" #include "simfil/value.h" #include "simfil/error.h" +#include "simfil/model/schema.h" namespace simfil { struct ModelNode; +/** + * Rewrite families available during compilation. + */ +enum class RewriteMode { + None, + Schema, +}; + +/** + * Options used while parsing and rewriting a query. + */ +struct CompileOptions +{ + bool any = true; + RewriteMode rewriteMode = RewriteMode::None; + SchemaId rootSchema = NoSchemaId; +}; + +/** + * One schema path referenced by a compiled expression. + * + * The path is expressed relative to the root schema supplied to + * `referencedSchemaPaths`. If `viaWildcard` is set, the path came from a + * recursive wildcard-field lookup such as `**.foo`. + */ +struct ReferencedSchemaPath +{ + SchemaPath path; + SourceLocation location; + bool viaWildcard = false; + std::optional equalsStringLiteral; +}; + +/** + * Schema references discovered by static AST inspection. + * + * The flags make the result conservative: callers can reject automatic scope + * decisions when the query contains broad wildcards or field access that cannot + * be tied to concrete schema paths. + */ +struct ReferencedSchemaPaths +{ + std::vector paths; + bool hasDynamicAccess = false; + bool hasUnresolvedAccess = false; + bool hasBroadWildcardAccess = false; +}; + +/** + * One static `field == "value"` comparison discovered in a compiled query. + * + * Only direct positive equality comparisons are reported. The field name is + * the exact AST field node text and is not interpreted by simfil. + */ +struct FieldStringComparison +{ + std::string fieldName; + std::string value; +}; + +/** + * Schema-independent query terms extracted from a compiled AST. + * + * `leafFields` contains the final field-like segment of field/path access, + * including recursive wildcard field names such as `**.speedLimitKmh`. + * `stringLiterals` contains string constants that appeared in the query. + */ +struct ReferencedQueryTerms +{ + std::set leafFields; + std::set stringLiterals; + std::vector positiveFieldStringComparisons; +}; + /** * Compile expression `src`. * Param: @@ -26,10 +104,45 @@ struct ModelNode; * Param: * any If true, wrap expression with call to `any(...)`. * Param: - * autoWildcard If true, expand constant expressions to `** == `. + * autoWildcard Deprecated compatibility switch. Ignored; use CompileOptions + * with RewriteMode::Schema and a root schema for rewrites. */ auto compile(Environment& env, std::string_view query, bool any = true, bool autoWildcard = false) -> tl::expected; +/** + * Compile expression `src` with explicit options. + * + * If rootSchema is set and schema rewrites are enabled, shorthand field/enum + * queries are classified through the schema instead of lexical heuristics. + */ +auto compile(Environment& env, std::string_view query, CompileOptions options) -> tl::expected; + +/** + * Collect schema paths that are referenced by a compiled query. + * + * This is static analysis over the AST, not runtime evaluation: both sides of + * `and`/`or` are inspected, and schema-aware rewrites are resolved to the exact + * paths they can touch. + */ +auto referencedSchemaPaths(Environment& env, const AST& ast, SchemaId rootSchema) -> tl::expected; + +/** + * Return the symbol represented by a whole-query bare field or string literal. + * + * This is intentionally AST-based: callers that need exact-query shorthand + * handling should not re-tokenize the source string with ad-hoc rules. + */ +auto standaloneQuerySymbol(Environment& env, std::string_view query) -> tl::expected, Error>; + +/** + * Collect schema-independent terms referenced by a compiled query AST. + * + * This uses simfil's parser/rewriter output, but deliberately does not require + * a schema root. Callers can use the returned terms with their own schema + * indices without re-tokenizing the query string. + */ +auto referencedQueryTerms(const AST& ast) -> ReferencedQueryTerms; + /** * Evaluate compiled expression. * Param: diff --git a/include/simfil/value.h b/include/simfil/value.h index 3805a75d..2ee85853 100644 --- a/include/simfil/value.h +++ b/include/simfil/value.h @@ -5,6 +5,7 @@ #include #include #include +#include #include "model/nodes.h" #include "simfil/byte-array.h" @@ -104,11 +105,11 @@ inline auto valueType2String(ValueType t) -> const char* */ struct TypeFlags { - std::bitset<10> flags; + std::bitset(static_cast>(ValueType::LAST_))> flags; auto test(ValueType type) const { - return flags.test(static_cast>(type)); + return flags.test(static_cast(static_cast>(type))); } auto test(TypeFlags other) const diff --git a/repl/repl.cpp b/repl/repl.cpp index d74e6f89..d97cb8d6 100644 --- a/repl/repl.cpp +++ b/repl/repl.cpp @@ -17,7 +17,6 @@ #include #include #include -#include #include #if defined(WITH_READLINE) @@ -36,6 +35,7 @@ struct bool auto_wildcard = false; bool verbose = true; bool multi_threaded = true; + bool schema = true; } options; static void set_option(const std::string& option, bool& flag, std::string_view cmd) @@ -197,6 +197,8 @@ static void show_help() << "Options:\n" << " -D \n" << " Define a constant variable, set to value\n" + << " -s SCHEMA\n" + << " Use a JSON-Schema from SCHEMA file\n" << " -h\n" << " Show this help" << "\n"; @@ -240,6 +242,17 @@ int main(int argc, char *argv[]) #endif }; + auto load_schema = [](std::string_view) { + std::cerr << "Schema support is not implemented!\n"; + }; + + auto take_option_value = [&argv](std::string_view arg) -> std::string_view { + arg.remove_prefix(2); + if (!arg.empty() || argv[1] == nullptr) + return arg; + return *++argv; + }; + auto tail_args = false; while (*++argv != nullptr) { std::string_view arg = *argv; @@ -251,11 +264,17 @@ int main(int argc, char *argv[]) case 'h': show_help(); return 0; - case 'D': - arg.remove_prefix(2); - if (arg.empty()) { - arg = *++argv; + case 's': + arg = take_option_value(arg); + if (!arg.empty()) { + load_schema(arg); + } else { + std::cerr << "Missing schema file\n"; + return 1; } + break; + case 'D': + arg = take_option_value(arg); if (auto pos = arg.find('='); (pos != std::string::npos) && (pos > 0)) { constants.try_emplace(std::string(arg.substr(0, pos)), simfil::Value::make(std::string(arg.substr(pos + 1)))); } else { @@ -284,6 +303,7 @@ int main(int argc, char *argv[]) set_option("wildcard", options.auto_wildcard, cmd); set_option("verbose", options.verbose, cmd); set_option("mt", options.multi_threaded, cmd); + set_option("schema", options.schema, cmd); continue; } diff --git a/src/completion.cpp b/src/completion.cpp index 83f3be14..c654c78b 100644 --- a/src/completion.cpp +++ b/src/completion.cpp @@ -1,6 +1,7 @@ #include "completion.h" #include "expressions.h" +#include "simfil/model/schema.h" #include "simfil/model/string-pool.h" #include "simfil/result.h" #include "simfil/simfil.h" @@ -78,6 +79,141 @@ auto escapeKey(std::string_view str) return escaped; } +/// Escape a string value as a SIMFIL string literal completion. +auto escapeStringLiteral(std::string_view str) +{ + std::string escaped = "\""; + escaped.reserve(str.size() + 2); + + for (auto c : str) { + if (c == '"' || c == '\\') + escaped.push_back('\\'); + escaped.push_back(c); + } + + escaped.push_back('"'); + return escaped; +} + +/// Return the text that should be inserted for an enum-like string symbol. +auto enumSymbolCompletionText(std::string_view str) +{ + return escapeStringLiteral(str); +} + +/// Add one field completion candidate if it matches the current prefix. +auto completeFieldName( + std::string_view key, + std::string_view prefix, + bool caseSensitive, + simfil::Completion& comp, + simfil::SourceLocation loc) -> simfil::Result +{ + if (comp.size() >= comp.limit) + return simfil::Result::Stop; + + if (!startsWith(key, prefix, caseSensitive)) + return simfil::Result::Continue; + + if (needsEscaping(key)) + comp.add(escapeKey(key), loc, simfil::CompletionCandidate::Type::FIELD); + else + comp.add(std::string{key}, loc, simfil::CompletionCandidate::Type::FIELD); + + return simfil::Result::Continue; +} + +/// Complete fields listed by the node schema but not necessarily present in the model. +auto completeSchemaFields( + const simfil::Context& ctx, + const simfil::ModelNode& node, + std::string_view prefix, + simfil::Completion& comp, + simfil::SourceLocation loc) -> simfil::Result +{ + const auto* schema = ctx.env->querySchema(node.schema()); + if (!schema) + return simfil::Result::Continue; + + const auto caseSensitive = comp.options.smartCase && containsUppercaseCharacter(prefix); + for (auto fieldId : schema->directFields()) { + auto fieldName = ctx.env->strings()->resolve(fieldId); + if (!fieldName || fieldName->empty()) + continue; + + if (auto r = completeFieldName(*fieldName, prefix, caseSensitive, comp, loc); r != simfil::Result::Continue) + return r; + } + + return simfil::Result::Continue; +} + +/// Complete schema fields reachable below the current node as shorthand root tokens. +auto completeSchemaShorthandFields( + const simfil::Context& ctx, + const simfil::ModelNode& node, + std::string_view prefix, + simfil::Completion& comp, + simfil::SourceLocation loc) -> simfil::Result +{ + const auto* schema = ctx.env->querySchema(node.schema()); + if (!schema) + return simfil::Result::Continue; + + const auto directFields = schema->directFields(); + const auto isDirectField = [&](simfil::StringId fieldId) { + return std::ranges::find(directFields, fieldId) != directFields.end(); + }; + + const auto caseSensitive = comp.options.smartCase && containsUppercaseCharacter(prefix); + for (auto fieldId : schema->nestedFields()) { + if (comp.size() >= comp.limit) + return simfil::Result::Stop; + if (isDirectField(fieldId)) { + continue; + } + + auto fieldName = ctx.env->strings()->resolve(fieldId); + if (!fieldName || fieldName->empty()) + continue; + + if (auto r = completeFieldName(*fieldName, prefix, caseSensitive, comp, loc); r != simfil::Result::Continue) + return r; + } + + return simfil::Result::Continue; +} + +/// Complete enum-like string symbols reachable from the node schema. +auto completeSchemaEnumSymbols( + const simfil::Context& ctx, + const simfil::ModelNode& node, + std::string_view prefix, + simfil::Completion& comp, + simfil::SourceLocation loc) -> simfil::Result +{ + const auto* schema = ctx.env->querySchema(node.schema()); + if (!schema) + return simfil::Result::Continue; + + const auto caseSensitive = comp.options.smartCase && containsUppercaseCharacter(prefix); + for (auto symbolId : schema->nestedEnumSymbols()) { + if (comp.size() >= comp.limit) + return simfil::Result::Stop; + + auto symbol = ctx.env->strings()->resolve(symbolId); + if (!symbol || symbol->empty() || !startsWith(*symbol, prefix, caseSensitive)) + continue; + if (schema->canHaveField(symbolId)) { + continue; + } + + comp.add(enumSymbolCompletionText(*symbol), loc, simfil::CompletionCandidate::Type::CONSTANT); + } + + return simfil::Result::Continue; +} + /// Complete a function name staritng with `prefix` at `loc`. auto completeFunctions(const simfil::Context& ctx, std::string_view prefix, simfil::Completion& comp, simfil::SourceLocation loc) -> simfil::Result { @@ -94,13 +230,21 @@ auto completeFunctions(const simfil::Context& ctx, std::string_view prefix, simf } /// Complete a single WORD starting with `prefix` at `loc`. -auto completeWords(const simfil::Context& ctx, std::string_view prefix, simfil::Completion& comp, simfil::SourceLocation loc) -> simfil::Result +auto completeWords( + const simfil::Context& ctx, + std::string_view prefix, + simfil::Completion& comp, + simfil::SourceLocation loc, + const simfil::ModelNode* node = nullptr) -> simfil::Result { using simfil::Result; - // Generate completion candidates for uppercase string constants from string pool. + // String values from the model are string literals in SIMFIL. They are + // suggested quoted because bare words are fields unless schema compilation + // later proves that a token is an enum-like operand. auto stringPool = ctx.env->strings(); const auto& strings = stringPool->strings(); + const auto* schema = node ? ctx.env->querySchema(node->schema()) : nullptr; const auto caseSensitive = comp.options.smartCase && containsUppercaseCharacter(prefix); for (const auto& str : strings) { @@ -113,10 +257,24 @@ auto completeWords(const simfil::Context& ctx, std::string_view prefix, simfil:: }); if (isWORD && str.size() >= prefix.size() && startsWith(str, prefix, caseSensitive)) { - comp.add(str, loc, simfil::CompletionCandidate::Type::CONSTANT); + auto const stringId = stringPool->get(str); + if (schema && stringId != simfil::StringPool::Empty && schema->canHaveEnumSymbol(stringId)) { + // Schema enum values are completed below from schema metadata, + // which also prevents datasource string-pool duplicates. + continue; + } + comp.add(escapeStringLiteral(str), loc, simfil::CompletionCandidate::Type::CONSTANT); } } + if (node) { + if (auto r = completeSchemaShorthandFields(ctx, *node, prefix, comp, loc); r != Result::Continue) + return r; + + if (auto r = completeSchemaEnumSymbols(ctx, *node, prefix, comp, loc); r != Result::Continue) + return r; + } + return Result::Continue; } @@ -125,8 +283,8 @@ auto completeWords(const simfil::Context& ctx, std::string_view prefix, simfil:: namespace simfil { -CompletionFieldOrWordExpr::CompletionFieldOrWordExpr(ExprId id, std::string prefix, Completion* comp, const Token& token, bool inPath) - : Expr(id, token) +CompletionFieldOrWordExpr::CompletionFieldOrWordExpr(std::string prefix, Completion* comp, const Token& token, bool inPath) + : Expr(token) , prefix_(std::move(prefix)) , comp_(comp) , inPath_(inPath) @@ -151,12 +309,8 @@ auto CompletionFieldOrWordExpr::ieval(Context ctx, const Value& val, const Resul const auto caseSensitive = comp_->options.smartCase && containsUppercaseCharacter(prefix_); - // First we try to complete fields + // First we try to complete fields already present in the model. for (StringId id : node->fieldNames()) { - if (comp_->size() >= comp_->limit) { - return Result::Stop; - } - if (id == StringPool::Empty) continue; @@ -165,18 +319,16 @@ auto CompletionFieldOrWordExpr::ieval(Context ctx, const Value& val, const Resul continue; const auto& key = *keyPtr; - if (startsWith(key, prefix_, caseSensitive)) { - if (needsEscaping(key)) { - comp_->add(escapeKey(key), sourceLocation(), CompletionCandidate::Type::FIELD); - } else { - comp_->add(std::string{key}, sourceLocation(), CompletionCandidate::Type::FIELD); - } - } + if (auto r = completeFieldName(key, prefix_, caseSensitive, *comp_, sourceLocation()); r != Result::Continue) + return r; } + if (auto r = completeSchemaFields(ctx, *node, prefix_, *comp_, sourceLocation()); r != Result::Continue) + return r; + // If not in a path, we try to complete words and functions if (!inPath_) { - if (auto r = completeWords(ctx, prefix_, *comp_, sourceLocation()); r != Result::Continue) + if (auto r = completeWords(ctx, prefix_, *comp_, sourceLocation(), node); r != Result::Continue) return r; if (auto r = completeFunctions(ctx, prefix_, *comp_, sourceLocation()); r != Result::Continue) @@ -225,9 +377,8 @@ struct FindExpressionRange : ExprVisitor } -CompletionAndExpr::CompletionAndExpr(ExprId id, ExprPtr left, ExprPtr right, const Completion* comp) - : Expr(id) - , left_(std::move(left)) +CompletionAndExpr::CompletionAndExpr(ExprPtr left, ExprPtr right, const Completion* comp) + : left_(std::move(left)) , right_(std::move(right)) { FindExpressionRange leftRange; @@ -281,9 +432,8 @@ auto CompletionAndExpr::toString() const -> std::string return "(and ? ?)"; } -CompletionOrExpr::CompletionOrExpr(ExprId id, ExprPtr left, ExprPtr right, const Completion* comp) - : Expr(id) - , left_(std::move(left)) +CompletionOrExpr::CompletionOrExpr(ExprPtr left, ExprPtr right, const Completion* comp) + : left_(std::move(left)) , right_(std::move(right)) { FindExpressionRange leftRange; @@ -337,8 +487,8 @@ auto CompletionOrExpr::toString() const -> std::string return "(or ? ?)"; } -CompletionWordExpr::CompletionWordExpr(ExprId id, std::string prefix, Completion* comp, const Token& token) - : Expr(id, token) +CompletionWordExpr::CompletionWordExpr(std::string prefix, Completion* comp, const Token& token) + : Expr(token) , prefix_(std::move(prefix)) , comp_(comp) {} @@ -358,7 +508,8 @@ auto CompletionWordExpr::ieval(Context ctx, const Value& val, const ResultFn& re if (ctx.phase == Context::Phase::Compilation) return res(ctx, Value::undef()); - if (auto r = completeWords(ctx, prefix_, *comp_, sourceLocation()); r != Result::Continue) + auto node = val.node(); + if (auto r = completeWords(ctx, prefix_, *comp_, sourceLocation(), node); r != Result::Continue) return r; return res(ctx, Value::undef()); diff --git a/src/completion.h b/src/completion.h index 8f9c51ba..13b8ad0e 100644 --- a/src/completion.h +++ b/src/completion.h @@ -45,7 +45,7 @@ struct Completion class CompletionFieldOrWordExpr : public Expr { public: - CompletionFieldOrWordExpr(ExprId id, std::string prefix, Completion* comp, const Token& token, bool inPath); + CompletionFieldOrWordExpr(std::string prefix, Completion* comp, const Token& token, bool inPath); auto type() const -> Type override; auto ieval(Context ctx, const Value& value, const ResultFn& result) const -> tl::expected override; @@ -60,7 +60,7 @@ class CompletionFieldOrWordExpr : public Expr class CompletionAndExpr : public Expr { public: - CompletionAndExpr(ExprId id, ExprPtr left, ExprPtr right, const Completion* comp); + CompletionAndExpr(ExprPtr left, ExprPtr right, const Completion* comp); auto type() const -> Type override; auto ieval(Context ctx, const Value& val, const ResultFn& res) const -> tl::expected override; @@ -73,7 +73,7 @@ class CompletionAndExpr : public Expr class CompletionOrExpr : public Expr { public: - CompletionOrExpr(ExprId id, ExprPtr left, ExprPtr right, const Completion* comp); + CompletionOrExpr(ExprPtr left, ExprPtr right, const Completion* comp); auto type() const -> Type override; auto ieval(Context ctx, const Value& val, const ResultFn& res) const -> tl::expected override; @@ -86,7 +86,7 @@ class CompletionOrExpr : public Expr class CompletionWordExpr : public Expr { public: - CompletionWordExpr(ExprId id, std::string prefix, Completion* comp, const Token& token); + CompletionWordExpr(std::string prefix, Completion* comp, const Token& token); auto type() const -> Type override; auto constant() const -> bool override; diff --git a/src/diagnostics.cpp b/src/diagnostics.cpp index 062f4fdf..bce1fb68 100644 --- a/src/diagnostics.cpp +++ b/src/diagnostics.cpp @@ -253,6 +253,13 @@ auto Diagnostics::prepareIndices(const Expr& ast) -> void indices_[e.id()] = fieldIndex_++; } + auto visit(const WildcardFieldExpr& e) -> void override { + ExprVisitor::visit(e); + if (e.id() >= indices_.size()) + indices_.resize(e.id() + 1, Diagnostics::InvalidIndex); + indices_[e.id()] = fieldIndex_++; + } + auto visitComparisonOperator(const ComparisonExprBase& e) -> void { if (e.id() >= indices_.size()) diff --git a/src/environment.cpp b/src/environment.cpp index 88bbd07b..279a9656 100644 --- a/src/environment.cpp +++ b/src/environment.cpp @@ -59,6 +59,13 @@ auto Environment::strings() const -> std::shared_ptr { return stringPool; } +auto Environment::querySchema(SchemaId schemaId) const -> const Schema* +{ + if (!querySchemaCallback || schemaId == NoSchemaId) + return nullptr; + return querySchemaCallback(schemaId); +} + Context::Context(Environment* env, Diagnostics* diag, Context::Phase phase) : env(env) , diag(diag) diff --git a/src/expression-visitor.cpp b/src/expression-visitor.cpp index ec5a2acf..e188a580 100644 --- a/src/expression-visitor.cpp +++ b/src/expression-visitor.cpp @@ -13,6 +13,10 @@ ExprVisitor::~ExprVisitor() = default; void ExprVisitor::visit(const Expr& e) { index_++; + + const auto count = e.numChildren(); + for (auto i = 0; i < count; ++i) + e.childAt(i)->accept(*this); } void ExprVisitor::visit(const WildcardExpr& expr) @@ -38,58 +42,36 @@ void ExprVisitor::visit(const ConstExpr& expr) void ExprVisitor::visit(const SubscriptExpr& expr) { visit(static_cast(expr)); - - if (expr.left_) - expr.left_->accept(*this); - if (expr.index_) - expr.index_->accept(*this); } void ExprVisitor::visit(const SubExpr& expr) { visit(static_cast(expr)); - - if (expr.left_) - expr.left_->accept(*this); - if (expr.sub_) - expr.sub_->accept(*this); } void ExprVisitor::visit(const AnyExpr& expr) { visit(static_cast(expr)); - - for (const auto& arg : expr.args_) - if (arg) - arg->accept(*this); } void ExprVisitor::visit(const EachExpr& expr) { visit(static_cast(expr)); - - for (const auto& arg : expr.args_) - if (arg) - arg->accept(*this); } void ExprVisitor::visit(const CallExpression& expr) { visit(static_cast(expr)); - - for (const auto& arg : expr.args_) - if (arg) - arg->accept(*this); } void ExprVisitor::visit(const PathExpr& expr) { visit(static_cast(expr)); +} - if (expr.left_) - expr.left_->accept(*this); - if (expr.right_) - expr.right_->accept(*this); +void ExprVisitor::visit(const PathAlternativesExpr& expr) +{ + visit(static_cast(expr)); } void ExprVisitor::visit(const FieldExpr& expr) @@ -97,50 +79,34 @@ void ExprVisitor::visit(const FieldExpr& expr) visit(static_cast(expr)); } -void ExprVisitor::visit(const UnpackExpr& expr) +void ExprVisitor::visit(const WildcardFieldExpr& expr) { visit(static_cast(expr)); +} - if (expr.sub_) - expr.sub_->accept(*this); +void ExprVisitor::visit(const UnpackExpr& expr) +{ + visit(static_cast(expr)); } void ExprVisitor::visit(const UnaryWordOpExpr& expr) { visit(static_cast(expr)); - - if (expr.left_) - expr.left_->accept(*this); } void ExprVisitor::visit(const BinaryWordOpExpr& expr) { visit(static_cast(expr)); - - if (expr.left_) - expr.left_->accept(*this); - if (expr.right_) - expr.right_->accept(*this); } void ExprVisitor::visit(const AndExpr& expr) { visit(static_cast(expr)); - - if (expr.left_) - expr.left_->accept(*this); - if (expr.right_) - expr.right_->accept(*this); } void ExprVisitor::visit(const OrExpr& expr) { visit(static_cast(expr)); - - if (expr.left_) - expr.left_->accept(*this); - if (expr.right_) - expr.right_->accept(*this); } void ExprVisitor::visit(const BinaryExpr& e) diff --git a/src/expressions.cpp b/src/expressions.cpp index 82d77d2b..eeae9c44 100644 --- a/src/expressions.cpp +++ b/src/expressions.cpp @@ -2,7 +2,11 @@ #include "fmt/format.h" #include "simfil/environment.h" +#include "simfil/expression.h" +#include "simfil/model/string-pool.h" +#include "simfil/model/schema.h" #include "simfil/result.h" +#include "simfil/sourcelocation.h" #include "simfil/value.h" #include "simfil/function.h" #include "simfil/diagnostics.h" @@ -10,6 +14,7 @@ #include "fmt/core.h" #include "fmt/ranges.h" #include "src/expected.h" +#include #include #include @@ -77,8 +82,10 @@ auto boolify(const Value& v) -> bool } -WildcardExpr::WildcardExpr(ExprId id) - : Expr(id) +WildcardExpr::WildcardExpr() = default; + +WildcardExpr::WildcardExpr(SourceLocation location) + : Expr(location) {} auto WildcardExpr::type() const -> Type @@ -100,7 +107,8 @@ auto WildcardExpr::ieval(Context ctx, const Value& val, const ResultFn& ores) co [[nodiscard]] auto iterate(ModelNode const& val) noexcept -> tl::expected { - if (val.type() == ValueType::Null) [[unlikely]] + const auto valType = val.type(); + if (valType == ValueType::Null) [[unlikely]] return Result::Continue; auto result = res(ctx, Value::field(val)); @@ -143,9 +151,7 @@ auto WildcardExpr::toString() const -> std::string return "**"s; } -AnyChildExpr::AnyChildExpr(ExprId id) - : Expr(id) -{} +AnyChildExpr::AnyChildExpr() = default; auto AnyChildExpr::type() const -> Type { @@ -186,13 +192,12 @@ auto AnyChildExpr::toString() const -> std::string return "*"s; } -FieldExpr::FieldExpr(ExprId id, std::string name) - : Expr(id) - , name_(std::move(name)) +FieldExpr::FieldExpr(std::string name) + : name_(std::move(name)) {} -FieldExpr::FieldExpr(ExprId id, std::string name, const Token& token) - : Expr(id, token) +FieldExpr::FieldExpr(std::string name, const Token& token) + : Expr(token) , name_(std::move(name)) {} @@ -262,14 +267,22 @@ auto FieldExpr::toString() const -> std::string return name_; } -MultiConstExpr::MultiConstExpr(ExprId id, const std::vector& vec) - : Expr(id) - , values_(vec) +auto FieldExpr::isCurrent() const -> bool +{ + return name_ == "_"; +} + +auto FieldExpr::field() const -> std::string +{ + return name_; +} + +MultiConstExpr::MultiConstExpr(const std::vector& vec) + : values_(vec) {} -MultiConstExpr::MultiConstExpr(ExprId id, std::vector&& vec) - : Expr(id) - , values_(std::move(vec)) +MultiConstExpr::MultiConstExpr(std::vector&& vec) + : values_(std::move(vec)) {} auto MultiConstExpr::type() const -> Type @@ -308,8 +321,12 @@ auto MultiConstExpr::toString() const -> std::string return fmt::format("{{{}}}", fmt::join(items, " ")); } -ConstExpr::ConstExpr(ExprId id, Value value) - : Expr(id) +ConstExpr::ConstExpr(Value value) + : value_(std::move(value)) +{} + +ConstExpr::ConstExpr(Value value, const Token& token) + : Expr(token) , value_(std::move(value)) {} @@ -345,9 +362,8 @@ auto ConstExpr::value() const -> const Value& return value_; } -SubscriptExpr::SubscriptExpr(ExprId id, ExprPtr left, ExprPtr index) - : Expr(id) - , left_(std::move(left)) +SubscriptExpr::SubscriptExpr(ExprPtr left, ExprPtr index) + : left_(std::move(left)) , index_(std::move(index)) {} @@ -400,14 +416,28 @@ void SubscriptExpr::accept(ExprVisitor& v) const v.visit(*this); } +auto SubscriptExpr::numChildren() const -> std::size_t +{ + return 2; +} + +auto SubscriptExpr::childAt(std::size_t index) -> ExprPtr& +{ + return detail::childAtOrThrow(index, left_, index_); +} + +auto SubscriptExpr::childAt(std::size_t index) const -> const ExprPtr& +{ + return detail::childAtOrThrow(index, left_, index_); +} + auto SubscriptExpr::toString() const -> std::string { return fmt::format("(index {} {})", left_->toString(), index_->toString()); } -SubExpr::SubExpr(ExprId id, ExprPtr left, ExprPtr sub) - : Expr(id) - , left_(std::move(left)) +SubExpr::SubExpr(ExprPtr left, ExprPtr sub) + : left_(std::move(left)) , sub_(std::move(sub)) {} @@ -453,9 +483,23 @@ void SubExpr::accept(ExprVisitor& v) const v.visit(*this); } -AnyExpr::AnyExpr(ExprId id, std::vector args) - : Expr(id) - , args_(std::move(args)) +auto SubExpr::numChildren() const -> std::size_t +{ + return 2; +} + +auto SubExpr::childAt(std::size_t index) -> ExprPtr& +{ + return detail::childAtOrThrow(index, left_, sub_); +} + +auto SubExpr::childAt(std::size_t index) const -> const ExprPtr& +{ + return detail::childAtOrThrow(index, left_, sub_); +} + +AnyExpr::AnyExpr(std::vector args) + : args_(std::move(args)) {} auto AnyExpr::type() const -> Type @@ -497,6 +541,21 @@ auto AnyExpr::accept(ExprVisitor& v) const -> void v.visit(*this); } +auto AnyExpr::numChildren() const -> std::size_t +{ + return args_.size(); +} + +auto AnyExpr::childAt(std::size_t index) -> ExprPtr& +{ + return detail::childAtOrThrow(index, args_); +} + +auto AnyExpr::childAt(std::size_t index) const -> const ExprPtr& +{ + return detail::childAtOrThrow(index, args_); +} + auto AnyExpr::toString() const -> std::string { if (args_.empty()) @@ -509,9 +568,8 @@ auto AnyExpr::toString() const -> std::string return fmt::format("(any {})", fmt::join(items, " ")); } -EachExpr::EachExpr(ExprId id, std::vector args) - : Expr(id) - , args_(std::move(args)) +EachExpr::EachExpr(std::vector args) + : args_(std::move(args)) {} auto EachExpr::type() const -> Type @@ -552,6 +610,21 @@ auto EachExpr::accept(ExprVisitor& v) const -> void v.visit(*this); } +auto EachExpr::numChildren() const -> std::size_t +{ + return args_.size(); +} + +auto EachExpr::childAt(std::size_t index) -> ExprPtr& +{ + return detail::childAtOrThrow(index, args_); +} + +auto EachExpr::childAt(std::size_t index) const -> const ExprPtr& +{ + return detail::childAtOrThrow(index, args_); +} + auto EachExpr::toString() const -> std::string { if (args_.empty()) @@ -564,9 +637,8 @@ auto EachExpr::toString() const -> std::string return fmt::format("(each {})", fmt::join(items, " ")); } -CallExpression::CallExpression(ExprId id, std::string name, std::vector args) - : Expr(id) - , name_(std::move(name)) +CallExpression::CallExpression(std::string name, std::vector args) + : name_(std::move(name)) , args_(std::move(args)) {} @@ -606,6 +678,21 @@ void CallExpression::accept(ExprVisitor& v) const v.visit(*this); } +auto CallExpression::numChildren() const -> std::size_t +{ + return args_.size(); +} + +auto CallExpression::childAt(std::size_t index) -> ExprPtr& +{ + return detail::childAtOrThrow(index, args_); +} + +auto CallExpression::childAt(std::size_t index) const -> const ExprPtr& +{ + return detail::childAtOrThrow(index, args_); +} + auto CallExpression::toString() const -> std::string { if (args_.empty()) @@ -618,8 +705,16 @@ auto CallExpression::toString() const -> std::string return fmt::format("({} {})", name_, fmt::join(items, " ")); } -PathExpr::PathExpr(ExprId id, ExprPtr left, ExprPtr right) - : Expr(id) +PathExpr::PathExpr(ExprPtr left, ExprPtr right) + : left_(std::move(left)) + , right_(std::move(right)) +{ + assert(left_.get()); + assert(right_.get()); +} + +PathExpr::PathExpr(ExprPtr left, ExprPtr right, SourceLocation location) + : Expr(location) , left_(std::move(left)) , right_(std::move(right)) { @@ -667,14 +762,111 @@ void PathExpr::accept(ExprVisitor& v) const v.visit(*this); } +auto PathExpr::numChildren() const -> std::size_t +{ + return 2; +} + +auto PathExpr::childAt(std::size_t index) -> ExprPtr& +{ + return detail::childAtOrThrow(index, left_, right_); +} + +auto PathExpr::childAt(std::size_t index) const -> const ExprPtr& +{ + return detail::childAtOrThrow(index, left_, right_); +} + auto PathExpr::toString() const -> std::string { return fmt::format("(. {} {})", left_->toString(), right_->toString()); } -UnpackExpr::UnpackExpr(ExprId id, ExprPtr sub) - : Expr(id) - , sub_(std::move(sub)) +auto PathExpr::left() -> Expr* +{ + return left_.get(); +} + +auto PathExpr::left() const -> const Expr* +{ + return left_.get(); +} + +auto PathExpr::right() -> Expr* +{ + return right_.get(); +} + +auto PathExpr::right() const -> const Expr* +{ + return right_.get(); +} + +PathAlternativesExpr::PathAlternativesExpr(std::vector alternatives, SourceLocation location) + : Expr(location) + , alternatives_(std::move(alternatives)) +{} + +auto PathAlternativesExpr::type() const -> Type +{ + return Type::PATH; +} + +auto PathAlternativesExpr::ieval(Context ctx, const Value& val, const ResultFn& ores) const -> tl::expected +{ + CountedResultFn res(ores, ctx); + auto finalResult = Result::Continue; + + for (auto const& alternative : alternatives_) { + auto result = alternative->eval(ctx, val, LambdaResultFn([&res](Context ctx, Value&& vv) -> tl::expected { + return res(ctx, std::move(vv)); + })); + TRY_EXPECTED(result); + if (*result == Result::Stop) { + finalResult = Result::Stop; + break; + } + } + + res.ensureCall(); + return finalResult; +} + +auto PathAlternativesExpr::ieval(Context ctx, Value&& val, const ResultFn& ores) const -> tl::expected +{ + return ieval(ctx, val, ores); +} + +void PathAlternativesExpr::accept(ExprVisitor& v) const +{ + v.visit(*this); +} + +auto PathAlternativesExpr::numChildren() const -> std::size_t +{ + return alternatives_.size(); +} + +auto PathAlternativesExpr::childAt(std::size_t index) -> ExprPtr& +{ + return detail::childAtOrThrow(index, alternatives_); +} + +auto PathAlternativesExpr::childAt(std::size_t index) const -> const ExprPtr& +{ + return detail::childAtOrThrow(index, alternatives_); +} + +auto PathAlternativesExpr::toString() const -> std::string +{ + auto items = alternatives_ | std::views::transform([](const auto& arg) { + return arg->toString(); + }); + return fmt::format("(paths {})", fmt::join(items, " ")); +} + +UnpackExpr::UnpackExpr(ExprPtr sub) + : sub_(std::move(sub)) {} auto UnpackExpr::type() const -> Type @@ -717,14 +909,28 @@ void UnpackExpr::accept(ExprVisitor& v) const v.visit(*this); } +auto UnpackExpr::numChildren() const -> std::size_t +{ + return 1; +} + +auto UnpackExpr::childAt(std::size_t index) -> ExprPtr& +{ + return detail::childAtOrThrow(index, sub_); +} + +auto UnpackExpr::childAt(std::size_t index) const -> const ExprPtr& +{ + return detail::childAtOrThrow(index, sub_); +} + auto UnpackExpr::toString() const -> std::string { return fmt::format("(... {})", sub_->toString()); } -UnaryWordOpExpr::UnaryWordOpExpr(ExprId id, std::string ident, ExprPtr left) - : Expr(id) - , ident_(std::move(ident)) +UnaryWordOpExpr::UnaryWordOpExpr(std::string ident, ExprPtr left) + : ident_(std::move(ident)) , left_(std::move(left)) {} @@ -756,14 +962,28 @@ void UnaryWordOpExpr::accept(ExprVisitor& v) const v.visit(*this); } +auto UnaryWordOpExpr::numChildren() const -> std::size_t +{ + return 1; +} + +auto UnaryWordOpExpr::childAt(std::size_t index) -> ExprPtr& +{ + return detail::childAtOrThrow(index, left_); +} + +auto UnaryWordOpExpr::childAt(std::size_t index) const -> const ExprPtr& +{ + return detail::childAtOrThrow(index, left_); +} + auto UnaryWordOpExpr::toString() const -> std::string { return fmt::format("({} {})", ident_, left_->toString()); } -BinaryWordOpExpr::BinaryWordOpExpr(ExprId id, std::string ident, ExprPtr left, ExprPtr right) - : Expr(id) - , ident_(std::move(ident)) +BinaryWordOpExpr::BinaryWordOpExpr(std::string ident, ExprPtr left, ExprPtr right) + : ident_(std::move(ident)) , left_(std::move(left)) , right_(std::move(right)) {} @@ -806,14 +1026,28 @@ void BinaryWordOpExpr::accept(ExprVisitor& v) const v.visit(*this); } +auto BinaryWordOpExpr::numChildren() const -> std::size_t +{ + return 2; +} + +auto BinaryWordOpExpr::childAt(std::size_t index) -> ExprPtr& +{ + return detail::childAtOrThrow(index, left_, right_); +} + +auto BinaryWordOpExpr::childAt(std::size_t index) const -> const ExprPtr& +{ + return detail::childAtOrThrow(index, left_, right_); +} + auto BinaryWordOpExpr::toString() const -> std::string { return fmt::format("({} {} {})", ident_, left_->toString(), right_->toString()); } -AndExpr::AndExpr(ExprId id, ExprPtr left, ExprPtr right) - : Expr(id) - , left_(std::move(left)) +AndExpr::AndExpr(ExprPtr left, ExprPtr right) + : left_(std::move(left)) , right_(std::move(right)) { assert(left_.get()); @@ -850,14 +1084,28 @@ void AndExpr::accept(ExprVisitor& v) const v.visit(*this); } +auto AndExpr::numChildren() const -> std::size_t +{ + return 2; +} + +auto AndExpr::childAt(std::size_t index) -> ExprPtr& +{ + return detail::childAtOrThrow(index, left_, right_); +} + +auto AndExpr::childAt(std::size_t index) const -> const ExprPtr& +{ + return detail::childAtOrThrow(index, left_, right_); +} + auto AndExpr::toString() const -> std::string { return fmt::format("(and {} {})", left_->toString(), right_->toString()); } -OrExpr::OrExpr(ExprId id, ExprPtr left, ExprPtr right) - : Expr(id) - , left_(std::move(left)) +OrExpr::OrExpr(ExprPtr left, ExprPtr right) + : left_(std::move(left)) , right_(std::move(right)) { assert(left_.get()); @@ -895,9 +1143,338 @@ void OrExpr::accept(ExprVisitor& v) const v.visit(*this); } +auto OrExpr::numChildren() const -> std::size_t +{ + return 2; +} + +auto OrExpr::childAt(std::size_t index) -> ExprPtr& +{ + return detail::childAtOrThrow(index, left_, right_); +} + +auto OrExpr::childAt(std::size_t index) const -> const ExprPtr& +{ + return detail::childAtOrThrow(index, left_, right_); +} + auto OrExpr::toString() const -> std::string { return fmt::format("(or {} {})", left_->toString(), right_->toString()); } +WildcardFieldExpr::WildcardFieldExpr(bool recurse, std::string name, SourceLocation location) + : Expr(location) + , name_(std::move(name)) + , recurse_(recurse) +{} + +auto WildcardFieldExpr::childSchemaMayHaveField(const Context& ctx, SchemaId schemaId) const -> bool +{ + if (schemaId == NoSchemaId) + return true; + + const auto* childSchema = ctx.env->querySchema(schemaId); + if (!childSchema || !childSchema->finalized()) + return true; + + return childSchema->canHaveField(nameId_); +} + +auto WildcardFieldExpr::buildObjectSchemaPlan(const Context& ctx, const ObjectSchema& schema) const -> SchemaPlan +{ + SchemaPlan plan; + plan.kind = SchemaPlan::Kind::Object; + plan.directField = false; + + for (const auto& field : schema.fields()) { + if (field.field == nameId_) + plan.directField = true; + + const auto descendsToTarget = field.schemas.empty() + || std::ranges::any_of(field.schemas, [this, &ctx](auto schemaId) { + return childSchemaMayHaveField(ctx, schemaId); + }); + if (descendsToTarget) + plan.objectChildFields.push_back(field.field); + } + + std::ranges::sort(plan.objectChildFields); + auto duplicates = std::ranges::unique(plan.objectChildFields); + plan.objectChildFields.erase(duplicates.begin(), duplicates.end()); + + const auto fieldCount = schema.fields().size(); + const auto sparseChildPlan = plan.objectChildFields.size() * 2 < fieldCount; + const auto skipsLargeDirectLookup = !plan.directField && fieldCount > 4; + if (!sparseChildPlan && !skipsLargeDirectLookup) { + plan.kind = SchemaPlan::Kind::Unknown; + plan.directField = true; + plan.objectChildFields.clear(); + } + + return plan; +} + +auto WildcardFieldExpr::buildSchemaPlan(const Context& ctx, const Schema& schema) const -> SchemaPlan +{ + SchemaPlan plan; + plan.canHaveField = schema.canHaveField(nameId_); + if (!plan.canHaveField) { + plan.directField = false; + return plan; + } + + if (schema.kind() == Schema::Kind::Object) { + if (const auto* objectSchema = dynamic_cast(&schema)) + return buildObjectSchemaPlan(ctx, *objectSchema); + return plan; + } + + if (schema.kind() == Schema::Kind::Array) { + if (dynamic_cast(&schema)) + plan.kind = SchemaPlan::Kind::Array; + plan.directField = false; + } + + return plan; +} + +auto WildcardFieldExpr::schemaPlan(const Context& ctx, SchemaId schemaId, const Schema& schema) const -> const SchemaPlan* +{ + if (schemaId == NoSchemaId || !schema.finalized()) + return nullptr; + + const auto planIndex = static_cast(schemaId); + const auto schemaRevision = schema.revision(); + if (planIndex < schemaPlans_.size()) { + const auto& cachedPlan = schemaPlans_[planIndex]; + if (cachedPlan && cachedPlan->schema == &schema && cachedPlan->schemaRevision == schemaRevision) + return &cachedPlan->plan; + } + + if (schemaPlans_.size() <= planIndex) + schemaPlans_.resize(planIndex + 1); + auto plan = buildSchemaPlan(ctx, schema); + schemaPlans_[planIndex] = std::make_unique( + CachedSchemaPlan{schemaId, &schema, schemaRevision, std::move(plan)}); + + return &schemaPlans_[planIndex]->plan; +} + +auto WildcardFieldExpr::type() const -> Type +{ + return Type::PATH; +} + +auto WildcardFieldExpr::ieval(Context ctx, const Value& val, const ResultFn& ores) const -> tl::expected +{ + if (ctx.phase == Context::Phase::Compilation) + return ores(ctx, Value::undef()); + + CountedResultFn res(ores, ctx); + + Diagnostics::FieldExprData* diag = nullptr; + if (ctx.diag) + diag = &ctx.diag->get(*this); + + if (diag) { + diag->location = sourceLocation(); + if (diag->name.empty()) + diag->name = name_; + } + + // Querying a field not in the string-pool + // is a no-op here (not true for FieldExpr). + if (!nameId_) { + nameId_ = ctx.env->strings()->get(name_); + if (!nameId_) { + if (diag) + diag->evaluations++; + res.ensureCall(); + return {Result::Continue}; + } + } + + struct Iterate + { + Context& ctx; + ResultFn& res; + const WildcardFieldExpr& expr; + StringId field; + Diagnostics::FieldExprData* diag; + size_t maxDepth = 0; // 0 = recurse inf. + bool pruneRoot = true; + + struct SchemaDecision { + bool canHaveField = true; + const SchemaPlan* plan = nullptr; + }; + + [[nodiscard]] auto iterate(ModelNode const& val, size_t depth) noexcept -> tl::expected + { + if (maxDepth > 0 && depth > maxDepth) { + return Result::Continue; + } + + if (field == StringPool::StaticStringIds::Empty) + return Result::Continue; + + const auto valType = val.type(); + if (valType == ValueType::Null) [[unlikely]] + return Result::Continue; + + const auto schemaDecision = decideBySchema(val, depth); + if (!schemaDecision.canHaveField) + return Result::Continue; + + if (diag) + diag->evaluations++; + + const auto plan = matchingPlan(schemaDecision.plan, valType); + auto directResult = emitDirectField(val, plan, depth); + TRY_EXPECTED(directResult); + if (*directResult == Result::Stop) + return Result::Stop; + + // Once the requested non-recursive depth has been processed, avoid + // descending just to let the next call reject the child by depth. + if (maxDepth > 0 && depth >= maxDepth) + return Result::Continue; + + if (plan && plan->kind == SchemaPlan::Kind::Object) { + return iterateObjectFields(val, *plan, depth); + } + + return iterateAllChildren(val, depth); + } + + [[nodiscard]] auto decideBySchema(ModelNode const& val, size_t depth) const noexcept -> SchemaDecision + { + if (!(depth > 0 || pruneRoot)) + return {}; + + const auto schemaId = val.schema(); + const auto* schema = ctx.env->querySchema(schemaId); + if (!schema) + return {}; + + if (ctx.env->enableWildcardFieldPlans) { + if (const auto* plan = expr.schemaPlan(ctx, schemaId, *schema)) + return {plan->canHaveField, plan}; + } + + // This is the original schema-pruning path: the current node can be + // rejected, but child iteration is still generic when it may match. + if (schema->canHaveField(field)) + return {}; + + return {false, nullptr}; + } + + [[nodiscard]] static auto matchingPlan(const SchemaPlan* plan, ValueType valType) noexcept -> const SchemaPlan* + { + if (!plan) + return nullptr; + + if (plan->kind == SchemaPlan::Kind::Object && valType == ValueType::Object) + return plan; + + if (plan->kind == SchemaPlan::Kind::Array && valType == ValueType::Array) + return plan; + + return nullptr; + } + + [[nodiscard]] auto emitDirectField(ModelNode const& val, + const SchemaPlan* plan, + size_t depth) noexcept -> tl::expected + { + const auto directFieldPossible = !plan + || (plan->kind == SchemaPlan::Kind::Object && plan->directField); + if (!directFieldPossible || (maxDepth > 0 && depth == 0)) + return Result::Continue; + + auto sub = val.get(field); + if (!sub) + return Result::Continue; + + if (diag) + diag->hits++; + + auto result = res(ctx, Value::field(*sub)); + TRY_EXPECTED(result); + return *result; + } + + [[nodiscard]] auto iterateObjectFields(ModelNode const& val, + const SchemaPlan& plan, + size_t depth) noexcept -> tl::expected + { + if (plan.objectChildFields.empty()) + return Result::Continue; + + // For dense plans the normal iterator is cheaper because it resolves + // each child once and lets that child's schema reject irrelevant paths. + if (plan.objectChildFields.size() * 2 >= val.size()) + return iterateAllChildren(val, depth); + + tl::expected finalResult = Result::Continue; + for (auto i = 0u; i < val.size(); ++i) { + if (!std::ranges::binary_search(plan.objectChildFields, val.keyAt(i))) + continue; + + auto child = val.at(i); + if (!child) + continue; + + auto subResult = iterate(*child, depth + 1); + if (!subResult) + return subResult; + + if (*subResult == Result::Stop) + return Result::Stop; + } + + return finalResult; + } + + [[nodiscard]] auto iterateAllChildren(ModelNode const& val, size_t depth) noexcept -> tl::expected + { + tl::expected finalResult = Result::Continue; + val.iterate(ModelNode::IterLambda([&, this](const auto& subNode) { + auto subResult = iterate(subNode, depth + 1); + if (!subResult) { + finalResult = std::move(subResult); + return false; + } + + if (*subResult == Result::Stop) { + finalResult = Result::Stop; + return false; + } + + return true; + })); + + return finalResult; + } + }; + + auto r = val.nodePtr() + ? Iterate{ctx, res, *this, nameId_, diag, recurse_ ? 0ul : 1ul, recurse_}.iterate(**val.nodePtr(), 0) + : tl::expected(Result::Continue); + res.ensureCall(); + return r; +} + +void WildcardFieldExpr::accept(ExprVisitor& v) const +{ + v.visit(*this); +} + +auto WildcardFieldExpr::toString() const -> std::string +{ + return fmt::format("{}.{}", recurse_ ? "**" : "*", name_); +} + } diff --git a/src/expressions.h b/src/expressions.h index c98257ec..ca50293f 100644 --- a/src/expressions.h +++ b/src/expressions.h @@ -5,17 +5,60 @@ #include "simfil/operator.h" #include "simfil/diagnostics.h" #include "simfil/expression-visitor.h" +#include "simfil/sourcelocation.h" +#include #include +#include +#include #include +#include +#include namespace simfil { +namespace detail +{ + +inline auto childAtOrThrow(std::size_t index, std::vector& children) -> ExprPtr& +{ + return children.at(index); +} + +inline auto childAtOrThrow(std::size_t index, const std::vector& children) -> const ExprPtr& +{ + return children.at(index); +} + +template +auto childAtOrThrow(std::size_t index, Children&... children) -> ExprPtr& +{ + std::array childPtrs{&children...}; + if (index >= childPtrs.size()) + throw std::out_of_range("AST child index out of range"); + return *childPtrs[index]; +} + +template +auto childAtOrThrow(std::size_t index, const Children&... children) -> const ExprPtr& +{ + std::array childPtrs{&children...}; + if (index >= childPtrs.size()) + throw std::out_of_range("AST child index out of range"); + return *childPtrs[index]; +} + +} + +/** + * Returns the current node and every child of it recursively. + */ class WildcardExpr : public Expr { public: - explicit WildcardExpr(ExprId); + WildcardExpr(); + explicit WildcardExpr(SourceLocation location); auto type() const -> Type override; auto ieval(Context ctx, const Value& val, const ResultFn& ores) const -> tl::expected override; @@ -29,7 +72,7 @@ class WildcardExpr : public Expr class AnyChildExpr : public Expr { public: - explicit AnyChildExpr(ExprId); + AnyChildExpr(); auto type() const -> Type override; auto ieval(Context ctx, const Value& val, const ResultFn& res) const -> tl::expected override; @@ -40,8 +83,8 @@ class AnyChildExpr : public Expr class FieldExpr : public Expr { public: - FieldExpr(ExprId id, std::string name); - FieldExpr(ExprId id, std::string name, const Token& token); + explicit FieldExpr(std::string name); + FieldExpr(std::string name, const Token& token); auto type() const -> Type override; auto ieval(Context ctx, const Value& val, const ResultFn& res) const -> tl::expected override; @@ -49,6 +92,10 @@ class FieldExpr : public Expr void accept(ExprVisitor& v) const override; auto toString() const -> std::string override; + /** Returns true if this field is "_" */ + auto isCurrent() const -> bool; + auto field() const -> std::string; + std::string name_; mutable StringId nameId_ = {}; }; @@ -59,8 +106,8 @@ class MultiConstExpr : public Expr static constexpr size_t Limit = 10000; MultiConstExpr() = delete; - MultiConstExpr(ExprId id, const std::vector& vec); - MultiConstExpr(ExprId id, std::vector&& vec); + explicit MultiConstExpr(const std::vector& vec); + explicit MultiConstExpr(std::vector&& vec); auto type() const -> Type override; auto constant() const -> bool override; @@ -75,12 +122,27 @@ class ConstExpr : public Expr { public: ConstExpr() = delete; + template - ConstExpr(ExprId id, CType_&& value) - : Expr(id) + requires (!std::is_base_of_v>) + explicit ConstExpr(CType_&& value) + : value_(Value::make(std::forward(value))) + {} + + template + requires (!std::is_base_of_v>) + ConstExpr(CType_&& value, const Token& token) + : Expr(token) , value_(Value::make(std::forward(value))) {} - ConstExpr(ExprId id, Value value); + + ConstExpr(const ConstExpr&) = delete; + ConstExpr(ConstExpr&&) = delete; + auto operator=(const ConstExpr&) -> ConstExpr& = delete; + auto operator=(ConstExpr&&) -> ConstExpr& = delete; + + explicit ConstExpr(Value value); + ConstExpr(Value value, const Token& token); auto type() const -> Type override; auto constant() const -> bool override; @@ -97,11 +159,14 @@ class ConstExpr : public Expr class SubscriptExpr : public Expr { public: - SubscriptExpr(ExprId id, ExprPtr left, ExprPtr index); + SubscriptExpr(ExprPtr left, ExprPtr index); auto type() const -> Type override; auto ieval(Context ctx, const Value& val, const ResultFn& ores) const -> tl::expected override; void accept(ExprVisitor& v) const override; + auto numChildren() const -> std::size_t override; + auto childAt(std::size_t index) -> ExprPtr& override; + auto childAt(std::size_t index) const -> const ExprPtr& override; auto toString() const -> std::string override; ExprPtr left_; @@ -111,12 +176,15 @@ class SubscriptExpr : public Expr class SubExpr : public Expr { public: - SubExpr(ExprId id, ExprPtr left, ExprPtr sub); + SubExpr(ExprPtr left, ExprPtr sub); auto type() const -> Type override; auto ieval(Context ctx, const Value& val, const ResultFn& ores) const -> tl::expected override; auto ieval(Context ctx, Value&& val, const ResultFn& ores) const -> tl::expected override; void accept(ExprVisitor& v) const override; + auto numChildren() const -> std::size_t override; + auto childAt(std::size_t index) -> ExprPtr& override; + auto childAt(std::size_t index) const -> const ExprPtr& override; auto toString() const -> std::string override; ExprPtr left_, sub_; @@ -125,11 +193,14 @@ class SubExpr : public Expr class AnyExpr : public Expr { public: - AnyExpr(ExprId id, std::vector args); + explicit AnyExpr(std::vector args); auto type() const -> Type override; auto ieval(Context ctx, const Value& val, const ResultFn& res) const -> tl::expected override; void accept(ExprVisitor& v) const override; + auto numChildren() const -> std::size_t override; + auto childAt(std::size_t index) -> ExprPtr& override; + auto childAt(std::size_t index) const -> const ExprPtr& override; auto toString() const -> std::string override; std::vector args_; @@ -138,11 +209,14 @@ class AnyExpr : public Expr class EachExpr : public Expr { public: - EachExpr(ExprId id, std::vector args); + explicit EachExpr(std::vector args); auto type() const -> Type override; auto ieval(Context ctx, const Value& val, const ResultFn& res) const -> tl::expected override; void accept(ExprVisitor& v) const override; + auto numChildren() const -> std::size_t override; + auto childAt(std::size_t index) -> ExprPtr& override; + auto childAt(std::size_t index) const -> const ExprPtr& override; auto toString() const -> std::string override; std::vector args_; @@ -151,12 +225,15 @@ class EachExpr : public Expr class CallExpression : public Expr { public: - CallExpression(ExprId id, std::string name, std::vector args); + CallExpression(std::string name, std::vector args); auto type() const -> Type override; auto ieval(Context ctx, const Value& val, const ResultFn& res) const -> tl::expected override; auto ieval(Context ctx, Value&& val, const ResultFn& res) const -> tl::expected override; void accept(ExprVisitor& v) const override; + auto numChildren() const -> std::size_t override; + auto childAt(std::size_t index) -> ExprPtr& override; + auto childAt(std::size_t index) const -> const ExprPtr& override; auto toString() const -> std::string override; std::string name_; @@ -167,17 +244,44 @@ class CallExpression : public Expr class PathExpr : public Expr { public: - PathExpr(ExprId id, ExprPtr left, ExprPtr right); + PathExpr(ExprPtr left, ExprPtr right); + PathExpr(ExprPtr left, ExprPtr right, SourceLocation location); auto type() const -> Type override; auto ieval(Context ctx, const Value& val, const ResultFn& ores) const -> tl::expected override; auto ieval(Context ctx, Value&& val, const ResultFn& ores) const -> tl::expected override; void accept(ExprVisitor& v) const override; + auto numChildren() const -> std::size_t override; + auto childAt(std::size_t index) -> ExprPtr& override; + auto childAt(std::size_t index) const -> const ExprPtr& override; auto toString() const -> std::string override; + auto left() -> Expr*; + auto left() const -> const Expr*; + auto right() -> Expr*; + auto right() const -> const Expr*; + ExprPtr left_, right_; }; +/** Evaluates a set of exact path alternatives and forwards every produced value. */ +class PathAlternativesExpr : public Expr +{ +public: + explicit PathAlternativesExpr(std::vector alternatives, SourceLocation location = {}); + + auto type() const -> Type override; + auto ieval(Context ctx, const Value& val, const ResultFn& ores) const -> tl::expected override; + auto ieval(Context ctx, Value&& val, const ResultFn& ores) const -> tl::expected override; + void accept(ExprVisitor& v) const override; + auto numChildren() const -> std::size_t override; + auto childAt(std::size_t index) -> ExprPtr& override; + auto childAt(std::size_t index) const -> const ExprPtr& override; + auto toString() const -> std::string override; + + std::vector alternatives_; +}; + /** Calls `unpack` onto values of type Object. Forwards the value(s) otherwise. * * 1... => 1 @@ -186,11 +290,14 @@ class PathExpr : public Expr class UnpackExpr : public Expr { public: - UnpackExpr(ExprId id, ExprPtr sub); + explicit UnpackExpr(ExprPtr sub); auto type() const -> Type override; auto ieval(Context ctx, const Value& val, const ResultFn& res) const -> tl::expected override; void accept(ExprVisitor& v) const override; + auto numChildren() const -> std::size_t override; + auto childAt(std::size_t index) -> ExprPtr& override; + auto childAt(std::size_t index) const -> const ExprPtr& override; auto toString() const -> std::string override; ExprPtr sub_; @@ -203,9 +310,8 @@ template class UnaryExpr : public Expr { public: - UnaryExpr(ExprId id, ExprPtr sub) - : Expr(id) - , sub_(std::move(sub)) + explicit UnaryExpr(ExprPtr sub) + : sub_(std::move(sub)) {} auto type() const -> Type override @@ -223,10 +329,24 @@ class UnaryExpr : public Expr })); } - void accept(ExprVisitor& v) const override + auto accept(ExprVisitor& v) const -> void override { v.visit(*this); - sub_->accept(v); + } + + auto numChildren() const -> std::size_t override + { + return 1; + } + + auto childAt(std::size_t index) -> ExprPtr& override + { + return detail::childAtOrThrow(index, sub_); + } + + auto childAt(std::size_t index) const -> const ExprPtr& override + { + return detail::childAtOrThrow(index, sub_); } auto toString() const -> std::string override @@ -244,14 +364,13 @@ template class BinaryExpr : public Expr { public: - BinaryExpr(ExprId id, ExprPtr left, ExprPtr right) - : Expr(id) - , left_(std::move(left)) + BinaryExpr(ExprPtr left, ExprPtr right) + : left_(std::move(left)) , right_(std::move(right)) {} - BinaryExpr(ExprId id, const Token& token, ExprPtr left, ExprPtr right) - : Expr(id, token) + BinaryExpr(const Token& token, ExprPtr left, ExprPtr right) + : Expr(token) , left_(std::move(left)) , right_(std::move(right)) {} @@ -273,11 +392,24 @@ class BinaryExpr : public Expr })); } - void accept(ExprVisitor& v) const override + auto accept(ExprVisitor& v) const -> void override { v.visit(*this); - left_->accept(v); - right_->accept(v); + } + + auto numChildren() const -> std::size_t override + { + return 2; + } + + auto childAt(std::size_t index) -> ExprPtr& override + { + return detail::childAtOrThrow(index, left_, right_); + } + + auto childAt(std::size_t index) const -> const ExprPtr& override + { + return detail::childAtOrThrow(index, left_, right_); } auto toString() const -> std::string override @@ -292,14 +424,13 @@ class ComparisonExprBase : public Expr { public: - ComparisonExprBase(ExprId id, ExprPtr left, ExprPtr right) - : Expr(id) - , left_(std::move(left)) + ComparisonExprBase(ExprPtr left, ExprPtr right) + : left_(std::move(left)) , right_(std::move(right)) {} - ComparisonExprBase(ExprId id, const Token& token, ExprPtr left, ExprPtr right) - : Expr(id, token) + ComparisonExprBase(const Token& token, ExprPtr left, ExprPtr right) + : Expr(token) , left_(std::move(left)) , right_(std::move(right)) {} @@ -309,6 +440,21 @@ class ComparisonExprBase : public Expr return Type::VALUE; } + auto numChildren() const -> std::size_t override + { + return 2; + } + + auto childAt(std::size_t index) -> ExprPtr& override + { + return detail::childAtOrThrow(index, left_, right_); + } + + auto childAt(std::size_t index) const -> const ExprPtr& override + { + return detail::childAtOrThrow(index, left_, right_); + } + ExprPtr left_, right_; }; @@ -352,11 +498,9 @@ class ComparisonExpr : public ComparisonExprBase })); } - void accept(ExprVisitor& v) const override + auto accept(ExprVisitor& v) const -> void override { v.visit(static_cast(*this)); - left_->accept(v); - right_->accept(v); } auto toString() const -> std::string override @@ -404,11 +548,14 @@ class BinaryExpr : public ComparisonExpr Type override; auto ieval(Context ctx, const Value& val, const ResultFn& res) const -> tl::expected override; void accept(ExprVisitor& v) const override; + auto numChildren() const -> std::size_t override; + auto childAt(std::size_t index) -> ExprPtr& override; + auto childAt(std::size_t index) const -> const ExprPtr& override; auto toString() const -> std::string override; std::string ident_; @@ -418,11 +565,14 @@ class UnaryWordOpExpr : public Expr class BinaryWordOpExpr : public Expr { public: - BinaryWordOpExpr(ExprId id, std::string ident, ExprPtr left, ExprPtr right); + BinaryWordOpExpr(std::string ident, ExprPtr left, ExprPtr right); auto type() const -> Type override; auto ieval(Context ctx, const Value& val, const ResultFn& res) const -> tl::expected override; void accept(ExprVisitor& v) const override; + auto numChildren() const -> std::size_t override; + auto childAt(std::size_t index) -> ExprPtr& override; + auto childAt(std::size_t index) const -> const ExprPtr& override; auto toString() const -> std::string override; std::string ident_; @@ -432,11 +582,14 @@ class BinaryWordOpExpr : public Expr class AndExpr : public Expr { public: - AndExpr(ExprId id, ExprPtr left, ExprPtr right); + AndExpr(ExprPtr left, ExprPtr right); auto type() const -> Type override; auto ieval(Context ctx, const Value& val, const ResultFn& res) const -> tl::expected override; void accept(ExprVisitor& v) const override; + auto numChildren() const -> std::size_t override; + auto childAt(std::size_t index) -> ExprPtr& override; + auto childAt(std::size_t index) const -> const ExprPtr& override; auto toString() const -> std::string override; ExprPtr left_, right_; @@ -445,14 +598,66 @@ class AndExpr : public Expr class OrExpr : public Expr { public: - OrExpr(ExprId id, ExprPtr left, ExprPtr right); + OrExpr(ExprPtr left, ExprPtr right); auto type() const -> Type override; auto ieval(Context ctx, const Value& val, const ResultFn& res) const -> tl::expected override; void accept(ExprVisitor& v) const override; + auto numChildren() const -> std::size_t override; + auto childAt(std::size_t index) -> ExprPtr& override; + auto childAt(std::size_t index) const -> const ExprPtr& override; auto toString() const -> std::string override; ExprPtr left_, right_; }; +/** + * A specialized expression for queries of the form `**.field`, that + * takes object schema information into account. + * + * Special form of `WildcardExpr`. + */ +class WildcardFieldExpr : public Expr +{ +public: + explicit WildcardFieldExpr(bool recurse, std::string name, SourceLocation location = {}); + + auto type() const -> Type override; + auto ieval(Context ctx, const Value& val, const ResultFn& res) const -> tl::expected override; + void accept(ExprVisitor& v) const override; + auto toString() const -> std::string override; + + std::string name_; + mutable StringId nameId_ = {}; + const bool recurse_ = {}; + +private: + struct SchemaPlan { + enum class Kind { + Unknown, + Object, + Array, + }; + + Kind kind = Kind::Unknown; + bool canHaveField = true; + bool directField = true; + std::vector objectChildFields; + }; + + struct CachedSchemaPlan { + SchemaId schemaId = NoSchemaId; + const Schema* schema = nullptr; + std::uint64_t schemaRevision = 0; + SchemaPlan plan; + }; + + auto schemaPlan(const Context& ctx, SchemaId schemaId, const Schema& schema) const -> const SchemaPlan*; + auto buildSchemaPlan(const Context& ctx, const Schema& schema) const -> SchemaPlan; + auto buildObjectSchemaPlan(const Context& ctx, const ObjectSchema& schema) const -> SchemaPlan; + auto childSchemaMayHaveField(const Context& ctx, SchemaId schemaId) const -> bool; + + mutable std::vector> schemaPlans_; +}; + } diff --git a/src/model/json.cpp b/src/model/json.cpp index 7a0343ef..0686e520 100644 --- a/src/model/json.cpp +++ b/src/model/json.cpp @@ -88,8 +88,8 @@ auto buildModelNode(const json& input, ModelPool& model) -> tl::expected()); case json::value_t::string: { - // JSON strings are expected to participate in simfil's pooled-string - // facilities such as completion of uppercase constants. + // JSON strings participate in pooled-string facilities such as + // string-literal completion. auto stringId = model.strings()->emplace(input.get()); if (!stringId) { return tl::unexpected(stringId.error()); diff --git a/src/model/model.cpp b/src/model/model.cpp index 6b1ad9a1..d21ab1f1 100644 --- a/src/model/model.cpp +++ b/src/model/model.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include @@ -15,11 +16,13 @@ #include #include #include +#include #include #include #include #include "../expected.h" +#include "simfil/model/schema.h" namespace simfil { @@ -86,6 +89,10 @@ struct ModelPool::Impl ModelColumn strings_; ModelColumn byteArrays_; + ModelColumn objectSchemas_; + ModelColumn objectSingletonSchemas_; + ModelColumn arraySchemas_; + ModelColumn arraySingletonSchemas_; Object::Storage objectMemberArrays_; Array::Storage arrayMemberArrays_; } columns_; @@ -99,6 +106,10 @@ struct ModelPool::Impl s.text1b(columns_.stringData_, maxColumnSize); s.object(columns_.strings_); s.object(columns_.byteArrays_); + s.object(columns_.objectSchemas_); + s.object(columns_.objectSingletonSchemas_); + s.object(columns_.arraySchemas_); + s.object(columns_.arraySingletonSchemas_); s.ext(columns_.objectMemberArrays_, bitsery::ext::ArrayArenaExt{}); s.ext(columns_.arrayMemberArrays_, bitsery::ext::ArrayArenaExt{}); @@ -119,6 +130,20 @@ ModelPool::~ModelPool() // NOLINT std::vector ModelPool::checkForErrors() const { std::vector errors; + auto objectHasSchema = [&](ArrayIndex members) { + if (Object::Storage::is_singleton_handle(members)) { + return Object::Storage::singleton_payload(members) + < impl_->columns_.objectSingletonSchemas_.size(); + } + return members < impl_->columns_.objectSchemas_.size(); + }; + auto arrayHasSchema = [&](ArrayIndex members) { + if (Array::Storage::is_singleton_handle(members)) { + return Array::Storage::singleton_payload(members) + < impl_->columns_.arraySingletonSchemas_.size(); + } + return members < impl_->columns_.arraySchemas_.size(); + }; auto validateArrayIndex = [&](auto i, auto arrType, auto const& arena) { if (!arena.valid(static_cast(i))) { @@ -142,6 +167,10 @@ std::vector ModelPool::checkForErrors() const if (node->addr().column() == Objects) if (!validateArrayIndex(node->addr().index(), "object", impl_->columns_.objectMemberArrays_)) return; + if (!objectHasSchema(node->addr().index())) { + errors.emplace_back(fmt::format("Missing object schema index {}.", node->addr().index())); + return; + } for (auto const& [fieldName, fieldValue] : node->fields()) { validatePooledString(fieldName); validateModelNode(fieldValue); @@ -151,6 +180,10 @@ std::vector ModelPool::checkForErrors() const if (node->addr().column() == Arrays) if (!validateArrayIndex(node->addr().index(), "arrays", impl_->columns_.arrayMemberArrays_)) return; + if (!arrayHasSchema(node->addr().index())) { + errors.emplace_back(fmt::format("Missing array schema index {}.", node->addr().index())); + return; + } for (auto const& member : *node) validateModelNode(member); } @@ -163,16 +196,28 @@ std::vector ModelPool::checkForErrors() const }; // Validate objects - for (auto i = 0; i < impl_->columns_.objectMemberArrays_.size(); ++i) + for (auto i = FirstRegularArrayIndex; i < impl_->columns_.objectMemberArrays_.size(); ++i) validateModelNode(ModelNode::Ptr::make( shared_from_this(), ModelNodeAddress{Objects, (uint32_t)i})); + for (auto i = 0u; i < impl_->columns_.objectMemberArrays_.singleton_handle_count(); ++i) + validateModelNode(ModelNode::Ptr::make( + shared_from_this(), + ModelNodeAddress{ + Objects, + SingletonArrayHandleMask | static_cast(i)})); // Validate arrays - for (auto i = 0; i < impl_->columns_.arrayMemberArrays_.size(); ++i) + for (auto i = FirstRegularArrayIndex; i < impl_->columns_.arrayMemberArrays_.size(); ++i) validateModelNode(ModelNode::Ptr::make( shared_from_this(), ModelNodeAddress{Arrays, (uint32_t)i})); + for (auto i = 0u; i < impl_->columns_.arrayMemberArrays_.singleton_handle_count(); ++i) + validateModelNode(ModelNode::Ptr::make( + shared_from_this(), + ModelNodeAddress{ + Arrays, + SingletonArrayHandleMask | static_cast(i)})); // Validate roots for (auto i = 0; i < numRoots(); ++i) @@ -205,12 +250,22 @@ void ModelPool::clear() clear_and_shrink(columns.strings_); clear_and_shrink(columns.stringData_); clear_and_shrink(columns.byteArrays_); + clear_and_shrink(columns.objectSchemas_); + clear_and_shrink(columns.objectSingletonSchemas_); + clear_and_shrink(columns.arraySchemas_); + clear_and_shrink(columns.arraySingletonSchemas_); clear_and_shrink(columns.objectMemberArrays_); clear_and_shrink(columns.arrayMemberArrays_); } tl::expected ModelPool::resolve(ModelNode const& n, ResolveFn const& cb) const { + // Merged/container views can surface child nodes from another model. Always + // let the owning model interpret its own column/index address. + if (auto owner = n.owningModel(); owner && owner.get() != this) { + return owner->resolve(n, cb); + } + auto checkBounds = [&n](auto const& vec) -> std::optional { auto idx = n.addr_.index(); if (idx >= vec.size()) @@ -293,12 +348,32 @@ void ModelPool::addRoot(ModelNode::Ptr const& rootNode) { model_ptr ModelPool::newObject(size_t initialFieldCapacity, bool fixedSize) { auto memberArrId = impl_->columns_.objectMemberArrays_.new_array(initialFieldCapacity, fixedSize); + if (Object::Storage::is_singleton_handle(memberArrId)) { + auto singletonIndex = Object::Storage::singleton_payload(memberArrId); + if (impl_->columns_.objectSingletonSchemas_.size() <= singletonIndex) + impl_->columns_.objectSingletonSchemas_.resize(singletonIndex + 1); + impl_->columns_.objectSingletonSchemas_[singletonIndex] = SchemaId{}; + } else { + if (impl_->columns_.objectSchemas_.size() <= memberArrId) + impl_->columns_.objectSchemas_.resize(memberArrId + 1); + impl_->columns_.objectSchemas_[memberArrId] = SchemaId{}; + } return model_ptr::make(shared_from_this(), ModelNodeAddress{Objects, (uint32_t)memberArrId}); } model_ptr ModelPool::newArray(size_t initialFieldCapacity, bool fixedSize) { auto memberArrId = impl_->columns_.arrayMemberArrays_.new_array(initialFieldCapacity, fixedSize); + if (Array::Storage::is_singleton_handle(memberArrId)) { + auto singletonIndex = Array::Storage::singleton_payload(memberArrId); + if (impl_->columns_.arraySingletonSchemas_.size() <= singletonIndex) + impl_->columns_.arraySingletonSchemas_.resize(singletonIndex + 1); + impl_->columns_.arraySingletonSchemas_[singletonIndex] = SchemaId{}; + } else { + if (impl_->columns_.arraySchemas_.size() <= memberArrId) + impl_->columns_.arraySchemas_.resize(memberArrId + 1); + impl_->columns_.arraySchemas_[memberArrId] = SchemaId{}; + } return model_ptr::make(shared_from_this(), ModelNodeAddress{Arrays, (uint32_t)memberArrId}); } @@ -434,7 +509,11 @@ ModelPool::SerializationSizeStats ModelPool::serializationSizeStats() const stats.stringRangeBytes = impl_->columns_.strings_.byte_size(); stats.stringRangeBytes += impl_->columns_.byteArrays_.byte_size(); stats.objectMemberBytes = impl_->columns_.objectMemberArrays_.byte_size(); + stats.objectSchemaBytes = impl_->columns_.objectSchemas_.byte_size() + + impl_->columns_.objectSingletonSchemas_.byte_size(); stats.arrayMemberBytes = impl_->columns_.arrayMemberArrays_.byte_size(); + stats.arraySchemaBytes = impl_->columns_.arraySchemas_.byte_size() + + impl_->columns_.arraySingletonSchemas_.byte_size(); return stats; } @@ -447,6 +526,70 @@ Object::Storage& ModelPool::objectMemberStorage() { return impl_->columns_.objectMemberArrays_; } +SchemaId ModelPool::objectSchemaId(ArrayIndex members) const +{ + if (Object::Storage::is_singleton_handle(members)) { + auto singletonIndex = Object::Storage::singleton_payload(members); + if (singletonIndex >= impl_->columns_.objectSingletonSchemas_.size()) + return {}; + return SchemaId{impl_->columns_.objectSingletonSchemas_[singletonIndex]}; + } + if (members >= impl_->columns_.objectSchemas_.size()) + return {}; + return SchemaId{impl_->columns_.objectSchemas_[members]}; +} + +auto ModelPool::setObjectSchemaId(ArrayIndex members, SchemaId schemaId) -> tl::expected +{ + if (!impl_->columns_.objectMemberArrays_.valid(members)) + return tl::unexpected(Error::RuntimeError, "Object schema index out of range."); + + if (Object::Storage::is_singleton_handle(members)) { + auto singletonIndex = Object::Storage::singleton_payload(members); + if (impl_->columns_.objectSingletonSchemas_.size() <= singletonIndex) + impl_->columns_.objectSingletonSchemas_.resize(singletonIndex + 1); + impl_->columns_.objectSingletonSchemas_[singletonIndex] = schemaId; + return {}; + } + + if (impl_->columns_.objectSchemas_.size() <= members) + impl_->columns_.objectSchemas_.resize(members + 1); + impl_->columns_.objectSchemas_[members] = schemaId; + return {}; +} + +SchemaId ModelPool::arraySchemaId(ArrayIndex members) const +{ + if (Array::Storage::is_singleton_handle(members)) { + auto singletonIndex = Array::Storage::singleton_payload(members); + if (singletonIndex >= impl_->columns_.arraySingletonSchemas_.size()) + return {}; + return SchemaId{impl_->columns_.arraySingletonSchemas_[singletonIndex]}; + } + if (members >= impl_->columns_.arraySchemas_.size()) + return {}; + return SchemaId{impl_->columns_.arraySchemas_[members]}; +} + +auto ModelPool::setArraySchemaId(ArrayIndex members, SchemaId schemaId) -> tl::expected +{ + if (!impl_->columns_.arrayMemberArrays_.valid(members)) + return tl::unexpected(Error::RuntimeError, "Array schema index out of range."); + + if (Array::Storage::is_singleton_handle(members)) { + auto singletonIndex = Array::Storage::singleton_payload(members); + if (impl_->columns_.arraySingletonSchemas_.size() <= singletonIndex) + impl_->columns_.arraySingletonSchemas_.resize(singletonIndex + 1); + impl_->columns_.arraySingletonSchemas_[singletonIndex] = schemaId; + return {}; + } + + if (impl_->columns_.arraySchemas_.size() <= members) + impl_->columns_.arraySchemas_.resize(members + 1); + impl_->columns_.arraySchemas_[members] = schemaId; + return {}; +} + Object::Storage const& ModelPool::objectMemberStorage() const { return impl_->columns_.objectMemberArrays_; @@ -480,6 +623,7 @@ tl::expected ModelPool::read(const std::vector& input, con "Failed to read ModelPool: Error {}", static_cast>(s.adapter().error()))); } + return {}; } diff --git a/src/model/nodes.cpp b/src/model/nodes.cpp index 9274ea62..d15b342c 100644 --- a/src/model/nodes.cpp +++ b/src/model/nodes.cpp @@ -63,6 +63,13 @@ StringId ModelNode::keyAt(int64_t i) const { return result; } +SchemaId ModelNode::schema() const { + SchemaId result = NoSchemaId; + if (model_) + model_->resolve(*this, Model::Lambda([&](auto&& resolved) { result = resolved.schema(); })); + return result; +} + /// Get the number of children uint32_t ModelNode::size() const { uint32_t result = 0; @@ -186,6 +193,11 @@ StringId ModelNodeBase::keyAt(int64_t) const return 0; } +SchemaId ModelNodeBase::schema() const +{ + return NoSchemaId; +} + uint32_t ModelNodeBase::size() const { return 0; diff --git a/src/model/string-pool.cpp b/src/model/string-pool.cpp index e50354dd..ccce9e6c 100644 --- a/src/model/string-pool.cpp +++ b/src/model/string-pool.cpp @@ -53,37 +53,42 @@ StringPool::StringPool() StringPool::StringPool(const StringPool& other) { - std::unique_lock lockThis(stringStoreMutex_, std::defer_lock); - std::shared_lock lockOther(other.stringStoreMutex_, std::defer_lock); - std::lock(lockThis, lockOther); + // `this` is not observable while its copy constructor runs, so only the + // source pool needs synchronization. Locking both rwlocks through + // std::lock trips Helgrind's rwlock bookkeeping on some CI runners. + std::shared_lock lockOther(other.stringStoreMutex_); // Copy storedStrings_. storedStrings_ = other.storedStrings_; - // Map from old string data pointer to new string_view. - std::unordered_map strDataToNewStrView; + // Map every string from the source pool to the equivalent view owned by this copy. + std::unordered_map copiedViewForSourceView; - // Build the mapping from old string data pointers to new string_views. for (size_t i = 0; i < other.storedStrings_.size(); ++i) { - strDataToNewStrView[other.storedStrings_[i].data()] = storedStrings_[i]; + copiedViewForSourceView.emplace( + std::string_view(other.storedStrings_[i]), + std::string_view(storedStrings_[i])); } - // Rebuild idForString_ with new string_views pointing into this->storedStrings_. + auto copiedViewFor = [&](std::string_view oldStrView) -> std::string_view { + auto it = copiedViewForSourceView.find(oldStrView); + if (it != copiedViewForSourceView.end()) + return it->second; + + // This should not happen if the source pool only references its own storage. + raise("Failed to rebuild StringPool copy: unresolved stored string view"); + }; + + // Rebuild both lookup maps with string_views pointing into this->storedStrings_. idForString_.clear(); for (const auto& [oldStrView, id] : other.idForString_) { - // Get the new string_view corresponding to the old string data pointer. - auto it = strDataToNewStrView.find(oldStrView.data()); - if (it != strDataToNewStrView.end()) { - idForString_.emplace(it->second, id); - } - else { - // This should not happen if everything is consistent. - raise("Failed to rebuild idForString_ in StringPool copy constructor"); - } + idForString_.emplace(copiedViewFor(oldStrView), id); } - // Copy stringForId_. - stringForId_ = other.stringForId_; + stringForId_.clear(); + for (const auto& [id, oldStrView] : other.stringForId_) { + stringForId_.emplace(id, copiedViewFor(oldStrView)); + } // Copy other member variables. nextId_ = other.nextId_; diff --git a/src/parser.cpp b/src/parser.cpp index ca7be97d..1211cd93 100644 --- a/src/parser.cpp +++ b/src/parser.cpp @@ -113,11 +113,6 @@ auto Parser::relaxed() const -> bool return mode_ == Mode::Relaxed; } -auto Parser::nextId() -> Expr::ExprId -{ - return ctx.id++; -} - auto Parser::parseInfix(expected left, int prec) -> expected { TRY_EXPECTED(left); @@ -127,7 +122,7 @@ auto Parser::parseInfix(expected left, int prec) -> expected(nextId()); + return std::make_unique(); return unexpected( Error::ParserError, @@ -176,7 +171,7 @@ auto Parser::parseTo(Token::Type type) -> expected if (!*expr) { if (relaxed()) - return std::make_unique(nextId()); + return std::make_unique(); return unexpected( Error::ParserError, diff --git a/src/rewrite-rules.h b/src/rewrite-rules.h new file mode 100644 index 00000000..e34b9541 --- /dev/null +++ b/src/rewrite-rules.h @@ -0,0 +1,107 @@ +#pragma once + +#include "simfil/value.h" +#include "simfil/expression.h" + +#include "expressions.h" + +#include +#include + +namespace simfil +{ + +using RewriteRule = std::function; + +/** + * Apply a list of rewrite-rules top-down to an expression (sub-)tree. + */ +inline auto rewriteTopDown(ExprPtr expr, std::span rules, const RewriteRule* sourceRule = nullptr) -> ExprPtr +{ + for (const auto& rule : rules) { + // Prevent rule self-recursion. + if (&rule == sourceRule) + continue; + + auto rewrite = rule(expr); + if (rewrite && rewrite.get() != expr.get()) { + return rewriteTopDown(std::move(rewrite), rules, &rule); // NOLINT + } + } + + const auto count = expr->numChildren(); + for (auto i = 0; i < count; ++i) { + auto& child = expr->childAt(i); + child = rewriteTopDown(std::move(child), rules, nullptr); + } + + return std::move(expr); +} + +/** Rewrite `PathExpr(AnyChildExpr, FieldExpr)` -> `WildcardFieldExpr(non-recursive)` */ +inline auto rewriteAnyChildField(ExprPtr& expr) -> ExprPtr +{ + if (const auto* path = dynamic_cast(expr.get())) { + const auto* lhs = dynamic_cast(path->left()); + const auto* rhs = dynamic_cast(path->right()); + if (lhs && rhs && !rhs->isCurrent()) { + return std::make_unique(false, rhs->field(), rhs->sourceLocation()); + } + } + + return nullptr; +} + +/** Rewrite `PathExpr(WildcardExpr, _) | PathExpr(_, WildcardExpr)` -> `WildcardExpr` */ +inline auto rewriteWildcardThis(ExprPtr& expr) -> ExprPtr +{ + auto rewrite = [](const Expr* left, const Expr* right) -> std::unique_ptr { + const auto* lhs = dynamic_cast(left); + if (const auto* rhs = dynamic_cast(right); lhs && rhs && rhs->isCurrent()) { + return std::make_unique(lhs->sourceLocation()); + } + return nullptr; + }; + + if (const auto* path = dynamic_cast(expr.get())) { + if (auto replacement = rewrite(path->left(), path->right())) + return std::move(replacement); + if (auto replacement = rewrite(path->right(), path->left())) + return std::move(replacement); + } + + return nullptr; +} + +/** Rewrite `PathExpr(PathExpr(?, WildcardExpr), FieldExpr)` -> `PathExpr(?, WildcardFieldExpr(field))` */ +inline auto rewriteAnyWildcardField(ExprPtr& expr) -> ExprPtr +{ + if (auto* path = dynamic_cast(expr.get())) { + auto* lhs = dynamic_cast(path->left()); + const auto* rhs = dynamic_cast(path->right()); + if (lhs && rhs) { + const auto* lhsRhs = dynamic_cast(lhs->right()); + if (lhsRhs) { + return std::make_unique(std::move(lhs->left_), + std::make_unique(true, rhs->field(), rhs->sourceLocation())); + } + } + } + return nullptr; +} + +/** Rewrite `PathExpr(WildcardExpr, FieldExpr)` -> `WildcardFieldExpr(field)` */ +inline auto rewriteWildcardField(ExprPtr& expr) -> ExprPtr +{ + if (auto* path = dynamic_cast(expr.get())) { + const auto* lhs = dynamic_cast(path->left()); + const auto* rhs = dynamic_cast(path->right()); + if (lhs && rhs && !rhs->isCurrent()) { + return std::make_unique(true, rhs->field(), rhs->sourceLocation()); + } + } + + return nullptr; +} + +} diff --git a/src/simfil.cpp b/src/simfil.cpp index ae6f6738..bbc7ebae 100644 --- a/src/simfil.cpp +++ b/src/simfil.cpp @@ -15,11 +15,13 @@ #include "fmt/core.h" #include "expressions.h" -#include "expression-patterns.h" #include "completion.h" #include "expected.h" +#include "expression-patterns.h" +#include "rewrite-rules.h" #include +#include #include #include #include @@ -47,6 +49,16 @@ static constexpr std::string_view TypenameString("string"); static constexpr std::string_view TypenameBytes("bytes"); } +static const std::array bottomUpRewriteRules = { + rewriteWildcardThis, + rewriteWildcardField, +}; + +static const std::array topDownRewriteRules = { + rewriteAnyWildcardField, + rewriteAnyChildField, +}; + /** * Parser precedence groups. */ @@ -67,19 +79,630 @@ enum Precedence { }; /** - * Returns if a word should be parsed as a symbol (string). - * This is true for all UPPER_CASE words. + * Extract the user-facing string from a single field or string-literal query. + */ +static auto schemaLookupName(const Expr& expr) -> std::optional +{ + if (auto const* field = dynamic_cast(&expr)) { + return field->field(); + } + + if (auto const* constant = dynamic_cast(&expr)) { + auto const& value = constant->value(); + if (value.isa(ValueType::String)) { + return value.as(); + } + } + + return std::nullopt; +} + +/** + * Return names eligible for operand rewrites. Quoted string literals stay + * values; unquoted words are parsed as fields and may be reinterpreted by + * schema metadata below. + */ +static auto schemaOperandShorthandName(const Expr& expr, std::string_view query) -> std::optional +{ + if (auto const* field = dynamic_cast(&expr)) { + return field->field(); + } + + if (auto const* constant = dynamic_cast(&expr)) { + auto const loc = constant->sourceLocation(); + if (loc.size == 0 || loc.offset + loc.size > query.size()) { + return std::nullopt; + } + if (loc.offset < query.size() && (query[loc.offset] == '"' || query[loc.offset] == '\'')) { + return std::nullopt; + } + auto const& value = constant->value(); + if (value.isa(ValueType::String)) { + return value.as(); + } + } + + return std::nullopt; +} + +/** + * Convert a schema path to a SIMFIL path expression. + */ +static auto pathExpressionFromSchemaPath(Environment& env, const SchemaPath& path, SourceLocation location) -> expected +{ + ExprPtr expr = std::make_unique("_"); + for (auto const& segment : path) { + ExprPtr next; + switch (segment.kind) { + case SchemaPathSegment::Kind::Field: { + auto fieldName = env.strings()->resolve(segment.field); + if (!fieldName) { + return unexpected(Error::ParserError, "Schema path contains an unknown field string id."); + } + next = std::make_unique(std::string(*fieldName)); + break; + } + case SchemaPathSegment::Kind::ArrayElement: + next = std::make_unique(); + break; + } + expr = std::make_unique(std::move(expr), std::move(next), location); + } + return expr; +} + +/** + * Build `exact.path == enumValue` expressions for all schema-derived paths. + */ +static auto enumPathExpression( + Environment& env, + std::vector const& paths, + std::string enumValue, + SourceLocation location) -> expected +{ + ExprPtr result; + for (auto const& path : paths) { + auto lhs = pathExpressionFromSchemaPath(env, path, location); + TRY_EXPECTED(lhs); + + auto comparison = std::make_unique>( + std::move(*lhs), + std::make_unique(Value::make(std::string(enumValue)))); + + if (!result) + result = std::move(comparison); + else + result = std::make_unique(std::move(result), std::move(comparison)); + } + return result; +} + +static auto schemaQuery(Environment& env) -> std::function +{ + return [&env](SchemaId schemaId) -> const Schema* { + return env.querySchema(schemaId); + }; +} + +static auto stringIdForSchemaLookup(Environment& env, std::string_view name) -> std::optional +{ + if (auto existing = env.strings()->get(name); existing != StringPool::Empty) { + return existing; + } + + auto inserted = env.strings()->emplace(name); + if (!inserted) { + return std::nullopt; + } + return *inserted; +} + +static auto expressionForSingleSchemaPath( + Environment& env, + SchemaPath const& path, + SourceLocation location) -> expected +{ + return pathExpressionFromSchemaPath(env, path, location); +} + +static auto expressionForSchemaPathAlternatives( + Environment& env, + std::vector const& paths, + SourceLocation location, + std::string_view) -> expected +{ + if (paths.empty()) { + return nullptr; + } + + if (paths.size() == 1) { + return expressionForSingleSchemaPath(env, paths.front(), location); + } + + std::vector alternatives; + alternatives.reserve(paths.size()); + for (auto const& path : paths) { + auto alternative = expressionForSingleSchemaPath(env, path, location); + TRY_EXPECTED(alternative); + if (*alternative) { + alternatives.push_back(std::move(*alternative)); + } + } + + if (alternatives.empty()) { + return nullptr; + } + return std::make_unique(std::move(alternatives), location); +} + +/** + * Rewrite a single field/enum query by using schema metadata as source of truth. + */ +static auto rewriteStandaloneNameBySchema(Environment& env, ExprPtr expr, SchemaId rootSchema) -> expected +{ + if (rootSchema == NoSchemaId || !expr) + return expr; + + auto name = schemaLookupName(*expr); + if (!name) + return expr; + + // Querying the root schema may materialize schema-owned strings in + // completion/compile-local environments. + (void) env.querySchema(rootSchema); + + auto stringId = stringIdForSchemaLookup(env, *name); + if (!stringId) + return expr; + + auto querySchema = schemaQuery(env); + auto const* root = env.querySchema(rootSchema); + if (root) { + auto symbolEqualityPaths = root->symbolEqualityPaths(*stringId, querySchema); + if (!symbolEqualityPaths.empty()) + return enumPathExpression(env, symbolEqualityPaths, std::move(*name), expr->sourceLocation()); + } + + auto fieldPaths = Schema::fieldPaths(rootSchema, querySchema, *stringId); + if (!fieldPaths.empty()) + return std::make_unique(true, std::move(*name), expr->sourceLocation()); + + auto enumPaths = Schema::enumSymbolPaths(rootSchema, querySchema, *stringId); + if (!enumPaths.empty()) + return enumPathExpression(env, enumPaths, std::move(*name), expr->sourceLocation()); + + return expr; +} + +static auto rewriteOperandShorthandBySchema( + Environment& env, + std::string_view query, + ExprPtr expr, + SchemaId rootSchema, + bool isRoot, + bool insidePath) -> expected +{ + if (!expr || rootSchema == NoSchemaId) { + return expr; + } + + auto const entersPath = insidePath || dynamic_cast(expr.get()) != nullptr; + if (!isRoot && !insidePath) { + if (auto name = schemaOperandShorthandName(*expr, query)) { + if (auto stringId = stringIdForSchemaLookup(env, *name)) { + if (auto const* root = env.querySchema(rootSchema)) { + auto paths = root->scalarFieldPathsForSymbol(*stringId, schemaQuery(env)); + if (!paths.empty()) { + auto replacement = expressionForSchemaPathAlternatives( + env, + paths, + expr->sourceLocation(), + *name); + TRY_EXPECTED(replacement); + if (*replacement) { + return std::move(*replacement); + } + } + + auto querySchema = schemaQuery(env); + auto enumPaths = Schema::enumSymbolPaths(rootSchema, querySchema, *stringId); + if (!enumPaths.empty()) { + auto fieldPaths = Schema::fieldPaths(rootSchema, querySchema, *stringId); + if (!fieldPaths.empty()) { + return expr; + } + + return std::make_unique(Value::make(std::move(*name))); + } + } + } + } + } + + auto const count = expr->numChildren(); + for (auto i = 0u; i < count; ++i) { + auto& child = expr->childAt(i); + auto rewritten = rewriteOperandShorthandBySchema(env, query, std::move(child), rootSchema, false, entersPath); + TRY_EXPECTED(rewritten); + child = std::move(*rewritten); + } + + return expr; +} + +static auto fieldPathSegment(Environment& env, std::string_view fieldName) -> std::optional +{ + auto fieldId = env.strings()->get(fieldName); + if (fieldId == StringPool::Empty) { + return std::nullopt; + } + return SchemaPathSegment{SchemaPathSegment::Kind::Field, fieldId}; +} + +static auto stringConstValue(const Expr& expr) -> std::optional +{ + auto const* constant = dynamic_cast(&expr); + if (!constant) { + return std::nullopt; + } + auto const& value = constant->value(); + if (!value.isa(ValueType::String)) { + return std::nullopt; + } + return value.as(); +} + +static auto fieldNodeName(const Expr& expr) -> std::optional +{ + if (auto const* field = dynamic_cast(&expr)) { + return field->field(); + } + return std::nullopt; +} + +static auto addReferencedQueryStringLiteral(ReferencedQueryTerms& terms, std::string literal) -> void +{ + if (!literal.empty()) { + terms.stringLiterals.insert(std::move(literal)); + } +} + +static auto addReferencedQueryLeafField(ReferencedQueryTerms& terms, std::string fieldName) -> void +{ + if (!fieldName.empty()) { + terms.leafFields.insert(std::move(fieldName)); + } +} + +static auto collectReferencedQueryTermsFromExpr(const Expr& expr, ReferencedQueryTerms& terms) -> void; + +static auto collectReferencedQueryTermsFromPathLeaf(const Expr& expr, ReferencedQueryTerms& terms) -> void +{ + if (auto const* field = dynamic_cast(&expr)) { + addReferencedQueryLeafField(terms, field->field()); + return; + } + if (auto const* wildcardField = dynamic_cast(&expr)) { + addReferencedQueryLeafField(terms, wildcardField->name_); + return; + } + if (auto const* subscript = dynamic_cast(&expr)) { + if (auto literal = stringConstValue(*subscript->index_)) { + addReferencedQueryLeafField(terms, *literal); + addReferencedQueryStringLiteral(terms, std::move(*literal)); + return; + } + } + collectReferencedQueryTermsFromExpr(expr, terms); +} + +static auto collectReferencedQueryComparison( + ReferencedQueryTerms& terms, + const Expr& lhs, + const Expr& rhs) -> void +{ + auto fieldName = fieldNodeName(lhs); + auto literal = stringConstValue(rhs); + if (fieldName && literal) { + terms.positiveFieldStringComparisons.push_back({std::move(*fieldName), std::move(*literal)}); + } +} + +static auto collectReferencedQueryTermsFromExpr(const Expr& expr, ReferencedQueryTerms& terms) -> void +{ + if (auto const* constant = dynamic_cast(&expr)) { + if (auto literal = stringConstValue(*constant)) { + addReferencedQueryStringLiteral(terms, std::move(*literal)); + } + return; + } + if (auto const* field = dynamic_cast(&expr)) { + addReferencedQueryLeafField(terms, field->field()); + return; + } + if (auto const* wildcardField = dynamic_cast(&expr)) { + addReferencedQueryLeafField(terms, wildcardField->name_); + return; + } + if (auto const* path = dynamic_cast(&expr)) { + collectReferencedQueryTermsFromPathLeaf(*path->right(), terms); + return; + } + if (auto const* subscript = dynamic_cast(&expr)) { + if (auto literal = stringConstValue(*subscript->index_)) { + addReferencedQueryLeafField(terms, *literal); + addReferencedQueryStringLiteral(terms, std::move(*literal)); + return; + } + } + if (auto const* eq = dynamic_cast*>(&expr)) { + collectReferencedQueryComparison(terms, *eq->left_, *eq->right_); + collectReferencedQueryComparison(terms, *eq->right_, *eq->left_); + } + for (auto i = 0u; i < expr.numChildren(); ++i) { + collectReferencedQueryTermsFromExpr(*expr.childAt(i), terms); + } +} + +/** + * Flatten a static field path expression to a schema path. Returns nullopt for + * dynamic expressions, broad wildcards, or operators that cannot name one path. */ -static auto isSymbolWord(std::string_view sv) -> bool +static auto flattenReferencedPath(Environment& env, const Expr& expr) -> expected, Error> { - auto numUpperCaseLetters = 0; - return std::ranges::all_of(sv.begin(), sv.end(), [&numUpperCaseLetters](auto c) { - if (std::isupper(c)) { - ++numUpperCaseLetters; - return true; - } - return c == '_' || std::isdigit(c) != 0; - }) && numUpperCaseLetters > 0; + if (auto const* field = dynamic_cast(&expr)) { + if (field->isCurrent()) { + return SchemaPath{}; + } + auto segment = fieldPathSegment(env, field->field()); + if (!segment) { + return std::nullopt; + } + return SchemaPath{*segment}; + } + + if (auto const* wildcardField = dynamic_cast(&expr)) { + if (!wildcardField->recurse_) { + return std::nullopt; + } + auto segment = fieldPathSegment(env, wildcardField->name_); + if (!segment) { + return std::nullopt; + } + return SchemaPath{*segment}; + } + + if (auto const* path = dynamic_cast(&expr)) { + auto left = flattenReferencedPath(env, *path->left()); + TRY_EXPECTED(left); + if (!*left) { + return std::nullopt; + } + + SchemaPath result = std::move(**left); + if (auto const* field = dynamic_cast(path->right())) { + auto segment = fieldPathSegment(env, field->field()); + if (!segment) { + return std::nullopt; + } + result.push_back(*segment); + return result; + } + if (dynamic_cast(path->right())) { + result.push_back({SchemaPathSegment::Kind::ArrayElement, 0}); + return result; + } + if (auto const* subscript = dynamic_cast(path->right())) { + auto right = flattenReferencedPath(env, *subscript); + TRY_EXPECTED(right); + if (!*right) { + return std::nullopt; + } + result.insert(result.end(), (*right)->begin(), (*right)->end()); + return result; + } + return std::nullopt; + } + + if (auto const* subscript = dynamic_cast(&expr)) { + auto left = flattenReferencedPath(env, *subscript->left_); + TRY_EXPECTED(left); + if (!*left) { + return std::nullopt; + } + auto index = stringConstValue(*subscript->index_); + if (!index) { + return std::nullopt; + } + SchemaPath result = std::move(**left); + auto segment = fieldPathSegment(env, *index); + if (!segment) { + return std::nullopt; + } + result.push_back(*segment); + return result; + } + + return std::nullopt; +} + +static auto addReferencedPath( + ReferencedSchemaPaths& result, + SchemaPath path, + SourceLocation location, + bool viaWildcard, + std::optional equalsStringLiteral = std::nullopt) -> void +{ + if (path.empty()) { + return; + } + if (std::ranges::any_of(result.paths, [&](auto const& existing) { + return existing.path == path + && existing.viaWildcard == viaWildcard + && existing.equalsStringLiteral == equalsStringLiteral; + })) { + return; + } + result.paths.push_back({std::move(path), location, viaWildcard, std::move(equalsStringLiteral)}); +} + +static auto schemaPathIsReachable(Environment& env, SchemaId rootSchema, const SchemaPath& path) -> bool +{ + auto leafField = std::ranges::find_if( + path.rbegin(), + path.rend(), + [](auto const& segment) { + return segment.kind == SchemaPathSegment::Kind::Field; + }); + if (leafField == path.rend()) { + return true; + } + + auto querySchema = [&env](SchemaId schemaId) -> const Schema* { + return env.querySchema(schemaId); + }; + auto possiblePaths = Schema::fieldPaths(rootSchema, querySchema, leafField->field); + return std::ranges::find(possiblePaths, path) != possiblePaths.end(); +} + +static auto schemaPathEndsWith(const SchemaPath& path, const SchemaPath& suffix) -> bool +{ + if (suffix.size() > path.size()) { + return false; + } + return std::equal(suffix.rbegin(), suffix.rend(), path.rbegin()); +} + +static auto schemaPathsMatchingSuffix(Environment& env, SchemaId rootSchema, const SchemaPath& suffix) -> std::vector +{ + auto leafField = std::ranges::find_if( + suffix.rbegin(), + suffix.rend(), + [](auto const& segment) { + return segment.kind == SchemaPathSegment::Kind::Field; + }); + if (leafField == suffix.rend()) { + return {}; + } + + auto querySchema = [&env](SchemaId schemaId) -> const Schema* { + return env.querySchema(schemaId); + }; + auto possiblePaths = Schema::fieldPaths(rootSchema, querySchema, leafField->field); + std::erase_if(possiblePaths, [&](auto const& path) { + return !schemaPathEndsWith(path, suffix); + }); + return possiblePaths; +} + +static auto collectReferencedSchemaPaths( + Environment& env, + const Expr& expr, + SchemaId rootSchema, + ReferencedSchemaPaths& result) -> expected +{ + if (auto const* eq = dynamic_cast*>(&expr)) { + auto addComparisonPath = [&](Expr const& maybePath, Expr const& maybeLiteral) -> expected { + auto literal = stringConstValue(maybeLiteral); + if (!literal) { + return false; + } + + auto path = flattenReferencedPath(env, maybePath); + TRY_EXPECTED(path); + if (!*path) { + return false; + } + + auto pathValue = std::move(**path); + if (schemaPathIsReachable(env, rootSchema, pathValue)) { + addReferencedPath(result, std::move(pathValue), maybePath.sourceLocation(), false, std::move(*literal)); + } + else { + auto expandedPaths = schemaPathsMatchingSuffix(env, rootSchema, pathValue); + if (expandedPaths.empty()) { + result.hasUnresolvedAccess = true; + } + for (auto& expandedPath : expandedPaths) { + addReferencedPath(result, std::move(expandedPath), maybePath.sourceLocation(), false, *literal); + } + } + return true; + }; + + auto leftAdded = addComparisonPath(*eq->left_, *eq->right_); + TRY_EXPECTED(leftAdded); + if (*leftAdded) { + return {}; + } + + auto rightAdded = addComparisonPath(*eq->right_, *eq->left_); + TRY_EXPECTED(rightAdded); + if (*rightAdded) { + return {}; + } + } + + if (dynamic_cast(&expr)) { + result.hasBroadWildcardAccess = true; + return {}; + } + + if (auto const* wildcardField = dynamic_cast(&expr)) { + // Non-recursive child wildcards (`*.foo`) cannot currently be mapped + // to exact schema paths without exposing child traversal internals. + if (!wildcardField->recurse_) { + result.hasDynamicAccess = true; + return {}; + } + + auto fieldId = env.strings()->get(wildcardField->name_); + if (fieldId == StringPool::Empty) { + result.hasUnresolvedAccess = true; + return {}; + } + + auto querySchema = [&env](SchemaId schemaId) -> const Schema* { + return env.querySchema(schemaId); + }; + auto paths = Schema::fieldPaths(rootSchema, querySchema, fieldId); + if (paths.empty()) { + result.hasUnresolvedAccess = true; + return {}; + } + for (auto& path : paths) { + addReferencedPath(result, std::move(path), wildcardField->sourceLocation(), true); + } + return {}; + } + + if (dynamic_cast(&expr) + || dynamic_cast(&expr) + || dynamic_cast(&expr)) { + auto path = flattenReferencedPath(env, expr); + TRY_EXPECTED(path); + if (*path) { + if (schemaPathIsReachable(env, rootSchema, **path)) { + addReferencedPath(result, std::move(**path), expr.sourceLocation(), false); + } + else { + result.hasUnresolvedAccess = true; + } + return {}; + } + if (dynamic_cast(&expr)) { + result.hasDynamicAccess = true; + } + else { + result.hasUnresolvedAccess = true; + } + } + + for (auto i = 0u; i < expr.numChildren(); ++i) { + auto childResult = collectReferencedSchemaPaths(env, *expr.childAt(i), rootSchema, result); + TRY_EXPECTED(childResult); + } + return {}; } /** @@ -116,7 +739,7 @@ static auto scopedNotInPath(Parser& p) { * Tries to evaluate the input expression on a stub context. * Returns the evaluated result on success, otherwise the original expression is returned. */ -static auto simplifyOrForward(Environment* env, expected expr) -> expected +static auto simplifyOrForward(const RewriteRule* currentRule, Environment* env, expected expr) -> expected { if (!expr) return expr; @@ -149,16 +772,52 @@ static auto simplifyOrForward(Environment* env, expected expr) - env->warn("Expression is always "s + values[0].toString(), (*expr)->toString()); if (values.size() == 1) - return std::make_unique((*expr)->id(), std::move(values[0])); + return std::make_unique(std::move(values[0])); if (values.size() > 1) - return std::make_unique((*expr)->id(), std::vector(std::make_move_iterator(values.begin()), - std::make_move_iterator(values.end()))); + return std::make_unique(std::vector(std::make_move_iterator(values.begin()), + std::make_move_iterator(values.end()))); + + /* Apply bottom-up rewrite rules */ + for (const auto& rule : bottomUpRewriteRules) { + /* Prevent rule self-recursion */ + if (&rule == currentRule) + continue; + + if (auto rewrite = rule(*expr)) { + /* If a rewrite rule matched we try to simplify and re-write its output again */ + return simplifyOrForward(&rule, env, std::move(rewrite)); + } + } return expr; } +static auto simplifyOrForward(Environment* env, expected expr) -> expected +{ + return simplifyOrForward(nullptr, env, std::move(expr)); +} + + AST::~AST() = default; +auto AST::reenumerate() -> void +{ + if (!expr_) + return; + + auto nextId = Expr::ExprId{0}; + reenumerate(*expr_, nextId); +} + +auto AST::reenumerate(Expr& expr, Expr::ExprId& nextId) -> void +{ + expr.id_ = nextId++; + + const auto count = expr.numChildren(); + for (auto i = 0u; i < count; ++i) + reenumerate(*expr.childAt(i), nextId); +} + /** * Parser wrapper for parsing and & or operators. * @@ -174,12 +833,10 @@ class AndOrParser : public InfixParselet return right; if (t.type == Token::OP_AND) - return simplifyOrForward(p.env, std::make_unique(p.nextId(), - std::move(left), + return simplifyOrForward(p.env, std::make_unique(std::move(left), std::move(*right))); else if (t.type == Token::OP_OR) - return simplifyOrForward(p.env, std::make_unique(p.nextId(), - std::move(left), + return simplifyOrForward(p.env, std::make_unique(std::move(left), std::move(*right))); assert(0); return nullptr; @@ -205,12 +862,10 @@ class CompletionAndOrParser : public InfixParselet return right; if (t.type == Token::OP_AND) - return simplifyOrForward(p.env, std::make_unique(p.nextId(), - std::move(left), + return simplifyOrForward(p.env, std::make_unique(std::move(left), std::move(*right), comp_)); else if (t.type == Token::OP_OR) - return simplifyOrForward(p.env, std::make_unique(p.nextId(), - std::move(left), + return simplifyOrForward(p.env, std::make_unique(std::move(left), std::move(*right), comp_)); assert(0); return nullptr; @@ -231,7 +886,7 @@ class CastParser : public InfixParselet { auto type = p.consume(); if (type.type == Token::C_NULL) - return std::make_unique(p.nextId(), Value::null()); + return std::make_unique(Value::null()); if (type.type != Token::Type::WORD) return unexpected(Error::InvalidType, fmt::format("'as' expected typename got {}", type.toString())); @@ -239,17 +894,17 @@ class CastParser : public InfixParselet auto name = std::get(type.value); return simplifyOrForward(p.env, [&]() -> expected { if (name == strings::TypenameNull) - return std::make_unique(p.nextId(), Value::null()); + return std::make_unique(Value::null()); if (name == strings::TypenameBool) - return std::make_unique>(p.nextId(), std::move(left)); + return std::make_unique>(std::move(left)); if (name == strings::TypenameInt) - return std::make_unique>(p.nextId(), std::move(left)); + return std::make_unique>(std::move(left)); if (name == strings::TypenameFloat) - return std::make_unique>(p.nextId(), std::move(left)); + return std::make_unique>(std::move(left)); if (name == strings::TypenameString) - return std::make_unique>(p.nextId(), std::move(left)); + return std::make_unique>(std::move(left)); if (name == strings::TypenameBytes) - return std::make_unique>(p.nextId(), std::move(left)); + return std::make_unique>(std::move(left)); return unexpected(Error::InvalidType, fmt::format("Invalid type name for cast '{}'", name)); }()); @@ -277,8 +932,7 @@ class BinaryOpParser : public InfixParselet if (!right) return right; - return simplifyOrForward(p.env, std::make_unique>(p.nextId(), - t, + return simplifyOrForward(p.env, std::make_unique>(t, std::move(left), std::move(*right))); } @@ -303,7 +957,7 @@ class UnaryOpParser : public PrefixParselet if (!sub) return sub; - return simplifyOrForward(p.env, std::make_unique>(p.nextId(), std::move(*sub))); + return simplifyOrForward(p.env, std::make_unique>(std::move(*sub))); } }; @@ -315,7 +969,7 @@ class UnaryPostOpParser : public InfixParselet { auto parse(Parser& p, ExprPtr left, Token t) const -> expected override { - return p.parseInfix(simplifyOrForward(p.env, std::make_unique>(p.nextId(), std::move(left))), 0); + return p.parseInfix(simplifyOrForward(p.env, std::make_unique>(std::move(left))), 0); } auto precedence() const -> int override @@ -331,7 +985,7 @@ class UnpackOpParser : public InfixParselet { auto parse(Parser& p, ExprPtr left, Token t) const -> expected override { - return p.parseInfix(simplifyOrForward(p.env, std::make_unique(p.nextId(), std::move(left))), 0); + return p.parseInfix(simplifyOrForward(p.env, std::make_unique(std::move(left))), 0); } auto precedence() const -> int override @@ -353,14 +1007,12 @@ class WordOpParser : public InfixParselet return right; if (*right) - return simplifyOrForward(p.env, std::make_unique(p.nextId(), - std::get(t.value), + return simplifyOrForward(p.env, std::make_unique(std::get(t.value), std::move(left), std::move(*right))); /* Parse as unary operator */ - return p.parseInfix(simplifyOrForward(p.env, std::make_unique(p.nextId(), - std::get(t.value), + return p.parseInfix(simplifyOrForward(p.env, std::make_unique(std::get(t.value), std::move(left))), 0); } @@ -380,7 +1032,7 @@ class ScalarParser : public PrefixParselet { auto parse(Parser& p, Token t) const -> expected override { - return std::make_unique(p.nextId(), std::get(t.value)); + return std::make_unique(std::get(t.value), t); } }; @@ -394,7 +1046,7 @@ class RegExpParser : public PrefixParselet auto parse(Parser& p, Token t) const -> expected override { auto value = ReType::Type.make(std::get(t.value)); - return std::make_unique(p.nextId(), std::move(value)); + return std::make_unique(std::move(value), t); } }; @@ -415,7 +1067,7 @@ class ConstParser : public PrefixParselet auto parse(Parser& p, Token t) const -> expected override { - return std::make_unique(p.nextId(), value_); + return std::make_unique(value_, t); } Value value_; @@ -450,10 +1102,7 @@ class SubscriptParser : public PrefixParselet, public InfixParselet if (!body) return body; - auto outerId = p.nextId(); - auto innerId = p.nextId(); - return simplifyOrForward(p.env, std::make_unique(outerId, - std::make_unique(innerId, "_"), + return simplifyOrForward(p.env, std::make_unique(std::make_unique("_"), std::move(*body))); } @@ -464,8 +1113,7 @@ class SubscriptParser : public PrefixParselet, public InfixParselet if (!body) return body; - return simplifyOrForward(p.env, std::make_unique(p.nextId(), - std::move(left), + return simplifyOrForward(p.env, std::make_unique(std::move(left), std::move(*body))); } @@ -491,10 +1139,7 @@ class SubSelectParser : public PrefixParselet, public InfixParselet auto body = p.parseTo(Token::RBRACE); TRY_EXPECTED(body); - auto outerId = p.nextId(); - auto innerId = p.nextId(); - return simplifyOrForward(p.env, std::make_unique(outerId, - std::make_unique(innerId, "_"), + return simplifyOrForward(p.env, std::make_unique(std::make_unique("_"), std::move(*body))); } @@ -503,8 +1148,7 @@ class SubSelectParser : public PrefixParselet, public InfixParselet auto _ = scopedNotInPath(p); auto body = p.parseTo(Token::RBRACE); TRY_EXPECTED(body); - return simplifyOrForward(p.env, std::make_unique(p.nextId(), - std::move(left), + return simplifyOrForward(p.env, std::make_unique(std::move(left), std::move(*body))); } @@ -526,15 +1170,15 @@ class WordParser : public PrefixParselet { /* Self */ if (t.type == Token::SELF) - return std::make_unique(p.nextId(), "_", t); + return std::make_unique("_", t); /* Any Child */ if (t.type == Token::OP_TIMES) - return std::make_unique(p.nextId()); + return std::make_unique(); /* Wildcard */ if (t.type == Token::WILDCARD) - return std::make_unique(p.nextId()); + return std::make_unique(); auto word = std::get(t.value); @@ -547,25 +1191,21 @@ class WordParser : public PrefixParselet TRY_EXPECTED(arguments); if (word == "any") { - return simplifyOrForward(p.env, std::make_unique(p.nextId(), std::move(*arguments))); + return simplifyOrForward(p.env, std::make_unique(std::move(*arguments))); } else if (word == "each" || word == "all") { - return simplifyOrForward(p.env, std::make_unique(p.nextId(), std::move(*arguments))); + return simplifyOrForward(p.env, std::make_unique(std::move(*arguments))); } else { - return simplifyOrForward(p.env, std::make_unique(p.nextId(), word, std::move(*arguments))); + return simplifyOrForward(p.env, std::make_unique(word, std::move(*arguments))); } } else if (!p.ctx.inPath) { - /* Parse Symbols (words in upper-case) */ - if (isSymbolWord(word)) { - return std::make_unique(p.nextId(), Value::make(std::move(word))); - } /* Constant */ - else if (auto constant = p.env->findConstant(word)) { - return std::make_unique(p.nextId(), *constant); + if (auto constant = p.env->findConstant(word)) { + return std::make_unique(*constant, t); } } /* Single field name */ - return std::make_unique(p.nextId(), std::move(word), t); + return simplifyOrForward(p.env, std::make_unique(std::move(word), t)); } }; @@ -583,15 +1223,15 @@ class CompletionWordParser : public WordParser { /* Self */ if (t.type == Token::SELF) - return std::make_unique(p.nextId(), "_"); + return std::make_unique("_"); /* Any Child */ if (t.type == Token::OP_TIMES) - return std::make_unique(p.nextId()); + return std::make_unique(); /* Wildcard */ if (t.type == Token::WILDCARD) - return std::make_unique(p.nextId()); + return std::make_unique(); auto word = std::get(t.value); @@ -607,26 +1247,19 @@ class CompletionWordParser : public WordParser auto arguments = p.parseList(Token::RPAREN); TRY_EXPECTED(arguments); - return simplifyOrForward(p.env, std::make_unique(p.nextId(), word, std::move(*arguments))); + return simplifyOrForward(p.env, std::make_unique(word, std::move(*arguments))); } else if (!p.ctx.inPath) { - /* Parse Symbols (words in upper-case) */ - if (isSymbolWord(word)) { - if (t.containsPoint(comp_->point)) { - return std::make_unique(p.nextId(), word.substr(0, comp_->point - t.begin), comp_, t); - } - return std::make_unique(p.nextId(), Value::make(std::move(word))); - } /* Constant */ - else if (auto constant = p.env->findConstant(word)) { - return std::make_unique(p.nextId(), *constant); + if (auto constant = p.env->findConstant(word)) { + return std::make_unique(*constant, t); } } /* Single field name */ if (t.containsPoint(comp_->point)) { - return std::make_unique(p.nextId(), word.substr(0, comp_->point - t.begin), comp_, t, p.ctx.inPath); + return std::make_unique(word.substr(0, comp_->point - t.begin), comp_, t, p.ctx.inPath); } - return std::make_unique(p.nextId(), std::move(word)); + return simplifyOrForward(p.env, std::make_unique(std::move(word))); } Completion* comp_; @@ -641,6 +1274,20 @@ class CompletionWordParser : public WordParser class PathParser : public InfixParselet { public: + /** Return a source range covering `left . right` for downstream AST rewrites. */ + static auto pathSourceLocation(Expr const& left, Expr const& right, Token const& dot) -> SourceLocation + { + auto leftLocation = left.sourceLocation(); + auto rightLocation = right.sourceLocation(); + auto dotBegin = static_cast(dot.begin); + auto dotEnd = static_cast(dot.end); + auto begin = leftLocation.size == 0 ? dotBegin : std::min(leftLocation.offset, dotBegin); + auto end = rightLocation.size == 0 + ? dotEnd + : std::max(rightLocation.offset + rightLocation.size, dotEnd); + return SourceLocation(begin, end - begin); + } + auto parse(Parser& p, ExprPtr left, Token t) const -> expected override { auto inPath = true; @@ -653,7 +1300,8 @@ class PathParser : public InfixParselet auto right = p.parsePrecedence(precedence()); TRY_EXPECTED(right); - return std::make_unique(p.nextId(), std::move(left), std::move(*right)); + auto location = pathSourceLocation(*left, **right, t); + return simplifyOrForward(p.env, std::make_unique(std::move(left), std::move(*right), location)); } auto precedence() const -> int override @@ -684,10 +1332,11 @@ class CompletionPathParser : public PathParser if (!*right) { Token expectedWord(Token::WORD, "", t.end, t.end); - right = std::make_unique(p.nextId(), "", comp_, expectedWord, p.ctx.inPath); + right = std::make_unique("", comp_, expectedWord, p.ctx.inPath); } - return std::make_unique(p.nextId(), std::move(left), std::move(*right)); + auto location = pathSourceLocation(*left, **right, t); + return simplifyOrForward(p.env, std::make_unique(std::move(left), std::move(*right), location)); } Completion* comp_; @@ -805,7 +1454,17 @@ static auto setupParser(Parser& p) p.infixParsers[Token::DOT] = &pathParser; } -auto compile(Environment& env, std::string_view query, bool any, bool autoWildcard) -> expected +auto compile(Environment& env, std::string_view query, bool any, bool) -> expected +{ + return compile( + env, + query, + CompileOptions{ + .any = any, + .rewriteMode = RewriteMode::None}); +} + +auto compile(Environment& env, std::string_view query, CompileOptions options) -> expected { auto tokens = tokenize(query); TRY_EXPECTED(tokens); @@ -817,31 +1476,38 @@ auto compile(Environment& env, std::string_view query, bool any, bool autoWildca auto root = p.parse(); TRY_EXPECTED(root); - /* Expand a single value to `** == ` */ - if (autoWildcard && *root && (*root)->constant()) { - auto outerId = p.nextId(); - auto innerId = p.nextId(); - root = std::make_unique>( - outerId, std::make_unique(innerId), std::move(*root)); + if (options.rewriteMode == RewriteMode::Schema && options.rootSchema != NoSchemaId) { + root = rewriteStandaloneNameBySchema(env, std::move(*root), options.rootSchema); + TRY_EXPECTED(root); } if (!*root) return unexpected(Error::ParserError, "Expression is null"); - if (any) { + if (options.any) { std::vector args; args.emplace_back(std::move(*root)); - return simplifyOrForward(p.env, std::make_unique(p.nextId(), std::move(args))); + return simplifyOrForward(p.env, std::make_unique(std::move(args))); } else { return root; } }(); TRY_EXPECTED(expr); + if (options.rewriteMode == RewriteMode::Schema && options.rootSchema != NoSchemaId) { + expr = rewriteOperandShorthandBySchema(env, query, std::move(*expr), options.rootSchema, true, false); + TRY_EXPECTED(expr); + } + + /* Apply AST rewrite rules */ + expr = rewriteTopDown(std::move(*expr), topDownRewriteRules); + if (!p.match(Token::Type::NIL)) return unexpected(Error::ExpectedEOF, "Expected end-of-input; got "s + p.current().toString()); - return std::make_unique(std::string(query), std::move(*expr)); + auto ast = std::make_unique(std::string(query), std::move(*expr)); + ast->reenumerate(); + return ast; } auto complete(Environment& env, std::string_view query, size_t point, const ModelNode& node, const CompletionOptions& options) -> expected, Error> @@ -924,6 +1590,47 @@ auto complete(Environment& env, std::string_view query, size_t point, const Mode return candidates; } +auto referencedSchemaPaths(Environment& env, const AST& ast, SchemaId rootSchema) -> expected +{ + ReferencedSchemaPaths result; + if (rootSchema == NoSchemaId) { + result.hasUnresolvedAccess = true; + return result; + } + + (void) env.querySchema(rootSchema); + auto collected = collectReferencedSchemaPaths(env, ast.expr(), rootSchema, result); + TRY_EXPECTED(collected); + return result; +} + +auto referencedQueryTerms(const AST& ast) -> ReferencedQueryTerms +{ + ReferencedQueryTerms result; + collectReferencedQueryTermsFromExpr(ast.expr(), result); + return result; +} + +auto standaloneQuerySymbol(Environment& env, std::string_view query) -> expected, Error> +{ + auto ast = compile( + env, + query, + CompileOptions{ + .any = false, + .rewriteMode = RewriteMode::None}); + TRY_EXPECTED(ast); + + auto const& expr = (*ast)->expr(); + if (auto const* field = dynamic_cast(&expr)) { + return field->field(); + } + if (auto literal = stringConstValue(expr)) { + return literal; + } + return std::nullopt; +} + auto eval(Environment& env, const AST& ast, const ModelNode& node, Diagnostics* diag) -> expected, Error> { if (!node.model_) diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 53b8792b..96038825 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -8,6 +8,7 @@ add_executable(test.simfil common.hpp common.cpp token.cpp + schema.cpp simfil.cpp diagnostics.cpp completion.cpp diff --git a/test/common.hpp b/test/common.hpp index 3291a10f..5ad26bc4 100644 --- a/test/common.hpp +++ b/test/common.hpp @@ -15,6 +15,14 @@ #include #include +#if __has_include() +# include +#else +# define RUNNING_ON_VALGRIND false +# define CALLGRIND_START_INSTRUMENTATION (void)0 +# define CALLGRIND_STOP_INSTRUMENTATION (void)0 +#endif + using namespace simfil; static const char* const TestModel = R"json( diff --git a/test/completion.cpp b/test/completion.cpp index d45762e9..911c0568 100644 --- a/test/completion.cpp +++ b/test/completion.cpp @@ -1,10 +1,12 @@ #include "src/completion.h" #include +#include #include #include #include "common.hpp" #include "simfil/environment.h" +#include "simfil/model/schema.h" #include "src/expected.h" const auto model = R"json( @@ -52,6 +54,73 @@ auto EXPECT_COMPLETION(std::string_view query, std::optional point, std: } } +auto CompleteSchemaQuery(std::string_view query, std::optional point = {}) +{ + auto parsed = simfil::json::parse(R"json({"present": 1})json"); + REQUIRE(parsed); + + auto model = std::move(*parsed); + auto strings = model->strings(); + auto schemaOnlyField = strings->emplace("schemaOnlyField").value(); + auto specialField = strings->emplace("schema field").value(); + auto enumCarrier = strings->emplace("Carrier").value(); + auto enumFast = strings->emplace("FAST_MODE").value(); + auto nested = strings->emplace("nested").value(); + + auto valueSchema = std::make_unique(); + valueSchema->addEnumSymbol(enumCarrier); + valueSchema->addEnumSymbol(enumFast); + + auto objectSchema = std::make_unique(); + objectSchema->addField(schemaOnlyField, {SchemaId{2}}); + objectSchema->addField(specialField); + objectSchema->addField(nested, {SchemaId{3}}); + + auto nestedSchema = std::make_unique(); + nestedSchema->addField(enumFast); + + auto registry = std::make_shared>>(); + (*registry)[SchemaId{1}] = std::move(objectSchema); + (*registry)[SchemaId{2}] = std::move(valueSchema); + (*registry)[SchemaId{3}] = std::move(nestedSchema); + + auto lookup = [registry](SchemaId schemaId) -> Schema* { + if (auto it = registry->find(schemaId); it != registry->end()) + return it->second.get(); + return nullptr; + }; + for (auto const& [_, schema] : *registry) + schema->finalize(lookup); + + auto root = model->root(0); + REQUIRE(root); + auto rootObj = model->resolve(**root); + REQUIRE(rootObj); + REQUIRE(rootObj->setSchema(SchemaId{1})); + + Environment env(strings); + env.querySchemaCallback = [registry](SchemaId schemaId) -> const Schema* { + if (auto it = registry->find(schemaId); it != registry->end()) + return it->second.get(); + return nullptr; + }; + + CompletionOptions opts; + opts.showWildcardHints = false; + return complete(env, query, point.value_or(query.size()), **root, opts).value(); +} + +auto EXPECT_SCHEMA_COMPLETION(std::string_view query, std::string_view what, Type type) +{ + auto found = false; + for (const auto& item : CompleteSchemaQuery(query)) { + INFO(" Item: " << item.text); + if (item.text == what && item.type == type) + found = true; + } + REQUIRE(found); +} + TEST_CASE("CompleteField", "[completion.field.incompleteQuery]") { EXPECT_COMPLETION("((oth", {}, "other"); } @@ -77,7 +146,7 @@ TEST_CASE("CompleteField", "[completion.sub-field]") { } TEST_CASE("CompleteString", "[completion.string-const]") { - EXPECT_COMPLETION("1 > C", {}, "CONSTANT_1"); + EXPECT_COMPLETION("1 > C", {}, "\"CONSTANT_1\""); } TEST_CASE("Complete Function", "[completion.function]") { @@ -106,13 +175,13 @@ TEST_CASE("Complete in unclosed expression", "[completion.complete-in-unclosed-e } TEST_CASE("Complete SmartCas", "[completion.smart-case]") { - // Complete both the field and the constants + // Complete both the field and string literals from the model. EXPECT_COMPLETION("cons", {}, "constant", Type::FIELD, 4); - EXPECT_COMPLETION("cons", {}, "CONSTANT_1", Type::CONSTANT); + EXPECT_COMPLETION("cons", {}, "\"CONSTANT_1\"", Type::CONSTANT); // Do not complete the field - EXPECT_COMPLETION("CONS", {}, "CONSTANT_1", Type::CONSTANT, 4); // 3 entries bc. of `** =` - EXPECT_COMPLETION("CONS", {}, "CONSTANT_2", Type::CONSTANT); + EXPECT_COMPLETION("CONS", {}, "\"CONSTANT_1\"", Type::CONSTANT, 3); + EXPECT_COMPLETION("CONS", {}, "\"CONSTANT_2\"", Type::CONSTANT); } TEST_CASE("Complete Field with Special Characters", "[copletion.escape-field]") { @@ -131,8 +200,8 @@ TEST_CASE("Complete And/Or", "[copletion.and-or]") { } TEST_CASE("Complete Wildcard Hint", "[completion.generate-eq-value-hint]") { - EXPECT_COMPLETION("A_CONST", {}, "** = A_CONST", Type::HINT); EXPECT_COMPLETION("A_CONST", {}, "**.A_CONST", Type::HINT); + EXPECT_COMPLETION("\"A_CONST\"", {}, "** = \"A_CONST\"", Type::HINT); EXPECT_COMPLETION("field", {}, "**.field", Type::HINT); } @@ -154,3 +223,18 @@ TEST_CASE("Sort Completion", "[completion.sorted]") { return l.text < r.text; })); } + +TEST_CASE("Complete schema fields", "[completion.schema-field]") { + EXPECT_SCHEMA_COMPLETION("schema", "schemaOnlyField", Type::FIELD); + EXPECT_SCHEMA_COMPLETION("schema", "[\"schema field\"]", Type::FIELD); +} + +TEST_CASE("Complete schema enum symbols", "[completion.schema-enum]") { + EXPECT_SCHEMA_COMPLETION("Car", "\"Carrier\"", Type::CONSTANT); + EXPECT_SCHEMA_COMPLETION("FAST", "FAST_MODE", Type::FIELD); + + for (const auto& item : CompleteSchemaQuery("FAST")) { + INFO(" Item: " << item.text); + REQUIRE(item.text != "\"FAST_MODE\""); + } +} diff --git a/test/performance.cpp b/test/performance.cpp index c1e8de01..65c4d1b5 100644 --- a/test/performance.cpp +++ b/test/performance.cpp @@ -1,6 +1,5 @@ #include "simfil/simfil.h" #include "simfil/model/model.h" - #include #include #include diff --git a/test/schema.cpp b/test/schema.cpp new file mode 100644 index 00000000..8e7b8543 --- /dev/null +++ b/test/schema.cpp @@ -0,0 +1,1103 @@ +#include "simfil/diagnostics.h" +#include "simfil/model/nodes.h" +#include "simfil/simfil.h" +#include "simfil/environment.h" +#include "simfil/model/schema.h" +#include "simfil/model/model.h" +#include "simfil/model/json.h" +#include "common.hpp" + +#include +#include +#include +#include +#include +#include + +using namespace simfil; + +namespace +{ + +class SchemaRegistry +{ +public: + std::map> schemas; + + // Enable schema lookup. + // + // By having this flag we do not cheat the price of + // the function call for the no-schema benchmarks instead + // of setting the environments query pointer to null. + bool enabled = true; + + auto get(SchemaId id) const -> const Schema* + { + if (!enabled) + return nullptr; + + if (auto i = schemas.find(id); i != schemas.end()) + return i->second.get(); + return nullptr; + } + + auto get(SchemaId id) -> Schema* + { + if (!enabled) + return nullptr; + + if (auto i = schemas.find(id); i != schemas.end()) + return i->second.get(); + return nullptr; + } + + auto finalize() -> void + { + auto& self = *this; + for (const auto& [_, value] : schemas) { + value->finalize([&self](auto id) { return self(id); }); + } + } + + auto operator()(SchemaId id) -> Schema* + { + return get(id); + } + + auto operator()(SchemaId id) const -> const Schema* + { + return get(id); + } + + auto asFunction() const & -> std::function + { + return [this](SchemaId id) { + return (*this)(id); + }; + } +}; + +} + +TEST_CASE("Object schema id assignment", "[model.schema]") { + auto model = std::make_shared(); + + auto obj = model->newObject(0); + REQUIRE(obj->schema() == NoSchemaId); + + obj->setSchema(SchemaId{1}); + REQUIRE(obj->schema() == SchemaId{1}); +} + +TEST_CASE("Singleton object schema id assignment", "[model.schema]") { + auto model = std::make_shared(); + + auto obj = model->newObject(1, true); + REQUIRE(obj->schema() == NoSchemaId); + + REQUIRE(obj->addField("field", int64_t{1})); + obj->setSchema(SchemaId{1}); + REQUIRE(obj->schema() == SchemaId{1}); + REQUIRE(model->validate()); +} + +TEST_CASE("Array schema id assignment", "[model.schema]") { + auto model = std::make_shared(); + + auto arr = model->newArray(0); + REQUIRE(arr->schema() == NoSchemaId); + + arr->setSchema(SchemaId{1}); + REQUIRE(arr->schema() == SchemaId{1}); +} + +TEST_CASE("Singleton array schema id assignment", "[model.schema]") { + auto model = std::make_shared(); + + auto arr = model->newArray(1, true); + REQUIRE(arr->schema() == NoSchemaId); + + arr->append(int64_t(1)); + arr->setSchema(SchemaId{1}); + REQUIRE(arr->schema() == SchemaId{1}); + REQUIRE(model->validate()); +} + +TEST_CASE("Object schema finalization", "[model.schema]") { + auto strings = std::make_shared(); + const auto a = strings->emplace("a").value(); + const auto b = strings->emplace("b").value(); + const auto c = strings->emplace("c").value(); + const auto link = strings->emplace("link").value(); + const auto back = strings->emplace("back").value(); + const auto missing = strings->emplace("missing").value(); + const auto enumA = strings->emplace("ENUM_A").value(); + const auto enumB = strings->emplace("ENUM_B").value(); + const auto missingEnum = strings->emplace("MISSING_ENUM").value(); + + SECTION("dirty schemas are conservative") { + ObjectSchema schema; + schema.addField(a); + + // No finalize() called, so canHaveField must return `true`. + REQUIRE(schema.canHaveField(a)); + REQUIRE(schema.canHaveField(missing)); + + schema.finalize([](SchemaId) { return nullptr; }); + REQUIRE(schema.canHaveField(a)); + REQUIRE(!schema.canHaveField(missing)); + } + + SECTION("acyclic schemas finalize fields") { + std::vector schemas(3); + schemas[1].addField(a, {SchemaId{2}}); + schemas[2].addField(b); + + auto lookup = [&schemas](SchemaId schemaId) { + const auto index = static_cast(schemaId); + return index < schemas.size() ? &schemas[index] : nullptr; + }; + + schemas[1].finalize(lookup); + + REQUIRE(schemas[1].canHaveField(a)); + REQUIRE(schemas[1].canHaveField(b)); + REQUIRE_FALSE(schemas[1].canHaveField(c)); + } + + SECTION("cyclic schemas collect reachable fields") { + std::vector schemas(3); + schemas[1].addField(link, {SchemaId{2}}); + schemas[1].addField(c); + schemas[2].addField(back, {SchemaId{1}}); + + auto lookup = [&schemas](SchemaId schemaId) { + const auto index = static_cast(schemaId); + return index < schemas.size() ? &schemas[index] : nullptr; + }; + + schemas[1].finalize(lookup); + schemas[2].finalize(lookup); + + REQUIRE(schemas[1].canHaveField(link)); + REQUIRE(schemas[1].canHaveField(back)); + REQUIRE(schemas[1].canHaveField(c)); + REQUIRE_FALSE(schemas[1].canHaveField(missing)); + + REQUIRE(schemas[2].canHaveField(link)); + REQUIRE(schemas[2].canHaveField(back)); + REQUIRE(schemas[2].canHaveField(c)); + REQUIRE_FALSE(schemas[2].canHaveField(missing)); + } + + SECTION("array schemas finalize element fields") { + ObjectSchema objectA; + objectA.addField(a); + + ObjectSchema objectB; + objectB.addField(b); + + ArraySchema arraySchema; + arraySchema.addElementSchemas({SchemaId{1}, SchemaId{2}}); + + auto lookup = [&objectA, &objectB](SchemaId schemaId) -> Schema* { + switch (schemaId) { + case SchemaId{1}: + return &objectA; + case SchemaId{2}: + return &objectB; + default: + return nullptr; + } + }; + + arraySchema.finalize(lookup); + + REQUIRE(arraySchema.canHaveField(a)); + REQUIRE(arraySchema.canHaveField(b)); + REQUIRE_FALSE(arraySchema.canHaveField(c)); + } + + SECTION("value schemas finalize enum symbols") { + ValueSchema schema; + schema.addEnumSymbol(enumB); + schema.addEnumSymbol(enumA); + schema.addEnumSymbol(enumA); + + // Dirty value schemas are conservative until finalized. + REQUIRE(schema.canHaveEnumSymbol(missingEnum)); + + schema.finalize([](SchemaId) { return nullptr; }); + REQUIRE(schema.canHaveEnumSymbol(enumA)); + REQUIRE(schema.canHaveEnumSymbol(enumB)); + REQUIRE_FALSE(schema.canHaveEnumSymbol(missingEnum)); + REQUIRE(schema.nestedEnumSymbols().size() == 2); + } + + SECTION("object and array schemas collect reachable enum symbols") { + ObjectSchema objectSchema; + objectSchema.addField(a, {SchemaId{1}}); + + ArraySchema arraySchema; + arraySchema.addElementSchemas({SchemaId{2}}); + + ValueSchema enumSchema; + enumSchema.addEnumSymbol(enumA); + enumSchema.addEnumSymbol(enumB); + + auto lookup = [&](SchemaId schemaId) -> Schema* { + switch (schemaId) { + case SchemaId{1}: + return &enumSchema; + case SchemaId{2}: + return &objectSchema; + default: + return nullptr; + } + }; + + objectSchema.finalize(lookup); + arraySchema.finalize(lookup); + + REQUIRE(objectSchema.canHaveEnumSymbol(enumA)); + REQUIRE(objectSchema.canHaveEnumSymbol(enumB)); + REQUIRE_FALSE(objectSchema.canHaveEnumSymbol(missingEnum)); + + REQUIRE(arraySchema.canHaveEnumSymbol(enumA)); + REQUIRE(arraySchema.canHaveEnumSymbol(enumB)); + REQUIRE_FALSE(arraySchema.canHaveEnumSymbol(missingEnum)); + } +} + +TEST_CASE("Array schema serialization", "[model.schema]") { + auto model = std::make_shared(); + auto arr = model->newArray(1); + arr->append(int64_t(42)); + REQUIRE(arr->setSchema(SchemaId{7})); + model->addRoot(arr); + + std::stringstream stream; + REQUIRE(model->write(stream)); + + const auto input = std::vector(std::istreambuf_iterator(stream), {}); + auto recoveredModel = std::make_shared(); + REQUIRE(recoveredModel->read(input)); + + auto recoveredRoot = recoveredModel->root(0); + REQUIRE(recoveredRoot); + REQUIRE((*recoveredRoot)->type() == ValueType::Array); + REQUIRE((*recoveredRoot)->schema() == SchemaId{7}); +} + +TEST_CASE("Schema rewrites enum symbols to exact paths", "[model.schema]") +{ + auto model = json::parse(R"json( + { + "status": "Other", + "items": [ + {"kind": "Other"} + ], + "unrelated": { + "value": "Carrier" + }, + "CARRIER": 7 + } + )json").value(); + + auto registry = SchemaRegistry{}; + auto strings = model->strings(); + auto status = strings->get("status"); + auto items = strings->get("items"); + auto kind = strings->get("kind"); + auto carrierField = strings->get("CARRIER"); + auto carrierEnum = strings->get("Carrier"); + + auto rootSchema = std::make_unique(); + rootSchema->addField(status, {SchemaId{2}}); + rootSchema->addField(items, {SchemaId{3}}); + rootSchema->addField(carrierField); + + auto enumSchema = std::make_unique(); + enumSchema->addEnumSymbol(carrierEnum); + + auto arraySchema = std::make_unique(); + arraySchema->addElementSchemas({SchemaId{4}}); + + auto itemSchema = std::make_unique(); + itemSchema->addField(kind, {SchemaId{2}}); + + registry.schemas[SchemaId{1}] = std::move(rootSchema); + registry.schemas[SchemaId{2}] = std::move(enumSchema); + registry.schemas[SchemaId{3}] = std::move(arraySchema); + registry.schemas[SchemaId{4}] = std::move(itemSchema); + registry.finalize(); + + auto root = model->root(0); + REQUIRE(root); + auto rootObj = model->resolve(**root); + REQUIRE(rootObj); + REQUIRE(rootObj->setSchema(SchemaId{1})); + + Environment env(strings); + env.querySchemaCallback = registry.asFunction(); + + auto enumAst = compile(env, "Carrier", CompileOptions{ + .any = false, + .rewriteMode = RewriteMode::Schema, + .rootSchema = SchemaId{1}}); + REQUIRE(enumAst); + INFO((*enumAst)->expr().toString()); + REQUIRE((*enumAst)->expr().toString().find("**") == std::string::npos); + REQUIRE((*enumAst)->expr().toString().find("status") != std::string::npos); + REQUIRE((*enumAst)->expr().toString().find("kind") != std::string::npos); + + auto enumResult = eval(env, **enumAst, **root, nullptr); + REQUIRE(enumResult); + REQUIRE(enumResult->size() == 1); + REQUIRE(enumResult->front().isa(ValueType::Bool)); + REQUIRE_FALSE(enumResult->front().as()); + + auto fieldAst = compile(env, "CARRIER", CompileOptions{ + .any = false, + .rewriteMode = RewriteMode::Schema, + .rootSchema = SchemaId{1}}); + REQUIRE(fieldAst); + REQUIRE((*fieldAst)->expr().toString() == "**.CARRIER"); + + auto enumRefs = referencedSchemaPaths(env, **enumAst, SchemaId{1}); + REQUIRE(enumRefs); + REQUIRE_FALSE(enumRefs->hasBroadWildcardAccess); + REQUIRE_FALSE(enumRefs->hasDynamicAccess); + REQUIRE(enumRefs->paths.size() == 2); + REQUIRE(std::ranges::all_of(enumRefs->paths, [](auto const& ref) { + return ref.location == SourceLocation{0, 7}; + })); + REQUIRE(std::ranges::any_of(enumRefs->paths, [&](auto const& ref) { + return ref.path.size() == 1 && ref.path[0].field == status && !ref.viaWildcard; + })); + REQUIRE(std::ranges::any_of(enumRefs->paths, [&](auto const& ref) { + return ref.path.size() == 3 + && ref.path[0].field == items + && ref.path[1].kind == SchemaPathSegment::Kind::ArrayElement + && ref.path[2].field == kind + && !ref.viaWildcard; + })); + + auto fieldRefs = referencedSchemaPaths(env, **fieldAst, SchemaId{1}); + REQUIRE(fieldRefs); + REQUIRE(fieldRefs->paths.size() == 1); + REQUIRE(fieldRefs->paths.front().viaWildcard); + REQUIRE(fieldRefs->paths.front().location == SourceLocation{0, 7}); + REQUIRE(fieldRefs->paths.front().path.size() == 1); + REQUIRE(fieldRefs->paths.front().path.front().field == carrierField); + + auto unresolvedAst = compile(env, "unrelated.value", CompileOptions{ + .any = false, + .rewriteMode = RewriteMode::None, + .rootSchema = SchemaId{1}}); + REQUIRE(unresolvedAst); + auto unresolvedRefs = referencedSchemaPaths(env, **unresolvedAst, SchemaId{1}); + REQUIRE(unresolvedRefs); + REQUIRE(unresolvedRefs->paths.empty()); + REQUIRE(unresolvedRefs->hasUnresolvedAccess); + + REQUIRE(strings->get("absent") == StringPool::Empty); + auto absentAst = compile(env, "absent", CompileOptions{ + .any = false, + .rewriteMode = RewriteMode::None, + .rootSchema = SchemaId{1}}); + REQUIRE(absentAst); + auto absentRefs = referencedSchemaPaths(env, **absentAst, SchemaId{1}); + REQUIRE(absentRefs); + REQUIRE(absentRefs->paths.empty()); + REQUIRE(absentRefs->hasUnresolvedAccess); + REQUIRE(strings->get("absent") == StringPool::Empty); + + auto childWildcardAst = compile(env, "*.CARRIER", CompileOptions{ + .any = false, + .rewriteMode = RewriteMode::None, + .rootSchema = SchemaId{1}}); + REQUIRE(childWildcardAst); + auto childWildcardRefs = referencedSchemaPaths(env, **childWildcardAst, SchemaId{1}); + REQUIRE(childWildcardRefs); + REQUIRE(childWildcardRefs->paths.empty()); + REQUIRE(childWildcardRefs->hasDynamicAccess); +} + +TEST_CASE("Schema operand shorthand rewrites only source tokens", "[model.schema]") +{ + auto model = json::parse(R"json( + { + "$name": "speed", + "value": 50, + "primary": 70, + "secondary": 90, + "unit": "MPH" + } + )json").value(); + + auto strings = model->strings(); + auto alias = strings->emplace("speed").value(); + auto multiAlias = strings->emplace("limit").value(); + auto enumSymbol = strings->emplace("MPH").value(); + auto name = strings->emplace("$name").value(); + auto value = strings->emplace("value").value(); + auto primary = strings->emplace("primary").value(); + auto secondary = strings->emplace("secondary").value(); + auto unit = strings->emplace("unit").value(); + + class AliasSchema final : public ObjectSchema + { + public: + AliasSchema(StringId alias, StringId multiAlias, StringId name, StringId value, StringId primary, StringId secondary, StringId unit) + : alias_(alias) + , multiAlias_(multiAlias) + , name_(name) + , value_(value) + , primary_(primary) + , secondary_(secondary) + , unit_(unit) + { + addField(name_); + addField(value_); + addField(primary_); + addField(secondary_); + addField(unit_, {SchemaId{2}}); + } + + auto symbolEqualityPaths( + StringId symbol, + const std::function&) const -> std::vector override + { + if (symbol != alias_) + return {}; + return {SchemaPath{{SchemaPathSegment::Kind::Field, name_}}}; + } + + auto scalarFieldPathsForSymbol( + StringId symbol, + const std::function&) const -> std::vector override + { + if (symbol == alias_) { + return {SchemaPath{{SchemaPathSegment::Kind::Field, value_}}}; + } + if (symbol == multiAlias_) { + return { + SchemaPath{{SchemaPathSegment::Kind::Field, primary_}}, + SchemaPath{{SchemaPathSegment::Kind::Field, secondary_}}, + }; + } + else { + return {}; + } + } + + private: + StringId alias_; + StringId multiAlias_; + StringId name_; + StringId value_; + StringId primary_; + StringId secondary_; + StringId unit_; + }; + + SchemaRegistry registry; + auto enumSchema = std::make_unique(); + enumSchema->addEnumSymbol(enumSymbol); + registry.schemas[SchemaId{1}] = std::make_unique(alias, multiAlias, name, value, primary, secondary, unit); + registry.schemas[SchemaId{2}] = std::move(enumSchema); + registry.finalize(); + + auto root = model->root(0); + REQUIRE(root); + auto rootObj = model->resolve(**root); + REQUIRE(rootObj); + REQUIRE(rootObj->setSchema(SchemaId{1})); + + Environment env(strings); + env.querySchemaCallback = registry.asFunction(); + + auto standaloneAst = compile(env, "speed", CompileOptions{ + .any = false, + .rewriteMode = RewriteMode::Schema, + .rootSchema = SchemaId{1}}); + REQUIRE(standaloneAst); + INFO((*standaloneAst)->expr().toString()); + REQUIRE((*standaloneAst)->expr().toString().find("$name") != std::string::npos); + REQUIRE((*standaloneAst)->expr().toString().find("value") == std::string::npos); + + auto standaloneResult = eval(env, **standaloneAst, **root, nullptr); + REQUIRE(standaloneResult); + REQUIRE(standaloneResult->size() == 1); + REQUIRE(standaloneResult->front().isa(ValueType::Bool)); + REQUIRE(standaloneResult->front().as()); + + auto expressionAst = compile(env, "speed > 40", CompileOptions{ + .any = false, + .rewriteMode = RewriteMode::Schema, + .rootSchema = SchemaId{1}}); + REQUIRE(expressionAst); + INFO((*expressionAst)->expr().toString()); + REQUIRE((*expressionAst)->expr().toString().find("value") != std::string::npos); + + auto expressionResult = eval(env, **expressionAst, **root, nullptr); + REQUIRE(expressionResult); + REQUIRE(expressionResult->size() == 1); + REQUIRE(expressionResult->front().isa(ValueType::Bool)); + REQUIRE(expressionResult->front().as()); + + auto enumOperandAst = compile(env, "unit == MPH", CompileOptions{ + .any = false, + .rewriteMode = RewriteMode::Schema, + .rootSchema = SchemaId{1}}); + REQUIRE(enumOperandAst); + INFO((*enumOperandAst)->expr().toString()); + REQUIRE((*enumOperandAst)->expr().toString().find("\"MPH\"") != std::string::npos); + + auto enumOperandResult = eval(env, **enumOperandAst, **root, nullptr); + REQUIRE(enumOperandResult); + REQUIRE(enumOperandResult->size() == 1); + REQUIRE(enumOperandResult->front().isa(ValueType::Bool)); + REQUIRE(enumOperandResult->front().as()); + + auto quotedAst = compile(env, R"("speed" == speed)", CompileOptions{ + .any = false, + .rewriteMode = RewriteMode::Schema, + .rootSchema = SchemaId{1}}); + REQUIRE(quotedAst); + INFO((*quotedAst)->expr().toString()); + REQUIRE((*quotedAst)->expr().toString().find("\"speed\"") != std::string::npos); + REQUIRE((*quotedAst)->expr().toString().find("value") != std::string::npos); + + auto anyAlternativeAst = compile(env, "limit > 80", CompileOptions{ + .any = true, + .rewriteMode = RewriteMode::Schema, + .rootSchema = SchemaId{1}}); + REQUIRE(anyAlternativeAst); + INFO((*anyAlternativeAst)->expr().toString()); + REQUIRE((*anyAlternativeAst)->expr().toString().find("(paths") != std::string::npos); + REQUIRE((*anyAlternativeAst)->expr().toString().find("primary") != std::string::npos); + REQUIRE((*anyAlternativeAst)->expr().toString().find("secondary") != std::string::npos); + + auto anyAlternativeResult = eval(env, **anyAlternativeAst, **root, nullptr); + REQUIRE(anyAlternativeResult); + REQUIRE(anyAlternativeResult->size() == 1); + REQUIRE(anyAlternativeResult->front().isa(ValueType::Bool)); + REQUIRE(anyAlternativeResult->front().as()); + + auto eachAlternativeAst = compile(env, "each(limit > 80)", CompileOptions{ + .any = false, + .rewriteMode = RewriteMode::Schema, + .rootSchema = SchemaId{1}}); + REQUIRE(eachAlternativeAst); + auto eachAlternativeResult = eval(env, **eachAlternativeAst, **root, nullptr); + REQUIRE(eachAlternativeResult); + REQUIRE(eachAlternativeResult->size() == 1); + REQUIRE(eachAlternativeResult->front().isa(ValueType::Bool)); + REQUIRE_FALSE(eachAlternativeResult->front().as()); + + auto eachAllAlternativeAst = compile(env, "each(limit < 100)", CompileOptions{ + .any = false, + .rewriteMode = RewriteMode::Schema, + .rootSchema = SchemaId{1}}); + REQUIRE(eachAllAlternativeAst); + auto eachAllAlternativeResult = eval(env, **eachAllAlternativeAst, **root, nullptr); + REQUIRE(eachAllAlternativeResult); + REQUIRE(eachAllAlternativeResult->size() == 1); + REQUIRE(eachAllAlternativeResult->front().isa(ValueType::Bool)); + REQUIRE(eachAllAlternativeResult->front().as()); +} + +// A minimal test that makes sure a field not in the schema +// is pruned if we query for it via **.field. +TEST_CASE("WildcardFieldExpr Field Pruning", "[model.schema]") +{ + auto jsonModel = R"json( + { + "field": 123 + } + )json"; + auto model = json::parse(jsonModel).value(); + auto registry = SchemaRegistry{}; + auto strings = model->strings(); + auto fieldId = strings->get("field"); + + // We need to add "noField" to the StringPool to prevent + // evaluation skipping the expression. + (void)strings->emplace("noField"); + + // Build a simple schema + auto schemaName = strings->emplace("schema1").value(); + auto schema1 = std::make_unique(); + schema1->addField(fieldId, { NoSchemaId }); + + registry.schemas[(SchemaId)schemaName] = std::move(schema1); + registry.finalize(); + + // Assign schemas to the model + auto root = model->root(0); + REQUIRE(root); + + auto rootObj = model->resolve(*root.value()); + REQUIRE(rootObj); + REQUIRE(rootObj->setSchema((SchemaId)schemaName)); + REQUIRE(rootObj->schema() == (SchemaId)schemaName); + + // Run a query and check if pruning of unknown fields works + Environment env(strings); + env.querySchemaCallback = registry.asFunction(); + + auto ast = compile(env, "**.noField", false, false); + REQUIRE(ast); + + Diagnostics diagWithPruning; + registry.enabled = true; + auto resultWithPruning = eval(env, *ast.value(), *model->root(0).value(), &diagWithPruning); + REQUIRE(resultWithPruning); + + Diagnostics diagNoPruning; + registry.enabled = false; + auto resultNoPruning = eval(env, *ast.value(), *model->root(0).value(), &diagNoPruning); + REQUIRE(resultNoPruning); + + // We compare field evaluations for both runs + auto withPruningData = diagWithPruning.fieldData_[0]; + auto noPruningData = diagNoPruning.fieldData_[0]; + REQUIRE(withPruningData.evaluations < noPruningData.evaluations); +} + +TEST_CASE("WildcardFieldExpr Array Field Pruning", "[model.schema]") +{ + auto jsonModel = R"json( + [ + { + "field": 123 + } + ] + )json"; + auto model = json::parse(jsonModel).value(); + auto registry = SchemaRegistry{}; + auto strings = model->strings(); + auto fieldId = strings->get("field"); + + (void)strings->emplace("noField"); + + constexpr auto objectSchemaId = SchemaId{1}; + constexpr auto arraySchemaId = SchemaId{2}; + + auto objectSchema = std::make_unique(); + objectSchema->addField(fieldId, { NoSchemaId }); + registry.schemas[objectSchemaId] = std::move(objectSchema); + + auto arraySchema = std::make_unique(); + arraySchema->addElementSchemas({objectSchemaId}); + registry.schemas[arraySchemaId] = std::move(arraySchema); + registry.finalize(); + + auto root = model->root(0); + REQUIRE(root); + auto rootArray = model->resolve(*root.value()); + + REQUIRE(rootArray); + REQUIRE(rootArray->setSchema(arraySchemaId)); + REQUIRE(rootArray->schema() == arraySchemaId); + + Environment env(strings); + env.querySchemaCallback = registry.asFunction(); + + auto ast = compile(env, "**.noField", false, false); + REQUIRE(ast); + + auto modelRoot = model->root(0); + REQUIRE(modelRoot); + + Diagnostics diagWithPruning; + registry.enabled = true; + auto resultWithPruning = eval(env, **ast, **modelRoot, &diagWithPruning); + REQUIRE(resultWithPruning); + + Diagnostics diagNoPruning; + registry.enabled = false; + auto resultNoPruning = eval(env, **ast, **modelRoot, &diagNoPruning); + REQUIRE(resultNoPruning); + + auto withPruningData = diagWithPruning.fieldData_[0]; + auto noPruningData = diagNoPruning.fieldData_[0]; + REQUIRE(withPruningData.evaluations < noPruningData.evaluations); +} + +TEST_CASE("WildcardFieldExpr non-recursive queries ignore partial root schemas", "[model.schema]") +{ + auto jsonModel = R"json( + { + "object": { + "field": 123 + } + } + )json"; + auto model = json::parse(jsonModel).value(); + auto registry = SchemaRegistry{}; + auto strings = model->strings(); + auto objectId = strings->get("object"); + (void)strings->emplace("field"); + + const auto rootSchemaId = SchemaId{1}; + auto rootSchema = std::make_unique(); + rootSchema->addField(objectId, { NoSchemaId }); + registry.schemas[rootSchemaId] = std::move(rootSchema); + registry.finalize(); + + auto root = model->root(0); + REQUIRE(root); + auto rootObj = model->resolve(*root.value()); + REQUIRE(rootObj); + REQUIRE(rootObj->setSchema(rootSchemaId)); + + Environment env(strings); + env.querySchemaCallback = registry.asFunction(); + + auto ast = compile(env, "*.field", false, false); + REQUIRE(ast); + + auto result = eval(env, **ast, **root, nullptr); + REQUIRE(result); + REQUIRE(result->size() == 1); + REQUIRE((*result)[0].toString() == "123"); +} + +TEST_CASE("WildcardFieldExpr schema plan cache follows schema mutations", "[model.schema]") +{ + auto jsonModel = R"json( + { + "target": 123 + } + )json"; + auto model = json::parse(jsonModel).value(); + auto registry = SchemaRegistry{}; + auto strings = model->strings(); + auto targetId = strings->get("target"); + auto otherId = strings->emplace("other").value(); + + const auto rootSchemaId = SchemaId{1}; + auto rootSchema = std::make_unique(); + auto* rootSchemaPtr = rootSchema.get(); + rootSchema->addField(otherId); + registry.schemas[rootSchemaId] = std::move(rootSchema); + registry.finalize(); + + auto root = model->root(0); + REQUIRE(root); + auto rootObj = model->resolve(*root.value()); + REQUIRE(rootObj); + REQUIRE(rootObj->setSchema(rootSchemaId)); + + Environment env(strings); + env.querySchemaCallback = registry.asFunction(); + + auto ast = compile(env, "**.target", false, false); + REQUIRE(ast); + + auto beforeSchemaUpdate = eval(env, **ast, **root, nullptr); + REQUIRE(beforeSchemaUpdate); + REQUIRE(beforeSchemaUpdate->size() == 1); + REQUIRE((*beforeSchemaUpdate)[0].isa(ValueType::Null)); + + rootSchemaPtr->addField(targetId); + registry.finalize(); + + auto afterSchemaUpdate = eval(env, **ast, **root, nullptr); + REQUIRE(afterSchemaUpdate); + REQUIRE(afterSchemaUpdate->size() == 1); + REQUIRE((*afterSchemaUpdate)[0].toString() == "123"); +} + +TEST_CASE("Schema query performance", "[perf.schema]") { + if (RUNNING_ON_VALGRIND) { // NOLINT + SKIP("Skipping benchmarks when running under valgrind"); + } + + constexpr auto n = std::size_t{10'000}; + static_assert(n % 2 == 0, "n must be even"); + + const auto payloadASchemaId = SchemaId{1}; + const auto payloadBSchemaId = SchemaId{2}; + const auto xASchemaId = SchemaId{3}; + const auto xBSchemaId = SchemaId{4}; + const auto yASchemaId = SchemaId{5}; + const auto yBSchemaId = SchemaId{6}; + const auto rootObjASchemaId = SchemaId{7}; + const auto rootObjBSchemaId = SchemaId{8}; + const auto arraySchemaId = SchemaId{9}; + + auto strings = std::make_shared(); + auto model = std::make_shared(strings); + auto registry = SchemaRegistry{}; + + const auto aId = strings->emplace("a").value(); + const auto bId = strings->emplace("b").value(); + const auto yId = strings->emplace("y").value(); + const auto xId = strings->emplace("x").value(); + const auto missingId = strings->emplace("missing").value(); + const auto payloadId = strings->emplace("payload").value(); + + auto payloadASchema = std::make_unique(); + payloadASchema->addField(xId, { xASchemaId }); + registry.schemas[payloadASchemaId] = std::move(payloadASchema); + + auto payloadBSchema = std::make_unique(); + payloadBSchema->addField(xId, { xBSchemaId }); + registry.schemas[payloadBSchemaId] = std::move(payloadBSchema); + + auto xASchema = std::make_unique(); + xASchema->addField(yId, { yASchemaId }); + registry.schemas[xASchemaId] = std::move(xASchema); + + auto xBSchema = std::make_unique(); + xBSchema->addField(yId, { yBSchemaId }); + registry.schemas[xBSchemaId] = std::move(xBSchema); + + auto yASchema = std::make_unique(); + yASchema->addField(aId); + registry.schemas[yASchemaId] = std::move(yASchema); + + auto yBSchema = std::make_unique(); + yBSchema->addField(bId); + registry.schemas[yBSchemaId] = std::move(yBSchema); + + auto rootObjASchema = std::make_unique(); + rootObjASchema->addField(payloadId, { payloadASchemaId }); + registry.schemas[rootObjASchemaId] = std::move(rootObjASchema); + + auto rootObjBSchema = std::make_unique(); + rootObjBSchema->addField(payloadId, { payloadBSchemaId }); + registry.schemas[rootObjBSchemaId] = std::move(rootObjBSchema); + + auto arraySchema = std::make_unique(); + arraySchema->addElementSchemas({ rootObjASchemaId, rootObjBSchemaId }); + registry.schemas[arraySchemaId] = std::move(arraySchema); + registry.finalize(); + + auto root = model->newArray(n); + for (auto i = 0u; i < n; ++i) { + auto obj = model->newObject(1, true); + auto payload = model->newObject(1, true); + auto x = model->newObject(1, true); + auto y = model->newObject(1, true); + + if (i % 2 == 0) { + y->addField("a", int64_t(1)); + y->setSchema(yASchemaId); + x->setSchema(xASchemaId); + payload->setSchema(payloadASchemaId); + obj->setSchema(rootObjASchemaId); + } else { + y->addField("b", int64_t(1)); + y->setSchema(yBSchemaId); + x->setSchema(xBSchemaId); + payload->setSchema(payloadBSchemaId); + obj->setSchema(rootObjBSchemaId); + } + + x->addField("y", y); + payload->addField("x", x); + obj->addField("payload", payload); + root->append(obj); + } + + REQUIRE(root->setSchema(arraySchemaId)); + model->addRoot(root); + + Environment env(strings); + env.querySchemaCallback = registry.asFunction(); + + auto modelRoot = model->root(0); + REQUIRE(modelRoot); + + auto aAst = compile(env, "count(**.a == 1)", false, false); + REQUIRE(aAst); + + auto missingAst = compile(env, "count(**.missing == 1)", false, false); + REQUIRE(missingAst); + + registry.enabled = false; + BENCHMARK("Query nested field 'a' recursive without schema") { + auto res = eval(env, **aAst, **modelRoot, nullptr); + REQUIRE(res); + REQUIRE(res->size() == 1); + + auto count = res->front().template as(); + REQUIRE(count == int64_t(n / 2)); + return count; + }; + + BENCHMARK("Query missing field 'missing' without schema") { + auto res = eval(env, **missingAst, **modelRoot, nullptr); + REQUIRE(res); + REQUIRE(res->size() == 1); + + auto count = res->front().template as(); + REQUIRE(count == 0); + return count; + }; + + registry.enabled = true; + BENCHMARK("Query nested field 'a' recursive with schema") { + auto res = eval(env, **aAst, **modelRoot, nullptr); + REQUIRE(res); + REQUIRE(res->size() == 1); + + auto count = res->front().template as(); + REQUIRE(count == int64_t(n / 2)); + return count; + }; + + BENCHMARK("Query missing field 'missing' with schema") { + auto res = eval(env, **missingAst, **modelRoot, nullptr); + REQUIRE(res); + REQUIRE(res->size() == 1); + + auto count = res->front().template as(); + REQUIRE(count == 0); + return count; + }; +} + +TEST_CASE("Sparse wide schema query performance", "[perf.schema]") { + if (RUNNING_ON_VALGRIND) { // NOLINT + SKIP("Skipping benchmarks when running under valgrind"); + } + + constexpr auto objectCount = std::size_t{2'000}; + constexpr auto branchCount = std::size_t{32}; + + const auto targetBranchSchemaId = SchemaId{1}; + const auto targetPayloadSchemaId = SchemaId{2}; + const auto noiseBranchSchemaId = SchemaId{3}; + const auto rootObjectSchemaId = SchemaId{4}; + const auto arraySchemaId = SchemaId{5}; + + auto strings = std::make_shared(); + auto model = std::make_shared(strings); + auto registry = SchemaRegistry{}; + + const auto targetId = strings->emplace("target").value(); + const auto payloadId = strings->emplace("payload").value(); + const auto noiseId = strings->emplace("noise").value(); + + std::vector branchNames; + std::vector branchIds; + branchNames.reserve(branchCount); + branchIds.reserve(branchCount); + for (auto branchIndex = std::size_t{0}; branchIndex < branchCount; ++branchIndex) { + branchNames.push_back("branch" + std::to_string(branchIndex)); + branchIds.push_back(strings->emplace(branchNames.back()).value()); + } + + auto targetBranchSchema = std::make_unique(); + targetBranchSchema->addField(payloadId, { targetPayloadSchemaId }); + registry.schemas[targetBranchSchemaId] = std::move(targetBranchSchema); + + auto targetPayloadSchema = std::make_unique(); + targetPayloadSchema->addField(targetId); + registry.schemas[targetPayloadSchemaId] = std::move(targetPayloadSchema); + + auto noiseBranchSchema = std::make_unique(); + noiseBranchSchema->addField(noiseId); + registry.schemas[noiseBranchSchemaId] = std::move(noiseBranchSchema); + + auto rootObjectSchema = std::make_unique(); + rootObjectSchema->addField(branchIds.front(), { targetBranchSchemaId }); + for (auto branchIndex = std::size_t{1}; branchIndex < branchCount; ++branchIndex) + rootObjectSchema->addField(branchIds[branchIndex], { noiseBranchSchemaId }); + registry.schemas[rootObjectSchemaId] = std::move(rootObjectSchema); + + auto arraySchema = std::make_unique(); + arraySchema->addElementSchemas({ rootObjectSchemaId }); + registry.schemas[arraySchemaId] = std::move(arraySchema); + registry.finalize(); + + auto root = model->newArray(objectCount); + for (auto objectIndex = std::size_t{0}; objectIndex < objectCount; ++objectIndex) { + auto obj = model->newObject(branchCount, true); + + auto targetBranch = model->newObject(1, true); + auto targetPayload = model->newObject(1, true); + targetPayload->addField("target", int64_t(1)); + REQUIRE(targetPayload->setSchema(targetPayloadSchemaId)); + targetBranch->addField("payload", targetPayload); + REQUIRE(targetBranch->setSchema(targetBranchSchemaId)); + obj->addField(branchNames.front(), targetBranch); + + for (auto branchIndex = std::size_t{1}; branchIndex < branchCount; ++branchIndex) { + auto noiseBranch = model->newObject(1, true); + noiseBranch->addField("noise", static_cast(objectIndex + branchIndex)); + REQUIRE(noiseBranch->setSchema(noiseBranchSchemaId)); + obj->addField(branchNames[branchIndex], noiseBranch); + } + + REQUIRE(obj->setSchema(rootObjectSchemaId)); + root->append(obj); + } + + REQUIRE(root->setSchema(arraySchemaId)); + model->addRoot(root); + + Environment env(strings); + env.querySchemaCallback = registry.asFunction(); + + auto modelRoot = model->root(0); + REQUIRE(modelRoot); + + auto targetAst = compile(env, "count(**.target == 1)", false, false); + REQUIRE(targetAst); + + auto exactPathAst = compile(env, "count(*." + branchNames.front() + ".payload.target == 1)", false, false); + REQUIRE(exactPathAst); + + registry.enabled = false; + BENCHMARK("Query sparse wide field 'target' recursive without schema") { + auto res = eval(env, **targetAst, **modelRoot, nullptr); + REQUIRE(res); + REQUIRE(res->size() == 1); + + auto count = res->front().template as(); + REQUIRE(count == int64_t(objectCount)); + return count; + }; + + registry.enabled = true; + env.enableWildcardFieldPlans = false; + BENCHMARK("Query sparse wide field 'target' recursive with basic schema pruning") { + auto res = eval(env, **targetAst, **modelRoot, nullptr); + REQUIRE(res); + REQUIRE(res->size() == 1); + + auto count = res->front().template as(); + REQUIRE(count == int64_t(objectCount)); + return count; + }; + + env.enableWildcardFieldPlans = true; + BENCHMARK("Query sparse wide field 'target' recursive with schema field plans") { + auto res = eval(env, **targetAst, **modelRoot, nullptr); + REQUIRE(res); + REQUIRE(res->size() == 1); + + auto count = res->front().template as(); + REQUIRE(count == int64_t(objectCount)); + return count; + }; + + registry.enabled = false; + env.enableWildcardFieldPlans = false; + BENCHMARK("Query sparse wide field 'target' via exact path without schema") { + auto res = eval(env, **exactPathAst, **modelRoot, nullptr); + REQUIRE(res); + REQUIRE(res->size() == 1); + + auto count = res->front().template as(); + REQUIRE(count == int64_t(objectCount)); + return count; + }; +} diff --git a/test/simfil.cpp b/test/simfil.cpp index 7ec6511c..040a5808 100644 --- a/test/simfil.cpp +++ b/test/simfil.cpp @@ -57,9 +57,11 @@ TEST_CASE("Path", "[ast.path]") { TEST_CASE("Wildcard", "[ast.wildcard]") { REQUIRE_AST("*", "*"); REQUIRE_AST("**", "**"); - REQUIRE_AST("**.a", "(. ** a)"); - REQUIRE_AST("a.**.b", "(. (. a **) b)"); - REQUIRE_AST("a.**.b.**.c", "(. (. (. (. a **) b) **) c)"); + REQUIRE_AST("*.a", "*.a"); + REQUIRE_AST("**.a", "**.a"); /* Optimization rewrites this from (. ** a) to **.a */ + REQUIRE_AST("**.a.b.c", "(. (. **.a b) c)"); + REQUIRE_AST("a.**.b", "(. a **.b)"); + REQUIRE_AST("a.**.b.**.c", "(. (. a **.b) **.c)"); REQUIRE_AST("* == *", "(== * *)"); /* Do not optimize away */ REQUIRE_AST("** == **", "(== ** **)"); /* Do not optimize away */ @@ -244,13 +246,13 @@ TEST_CASE("CompareIncompatibleTypes", "[ast.compare-incompatible]") { REQUIRE_AST("range(0,10)!=\"A\"", "true"); } -TEST_CASE("Auto Expand Constant", "[ast.auto-expand-constant]") { +TEST_CASE("Deprecated auto wildcard has no non-schema fallback", "[ast.auto-expand-constant]") { REQUIRE_AST_AUTOWILDCARD("a = 1", "(== a 1)"); REQUIRE_AST_AUTOWILDCARD("a.* = 1", "(== (. a *) 1)"); REQUIRE_AST_AUTOWILDCARD("** = 1", "(== ** 1)"); - REQUIRE_AST_AUTOWILDCARD("1", "(== ** 1)"); - REQUIRE_AST_AUTOWILDCARD("1+4", "(== ** 5)"); - REQUIRE_AST_AUTOWILDCARD("ABC", "(== ** \"ABC\")"); + REQUIRE_AST_AUTOWILDCARD("1", "1"); + REQUIRE_AST_AUTOWILDCARD("1+4", "5"); + REQUIRE_AST_AUTOWILDCARD("ABC", "ABC"); } TEST_CASE("CompareIncompatibleTypesFields", "[ast.compare-incompatible-types-fields]") { @@ -348,16 +350,16 @@ TEST_CASE("Constants", "[ast.constant]") { REQUIRE_AST("a_number", "123"); } -TEST_CASE("Symbols", "[ast.symbol]") { - REQUIRE_AST("ABC", "\"ABC\""); - REQUIRE_AST("ABC == ABC", "true"); +TEST_CASE("Unquoted words are fields without schema metadata", "[ast.symbol]") { + REQUIRE_AST("ABC", "ABC"); + REQUIRE_AST("ABC == ABC", "(== ABC ABC)"); REQUIRE_AST("a.ABC", "(. a ABC)"); REQUIRE_AST("a.ABC.DEF", "(. (. a ABC) DEF)"); - REQUIRE_AST("a.(ABC)", "(. a \"ABC\")"); + REQUIRE_AST("a.(ABC)", "(. a ABC)"); REQUIRE_AST("a.(_.ABC)", "(. a (. _ ABC))"); - REQUIRE_AST("a[ABC]", "(index a \"ABC\")"); + REQUIRE_AST("a[ABC]", "(index a ABC)"); REQUIRE_AST("a[_.ABC]", "(index a (. _ ABC))"); - REQUIRE_AST("a{ABC}", "(sub a \"ABC\")"); + REQUIRE_AST("a{ABC}", "(sub a ABC)"); REQUIRE_AST("a{_.ABC}", "(sub a (. _ ABC))"); } @@ -425,6 +427,7 @@ TEST_CASE("Path Wildcard", "[yaml.path-wildcard]") { REQUIRE_RESULT("sub.*", R"(sub a|sub b|{"a":"sub sub a","b":"sub sub b"})"); REQUIRE_RESULT("sub.**", R"({"a":"sub a","b":"sub b","sub":{"a":"sub sub a","b":"sub sub b"}}|sub a|sub b|)" R"({"a":"sub sub a","b":"sub sub b"}|sub sub a|sub sub b)"); + REQUIRE_RESULT("**.a", "1|sub a|sub sub a"); REQUIRE_RESULT("(sub.*.{typeof _ != 'model'} + sub.*.{typeof _ != 'model'})._", "sub asub a|sub asub b|sub bsub a|sub bsub b"); /* . filters null */ REQUIRE_RESULT("sub.*.{typeof _ != 'model'} + sub.*.{typeof _ != 'model'}", "sub asub a|sub asub b|sub bsub a|sub bsub b"); /* {_} filters null */ REQUIRE_RESULT("count(*)", "12"); @@ -714,6 +717,24 @@ TEST_CASE("Switch Model String Pool", "[model.setStrings]") REQUIRE(oldFieldDict->size() != newFieldDict->size()); } +TEST_CASE("StringPool copy owns lookup views", "[string-pool]") +{ + auto source = std::make_shared(); + auto id = source->emplace("owned-dynamic-field"); + REQUIRE(id); + + auto sourceView = source->resolve(*id); + REQUIRE(sourceView); + + auto copy = std::make_shared(*source); + auto copyView = copy->resolve(*id); + REQUIRE(copyView); + + REQUIRE(*copyView == *sourceView); + REQUIRE(copyView->data() != sourceView->data()); + REQUIRE(copy->get("owned-dynamic-field") == *id); +} + TEST_CASE("Exception Handler", "[exception]") { bool handlerCalled = false; @@ -754,6 +775,13 @@ TEST_CASE("Visit AST", "[visit.ast]") visitedFieldName = expr.name_; } + + auto visit(const WildcardFieldExpr& expr) -> void override + { + ExprVisitor::visit(expr); + + visitedFieldName = expr.name_; + } }; Visitor visitor; @@ -761,3 +789,63 @@ TEST_CASE("Visit AST", "[visit.ast]") REQUIRE(visitor.visitedFieldName == "field"); } + +TEST_CASE("Visitors traverse unary children once", "[visit.ast]") +{ + UnaryExpr expr(std::make_unique("field")); + + struct Visitor : ExprVisitor + { + int fieldVisits = 0; + + using ExprVisitor::visit; + + auto visit(const FieldExpr& expr) -> void override + { + ExprVisitor::visit(expr); + ++fieldVisits; + } + }; + + Visitor visitor; + expr.accept(visitor); + + REQUIRE(visitor.fieldVisits == 1); +} + +TEST_CASE("Parsed token locations are preserved", "[ast.source-location]") +{ + Environment env(Environment::WithNewStringCache); + + auto fieldAst = compile(env, "field", false, false); + REQUIRE(fieldAst); + + const auto* fieldExpr = dynamic_cast(&(*fieldAst)->expr()); + REQUIRE(fieldExpr); + REQUIRE(fieldExpr->sourceLocation().offset == 0); + REQUIRE(fieldExpr->sourceLocation().size == 5); + + auto binaryAst = compile(env, "field + 1", false, false); + REQUIRE(binaryAst); + + const auto* binaryExpr = dynamic_cast*>(&(*binaryAst)->expr()); + REQUIRE(binaryExpr); + REQUIRE(binaryExpr->sourceLocation().offset == 6); + REQUIRE(binaryExpr->sourceLocation().size == 1); +} + +TEST_CASE("AST expr ids are reenumerated after rewrites", "[ast.expr-id]") +{ + auto ast = Compile("**.field = 123", false); + + std::vector ids; + const auto collectIds = [&](const auto& self, const Expr& expr) -> void { + ids.emplace_back(expr.id()); + for (auto i = 0u; i < expr.numChildren(); ++i) + self(self, *expr.childAt(i)); + }; + + collectIds(collectIds, ast->expr()); + + REQUIRE(ids == std::vector{0, 1, 2}); +} diff --git a/test/value.cpp b/test/value.cpp index c01640a3..35cc59f5 100644 --- a/test/value.cpp +++ b/test/value.cpp @@ -1,9 +1,12 @@ #include #include #include +#include +#include #include "simfil/value.h" #include "simfil/model/model.h" +#include "simfil/model/schema.h" #include "simfil/token.h" #include "simfil/transient.h"