diff --git a/CHANGELOG.md b/CHANGELOG.md index 1a4ce80a9..f0a86d9c8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,13 @@ +## dbt-databricks 1.12.0 (TBD) + +### Features + +- Add support for metric views as a materialization ([#1285](https://github.com/databricks/dbt-databricks/pull/1285)) +- Add support for row filters ([#1294](https://github.com/databricks/dbt-databricks/pull/1294)) +- Add support for Python UDFs ([#1336](https://github.com/databricks/dbt-databricks/pull/1336)) +- Add support for key-only `databricks_tags` for table and column tagging. This can now be configured by setting tag values to empty strings `""` or `None`. ([#1339](https://github.com/databricks/dbt-databricks/pull/1339)) +- Replace information_schema queries with DESCRIBE TABLE EXTENDED AS JSON for metadata fetching in incremental, materialized view, and view materializations (DBR 17.3+, falls back to info_schema on older runtimes) + ## dbt-databricks 1.11.8 (TBD) ### Features @@ -60,6 +70,7 @@ ### Features - Add `query_id` to `SQLQueryStatus` events to improve query tracing and debugging +- Add support for Row Filters ([#1294](https://github.com/databricks/dbt-databricks/pull/1294)) ### Fixes diff --git a/dbt/adapters/databricks/__version__.py b/dbt/adapters/databricks/__version__.py index 92ef49c34..2d7e71e1c 100644 --- a/dbt/adapters/databricks/__version__.py +++ b/dbt/adapters/databricks/__version__.py @@ -1 +1 @@ -version = "1.11.7" +version = "1.12.0a1" diff --git a/dbt/adapters/databricks/dbr_capabilities.py b/dbt/adapters/databricks/dbr_capabilities.py index 9b7c918d7..38bd133ef 100644 --- a/dbt/adapters/databricks/dbr_capabilities.py +++ b/dbt/adapters/databricks/dbr_capabilities.py @@ -13,13 +13,14 @@ class DBRCapability(Enum): """Named capabilities that depend on DBR version.""" - TIMESTAMPDIFF = "timestampdiff" - ICEBERG = "iceberg" COMMENT_ON_COLUMN = "comment_on_column" - JSON_COLUMN_METADATA = "json_column_metadata" - STREAMING_TABLE_JSON_METADATA = "streaming_table_json_metadata" + DESCRIBE_TABLE_EXTENDED_AS_JSON = "describe_table_extended_as_json" + ICEBERG = "iceberg" INSERT_BY_NAME = "insert_by_name" + JSON_COLUMN_METADATA = "json_column_metadata" REPLACE_ON = "replace_on" + STREAMING_TABLE_JSON_METADATA = "streaming_table_json_metadata" + TIMESTAMPDIFF = "timestampdiff" @dataclass @@ -61,6 +62,9 @@ class DBRCapabilities: DBRCapability.REPLACE_ON: CapabilitySpec( min_version=(17, 1), ), + DBRCapability.DESCRIBE_TABLE_EXTENDED_AS_JSON: CapabilitySpec( + min_version=(17, 3), + ), } def __init__( diff --git a/dbt/adapters/databricks/impl.py b/dbt/adapters/databricks/impl.py index 4d6964495..e535d36b4 100644 --- a/dbt/adapters/databricks/impl.py +++ b/dbt/adapters/databricks/impl.py @@ -1,3 +1,4 @@ +import json import posixpath import re from abc import ABC, abstractmethod @@ -11,6 +12,7 @@ from typing import TYPE_CHECKING, Any, ClassVar, Generic, NamedTuple, Optional, Union, cast from uuid import uuid4 +import agate from dbt.adapters.base import AdapterConfig, PythonJobHelper from dbt.adapters.base.impl import catch_as_completed, log_code_execution from dbt.adapters.base.meta import available @@ -76,6 +78,7 @@ from dbt.adapters.databricks.relation_configs.materialized_view import ( MaterializedViewConfig, ) +from dbt.adapters.databricks.relation_configs.metric_view import MetricViewConfig from dbt.adapters.databricks.relation_configs.streaming_table import ( StreamingTableConfig, ) @@ -399,6 +402,23 @@ def require_capability(self, capability: DBRCapability) -> None: f"Current connection does not meet this requirement." ) + def is_describe_as_json_supported(self, relation: DatabricksRelation) -> bool: + """ + Check if DESCRIBE TABLE EXTENDED AS JSON can be used for the relation. + """ + return ( + not relation.is_hive_metastore() + and not relation.is_foreign_table + and self.has_capability(DBRCapability.DESCRIBE_TABLE_EXTENDED_AS_JSON) + ) + + def fetch_json_metadata(self, relation: DatabricksRelation) -> dict[str, Any]: + """Fetch the JSON metadata for a relation using DESCRIBE TABLE EXTENDED AS JSON.""" + kwargs = {"relation": relation} + describe_results = self.execute_macro("describe_table_extended_as_json", kwargs=kwargs) + json_metadata = json.loads(describe_results.rows[0].get("json_metadata")) + return json_metadata + def list_schemas(self, database: Optional[str]) -> list[str]: results = self.execute_macro(LIST_SCHEMAS_MACRO_NAME, kwargs={"database": database}) return [row[0] for row in results] @@ -937,6 +957,8 @@ def get_relation_config(self, relation: DatabricksRelation) -> DatabricksRelatio return IncrementalTableAPI.get_from_relation(self, relation) elif relation.type == DatabricksRelationType.View: return ViewAPI.get_from_relation(self, relation) + elif relation.type == DatabricksRelationType.MetricView: + return MetricViewAPI.get_from_relation(self, relation) else: raise NotImplementedError(f"Relation type {relation.type} is not supported.") @@ -952,6 +974,8 @@ def get_config_from_model(self, model: RelationConfig) -> DatabricksRelationConf return IncrementalTableAPI.get_from_relation_config(model) elif model.config.materialized == "view": return ViewAPI.get_from_relation_config(model) + elif model.config.materialized == "metric_view": + return MetricViewAPI.get_from_relation_config(model) else: raise NotImplementedError( f"Materialization {model.config.materialized} is not supported." @@ -1077,10 +1101,17 @@ def _describe_relation( ) kwargs = {"relation": relation} - results["information_schema.views"] = get_first_row( - adapter.execute_macro("get_view_description", kwargs=kwargs) - ) + if adapter.is_describe_as_json_supported(relation): + json_metadata = adapter.fetch_json_metadata(relation) + results["information_schema.views"] = ( + DatabricksDescribeJsonMetadata.parse_view_description(json_metadata) + ) + else: + results["information_schema.views"] = get_first_row( + adapter.execute_macro("get_view_description", kwargs=kwargs) + ) results["show_tblproperties"] = adapter.execute_macro("fetch_tbl_properties", kwargs=kwargs) + results["row_filters"] = adapter.execute_macro("fetch_row_filters", kwargs=kwargs) return results @@ -1104,6 +1135,7 @@ def _describe_relation( kwargs = {"relation": relation} results["show_tblproperties"] = adapter.execute_macro("fetch_tbl_properties", kwargs=kwargs) + results["row_filters"] = adapter.execute_macro("fetch_row_filters", kwargs=kwargs) return results @@ -1126,16 +1158,27 @@ def _describe_relation( results["information_schema.column_tags"] = adapter.execute_macro( "fetch_column_tags", kwargs=kwargs ) - results["non_null_constraint_columns"] = adapter.execute_macro( - "fetch_non_null_constraint_columns", kwargs=kwargs - ) - results["primary_key_constraints"] = adapter.execute_macro( - "fetch_primary_key_constraints", kwargs=kwargs - ) - results["foreign_key_constraints"] = adapter.execute_macro( - "fetch_foreign_key_constraints", kwargs=kwargs - ) - results["column_masks"] = adapter.execute_macro("fetch_column_masks", kwargs=kwargs) + results["row_filters"] = adapter.execute_macro("fetch_row_filters", kwargs=kwargs) + + if adapter.is_describe_as_json_supported(relation): + json_metadata = adapter.fetch_json_metadata(relation) + relation_metadata = DatabricksDescribeJsonMetadata.from_json_metadata(json_metadata) + results["non_null_constraint_columns"] = relation_metadata.non_null_constraints + results["primary_key_constraints"] = relation_metadata.primary_key_constraints + results["foreign_key_constraints"] = relation_metadata.foreign_key_constraints + results["column_masks"] = relation_metadata.column_masks + else: + results["non_null_constraint_columns"] = adapter.execute_macro( + "fetch_non_null_constraint_columns", kwargs=kwargs + ) + results["primary_key_constraints"] = adapter.execute_macro( + "fetch_primary_key_constraints", kwargs=kwargs + ) + results["foreign_key_constraints"] = adapter.execute_macro( + "fetch_foreign_key_constraints", kwargs=kwargs + ) + results["column_masks"] = adapter.execute_macro("fetch_column_masks", kwargs=kwargs) + results["show_tblproperties"] = adapter.execute_macro("fetch_tbl_properties", kwargs=kwargs) kwargs = {"table_name": relation} @@ -1159,9 +1202,16 @@ def _describe_relation( results = {} kwargs = {"relation": relation} - results["information_schema.views"] = get_first_row( - adapter.execute_macro("get_view_description", kwargs=kwargs) - ) + if adapter.is_describe_as_json_supported(relation): + json_metadata = adapter.fetch_json_metadata(relation) + results["information_schema.views"] = ( + DatabricksDescribeJsonMetadata.parse_view_description(json_metadata) + ) + else: + results["information_schema.views"] = get_first_row( + adapter.execute_macro("get_view_description", kwargs=kwargs) + ) + results["information_schema.tags"] = adapter.execute_macro("fetch_tags", kwargs=kwargs) results["show_tblproperties"] = adapter.execute_macro("fetch_tbl_properties", kwargs=kwargs) @@ -1170,3 +1220,137 @@ def _describe_relation( DESCRIBE_TABLE_EXTENDED_MACRO_NAME, kwargs=kwargs ) return results + + +class MetricViewAPI(RelationAPIBase[MetricViewConfig]): + relation_type = DatabricksRelationType.MetricView + + @classmethod + def config_type(cls) -> type[MetricViewConfig]: + return MetricViewConfig + + @classmethod + def _describe_relation( + cls, adapter: DatabricksAdapter, relation: DatabricksRelation + ) -> RelationResults: + results = {} + kwargs = {"relation": relation} + results["information_schema.tags"] = adapter.execute_macro("fetch_tags", kwargs=kwargs) + results["show_tblproperties"] = adapter.execute_macro("fetch_tbl_properties", kwargs=kwargs) + kwargs = {"table_name": relation} + results["describe_extended"] = adapter.execute_macro( + DESCRIBE_TABLE_EXTENDED_MACRO_NAME, kwargs=kwargs + ) + return results + +@dataclass +class DatabricksDescribeJsonMetadata: + column_masks: Optional["agate.Table"] = None + foreign_key_constraints: Optional["agate.Table"] = None + non_null_constraints: Optional["agate.Table"] = None + primary_key_constraints: Optional["agate.Table"] = None + view_description: Optional["agate.Row"] = None + + @classmethod + def from_json_metadata(cls, json_metadata: dict[str, Any]) -> "DatabricksDescribeJsonMetadata": + """Parse and convert the json metadata into structured metadata for the adapter to use.""" + return DatabricksDescribeJsonMetadata( + column_masks=cls.parse_column_masks(json_metadata), + foreign_key_constraints=cls.parse_foreign_key_constraints(json_metadata), + non_null_constraints=cls.parse_non_null_constraints(json_metadata), + primary_key_constraints=cls.parse_primary_key_constraints(json_metadata), + view_description=cls.parse_view_description(json_metadata), + ) + + @classmethod + def parse_column_masks(cls, json_metadata: dict[str, Any]) -> agate.Table: + """Parse json metadata into an agate Table of column masks (info_schema format).""" + raw_masks = json_metadata.get("column_masks", []) + rows = [] + for mask in raw_masks: + column_name = mask["column_name"] + fn = mask["function_name"] + mask_name = f"{fn['catalog_name']}.{fn['schema_name']}.{fn['function_name']}" + using_columns = ",".join(mask.get("using_column_names", [])) + rows.append((column_name, mask_name, using_columns or None)) + + return agate.Table( + rows=rows, + column_names=["column_name", "mask_name", "using_columns"], + column_types=[agate.Text(), agate.Text(), agate.Text()], + ) + + @classmethod + def parse_foreign_key_constraints(cls, json_metadata: dict[str, Any]) -> agate.Table: + """Parse json metadata into an agate Table of FK constraints (info_schema format).""" + table_constraint = re.sub(r"\s+", " ", json_metadata.get("table_constraints", "").strip()) + pairs = re.findall(r"\(\s*(\w+)\s*,(.*?)\)(?=\s*,\s*\(|\s*\])", table_constraint) + fk_rows = [] + for name, constraint in pairs: + constraint = constraint.strip() + if re.search(r"FOREIGN\s+KEY", constraint): + fk_part, ref_part = constraint.split("REFERENCES", 1) + from_cols = re.findall(r"`([^`]+)`", fk_part) + ref_parts = re.findall(r"`([^`]+)`", ref_part) + to_catalog = ref_parts[0] + to_schema = ref_parts[1] + to_table = ref_parts[2] + to_cols = ref_parts[3:] + for from_col, to_col in zip(from_cols, to_cols): + fk_rows.append([name, from_col, to_catalog, to_schema, to_table, to_col]) + + fk_column_names = [ + "constraint_name", + "from_column", + "to_catalog", + "to_schema", + "to_table", + "to_column", + ] + fk_columns_types = [ + agate.Text(), + agate.Text(), + agate.Text(), + agate.Text(), + agate.Text(), + agate.Text(), + ] + return agate.Table(fk_rows, fk_column_names, fk_columns_types) + + @classmethod + def parse_non_null_constraints(cls, json_metadata: dict[str, Any]) -> agate.Table: + """Parse json metadata into an agate Table of non-null constraints (info_schema format).""" + columns = json_metadata.get("columns", []) + + non_null_cols = [column["name"] for column in columns if not column.get("nullable")] + return agate.Table( + rows=[[col] for col in non_null_cols], + column_names=["column_name"], + column_types=[agate.Text()], + ) + + @classmethod + def parse_primary_key_constraints(cls, json_metadata: dict[str, Any]) -> agate.Table: + """Parse json metadata into an agate Table of PK constraints (info_schema format).""" + table_constraint = re.sub(r"\s+", " ", json_metadata.get("table_constraints", "").strip()) + pairs = re.findall(r"\(\s*(\w+)\s*,(.*?)\)(?=\s*,\s*\(|\s*\])", table_constraint) + pk_rows = [] + for name, constraint in pairs: + constraint = constraint.strip() + parts = re.findall(r"`([^`]+)`", constraint) + if re.search(r"PRIMARY\s+KEY", constraint): + for col in parts: + pk_rows.append([name, col]) + + pk_column_names = ["constraint_name", "column_name"] + pk_columns_types = [agate.Text(), agate.Text()] + return agate.Table(pk_rows, pk_column_names, pk_columns_types) + + @classmethod + def parse_view_description(cls, json_metadata: dict[str, Any]) -> "agate.Row": + """Parse json metadata into an agate Row for the view description (info_schema format).""" + view_text = json_metadata.get("view_text", None) + if view_text is None: + return agate.Row(values=set()) + else: + return agate.Row(values=(view_text,), keys=("view_definition",)) diff --git a/dbt/adapters/databricks/relation.py b/dbt/adapters/databricks/relation.py index 1b18e3ece..efb5e307c 100644 --- a/dbt/adapters/databricks/relation.py +++ b/dbt/adapters/databricks/relation.py @@ -2,7 +2,7 @@ from dataclasses import dataclass, field from typing import Any, Optional, Type # noqa -from dbt.adapters.base.relation import BaseRelation, InformationSchema, Policy +from dbt.adapters.base.relation import BaseRelation, FunctionConfig, InformationSchema, Policy from dbt.adapters.contracts.relation import ( ComponentName, ) @@ -14,6 +14,7 @@ from dbt_common.utils import filter_null_values from dbt.adapters.databricks.constraints import TypedConstraint, process_constraint +from dbt.adapters.databricks.logging import logger from dbt.adapters.databricks.utils import remove_undefined KEY_TABLE_PROVIDER = "Provider" @@ -49,6 +50,15 @@ def render(self) -> str: """Return the type formatted for SQL statements (replace underscores with spaces)""" return self.value.replace("_", " ").upper() + def render_for_alter(self) -> str: + """Return the type formatted for ALTER statements. + + Metric views use ALTER VIEW (not ALTER METRIC VIEW) syntax. + """ + if self == DatabricksRelationType.MetricView: + return "VIEW" + return self.render() + class DatabricksTableType(StrEnum): External = "external" @@ -132,10 +142,18 @@ def is_hive_metastore(self) -> bool: def is_materialized_view(self) -> bool: return self.type == DatabricksRelationType.MaterializedView + @property + def is_metric_view(self) -> bool: + return self.type == DatabricksRelationType.MetricView + @property def is_streaming_table(self) -> bool: return self.type == DatabricksRelationType.StreamingTable + @property + def is_foreign_table(self) -> bool: + return self.type == DatabricksRelationType.Foreign + @property def is_external_table(self) -> bool: return self.databricks_table_type == DatabricksTableType.External @@ -245,6 +263,35 @@ def render_constraints_for_create(self) -> str: def render(self) -> str: return super().render().lower() + def get_function_config(self, model: dict[str, Any]) -> Optional[FunctionConfig]: + if model.get("resource_type") == "function" and model.get("language") == "python": + config = model.get("config", {}) + runtime_version = config.get("runtime_version") + entry_point = config.get("entry_point") + + # Databricks does not use runtime_version or entry_point in SQL. + # Provide defaults to satisfy dbt-adapters validation. + if not runtime_version: + runtime_version = "3.11" + logger.debug( + "runtime_version not specified for Python UDF; " + "defaulting to '3.11' (not used in Databricks SQL)" + ) + if not entry_point: + entry_point = model.get("name", "main") + logger.debug( + f"entry_point not specified for Python UDF; " + f"defaulting to '{entry_point}' (not used in Databricks SQL)" + ) + + return FunctionConfig( + language=model.get("language", ""), + type=config.get("type", ""), + runtime_version=runtime_version, + entry_point=entry_point, + ) + return super().get_function_config(model) + def is_hive_metastore(database: Optional[str], temporary: Optional[bool] = False) -> bool: return (database is None or database.lower() == "hive_metastore") and not temporary diff --git a/dbt/adapters/databricks/relation_configs/column_tags.py b/dbt/adapters/databricks/relation_configs/column_tags.py index 93d5ea890..26c8f6aa5 100644 --- a/dbt/adapters/databricks/relation_configs/column_tags.py +++ b/dbt/adapters/databricks/relation_configs/column_tags.py @@ -55,7 +55,7 @@ def from_relation_results(cls, results: RelationResults) -> ColumnTagsConfig: # row contains [column_name, tag_name, tag_value] column_name = str(row[0]) tag_name = str(row[1]) - tag_value = str(row[2]) + tag_value = "" if row[2] is None else str(row[2]) if column_name not in set_column_tags: set_column_tags[column_name] = {} @@ -79,7 +79,7 @@ def from_relation_config(cls, relation_config: RelationConfig) -> ColumnTagsConf if databricks_tags: if isinstance(databricks_tags, dict): set_column_tags[col["name"]] = { - str(k): str(v) for k, v in databricks_tags.items() + str(k): "" if v is None else str(v) for k, v in databricks_tags.items() } else: raise DbtRuntimeError("databricks_tags must be a dictionary") diff --git a/dbt/adapters/databricks/relation_configs/incremental.py b/dbt/adapters/databricks/relation_configs/incremental.py index af8d213d9..0044d8d84 100644 --- a/dbt/adapters/databricks/relation_configs/incremental.py +++ b/dbt/adapters/databricks/relation_configs/incremental.py @@ -7,6 +7,7 @@ from dbt.adapters.databricks.relation_configs.comment import CommentProcessor from dbt.adapters.databricks.relation_configs.constraints import ConstraintsProcessor from dbt.adapters.databricks.relation_configs.liquid_clustering import LiquidClusteringProcessor +from dbt.adapters.databricks.relation_configs.row_filter import RowFilterProcessor from dbt.adapters.databricks.relation_configs.tags import TagsProcessor from dbt.adapters.databricks.relation_configs.tblproperties import TblPropertiesProcessor @@ -18,6 +19,7 @@ class IncrementalTableConfig(DatabricksRelationConfigBase): ColumnMaskProcessor, ColumnTagsProcessor, ConstraintsProcessor, + RowFilterProcessor, TagsProcessor, TblPropertiesProcessor, LiquidClusteringProcessor, diff --git a/dbt/adapters/databricks/relation_configs/materialized_view.py b/dbt/adapters/databricks/relation_configs/materialized_view.py index 10aaf6ad8..503234a64 100644 --- a/dbt/adapters/databricks/relation_configs/materialized_view.py +++ b/dbt/adapters/databricks/relation_configs/materialized_view.py @@ -20,6 +20,7 @@ from dbt.adapters.databricks.relation_configs.refresh import ( RefreshProcessor, ) +from dbt.adapters.databricks.relation_configs.row_filter import RowFilterProcessor from dbt.adapters.databricks.relation_configs.tags import TagsProcessor from dbt.adapters.databricks.relation_configs.tblproperties import ( TblPropertiesProcessor, @@ -35,6 +36,7 @@ class MaterializedViewConfig(DatabricksRelationConfigBase): RefreshProcessor, QueryProcessor, TagsProcessor, + RowFilterProcessor, ] def get_changeset( @@ -42,7 +44,7 @@ def get_changeset( ) -> Optional[DatabricksRelationChangeSet]: changes: dict[str, DatabricksComponentConfig] = {} requires_refresh = False - updateable_component_keys = ["refresh", "tags"] + updateable_component_keys = ["refresh", "tags", "row_filter"] for component in self.config_components: key = component.name diff --git a/dbt/adapters/databricks/relation_configs/metric_view.py b/dbt/adapters/databricks/relation_configs/metric_view.py new file mode 100644 index 000000000..a6bf81690 --- /dev/null +++ b/dbt/adapters/databricks/relation_configs/metric_view.py @@ -0,0 +1,101 @@ +from typing import ClassVar, Optional + +from dbt.adapters.contracts.relation import RelationConfig +from dbt.adapters.relation_configs.config_base import RelationResults +from dbt_common.exceptions import DbtRuntimeError + +from dbt.adapters.databricks.relation_configs.base import ( + DatabricksComponentConfig, + DatabricksComponentProcessor, + DatabricksRelationConfigBase, +) +from dbt.adapters.databricks.relation_configs.tags import TagsProcessor +from dbt.adapters.databricks.relation_configs.tblproperties import TblPropertiesProcessor + + +class MetricViewQueryConfig(DatabricksComponentConfig): + """Component encapsulating the YAML definition of a metric view.""" + + query: str + + def get_diff(self, other: "MetricViewQueryConfig") -> Optional["MetricViewQueryConfig"]: + # Normalize whitespace for comparison + self_normalized = " ".join(self.query.split()) + other_normalized = " ".join(other.query.split()) + if self_normalized != other_normalized: + return self + return None + + +class MetricViewQueryProcessor(DatabricksComponentProcessor[MetricViewQueryConfig]): + """Processor for metric view YAML definitions. + + Metric views store their YAML definitions in information_schema.views, but wrapped + in $$ delimiters. This processor extracts and compares the YAML content. + """ + + name: ClassVar[str] = "query" + + @classmethod + def from_relation_results(cls, result: RelationResults) -> MetricViewQueryConfig: + from dbt.adapters.databricks.logging import logger + + # Get the view text from DESCRIBE EXTENDED output + describe_extended = result.get("describe_extended") + if not describe_extended: + raise DbtRuntimeError( + f"Cannot find metric view description. Result keys: {list(result.keys())}" + ) + + # Find the "View Text" row in DESCRIBE EXTENDED output + view_definition = None + for row in describe_extended: + if row[0] == "View Text": + view_definition = row[1] + break + + logger.debug( + f"MetricViewQueryProcessor: view_definition = " + f"{view_definition[:200] if view_definition else 'None'}" + ) + + if not view_definition: + raise DbtRuntimeError("Metric view has no 'View Text' in DESCRIBE EXTENDED output") + + view_definition = view_definition.strip() + + # Extract YAML content from $$ delimiters if present + # Format: $$ yaml_content $$ + # Check start/end explicitly to avoid issues with embedded $$ in YAML content + trimmed = view_definition.strip() + if trimmed.startswith("$$") and trimmed.endswith("$$"): + # Strip the leading and trailing $$ markers + view_definition = trimmed[2:-2].strip() + + return MetricViewQueryConfig(query=view_definition) + + @classmethod + def from_relation_config(cls, relation_config: RelationConfig) -> MetricViewQueryConfig: + query = relation_config.compiled_code + + if query: + return MetricViewQueryConfig(query=query.strip()) + else: + raise DbtRuntimeError( + f"Cannot compile metric view {relation_config.identifier} with no YAML definition" + ) + + +class MetricViewConfig(DatabricksRelationConfigBase): + """Config for metric views. + + Metric views use YAML definitions stored in information_schema.views wrapped in $$ delimiters. + Changes to the YAML definition can be applied via ALTER VIEW AS. + Tags and tblproperties can also be altered incrementally. + """ + + config_components = [ + TagsProcessor, + TblPropertiesProcessor, + MetricViewQueryProcessor, + ] diff --git a/dbt/adapters/databricks/relation_configs/row_filter.py b/dbt/adapters/databricks/relation_configs/row_filter.py new file mode 100644 index 000000000..b344b5edf --- /dev/null +++ b/dbt/adapters/databricks/relation_configs/row_filter.py @@ -0,0 +1,206 @@ +import csv +from io import StringIO +from typing import ClassVar, Optional + +from dbt.adapters.contracts.relation import RelationConfig +from dbt.adapters.relation_configs.config_base import RelationResults + +from dbt.adapters.databricks.relation_configs.base import ( + DatabricksComponentConfig, + DatabricksComponentProcessor, + get_config_value, +) + + +class RowFilterConfig(DatabricksComponentConfig): + """Row filter definition (function + columns). + + This class represents both the desired/existing state of a row filter AND + the diff result. When used as a diff result: + - should_unset=True means "remove the existing filter" + - should_unset=False with function set means "apply this filter" + - is_change=True indicates this is a diff that should trigger ALTER + + The is_change field is critical for streaming tables which use `diff or value` + pattern (streaming_table.py:56). Without it, unchanged row filters would still + trigger ALTER statements because the fallback value has a truthy function field. + """ + + # Fully qualified function name (catalog.schema.function) + function: Optional[str] = None + # Column names passed to the filter function + columns: tuple[str, ...] = () + # True when this instance represents a diff meaning "unset/drop the filter" + should_unset: bool = False + # True when this represents an actual change that should trigger ALTER + # (distinguishes diff result from `diff or value` fallback state) + is_change: bool = False + + @staticmethod + def _normalize_function(func: Optional[str]) -> Optional[str]: + """Normalize function name for comparison (lowercase, strip backticks).""" + if func is None: + return None + return func.lower().replace("`", "") + + @staticmethod + def _normalize_columns(cols: tuple[str, ...]) -> tuple[str, ...]: + """Normalize column names for comparison.""" + return tuple(c.lower().replace("`", "") for c in cols) + + def get_diff(self, other: "RowFilterConfig") -> Optional["RowFilterConfig"]: + """Compare desired state (self) with existing state (other). + + Returns: + - None if no changes needed + - RowFilterConfig with should_unset=True, is_change=True if filter should be removed + - RowFilterConfig with the filter config and is_change=True if filter + should be set/updated + """ + # Case 1: No filter desired, no filter exists -> no change + if self.function is None and other.function is None: + return None + + # Case 2: No filter desired, filter exists -> unset it + if self.function is None and other.function is not None: + return RowFilterConfig(should_unset=True, is_change=True) + + # Case 3: Filter desired, compare with existing + if self._normalize_function(self.function) == self._normalize_function(other.function): + if self._normalize_columns(self.columns) == self._normalize_columns(other.columns): + return None # No change + + # Filter is new or changed -> return new instance with is_change=True + # (can't return self because model is frozen and we need is_change=True) + return RowFilterConfig( + function=self.function, + columns=self.columns, + is_change=True, + ) + + +class RowFilterProcessor(DatabricksComponentProcessor[RowFilterConfig]): + """Processor for extracting row filter config from relations and model nodes.""" + + name: ClassVar[str] = "row_filter" + + @classmethod + def from_relation_results(cls, results: RelationResults) -> RowFilterConfig: + """Extract existing row filter from INFORMATION_SCHEMA results.""" + table = results.get("row_filters") + + if not table or len(table.rows) == 0: + return RowFilterConfig() + + # Handle multiple rows case (ABAC, platform bugs, etc.) + if len(table.rows) > 1: + filter_names = [row[3] for row in table.rows] # filter_name is index 3 + raise ValueError( + f"Multiple row filters found: {filter_names}. " + f"This may indicate ABAC-derived filters or a platform issue. " + f"dbt expects a single row filter per table." + ) + + # Unity Catalog returns one row per table (single filter constraint) + row = table.rows[0] + # Columns: table_catalog(0), table_schema(1), table_name(2), + # filter_name(3), target_columns(4) + filter_name = row[3] # Already fully qualified: catalog.schema.function + target_columns = row[4] # Comma-separated column list + + # filter_name is already fully qualified from INFORMATION_SCHEMA (catalog.schema.func) + # Store raw - backticks are added at SQL generation time + function = filter_name + + # Parse target_columns (handle quoted values with commas) + columns = cls._parse_target_columns(target_columns) + + return RowFilterConfig(function=function, columns=tuple(columns)) + + @classmethod + def from_relation_config(cls, relation_config: RelationConfig) -> RowFilterConfig: + """Extract row filter config from dbt model node.""" + row_filter = get_config_value(relation_config, "row_filter") + + if not row_filter: + return RowFilterConfig() + + function = row_filter.get("function") + columns = row_filter.get("columns", []) + + if not function: + return RowFilterConfig() + + # Normalize string to list + if isinstance(columns, str): + columns = [columns] + + # Validate columns is non-empty when function is set + if not columns or len(columns) == 0: + raise ValueError( + f"Row filter function '{function}' requires a non-empty 'columns' value. " + f'Example: columns: region OR columns: ["region_id", "country_code"]' + ) + + # Validate each column element is a non-empty string + for i, col in enumerate(columns): + if not isinstance(col, str) or not col.strip(): + raise ValueError( + f"Row filter column at index {i} must be a non-empty string. Got: {repr(col)}" + ) + + # Qualify function name if not already qualified + function = cls._qualify_function_name(function, relation_config) + + return RowFilterConfig(function=function, columns=tuple(columns)) + + @classmethod + def _qualify_function_name(cls, function: str, relation_config: RelationConfig) -> str: + """Ensure function name is fully qualified with catalog.schema. + + Handle 1-part, 2-part, 3-part names explicitly. + + IMPORTANT: This logic must stay in sync with the Jinja + `qualify_row_filter_function()` macro. Both use the same rules: + - 1-part: qualify with relation's database.schema + - 2-part: reject as ambiguous + - 3-part: use as-is + - 4+ parts: reject + """ + parts = function.replace("`", "").split(".") + + if len(parts) == 1: + # Unqualified: fn -> catalog.schema.fn + catalog = relation_config.database + schema = relation_config.schema + return f"{catalog}.{schema}.{parts[0]}" + elif len(parts) == 2: + # Ambiguous: schema.fn - reject with clear error + raise ValueError( + f"Row filter function '{function}' is ambiguous. " + f"Use either unqualified name (e.g., 'my_filter') or " + f"fully qualified name (e.g., 'catalog.schema.my_filter')." + ) + elif len(parts) == 3: + return f"{parts[0]}.{parts[1]}.{parts[2]}" + else: + raise ValueError( + f"Row filter function '{function}' has too many parts. " + f"Expected format: 'catalog.schema.function_name'." + ) + + @classmethod + def _parse_target_columns(cls, target_columns: Optional[str]) -> list[str]: + """Parse target_columns string from INFORMATION_SCHEMA. + + Handles quoted values with commas. + """ + if not target_columns: + return [] + + # Use CSV parser to handle quoted strings with embedded commas + # skipinitialspace=True handles space after comma: '"col1", "col2"' + reader = csv.reader(StringIO(target_columns), skipinitialspace=True) + for row in reader: + return [col.strip() for col in row] + return [] diff --git a/dbt/adapters/databricks/relation_configs/streaming_table.py b/dbt/adapters/databricks/relation_configs/streaming_table.py index b020272cd..5333d5b2d 100644 --- a/dbt/adapters/databricks/relation_configs/streaming_table.py +++ b/dbt/adapters/databricks/relation_configs/streaming_table.py @@ -16,6 +16,7 @@ ) from dbt.adapters.databricks.relation_configs.query import DescribeQueryProcessor from dbt.adapters.databricks.relation_configs.refresh import RefreshConfig, RefreshProcessor +from dbt.adapters.databricks.relation_configs.row_filter import RowFilterProcessor from dbt.adapters.databricks.relation_configs.tags import TagsProcessor from dbt.adapters.databricks.relation_configs.tblproperties import ( TblPropertiesProcessor, @@ -31,7 +32,7 @@ class StreamingTableConfig(DatabricksRelationConfigBase): RefreshProcessor, TagsProcessor, DescribeQueryProcessor, - TagsProcessor, + RowFilterProcessor, ] def get_changeset( diff --git a/dbt/adapters/databricks/relation_configs/tags.py b/dbt/adapters/databricks/relation_configs/tags.py index 9286bc9b3..757c3fdf0 100644 --- a/dbt/adapters/databricks/relation_configs/tags.py +++ b/dbt/adapters/databricks/relation_configs/tags.py @@ -33,7 +33,7 @@ def from_relation_results(cls, results: RelationResults) -> TagsConfig: if table: for row in table.rows: - tags[str(row[0])] = str(row[1]) + tags[str(row[0])] = "" if row[1] is None else str(row[1]) return TagsConfig(set_tags=tags) @@ -43,7 +43,7 @@ def from_relation_config(cls, relation_config: RelationConfig) -> TagsConfig: if not tags: return TagsConfig(set_tags=dict()) if isinstance(tags, dict): - tags = {str(k): str(v) for k, v in tags.items()} + tags = {str(k): "" if v is None else str(v) for k, v in tags.items()} return TagsConfig(set_tags=tags) else: raise DbtRuntimeError("databricks_tags must be a dictionary") diff --git a/dbt/include/databricks/macros/adapters/metadata.sql b/dbt/include/databricks/macros/adapters/metadata.sql index 44f9c0a15..08f01d231 100644 --- a/dbt/include/databricks/macros/adapters/metadata.sql +++ b/dbt/include/databricks/macros/adapters/metadata.sql @@ -129,9 +129,17 @@ SELECT NULL ) AS databricks_table_type FROM `system`.`information_schema`.`tables` -WHERE table_catalog = '{{ relation.database|lower }}' +WHERE table_catalog = '{{ relation.database|lower }}' AND table_schema = '{{ relation.schema|lower }}' {%- if relation.identifier %} AND table_name = '{{ relation.identifier|lower }}' {% endif %} {% endmacro %} + +{% macro describe_table_extended_as_json(relation) %} + {{ return(run_query_as(describe_table_extended_as_json_sql(relation), 'describe_table_extended_as_json')) }} +{% endmacro %} + +{% macro describe_table_extended_as_json_sql(relation) %} +DESCRIBE TABLE EXTENDED {{ relation.render() }} AS JSON +{% endmacro %} diff --git a/dbt/include/databricks/macros/materializations/functions/scalar.sql b/dbt/include/databricks/macros/materializations/functions/scalar.sql index 41b284344..cc8b6a504 100644 --- a/dbt/include/databricks/macros/materializations/functions/scalar.sql +++ b/dbt/include/databricks/macros/materializations/functions/scalar.sql @@ -7,4 +7,32 @@ {% macro databricks__scalar_function_body_sql() %} RETURN {{ model.compiled_code }} -{% endmacro %} \ No newline at end of file +{% endmacro %} + +{# Python UDF signature macro #} +{% macro databricks__scalar_function_create_replace_signature_python(target_relation) %} + CREATE OR REPLACE FUNCTION {{ target_relation.render() }} ({{ formatted_scalar_function_args_sql() }}) + RETURNS {{ model.returns.data_type }} + LANGUAGE PYTHON + AS +{% endmacro %} + +{# Python UDF body macro - uses dollar-quoting #} +{% macro databricks__scalar_function_body_python() %} +$$ +{{ model.compiled_code }} +$$ +{% endmacro %} + +{# Main Python UDF macro - combines signature and body #} +{% macro databricks__scalar_function_python(target_relation) %} + {#- Warn if user explicitly provided no-op config fields -#} + {%- if model.config.get('runtime_version') -%} + {{ exceptions.warn("'runtime_version' is accepted for compatibility but has no effect on Databricks Python UDFs. Databricks manages the Python runtime internally.") }} + {%- endif -%} + {%- if model.config.get('entry_point') -%} + {{ exceptions.warn("'entry_point' is accepted for compatibility but has no effect on Databricks Python UDFs. The function body is used directly.") }} + {%- endif -%} + {{ databricks__scalar_function_create_replace_signature_python(target_relation) }} + {{ databricks__scalar_function_body_python() }} +{% endmacro %} diff --git a/dbt/include/databricks/macros/materializations/incremental/incremental.sql b/dbt/include/databricks/macros/materializations/incremental/incremental.sql index ac717ae45..c8f09220f 100644 --- a/dbt/include/databricks/macros/materializations/incremental/incremental.sql +++ b/dbt/include/databricks/macros/materializations/incremental/incremental.sql @@ -177,6 +177,7 @@ {% set tags = _configuration_changes.changes.get("tags", None) %} {% set tblproperties = _configuration_changes.changes.get("tblproperties", None) %} {% set liquid_clustering = _configuration_changes.changes.get("liquid_clustering") %} + {% set row_filter = _configuration_changes.changes.get("row_filter") %} {% set constraints = _configuration_changes.changes.get("constraints") %} {% if tags is not none %} {% do apply_tags(target_relation, tags.set_tags) %} @@ -187,6 +188,9 @@ {% if liquid_clustering is not none %} {% do apply_liquid_clustered_cols(target_relation, liquid_clustering) %} {% endif %} + {% if row_filter is not none %} + {{ apply_row_filter(target_relation, row_filter) }} + {% endif %} {#- Incremental constraint application requires information_schema access (see fetch_*_constraints macros) -#} {% set contract_config = config.get('contract') %} {% if constraints and contract_config and contract_config.enforced and not target_relation.is_hive_metastore() %} diff --git a/dbt/include/databricks/macros/materializations/metric_view.sql b/dbt/include/databricks/macros/materializations/metric_view.sql new file mode 100644 index 000000000..b3884dce1 --- /dev/null +++ b/dbt/include/databricks/macros/materializations/metric_view.sql @@ -0,0 +1,41 @@ +{% materialization metric_view, adapter='databricks' -%} + {%- set existing_relation = load_relation_with_metadata(this) -%} + {%- set target_relation = this.incorporate(type='metric_view') -%} + {% set grant_config = config.get('grants') %} + {% set tags = config.get('databricks_tags') %} + {% set sql = adapter.clean_sql(sql) %} + + {{ run_pre_hooks() }} + + {% if existing_relation %} + {#- Only use alter path if existing relation is actually a metric_view -#} + {% if existing_relation.is_metric_view and relation_should_be_altered(existing_relation) %} + {% set configuration_changes = get_configuration_changes(existing_relation) %} + {% if configuration_changes and configuration_changes.changes %} + {% if configuration_changes.requires_full_refresh %} + {{ replace_with_metric_view(existing_relation, target_relation) }} + {% else %} + {{ alter_metric_view(target_relation, configuration_changes.changes) }} + {% endif %} + {% else %} + {# No changes detected - run a no-op statement for dbt tracking #} + {% call statement('main') %} + select 1 + {% endcall %} + {% endif %} + {% else %} + {{ replace_with_metric_view(existing_relation, target_relation) }} + {% endif %} + {% else %} + {% call statement('main') -%} + {{ get_create_metric_view_as_sql(target_relation, sql) }} + {%- endcall %} + {{ apply_tags(target_relation, tags) }} + {% endif %} + + {% do apply_grants(target_relation, grant_config, should_revoke=should_revoke(existing_relation, full_refresh_mode=True)) %} + + {{ run_post_hooks() }} + + {{ return({'relations': [target_relation]}) }} +{%- endmaterialization %} diff --git a/dbt/include/databricks/macros/materializations/table.sql b/dbt/include/databricks/macros/materializations/table.sql index b749c5148..f7b4204d8 100644 --- a/dbt/include/databricks/macros/materializations/table.sql +++ b/dbt/include/databricks/macros/materializations/table.sql @@ -15,7 +15,7 @@ {% set staging_relation = make_staging_relation(target_relation) %} {{ run_pre_hooks() }} - + {% call statement('main', language=language) %} {{ get_create_intermediate_table(intermediate_relation, compiled_code, language) }} {% endcall %} diff --git a/dbt/include/databricks/macros/materializations/view.sql b/dbt/include/databricks/macros/materializations/view.sql index 8c824e5de..7ecc0829d 100644 --- a/dbt/include/databricks/macros/materializations/view.sql +++ b/dbt/include/databricks/macros/materializations/view.sql @@ -87,7 +87,7 @@ {% macro relation_should_be_altered(existing_relation) %} {% set update_via_alter = config.get('view_update_via_alter', False) | as_bool %} - {% if existing_relation.is_view and update_via_alter %} + {% if (existing_relation.is_view or existing_relation.is_metric_view) and update_via_alter %} {% if existing_relation.is_hive_metastore() %} {{ exceptions.raise_compiler_error("Cannot update a view in the Hive metastore via ALTER VIEW. Please set `view_update_via_alter: false` in your model configuration.") }} {% endif %} diff --git a/dbt/include/databricks/macros/relations/components/query.sql b/dbt/include/databricks/macros/relations/components/query.sql index 514772844..6b31065bc 100644 --- a/dbt/include/databricks/macros/relations/components/query.sql +++ b/dbt/include/databricks/macros/relations/components/query.sql @@ -8,7 +8,7 @@ {% endmacro %} {% macro get_alter_query_sql(target_relation, query) -%} - ALTER {{ target_relation.type.render() }} {{ target_relation.render() }} AS ( + ALTER {{ target_relation.type.render_for_alter() }} {{ target_relation.render() }} AS ( {{ query }} ) -{%- endmacro %} \ No newline at end of file +{%- endmacro %} diff --git a/dbt/include/databricks/macros/relations/components/row_filter.sql b/dbt/include/databricks/macros/relations/components/row_filter.sql new file mode 100644 index 000000000..ace2757e6 --- /dev/null +++ b/dbt/include/databricks/macros/relations/components/row_filter.sql @@ -0,0 +1,204 @@ +{#-- +ROW FILTER MACROS +================= + +Implements row-level security via Unity Catalog ROW FILTER clause. + +NOTE: This file contains intentional duplication between CREATE and ALTER paths. +See row_filter.py _qualify_function_name() for the Python equivalent logic. +Keep both implementations in sync. + +ARCHITECTURAL NOTE: CREATE and ALTER Paths +------------------------------------------ +1. CREATE PATH (get_create_row_filter_clause): + - Used during initial table/MV/ST creation + - Reads raw config.get('row_filter') and validates in Jinja + - Row filter is embedded in the CREATE statement + +2. ALTER PATH (alter_set_row_filter / alter_drop_row_filter): + - Used for subsequent changes in incremental materializations (incremental, MV, ST) + - Receives pre-validated RowFilterConfig from Python + - When should_unset=True, drops the filter; when function is set, applies it + +The duplication exists because: +- CREATE path reads raw config (Python processing hasn't happened yet) +- ALTER path uses Python-processed config via RowFilterProcessor.from_relation_config() +--#} + +{#-- ===== FETCH MACROS ===== --#} + +{%- macro fetch_row_filters(relation) -%} + {%- if relation.is_hive_metastore() -%} + {{ exceptions.raise_compiler_error("Row filters are not supported for Hive Metastore") }} + {%- endif -%} + {%- call statement('list_row_filters', fetch_result=True) -%} + {{ fetch_row_filters_sql(relation) }} + {%- endcall -%} + {%- do return(load_result('list_row_filters').table) -%} +{%- endmacro -%} + + +{%- macro fetch_row_filters_sql(relation) -%} + SELECT + table_catalog, + table_schema, + table_name, + filter_name, + target_columns + FROM `{{ relation.database }}`.`information_schema`.`row_filters` + WHERE table_catalog = '{{ relation.database | lower }}' + AND table_schema = '{{ relation.schema | lower }}' + AND table_name = '{{ relation.identifier | lower }}' +{%- endmacro -%} + + +{#-- ===== HELPER MACROS ===== --#} + +{%- macro quote_row_filter_columns(columns) -%} + {#- Quote column names with backticks and join with comma. + Shared by alter_set_row_filter and get_create_row_filter_clause. -#} + {%- set quoted = [] -%} + {%- for col in columns -%} + {%- do quoted.append('`' ~ col ~ '`') -%} + {%- endfor -%} + {{- quoted | join(', ') -}} +{%- endmacro -%} + + +{%- macro quote_row_filter_function(function) -%} + {#- Quote a fully qualified function name: catalog.schema.func -> `catalog`.`schema`.`func` + Function names are stored raw internally; backticks are added at SQL generation time. -#} + {%- set parts = function.split('.') -%} + {%- if parts | length == 3 -%} + `{{ parts[0] }}`.`{{ parts[1] }}`.`{{ parts[2] }}` + {%- else -%} + {#- Fallback for unexpected format -#} + `{{ function }}` + {%- endif -%} +{%- endmacro -%} + + +{#-- ===== ALTER MACROS (SQL-returning, don't execute) ===== --#} + +{%- macro alter_set_row_filter(relation, row_filter) -%} + ALTER {{ relation.type.render() }} {{ relation.render() }} + SET ROW FILTER {{ quote_row_filter_function(row_filter.function) }} + ON ({{ quote_row_filter_columns(row_filter.columns) }}) +{%- endmacro -%} + + +{%- macro alter_drop_row_filter(relation) -%} + ALTER {{ relation.type.render() }} {{ relation.render() }} + DROP ROW FILTER +{%- endmacro -%} + + +{#-- ===== APPLY MACRO (executes SQL) ===== --#} + +{%- macro apply_row_filter(target_relation, row_filter) -%} + {#- Executes SQL immediately via call statement('main') -#} + {#- row_filter is a RowFilterConfig object with is_change, should_unset, function fields -#} + {%- if target_relation.is_hive_metastore() -%} + {{ exceptions.raise_compiler_error("Row filters are not supported for Hive Metastore") }} + {%- endif -%} + + {%- if row_filter.is_change -%} + {%- if row_filter.should_unset -%} + {%- call statement('main') -%} + {{ alter_drop_row_filter(target_relation) }} + {%- endcall -%} + {%- elif row_filter.function -%} + {%- call statement('main') -%} + {{ alter_set_row_filter(target_relation, row_filter) }} + {%- endcall -%} + {%- endif -%} + {%- endif -%} +{%- endmacro -%} + + +{#-- ===== DROP IF EXISTS MACRO (for table materialization) ===== --#} + +{%- macro drop_row_filter_if_exists(relation) -%} + {#- Drops any existing row filter from a relation before CREATE OR REPLACE. + Used by table materialization when the model has no row filter configured + but the existing table might have one from a previous run. -#} + {%- if not relation.is_hive_metastore() -%} + {%- set existing_filters = fetch_row_filters(relation) -%} + {%- if existing_filters | length > 0 -%} + {%- call statement('drop_row_filter') -%} + {{ alter_drop_row_filter(relation) }} + {%- endcall -%} + {%- endif -%} + {%- endif -%} +{%- endmacro -%} + + +{#-- ===== CREATE CLAUSE MACROS ===== --#} + +{%- macro qualify_row_filter_function(function, relation) -%} + {#- Handle 1-part and 3-part, reject 2-part. -#} + {#- IMPORTANT: Keep in sync with Python _qualify_function_name() -#} + {%- set parts = function.replace('`', '').split('.') -%} + + {%- if parts | length == 1 -%} + {#- Unqualified: fn -> catalog.schema.fn -#} + {{ relation.database }}.{{ relation.schema }}.{{ parts[0] }} + {%- elif parts | length == 2 -%} + {#- Ambiguous: reject -#} + {{ exceptions.raise_compiler_error( + "Row filter function '" ~ function ~ "' is ambiguous. " ~ + "Use either unqualified name (e.g., 'my_filter') or " ~ + "fully qualified name (e.g., 'catalog.schema.my_filter')." + ) }} + {%- elif parts | length == 3 -%} + {#- Fully qualified -#} + {{ parts[0] }}.{{ parts[1] }}.{{ parts[2] }} + {%- else -%} + {{ exceptions.raise_compiler_error( + "Row filter function '" ~ function ~ "' has too many parts." + ) }} + {%- endif -%} +{%- endmacro -%} + + +{%- macro get_create_row_filter_clause(relation) -%} + {%- set row_filter = config.get('row_filter') -%} + + {%- if not row_filter or not row_filter.get('function') -%} + {#-- Model has no row filter - drop any existing filter from table --#} + {{ drop_row_filter_if_exists(relation) }} + {%- else -%} + {#-- Model has row filter - generate the WITH ROW FILTER clause --#} + {%- set columns = row_filter.get('columns', []) -%} + + {#- Normalize string to list -#} + {%- if columns is string -%} + {%- set columns = [columns] -%} + {%- endif -%} + + {#- Validate columns is non-empty -#} + {%- if not columns or columns | length == 0 -%} + {{ exceptions.raise_compiler_error( + "Row filter function '" ~ row_filter.get('function') ~ "' requires a non-empty 'columns' value." + ) }} + {%- endif -%} + + {%- set function = qualify_row_filter_function(row_filter.get('function'), relation) -%} + WITH ROW FILTER {{ quote_row_filter_function(function) }} ON ({{ quote_row_filter_columns(columns) }}) + {%- endif -%} +{%- endmacro -%} + + +{%- macro row_filter_exists() -%} + {%- set row_filter = config.get('row_filter') -%} + {%- set has_function = row_filter and row_filter.get('function') -%} + {%- set columns = row_filter.get('columns') if row_filter else none -%} + + {#- Normalize string to list -#} + {%- if columns is string -%} + {%- set columns = [columns] -%} + {%- endif -%} + + {%- set has_columns = columns and columns | length > 0 -%} + {%- do return(has_function and has_columns) -%} +{%- endmacro -%} diff --git a/dbt/include/databricks/macros/relations/config.sql b/dbt/include/databricks/macros/relations/config.sql index 4c6ae8910..c5c867522 100644 --- a/dbt/include/databricks/macros/relations/config.sql +++ b/dbt/include/databricks/macros/relations/config.sql @@ -3,4 +3,4 @@ {%- set model_config = adapter.get_config_from_model(config.model) -%} {%- set configuration_changes = model_config.get_changeset(existing_config) -%} {% do return(configuration_changes) %} -{%- endmacro -%} \ No newline at end of file +{%- endmacro -%} diff --git a/dbt/include/databricks/macros/relations/create.sql b/dbt/include/databricks/macros/relations/create.sql index fa96d19b6..abc442078 100644 --- a/dbt/include/databricks/macros/relations/create.sql +++ b/dbt/include/databricks/macros/relations/create.sql @@ -11,6 +11,9 @@ {%- elif relation.is_streaming_table -%} {{ get_create_streaming_table_as_sql(relation, sql) }} + {%- elif relation.is_metric_view -%} + {{ get_create_metric_view_as_sql(relation, sql) }} + {%- else -%} {{- exceptions.raise_compiler_error("`get_create_sql` has not been implemented for: " ~ relation.type ) -}} diff --git a/dbt/include/databricks/macros/relations/drop.sql b/dbt/include/databricks/macros/relations/drop.sql index 464ac51d2..710ad800f 100644 --- a/dbt/include/databricks/macros/relations/drop.sql +++ b/dbt/include/databricks/macros/relations/drop.sql @@ -3,7 +3,7 @@ {{ drop_materialized_view(relation) }} {%- elif relation.is_streaming_table-%} {{ drop_streaming_table(relation) }} - {%- elif relation.is_view -%} + {%- elif relation.is_view or relation.is_metric_view -%} {{ drop_view(relation) }} {%- else -%} {{ drop_table(relation) }} diff --git a/dbt/include/databricks/macros/relations/materialized_view/alter.sql b/dbt/include/databricks/macros/relations/materialized_view/alter.sql index 0ac14d96f..d5d1aea94 100644 --- a/dbt/include/databricks/macros/relations/materialized_view/alter.sql +++ b/dbt/include/databricks/macros/relations/materialized_view/alter.sql @@ -46,6 +46,18 @@ {%- if tags and tags.set_tags and tags.set_tags != [] -%} {{ return_statements.append(alter_set_tags(relation, tags.set_tags)) }} {%- endif -%} + + {#- Row filter handling - append SQL to list, don't execute -#} + {#- is_change guard prevents false alters when row_filter is unchanged -#} + {%- set row_filter = configuration_changes.changes.get("row_filter") -%} + {%- if row_filter and row_filter.is_change -%} + {%- if row_filter.should_unset -%} + {{ return_statements.append(alter_drop_row_filter(relation)) }} + {%- elif row_filter.function -%} + {{ return_statements.append(alter_set_row_filter(relation, row_filter)) }} + {%- endif -%} + {%- endif -%} + {% do return(return_statements) %} {%- endif -%} {% endmacro %} diff --git a/dbt/include/databricks/macros/relations/materialized_view/create.sql b/dbt/include/databricks/macros/relations/materialized_view/create.sql index 5f0e49525..00483574f 100644 --- a/dbt/include/databricks/macros/relations/materialized_view/create.sql +++ b/dbt/include/databricks/macros/relations/materialized_view/create.sql @@ -33,6 +33,7 @@ create or replace materialized view {{ target_relation.render() }} {{ get_column_and_constraints_sql(target_relation, columns_and_constraints[0]) }} + {{ get_create_row_filter_clause(target_relation) }} {{ get_create_sql_partition_by(partition_by) }} {{ liquid_clustered_cols() }} {{ get_create_sql_comment(comment) }} diff --git a/dbt/include/databricks/macros/relations/metric_view/alter.sql b/dbt/include/databricks/macros/relations/metric_view/alter.sql new file mode 100644 index 000000000..e3bfde7e5 --- /dev/null +++ b/dbt/include/databricks/macros/relations/metric_view/alter.sql @@ -0,0 +1,51 @@ +{% macro alter_metric_view(target_relation, changes) %} + {{ log("Updating metric view via ALTER") }} + {{ adapter.dispatch('alter_metric_view', 'dbt')(target_relation, changes) }} +{% endmacro %} + +{% macro databricks__alter_metric_view(target_relation, changes) %} + {% set tags = changes.get("tags") %} + {% set tblproperties = changes.get("tblproperties") %} + {% set query = changes.get("query") %} + + {# Handle YAML definition changes via ALTER VIEW AS #} + {% if query %} + {% call statement('main') %} + {{ get_alter_metric_view_as_sql(target_relation, query.query) }} + {% endcall %} + {% else %} + {# Ensure statement('main') is called for dbt to track the run #} + {% call statement('main') %} + select 1 + {% endcall %} + {% endif %} + + {% if tags %} + {{ apply_tags(target_relation, tags.set_tags) }} + {% endif %} + {% if tblproperties %} + {{ apply_tblproperties(target_relation, tblproperties.tblproperties) }} + {% endif %} +{% endmacro %} + +{% macro get_alter_metric_view_as_sql(relation, yaml_content) -%} + {{ adapter.dispatch('get_alter_metric_view_as_sql', 'dbt')(relation, yaml_content) }} +{%- endmacro %} + +{% macro databricks__get_alter_metric_view_as_sql(relation, yaml_content) %} +alter view {{ relation.render() }} as $$ +{{ yaml_content }} +$$ +{% endmacro %} + +{% macro replace_with_metric_view(existing_relation, target_relation) %} + {% set sql = adapter.clean_sql(sql) %} + {% set tags = config.get('databricks_tags') %} + {% set tblproperties = config.get('tblproperties') %} + {{ execute_multiple_statements(get_replace_sql(existing_relation, target_relation, sql)) }} + {%- do apply_tags(target_relation, tags) -%} + + {% if tblproperties %} + {{ apply_tblproperties(target_relation, tblproperties) }} + {% endif %} +{% endmacro %} diff --git a/dbt/include/databricks/macros/relations/metric_view/create.sql b/dbt/include/databricks/macros/relations/metric_view/create.sql new file mode 100644 index 000000000..6c8b386b9 --- /dev/null +++ b/dbt/include/databricks/macros/relations/metric_view/create.sql @@ -0,0 +1,12 @@ +{% macro get_create_metric_view_as_sql(relation, sql) -%} + {{ adapter.dispatch('get_create_metric_view_as_sql', 'dbt')(relation, sql) }} +{%- endmacro %} + +{% macro databricks__get_create_metric_view_as_sql(relation, sql) %} +create or replace view {{ relation.render() }} +with metrics +language yaml +as $$ +{{ sql }} +$$ +{% endmacro %} \ No newline at end of file diff --git a/dbt/include/databricks/macros/relations/metric_view/replace.sql b/dbt/include/databricks/macros/relations/metric_view/replace.sql new file mode 100644 index 000000000..2855790ef --- /dev/null +++ b/dbt/include/databricks/macros/relations/metric_view/replace.sql @@ -0,0 +1,7 @@ +{% macro get_replace_metric_view_sql(target_relation, sql) %} + {{ adapter.dispatch('get_replace_metric_view_sql', 'dbt')(target_relation, sql) }} +{% endmacro %} + +{% macro databricks__get_replace_metric_view_sql(target_relation, sql) %} + {{ get_create_metric_view_as_sql(target_relation, sql) }} +{% endmacro %} \ No newline at end of file diff --git a/dbt/include/databricks/macros/relations/replace.sql b/dbt/include/databricks/macros/relations/replace.sql index 72484ede2..ff44c9d4d 100644 --- a/dbt/include/databricks/macros/relations/replace.sql +++ b/dbt/include/databricks/macros/relations/replace.sql @@ -9,6 +9,12 @@ {{ exceptions.raise_not_implemented('get_replace_sql not implemented for target of table') }} {% endif %} + {#- Metric views always support CREATE OR REPLACE (no delta/file_format dependency) -#} + {#- Note: existing relation is typed as VIEW from DB, so check target for metric_view -#} + {% if target_relation.is_metric_view %} + {{ return(get_replace_metric_view_sql(target_relation, sql)) }} + {% endif %} + {% set safe_replace = config.get('use_safer_relation_operations', False) | as_bool %} {% set file_format = adapter.resolve_file_format(config) %} {% set is_replaceable = existing_relation.type == target_relation.type and existing_relation.can_be_replaced and file_format == "delta" %} diff --git a/dbt/include/databricks/macros/relations/streaming_table/alter.sql b/dbt/include/databricks/macros/relations/streaming_table/alter.sql index b45493991..a8c87648f 100644 --- a/dbt/include/databricks/macros/relations/streaming_table/alter.sql +++ b/dbt/include/databricks/macros/relations/streaming_table/alter.sql @@ -44,6 +44,18 @@ {%- if tags and tags.set_tags and tags.set_tags != [] -%} {{ return_statements.append(alter_set_tags(relation, tags.set_tags)) }} {%- endif -%} + + {#- Row filter handling - append SQL to list, don't execute -#} + {#- is_change guard prevents false alters from `diff or value` fallback in streaming_table.py:56 -#} + {%- set row_filter = configuration_changes.changes.get("row_filter") -%} + {%- if row_filter and row_filter.is_change -%} + {%- if row_filter.should_unset -%} + {{ return_statements.append(alter_drop_row_filter(relation)) }} + {%- elif row_filter.function -%} + {{ return_statements.append(alter_set_row_filter(relation, row_filter)) }} + {%- endif -%} + {%- endif -%} + {% do return(return_statements) %} {%- endif -%} {% endmacro %} diff --git a/dbt/include/databricks/macros/relations/streaming_table/create.sql b/dbt/include/databricks/macros/relations/streaming_table/create.sql index cadc68d1d..ab55e10fc 100644 --- a/dbt/include/databricks/macros/relations/streaming_table/create.sql +++ b/dbt/include/databricks/macros/relations/streaming_table/create.sql @@ -29,6 +29,7 @@ {#-- We don't enrich the relation with model constraints because they are not supported for streaming tables --#} CREATE STREAMING TABLE {{ relation.render() }} {{ get_column_and_constraints_sql(relation, columns_and_constraints[0]) }} + {{ get_create_row_filter_clause(relation) }} {{ get_create_sql_partition_by(partition_by) }} {{ liquid_clustered_cols() }} {{ get_create_sql_comment(comment) }} diff --git a/dbt/include/databricks/macros/relations/table/alter.sql b/dbt/include/databricks/macros/relations/table/alter.sql index ca2a7690e..095de9024 100644 --- a/dbt/include/databricks/macros/relations/table/alter.sql +++ b/dbt/include/databricks/macros/relations/table/alter.sql @@ -9,6 +9,7 @@ {% set liquid_clustering = configuration_changes.changes.get("liquid_clustering")%} {% set constraints = configuration_changes.changes.get("constraints") %} {% set column_masks = configuration_changes.changes.get("column_masks") %} + {% set row_filter = configuration_changes.changes.get("row_filter") %} {% if tags is not none %} {% do apply_tags(target_relation, tags.set_tags) %} {%- endif -%} @@ -33,5 +34,8 @@ {% if column_masks %} {{ apply_column_masks(target_relation, column_masks) }} {% endif %} + {% if row_filter %} + {{ apply_row_filter(target_relation, row_filter) }} + {% endif %} {%- endif -%} {% endmacro %} \ No newline at end of file diff --git a/dbt/include/databricks/macros/relations/table/create.sql b/dbt/include/databricks/macros/relations/table/create.sql index 447d9bf54..aa62a5255 100644 --- a/dbt/include/databricks/macros/relations/table/create.sql +++ b/dbt/include/databricks/macros/relations/table/create.sql @@ -46,6 +46,7 @@ {{ file_format_clause(catalog_relation) }} {{ databricks__options_clause(catalog_relation) }} {{ partition_cols(label="partitioned by") }} + {{ get_create_row_filter_clause(target_relation) }} {{ liquid_clustered_cols() }} {{ clustered_cols(label="clustered by") }} {{ location_clause(catalog_relation) }} @@ -74,6 +75,7 @@ {{ file_format_clause(catalog_relation) }} {{ databricks__options_clause(catalog_relation) }} {{ partition_cols(label="partitioned by") }} + {{ get_create_row_filter_clause(relation) }} {{ liquid_clustered_cols() }} {{ clustered_cols(label="clustered by") }} {{ location_clause(catalog_relation) }} diff --git a/dbt/include/databricks/macros/relations/tags.sql b/dbt/include/databricks/macros/relations/tags.sql index 6ba6e29a9..49defdf44 100644 --- a/dbt/include/databricks/macros/relations/tags.sql +++ b/dbt/include/databricks/macros/relations/tags.sql @@ -29,7 +29,7 @@ {%- endmacro -%} {% macro alter_set_tags(relation, tags) -%} - ALTER {{ relation.type.render() }} {{ relation.render() }} SET TAGS ( + ALTER {{ relation.type.render_for_alter() }} {{ relation.render() }} SET TAGS ( {% for tag in tags -%} '{{ tag }}' = '{{ tags[tag] }}' {%- if not loop.last %}, {% endif -%} {%- endfor %} diff --git a/dbt/include/databricks/macros/relations/tblproperties.sql b/dbt/include/databricks/macros/relations/tblproperties.sql index 333ca1efc..adbfcf55e 100644 --- a/dbt/include/databricks/macros/relations/tblproperties.sql +++ b/dbt/include/databricks/macros/relations/tblproperties.sql @@ -21,7 +21,7 @@ {% set tblproperty_statment = databricks__tblproperties_clause(tblproperties) %} {% if tblproperty_statment %} {%- call statement('main') -%} - ALTER {{ relation.type.render() }} {{ relation.render() }} SET {{ tblproperty_statment}} + ALTER {{ relation.type.render_for_alter() }} {{ relation.render() }} SET {{ tblproperty_statment}} {%- endcall -%} {% endif %} {%- endmacro -%} diff --git a/dbt/include/databricks/macros/relations/view/create.sql b/dbt/include/databricks/macros/relations/view/create.sql index 79e74001c..736f6d4f7 100644 --- a/dbt/include/databricks/macros/relations/view/create.sql +++ b/dbt/include/databricks/macros/relations/view/create.sql @@ -2,6 +2,9 @@ {% if column_mask_exists() %} {% do exceptions.raise_compiler_error("Column masks are not supported for views.") %} {% endif %} + {% if row_filter_exists() %} + {% do exceptions.raise_compiler_error("Row filters are not supported for views.") %} + {% endif %} {{ log("Creating view " ~ relation) }} create or replace view {{ relation.render() }} {%- if config.persist_column_docs() -%} diff --git a/tests/functional/adapter/column_tags/fixtures.py b/tests/functional/adapter/column_tags/fixtures.py index 5c26f835b..0e282d222 100644 --- a/tests/functional/adapter/column_tags/fixtures.py +++ b/tests/functional/adapter/column_tags/fixtures.py @@ -14,6 +14,8 @@ databricks_tags: pii: "true" sensitive: "true" + key_only: "" + null_value: """ updated_column_tag_model = """ @@ -30,6 +32,8 @@ databricks_tags: pii: "true" sensitive: "true" + key_only: "" + null_value: """ column_tags_seed = """ diff --git a/tests/functional/adapter/column_tags/test_column_tags.py b/tests/functional/adapter/column_tags/test_column_tags.py index e2407caed..1154b8a7e 100644 --- a/tests/functional/adapter/column_tags/test_column_tags.py +++ b/tests/functional/adapter/column_tags/test_column_tags.py @@ -32,6 +32,8 @@ def test_column_tags(self, project): expected_tags = { ("account_number", "pii", "true"), ("account_number", "sensitive", "true"), + ("account_number", "key_only", ""), + ("account_number", "null_value", ""), } actual_tags = {(row[0], row[1], row[2]) for row in tags} assert actual_tags == expected_tags @@ -52,6 +54,8 @@ def test_column_tags(self, project): ("id", "pii", "false"), ("account_number", "pii", "true"), ("account_number", "sensitive", "true"), + ("account_number", "key_only", ""), + ("account_number", "null_value", ""), } actual_tags = {(row[0], row[1], row[2]) for row in tags} assert actual_tags == expected_tags diff --git a/tests/functional/adapter/fixtures.py b/tests/functional/adapter/fixtures.py index ea8a42433..408158654 100644 --- a/tests/functional/adapter/fixtures.py +++ b/tests/functional/adapter/fixtures.py @@ -1,6 +1,12 @@ import pytest +class MaterializationV1Mixin: + @pytest.fixture(scope="class") + def project_config_update(self): + return {"flags": {"use_materialization_v2": False}} + + class MaterializationV2Mixin: @pytest.fixture(scope="class") def project_config_update(self): diff --git a/tests/functional/adapter/functions/fixtures.py b/tests/functional/adapter/functions/fixtures.py new file mode 100644 index 000000000..4a887898e --- /dev/null +++ b/tests/functional/adapter/functions/fixtures.py @@ -0,0 +1,65 @@ +DATABRICKS_PYTHON_UDF_BODY = """ +def price_for_xlarge(price): + return price * 2 + +return price_for_xlarge(price) +""" + +DATABRICKS_PYTHON_UDF_YML = """ +functions: + - name: price_for_xlarge + description: Calculate the price for the xlarge version of a standard item + config: + entry_point: price_for_xlarge + runtime_version: "3.11" + arguments: + - name: price + data_type: float + description: The price of the standard item + returns: + data_type: float + description: The resulting xlarge price +""" + + +DATABRICKS_MULTI_ARG_PYTHON_UDF_BODY = """ +def total_price(price, quantity): + return price * quantity + +return total_price(price, quantity) +""" + +DATABRICKS_MULTI_ARG_PYTHON_UDF_YML = """ +functions: + - name: total_price + description: Calculate total price from unit price and quantity + arguments: + - name: price + data_type: float + - name: quantity + data_type: int + returns: + data_type: float +""" + + +PYTHON_UDF_V1 = """ +return price * 2 +""" + +PYTHON_UDF_V2 = """ +return price * 3 +""" + +PYTHON_UDF_YML_V1 = """ +functions: + - name: price_for_xlarge + config: + entry_point: price_for_xlarge + runtime_version: "3.11" + arguments: + - name: price + data_type: float + returns: + data_type: float +""" diff --git a/tests/functional/adapter/functions/test_udfs.py b/tests/functional/adapter/functions/test_udfs.py index 22c9c0766..923021160 100644 --- a/tests/functional/adapter/functions/test_udfs.py +++ b/tests/functional/adapter/functions/test_udfs.py @@ -1,5 +1,152 @@ -from dbt.tests.adapter.functions.test_udfs import UDFsBasic +import pytest +from dbt.artifacts.schemas.results import RunStatus +from dbt.contracts.graph.nodes import FunctionNode +from dbt.tests.adapter.functions.test_udfs import ( + UDFsBasic, +) +from dbt.tests.util import run_dbt, write_file +from tests.functional.adapter.functions.fixtures import ( + DATABRICKS_MULTI_ARG_PYTHON_UDF_BODY, + DATABRICKS_MULTI_ARG_PYTHON_UDF_YML, + DATABRICKS_PYTHON_UDF_BODY, + DATABRICKS_PYTHON_UDF_YML, + PYTHON_UDF_V1, + PYTHON_UDF_V2, + PYTHON_UDF_YML_V1, +) + +@pytest.mark.skip_profile("databricks_cluster") class TestDatabricksUDFs(UDFsBasic): + """Basic SQL UDF test - requires Unity Catalog""" + pass + + +@pytest.mark.skip_profile("databricks_cluster") +class TestDatabricksPythonUDFSupported(UDFsBasic): + """Test that Python UDFs work on Databricks with Unity Catalog. + + Verifies: + - Python UDF creates successfully + - The UDF executes correctly and returns expected results + """ + + @pytest.fixture(scope="class") + def functions(self): + return { + "price_for_xlarge.py": DATABRICKS_PYTHON_UDF_BODY, + "price_for_xlarge.yml": DATABRICKS_PYTHON_UDF_YML, + } + + def test_udfs(self, project): + """Test Python UDF creation and execution on Databricks.""" + result = run_dbt(["build"]) + + # Verify build succeeded + assert len(result.results) == 1 + node_result = result.results[0] + assert node_result.status == RunStatus.Success + assert isinstance(node_result.node, FunctionNode) + assert node_result.node.name == "price_for_xlarge" + + # Verify the UDF actually works by executing it + result = run_dbt(["show", "--inline", "SELECT {{ function('price_for_xlarge') }}(100)"]) + assert len(result.results) == 1 + select_value = int(result.results[0].agate_table.rows[0].values()[0]) + assert select_value == 200, f"Expected 200, got {select_value}" + + +@pytest.mark.skip_profile("databricks_cluster") +class TestDatabricksPythonUDFLifecycle: + """Test the full lifecycle of a Python UDF. + + Verifies: + 1. Initial create - function is created successfully + 2. Subsequent runs - function is replaced (idempotent) + 3. Code change - function is updated with new implementation + """ + + @pytest.fixture(scope="class") + def functions(self): + return { + "price_for_xlarge.py": PYTHON_UDF_V1, + "price_for_xlarge.yml": PYTHON_UDF_YML_V1, + } + + def test_python_udf_lifecycle(self, project): + """Test the full lifecycle of a Python UDF.""" + + # =========================================== + # Phase 1: Initial Create + # =========================================== + result = run_dbt(["build"]) + + assert len(result.results) == 1 + assert result.results[0].status == RunStatus.Success + + # Verify function works: price * 2 + result = run_dbt(["show", "--inline", "SELECT {{ function('price_for_xlarge') }}(100)"]) + assert int(result.results[0].agate_table.rows[0].values()[0]) == 200 + + # =========================================== + # Phase 2: Subsequent Run (Idempotent) + # =========================================== + result = run_dbt(["build"]) + + assert len(result.results) == 1 + assert result.results[0].status == RunStatus.Success + + # Function still works the same way + result = run_dbt(["show", "--inline", "SELECT {{ function('price_for_xlarge') }}(100)"]) + assert int(result.results[0].agate_table.rows[0].values()[0]) == 200 + + # =========================================== + # Phase 3: Code Change (price * 2 -> price * 3) + # =========================================== + # Update the Python file + write_file(PYTHON_UDF_V2, project.project_root, "functions", "price_for_xlarge.py") + + result = run_dbt(["build"]) + + assert len(result.results) == 1 + assert result.results[0].status == RunStatus.Success + + # Verify new implementation works: price * 3 + result = run_dbt(["show", "--inline", "SELECT {{ function('price_for_xlarge') }}(100)"]) + assert int(result.results[0].agate_table.rows[0].values()[0]) == 300 + + +@pytest.mark.skip_profile("databricks_cluster") +class TestDatabricksMultiArgPythonUDF(UDFsBasic): + """Test that Python UDFs with multiple arguments work on Databricks. + + Verifies: + - Multi-arg Python UDF creates successfully + - The UDF executes correctly with multiple arguments + """ + + @pytest.fixture(scope="class") + def functions(self): + return { + "total_price.py": DATABRICKS_MULTI_ARG_PYTHON_UDF_BODY, + "total_price.yml": DATABRICKS_MULTI_ARG_PYTHON_UDF_YML, + } + + def test_udfs(self, project): + """Test multi-arg Python UDF creation and execution on Databricks.""" + result = run_dbt(["build"]) + + # Verify build succeeded + assert len(result.results) == 1 + node_result = result.results[0] + assert node_result.status == RunStatus.Success + assert isinstance(node_result.node, FunctionNode) + assert node_result.node.name == "total_price" + + # Verify the UDF actually works by executing it with two arguments + result = run_dbt(["show", "--inline", "SELECT {{ function('total_price') }}(25.0, 4)"]) + assert len(result.results) == 1 + select_value = int(result.results[0].agate_table.rows[0].values()[0]) + assert select_value == 100, f"Expected 100, got {select_value}" diff --git a/tests/functional/adapter/metric_views/fixtures.py b/tests/functional/adapter/metric_views/fixtures.py new file mode 100644 index 000000000..180a8c796 --- /dev/null +++ b/tests/functional/adapter/metric_views/fixtures.py @@ -0,0 +1,63 @@ +source_table = """ +{{ config(materialized='table') }} + +select 1 as id, 100 as revenue, 'completed' as status, '2024-01-01' as order_date +union all +select 2 as id, 200 as revenue, 'pending' as status, '2024-01-02' as order_date +union all +select 3 as id, 150 as revenue, 'completed' as status, '2024-01-03' as order_date +""" + +basic_metric_view = """ +{{ config(materialized='metric_view') }} + +version: 0.1 +source: "{{ ref('source_orders') }}" +dimensions: + - name: order_date + expr: order_date + - name: status + expr: status +measures: + - name: total_orders + expr: count(1) + - name: total_revenue + expr: sum(revenue) +""" + +metric_view_with_filter = """ +{{ config(materialized='metric_view') }} + +version: 0.1 +source: "{{ ref('source_orders') }}" +filter: status = 'completed' +dimensions: + - name: order_date + expr: order_date +measures: + - name: completed_orders + expr: count(1) + - name: completed_revenue + expr: sum(revenue) +""" + +metric_view_with_config = """ +{{ + config( + materialized='metric_view', + databricks_tags={ + 'team': 'analytics', + 'environment': 'test' + } + ) +}} + +version: 0.1 +source: "{{ ref('source_orders') }}" +dimensions: + - name: status + expr: status +measures: + - name: order_count + expr: count(1) +""" diff --git a/tests/functional/adapter/metric_views/test_metric_view_configuration_changes.py b/tests/functional/adapter/metric_views/test_metric_view_configuration_changes.py new file mode 100644 index 000000000..9693904b2 --- /dev/null +++ b/tests/functional/adapter/metric_views/test_metric_view_configuration_changes.py @@ -0,0 +1,201 @@ +import pytest +from dbt.tests import util +from dbt.tests.util import run_dbt + +from tests.functional.adapter.metric_views.fixtures import ( + source_table, +) + +# Test fixture for metric view with tags configuration +metric_view_with_tags = """ +{{ + config( + materialized='metric_view', + view_update_via_alter=true, + databricks_tags={ + 'team': 'analytics', + 'environment': 'test' + } + ) +}} + +version: 0.1 +source: "{{ ref('source_orders') }}" +dimensions: + - name: status + expr: status +measures: + - name: total_orders + expr: count(1) + - name: total_revenue + expr: sum(revenue) +""" + +# Updated tag configuration for testing ALTER +metric_view_with_updated_tags = """ +{{ + config( + materialized='metric_view', + view_update_via_alter=true, + databricks_tags={ + 'team': 'data-engineering', + 'environment': 'production', + 'owner': 'dbt-team' + } + ) +}} + +version: 0.1 +source: "{{ ref('source_orders') }}" +dimensions: + - name: status + expr: status +measures: + - name: total_orders + expr: count(1) + - name: total_revenue + expr: sum(revenue) +""" + +# Changed YAML definition that requires CREATE OR REPLACE +metric_view_with_changed_definition = """ +{{ + config( + materialized='metric_view', + view_update_via_alter=true, + databricks_tags={ + 'team': 'analytics', + 'environment': 'test' + } + ) +}} + +version: 0.1 +source: "{{ ref('source_orders') }}" +dimensions: + - name: status + expr: status + - name: order_date + expr: order_date +measures: + - name: total_orders + expr: count(1) + - name: total_revenue + expr: sum(revenue) + - name: avg_revenue + expr: avg(revenue) +""" + + +@pytest.mark.skip_profile("databricks_cluster") +class TestMetricViewConfigurationChanges: + """Test metric view configuration change handling""" + + @pytest.fixture(scope="class") + def models(self): + return { + "source_orders.sql": source_table, + "config_change_metrics.sql": metric_view_with_tags, + } + + def test_metric_view_tag_only_changes_via_alter(self, project): + """Test that tag-only changes use ALTER instead of CREATE OR REPLACE""" + # First run creates the metric view + results = run_dbt(["run"]) + assert len(results) == 2 + assert all(result.status == "success" for result in results) + + # Update the model with different tags + util.write_file(metric_view_with_updated_tags, "models", "config_change_metrics.sql") + + # Second run should use ALTER for tags + results = run_dbt(["run", "--models", "config_change_metrics"]) + assert len(results) == 1 + assert results[0].status == "success" + + # Verify the metric view still works + metric_view_name = f"{project.database}.{project.test_schema}.config_change_metrics" + query_result = project.run_sql( + f""" + SELECT + status, + MEASURE(total_orders) as order_count, + MEASURE(total_revenue) as revenue + FROM {metric_view_name} + GROUP BY status + ORDER BY status + """, + fetch="all", + ) + + assert len(query_result) == 2 + status_data = {row[0]: (row[1], row[2]) for row in query_result} + assert status_data["completed"] == (2, 250) + assert status_data["pending"] == (1, 200) + + def test_metric_view_definition_changes_require_replace(self, project): + """Test that YAML definition changes use CREATE OR REPLACE""" + # First run creates the metric view + results = run_dbt(["run"]) + assert len(results) == 2 + assert all(result.status == "success" for result in results) + + # Update the model with changed YAML definition + util.write_file(metric_view_with_changed_definition, "models", "config_change_metrics.sql") + + # Second run should use CREATE OR REPLACE for YAML changes + results = run_dbt(["run", "--models", "config_change_metrics"]) + assert len(results) == 1 + assert results[0].status == "success" + + # Verify the updated metric view works with new measure + metric_view_name = f"{project.database}.{project.test_schema}.config_change_metrics" + query_result = project.run_sql( + f""" + SELECT + status, + MEASURE(total_orders) as order_count, + MEASURE(total_revenue) as revenue, + MEASURE(avg_revenue) as avg_revenue + FROM {metric_view_name} + GROUP BY status + ORDER BY status + """, + fetch="all", + ) + + assert len(query_result) == 2 + status_data = {row[0]: (row[1], row[2], row[3]) for row in query_result} + assert status_data["completed"] == (2, 250, 125.0) # (100+150)/2 = 125 + assert status_data["pending"] == (1, 200, 200.0) + + def test_no_changes_skip_materialization(self, project): + """Test that no changes result in no-op""" + # First run creates the metric view + results = run_dbt(["run"]) + assert len(results) == 2 + assert all(result.status == "success" for result in results) + + # Second run with no changes should be a no-op + results = run_dbt(["run", "--models", "config_change_metrics"]) + assert len(results) == 1 + assert results[0].status == "success" + + # Verify the metric view still works + metric_view_name = f"{project.database}.{project.test_schema}.config_change_metrics" + query_result = project.run_sql( + f""" + SELECT + status, + MEASURE(total_orders) as order_count + FROM {metric_view_name} + GROUP BY status + ORDER BY status + """, + fetch="all", + ) + + assert len(query_result) == 2 + status_data = {row[0]: row[1] for row in query_result} + assert status_data["completed"] == 2 + assert status_data["pending"] == 1 diff --git a/tests/functional/adapter/metric_views/test_metric_view_materialization.py b/tests/functional/adapter/metric_views/test_metric_view_materialization.py new file mode 100644 index 000000000..cee060184 --- /dev/null +++ b/tests/functional/adapter/metric_views/test_metric_view_materialization.py @@ -0,0 +1,213 @@ +import pytest +from dbt.tests.util import get_manifest, run_dbt + +from tests.functional.adapter.metric_views.fixtures import ( + basic_metric_view, + metric_view_with_config, + metric_view_with_filter, + source_table, +) + + +@pytest.mark.skip_profile("databricks_cluster") +class TestBasicMetricViewMaterialization: + """Test basic metric view materialization functionality""" + + @pytest.fixture(scope="class") + def models(self): + return { + "source_orders.sql": source_table, + "order_metrics.sql": basic_metric_view, + } + + def test_metric_view_creation(self, project): + """Test that metric view materialization creates a metric view successfully""" + # Run dbt to create the models + results = run_dbt(["run"]) + assert len(results) == 2 + + # Verify both models ran successfully + assert all(result.status == "success" for result in results) + + # Check that the metric view was created + manifest = get_manifest(project.project_root) + metric_view_node = manifest.nodes["model.test.order_metrics"] + assert metric_view_node.config.materialized == "metric_view" + + # Test if the metric view actually works by querying it with MEASURE() + # This is the most important test - if this works, the metric view was created correctly + metric_view_name = f"{project.database}.{project.test_schema}.order_metrics" + + try: + # Query the metric view using MEASURE() function - this is the real test + query_result = project.run_sql( + f""" + SELECT + status, + MEASURE(total_orders) as order_count, + MEASURE(total_revenue) as revenue + FROM {metric_view_name} + GROUP BY status + ORDER BY status + """, + fetch="all", + ) + print(f"Metric view query result: {query_result}") + + # If we got results, verify the data is correct + if query_result: + assert len(query_result) == 2, f"Expected 2 status groups, got {len(query_result)}" + + # Check data: 2 completed orders worth 250, 1 pending order worth 200 + status_data = {row[0]: (row[1], row[2]) for row in query_result} + print(f"Status data: {status_data}") + + assert "completed" in status_data, "Missing 'completed' status" + assert "pending" in status_data, "Missing 'pending' status" + + completed_count, completed_revenue = status_data["completed"] + pending_count, pending_revenue = status_data["pending"] + + assert completed_count == 2, f"Expected 2 completed orders, got {completed_count}" + assert completed_revenue == 250, ( + f"Expected 250 completed revenue, got {completed_revenue}" + ) + assert pending_count == 1, f"Expected 1 pending order, got {pending_count}" + assert pending_revenue == 200, ( + f"Expected 200 pending revenue, got {pending_revenue}" + ) + + print("✅ Metric view query successful with correct data!") + else: + # fetch=True returned None, but let's try without fetch to see if it executes + project.run_sql( + f"SELECT MEASURE(total_orders) FROM {metric_view_name} LIMIT 1", fetch=False + ) + print("✅ Metric view query executed without error (but fetch returned None)") + + except Exception as e: + assert False, f"Metric view query failed: {e}" + + def test_metric_view_query(self, project): + """Test that the metric view can be queried using MEASURE() function""" + # First run dbt to create the models + run_dbt(["run"]) + + # Query the metric view using MEASURE() function + query_result = project.run_sql( + f""" + select + status, + measure(total_orders) as order_count, + measure(total_revenue) as revenue + from {project.database}.{project.test_schema}.order_metrics + group by status + order by status + """, + fetch="all", + ) + + # Verify we get expected results + assert len(query_result) == 2 # Should have 'completed' and 'pending' status + + # Check the data makes sense + completed_row = next(row for row in query_result if row[0] == "completed") + pending_row = next(row for row in query_result if row[0] == "pending") + + assert completed_row[1] == 2 # 2 completed orders + assert completed_row[2] == 250 # 100 + 150 revenue + assert pending_row[1] == 1 # 1 pending order + assert pending_row[2] == 200 # 200 revenue + + +@pytest.mark.skip_profile("databricks_cluster") +class TestMetricViewWithFilter: + """Test metric view materialization with filters""" + + @pytest.fixture(scope="class") + def models(self): + return { + "source_orders.sql": source_table, + "filtered_metrics.sql": metric_view_with_filter, + } + + def test_metric_view_with_filter_creation(self, project): + """Test that metric view with filter works correctly""" + # Run dbt to create the models + results = run_dbt(["run"]) + assert len(results) == 2 + + # Verify both models ran successfully + assert all(result.status == "success" for result in results) + + def test_metric_view_with_filter_query(self, project): + """Test that filtered metric view returns expected results""" + # First run dbt to create the models + run_dbt(["run"]) + + # Query the filtered metric view + query_result = project.run_sql( + f""" + select + measure(completed_orders) as order_count, + measure(completed_revenue) as revenue + from {project.database}.{project.test_schema}.filtered_metrics + """, + fetch="all", + ) + + # Should only see completed orders (2 orders with 250 total revenue) + assert len(query_result) == 1 + row = query_result[0] + assert row[0] == 2 # 2 completed orders + assert row[1] == 250 # 100 + 150 revenue from completed orders only + + +@pytest.mark.skip_profile("databricks_cluster") +class TestMetricViewConfiguration: + """Test metric view materialization with configuration options""" + + @pytest.fixture(scope="class") + def models(self): + return { + "source_orders.sql": source_table, + "config_metrics.sql": metric_view_with_config, + } + + def test_metric_view_with_tags(self, project): + """Test that metric view works with databricks_tags using ALTER VIEW""" + # Run dbt to create the models + results = run_dbt(["run"]) + assert len(results) == 2 + + # Verify both models ran successfully + assert all(result.status == "success" for result in results) + + # Check that the metric view was created + manifest = get_manifest(project.project_root) + config_node = manifest.nodes["model.test.config_metrics"] + assert config_node.config.materialized == "metric_view" + + # Verify the metric view works by querying it + metric_view_name = f"{project.database}.{project.test_schema}.config_metrics" + + query_result = project.run_sql( + f""" + SELECT + status, + MEASURE(order_count) as count + FROM {metric_view_name} + GROUP BY status + ORDER BY status + """, + fetch="all", + ) + + # Should have results showing tags were applied without error + assert query_result is not None + assert len(query_result) == 2 # completed and pending statuses + + # Check the data is correct + status_data = {row[0]: row[1] for row in query_result} + assert status_data["completed"] == 2 + assert status_data["pending"] == 1 diff --git a/tests/functional/adapter/metric_views/test_metric_view_simple_changes.py b/tests/functional/adapter/metric_views/test_metric_view_simple_changes.py new file mode 100644 index 000000000..aa012d17d --- /dev/null +++ b/tests/functional/adapter/metric_views/test_metric_view_simple_changes.py @@ -0,0 +1,74 @@ +import pytest +from dbt.tests.util import run_dbt + +from tests.functional.adapter.metric_views.fixtures import ( + source_table, +) + +# Test fixture for metric view without view_update_via_alter +metric_view_without_alter = """ +{{ + config( + materialized='metric_view', + databricks_tags={ + 'team': 'analytics', + 'environment': 'test' + } + ) +}} + +version: 0.1 +source: "{{ ref('source_orders') }}" +dimensions: + - name: status + expr: status +measures: + - name: total_orders + expr: count(1) + - name: total_revenue + expr: sum(revenue) +""" + + +@pytest.mark.skip_profile("databricks_cluster") +class TestMetricViewSimpleChanges: + """Test basic metric view behavior without configuration change detection""" + + @pytest.fixture(scope="class") + def models(self): + return { + "source_orders.sql": source_table, + "simple_metrics.sql": metric_view_without_alter, + } + + def test_metric_view_always_recreates(self, project): + """Test that metric view recreates without view_update_via_alter""" + # First run creates the metric view + results = run_dbt(["run"]) + assert len(results) == 2 + assert all(result.status == "success" for result in results) + + # Second run should recreate the metric view (full refresh behavior) + results = run_dbt(["run", "--models", "simple_metrics"]) + assert len(results) == 1 + assert results[0].status == "success" + + # Verify the metric view still works + metric_view_name = f"{project.database}.{project.test_schema}.simple_metrics" + query_result = project.run_sql( + f""" + SELECT + status, + MEASURE(total_orders) as order_count, + MEASURE(total_revenue) as revenue + FROM {metric_view_name} + GROUP BY status + ORDER BY status + """, + fetch="all", + ) + + assert len(query_result) == 2 + status_data = {row[0]: (row[1], row[2]) for row in query_result} + assert status_data["completed"] == (2, 250) + assert status_data["pending"] == (1, 200) diff --git a/tests/functional/adapter/metric_views/test_metric_view_update_via_alter.py b/tests/functional/adapter/metric_views/test_metric_view_update_via_alter.py new file mode 100644 index 000000000..94bc4c2f8 --- /dev/null +++ b/tests/functional/adapter/metric_views/test_metric_view_update_via_alter.py @@ -0,0 +1,294 @@ +import pytest +from dbt.tests import util + +from tests.functional.adapter.metric_views import fixtures + + +class BaseUpdateMetricView: + @pytest.fixture(scope="class") + def models(self): + return { + "source_orders.sql": fixtures.source_table, + "test_metric_view.sql": fixtures.metric_view_with_config, + } + + +class BaseUpdateMetricViewQuery(BaseUpdateMetricView): + # Subclasses should set this to indicate expected execution path + expect_alter_flow: bool = True + + def test_metric_view_update_query(self, project): + """Test that metric view query can be updated via ALTER VIEW AS""" + # First run creates the metric view + util.run_dbt(["run"]) + + # Update the query by changing the metric definition + updated_metric_view = """ +{{ + config( + materialized='metric_view', + databricks_tags={ + 'team': 'analytics', + 'environment': 'test' + } + ) +}} + +version: 0.1 +source: "{{ ref('source_orders') }}" +dimensions: + - name: status + expr: status + - name: order_date + expr: order_date +measures: + - name: order_count + expr: count(1) + - name: total_revenue + expr: sum(revenue) +""" + util.write_file(updated_metric_view, "models", "test_metric_view.sql") + + # Second run should update via ALTER or REPLACE depending on config + _, logs = util.run_dbt_and_capture(["--debug", "run", "--models", "test_metric_view"]) + + # Verify the correct execution path was taken + util.assert_message_in_logs("Updating metric view via ALTER", logs, self.expect_alter_flow) + + # Verify the metric view works with new definition + metric_view_name = f"{project.database}.{project.test_schema}.test_metric_view" + query_result = project.run_sql( + f""" + SELECT + status, + order_date, + MEASURE(order_count) as count, + MEASURE(total_revenue) as revenue + FROM {metric_view_name} + GROUP BY status, order_date + ORDER BY status, order_date + """, + fetch="all", + ) + + assert len(query_result) == 3 + + +class BaseUpdateMetricViewTblProperties(BaseUpdateMetricView): + def test_metric_view_update_tblproperties(self, project): + """Test that metric view tblproperties can be updated via ALTER""" + # First run creates the metric view with tags + util.run_dbt(["run"]) + + # Update with tblproperties added + updated_metric_view = """ +{{ + config( + materialized='metric_view', + databricks_tags={ + 'team': 'analytics', + 'environment': 'test' + }, + tblproperties={ + 'quality': 'gold' + } + ) +}} + +version: 0.1 +source: "{{ ref('source_orders') }}" +dimensions: + - name: status + expr: status +measures: + - name: order_count + expr: count(1) +""" + util.write_file(updated_metric_view, "models", "test_metric_view.sql") + + # Second run should update via ALTER + results = util.run_dbt(["run", "--models", "test_metric_view"]) + assert len(results) == 1 + assert results[0].status == "success" + + # Verify tblproperties were set + results = project.run_sql( + f"show tblproperties {project.database}.{project.test_schema}.test_metric_view", + fetch="all", + ) + + # Check that 'quality' property exists + tblprops = {row[0]: row[1] for row in results} + assert tblprops.get("quality") == "gold" + + +class BaseUpdateMetricViewTags(BaseUpdateMetricView): + def test_metric_view_update_tags(self, project): + """Test that metric view tags can be updated via ALTER""" + # First run creates the metric view with initial tags + util.run_dbt(["run"]) + + # Update the tags + updated_metric_view = """ +{{ + config( + materialized='metric_view', + databricks_tags={ + 'team': 'data-science', + 'environment': 'test', + 'priority': 'high' + } + ) +}} + +version: 0.1 +source: "{{ ref('source_orders') }}" +dimensions: + - name: status + expr: status +measures: + - name: order_count + expr: count(1) +""" + util.write_file(updated_metric_view, "models", "test_metric_view.sql") + + # Second run should update via ALTER + results = util.run_dbt(["run", "--models", "test_metric_view"]) + assert len(results) == 1 + assert results[0].status == "success" + + # Verify tags were updated + results = project.run_sql( + f""" + SELECT TAG_NAME, TAG_VALUE FROM {project.database}.information_schema.table_tags + WHERE schema_name = '{project.test_schema}' AND table_name = 'test_metric_view' + ORDER BY TAG_NAME + """, + fetch="all", + ) + + # Check that we have all three tags + tags = {row[0]: row[1] for row in results} + assert tags.get("team") == "data-science" + assert tags.get("environment") == "test" + assert tags.get("priority") == "high" + + +class BaseUpdateMetricViewNothing(BaseUpdateMetricView): + """Test that no-op updates work correctly""" + + def test_metric_view_update_nothing(self, project): + """Test that metric view with no changes doesn't error""" + # First run creates the metric view + util.run_dbt(["run"]) + + # Second run with no changes - should be no-op + results = util.run_dbt(["run", "--models", "test_metric_view"]) + assert len(results) == 1 + assert results[0].status == "success" + + # Verify the metric view still works + metric_view_name = f"{project.database}.{project.test_schema}.test_metric_view" + query_result = project.run_sql( + f""" + SELECT + status, + MEASURE(order_count) as count + FROM {metric_view_name} + GROUP BY status + ORDER BY status + """, + fetch="all", + ) + + assert len(query_result) == 2 + + +# Test classes with view_update_via_alter enabled +@pytest.mark.skip_profile("databricks_cluster") +class TestUpdateMetricViewViaAlterQuery(BaseUpdateMetricViewQuery): + expect_alter_flow = True + + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "flags": {"use_materialization_v2": True}, + "models": { + "+view_update_via_alter": True, + }, + } + + +@pytest.mark.skip_profile("databricks_cluster") +class TestUpdateMetricViewViaAlterTblProperties(BaseUpdateMetricViewTblProperties): + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "flags": {"use_materialization_v2": True}, + "models": { + "+view_update_via_alter": True, + }, + } + + +@pytest.mark.skip_profile("databricks_cluster") +class TestUpdateMetricViewViaAlterTags(BaseUpdateMetricViewTags): + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "flags": {"use_materialization_v2": True}, + "models": { + "+view_update_via_alter": True, + }, + } + + +@pytest.mark.skip_profile("databricks_cluster") +class TestUpdateMetricViewViaAlterNothing(BaseUpdateMetricViewNothing): + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "flags": {"use_materialization_v2": True}, + "models": { + "+view_update_via_alter": True, + }, + } + + +# Test classes WITHOUT view_update_via_alter (default replace behavior) +@pytest.mark.skip_profile("databricks_cluster") +class TestUpdateMetricViewViaReplaceQuery(BaseUpdateMetricViewQuery): + expect_alter_flow = False + + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "flags": {"use_materialization_v2": True}, + "models": { + "+view_update_via_alter": False, + }, + } + + +@pytest.mark.skip_profile("databricks_cluster") +class TestUpdateMetricViewViaReplaceTblProperties(BaseUpdateMetricViewTblProperties): + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "flags": {"use_materialization_v2": True}, + "models": { + "+view_update_via_alter": False, + }, + } + + +@pytest.mark.skip_profile("databricks_cluster") +class TestUpdateMetricViewViaReplaceTags(BaseUpdateMetricViewTags): + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "flags": {"use_materialization_v2": True}, + "models": { + "+view_update_via_alter": False, + }, + } diff --git a/tests/functional/adapter/row_filters/__init__.py b/tests/functional/adapter/row_filters/__init__.py new file mode 100644 index 000000000..6b95e0de6 --- /dev/null +++ b/tests/functional/adapter/row_filters/__init__.py @@ -0,0 +1 @@ +# Empty init file for test package diff --git a/tests/functional/adapter/row_filters/fixtures.py b/tests/functional/adapter/row_filters/fixtures.py new file mode 100644 index 000000000..ba5744430 --- /dev/null +++ b/tests/functional/adapter/row_filters/fixtures.py @@ -0,0 +1,89 @@ +base_model_sql = """ +{{ config( + materialized = 'table' +) }} +SELECT 'user1' as user_id, 'region_a' as region, CAST(100 AS BIGINT) as amount +""" + +base_model_mv = """ +{{ config( + materialized='materialized_view', +) }} +SELECT 'user1' as user_id, 'region_a' as region, CAST(100 AS BIGINT) as amount +""" + +row_filter_seed = """ +user_id,region,amount +user1,region_a,100 +""".strip() + +base_model_streaming_table = """ +{{ config( + materialized='streaming_table', +) }} +SELECT * FROM stream {{ ref('base_model_seed') }} +""" + +model_with_row_filter = """ +version: 2 +models: + - name: base_model + config: + row_filter: + function: region_filter + columns: [region] + columns: + - name: user_id + data_type: string + - name: region + data_type: string + - name: amount + data_type: bigint +""" + +model_updated_filter = """ +version: 2 +models: + - name: base_model + config: + row_filter: + function: user_filter + columns: [user_id] + columns: + - name: user_id + data_type: string + - name: region + data_type: string + - name: amount + data_type: bigint +""" + +model_no_filter = """ +version: 2 +models: + - name: base_model + columns: + - name: user_id + data_type: string + - name: region + data_type: string + - name: amount + data_type: bigint +""" + +# For view failure test +view_model_sql = """ +{{ config( + materialized = 'view' +) }} +SELECT 'user1' as user_id, 'region_a' as region, CAST(100 AS BIGINT) as amount +""" + +# For safe_relation_replace path test +base_model_safe_sql = """ +{{ config( + materialized = 'table', + use_safer_relation_operations = true +) }} +SELECT 'user1' as user_id, 'region_a' as region, CAST(100 AS BIGINT) as amount +""" diff --git a/tests/functional/adapter/row_filters/test_row_filter.py b/tests/functional/adapter/row_filters/test_row_filter.py new file mode 100644 index 000000000..39b4f960b --- /dev/null +++ b/tests/functional/adapter/row_filters/test_row_filter.py @@ -0,0 +1,355 @@ +import pytest +from dbt.tests.util import run_dbt, write_file + +from tests.functional.adapter.fixtures import MaterializationV1Mixin, MaterializationV2Mixin +from tests.functional.adapter.row_filters.fixtures import ( + base_model_mv, + base_model_sql, + base_model_streaming_table, + model_no_filter, + model_updated_filter, + model_with_row_filter, + row_filter_seed, + view_model_sql, +) + + +class BaseRowFilterMixin: + """Base mixin with helper methods for row filter tests. + + Does not include materialization version - subclasses should combine + with MaterializationV1Mixin or MaterializationV2Mixin as needed. + """ + + def create_filter_udfs(self, project): + """Create test UDFs for row filtering.""" + project.run_sql( + f""" + CREATE OR REPLACE FUNCTION {project.database}.{project.test_schema}.region_filter( + region STRING + ) + RETURNS BOOLEAN + RETURN region = 'region_a' + """ + ) + project.run_sql( + f""" + CREATE OR REPLACE FUNCTION {project.database}.{project.test_schema}.user_filter( + user_id STRING + ) + RETURNS BOOLEAN + RETURN user_id = 'user1' + """ + ) + + def get_row_filters(self, project, table_name): + """Query INFORMATION_SCHEMA for row filters. + + Uses catalog-specific path and includes table_catalog filter + for multi-catalog safety. + """ + return project.run_sql( + f""" + SELECT filter_name FROM {project.database}.information_schema.row_filters + WHERE table_catalog = '{project.database}' + AND table_schema = '{project.test_schema}' + AND table_name = '{table_name}' + """, + fetch="all", + ) + + +class RowFilterMixin(BaseRowFilterMixin, MaterializationV2Mixin): + """Row filter mixin for V2 materialization tests.""" + + pass + + +@pytest.mark.skip_profile("databricks_cluster") +class TestRowFilterTable(RowFilterMixin): + """Test row filters on table models.""" + + @pytest.fixture(scope="class") + def models(self): + return { + "base_model.sql": base_model_sql, + "schema.yml": model_with_row_filter, + } + + def test_row_filter_lifecycle(self, project): + """Test create, update, and remove row filter lifecycle.""" + # Setup + self.create_filter_udfs(project) + + # 1. Create with filter (tests unqualified function qualification) + run_dbt(["run"]) + filters = self.get_row_filters(project, "base_model") + assert len(filters) == 1 + assert "region_filter" in filters[0][0].lower() + + # 2. Update filter + write_file(model_updated_filter, "models", "schema.yml") + run_dbt(["run"]) + filters = self.get_row_filters(project, "base_model") + assert len(filters) == 1 + assert "user_filter" in filters[0][0].lower() + + # 3. Remove filter + write_file(model_no_filter, "models", "schema.yml") + run_dbt(["run"]) + filters = self.get_row_filters(project, "base_model") + assert len(filters) == 0 + + +@pytest.mark.skip_profile("databricks_cluster") +class TestIncrementalRowFilter(RowFilterMixin): + """Test row filters on incremental models.""" + + @pytest.fixture(scope="class") + def models(self): + return { + "base_model.sql": base_model_sql.replace("table", "incremental"), + "schema.yml": model_with_row_filter, + } + + def test_incremental_row_filter_lifecycle(self, project): + """Test create, update, and remove row filter lifecycle on incremental model.""" + self.create_filter_udfs(project) + + # 1. CREATE with filter (initial run) + run_dbt(["run"]) + filters = self.get_row_filters(project, "base_model") + assert len(filters) == 1 + assert "region_filter" in filters[0][0].lower() + + # 2. Incremental run should preserve filter (no config change) + run_dbt(["run"]) + filters = self.get_row_filters(project, "base_model") + assert len(filters) == 1 + assert "region_filter" in filters[0][0].lower() + + # 3. UPDATE filter (incremental run with config change) + write_file(model_updated_filter, "models", "schema.yml") + run_dbt(["run"]) + filters = self.get_row_filters(project, "base_model") + assert len(filters) == 1 + assert "user_filter" in filters[0][0].lower() + + # 4. REMOVE filter (incremental run removing config) + write_file(model_no_filter, "models", "schema.yml") + run_dbt(["run"]) + filters = self.get_row_filters(project, "base_model") + assert len(filters) == 0 + + +@pytest.mark.skip_profile("databricks_cluster", "databricks_uc_cluster") +class TestMaterializedViewRowFilter(RowFilterMixin): + """Test row filters on materialized view models.""" + + @pytest.fixture(scope="class") + def models(self): + return { + "base_model.sql": base_model_mv, + "schema.yml": model_with_row_filter, + } + + def test_mv_row_filter_lifecycle(self, project): + """Test create, update, and remove row filter lifecycle on MV.""" + self.create_filter_udfs(project) + + # 1. CREATE with filter + run_dbt(["run"]) + filters = self.get_row_filters(project, "base_model") + assert len(filters) == 1 + assert "region_filter" in filters[0][0].lower() + + # 2. UPDATE filter + write_file(model_updated_filter, "models", "schema.yml") + run_dbt(["run"]) + filters = self.get_row_filters(project, "base_model") + assert len(filters) == 1 + assert "user_filter" in filters[0][0].lower() + + # 3. DROP filter + write_file(model_no_filter, "models", "schema.yml") + run_dbt(["run"]) + filters = self.get_row_filters(project, "base_model") + assert len(filters) == 0 + + +@pytest.mark.skip_profile("databricks_cluster", "databricks_uc_cluster") +class TestStreamingTableRowFilter(RowFilterMixin): + """Test row filters on streaming table models.""" + + @pytest.fixture(scope="class") + def seeds(self): + return { + "base_model_seed.csv": row_filter_seed, + } + + @pytest.fixture(scope="class") + def models(self): + return { + "base_model.sql": base_model_streaming_table, + "schema.yml": model_with_row_filter, + } + + @pytest.fixture(scope="class", autouse=True) + def setup_streaming_table_seed(self, project): + """Run seed once for the entire class to avoid Delta table ID conflicts.""" + run_dbt(["seed"]) + + @pytest.fixture(scope="function", autouse=True) + def cleanup_streaming_table(self, project): + project.run_sql(f"DROP TABLE IF EXISTS {project.database}.{project.test_schema}.base_model") + yield + + def test_streaming_table_row_filter_lifecycle(self, project): + """Test create, update, and remove row filter lifecycle on ST.""" + self.create_filter_udfs(project) + + # 1. CREATE with filter + run_dbt(["run"]) + filters = self.get_row_filters(project, "base_model") + assert len(filters) == 1 + assert "region_filter" in filters[0][0].lower() + + # 2. UPDATE filter + write_file(model_updated_filter, "models", "schema.yml") + run_dbt(["run"]) + filters = self.get_row_filters(project, "base_model") + assert len(filters) == 1 + assert "user_filter" in filters[0][0].lower() + + # 3. DROP filter + write_file(model_no_filter, "models", "schema.yml") + run_dbt(["run"]) + filters = self.get_row_filters(project, "base_model") + assert len(filters) == 0 + + +@pytest.mark.skip_profile("databricks_cluster") +class TestViewRowFilterFailure(MaterializationV2Mixin): + """Test that row filters on regular views fail with clear error.""" + + @pytest.fixture(scope="class") + def models(self): + return { + "base_model.sql": view_model_sql, + "schema.yml": model_with_row_filter, + } + + def test_view_row_filter_failure(self, project): + """Verify row filters on views fail appropriately.""" + result = run_dbt(["run"], expect_pass=False) + assert result.results[0].status != "success" + # Note: The exact error depends on whether it's compile-time or runtime + # For views, the WITH ROW FILTER syntax itself may fail + + +# ============================================================================ +# SAFE_RELATION_REPLACE PATH TEST +# ============================================================================ + + +@pytest.mark.skip_profile("databricks_cluster") +class TestRowFilterTableSafeReplace(RowFilterMixin): + """Test row filters on table models with safe_relation_replace path.""" + + @pytest.fixture(scope="class") + def models(self): + from tests.functional.adapter.row_filters.fixtures import base_model_safe_sql + + return { + "base_model.sql": base_model_safe_sql, + "schema.yml": model_with_row_filter, + } + + def test_safe_replace_row_filter_lifecycle(self, project): + """Test create, update, and remove row filter lifecycle with safe_relation_replace.""" + self.create_filter_udfs(project) + + # 1. Create with filter (initial run - no existing relation) + run_dbt(["run"]) + filters = self.get_row_filters(project, "base_model") + assert len(filters) == 1 + assert "region_filter" in filters[0][0].lower() + + # 2. Update filter (existing relation - should take SAFE_RELATION_REPLACE path) + write_file(model_updated_filter, "models", "schema.yml") + run_dbt(["run"]) + filters = self.get_row_filters(project, "base_model") + assert len(filters) == 1 + assert "user_filter" in filters[0][0].lower() + + # 3. Remove filter (existing relation - should take SAFE_RELATION_REPLACE path) + write_file(model_no_filter, "models", "schema.yml") + run_dbt(["run"]) + filters = self.get_row_filters(project, "base_model") + assert len(filters) == 0 + + +# ============================================================================ +# V1 MATERIALIZATION TESTS +# ============================================================================ + + +@pytest.mark.skip_profile("databricks_cluster") +class TestRowFilterTableV1(BaseRowFilterMixin, MaterializationV1Mixin): + """Test row filters on table models with V1 materialization.""" + + @pytest.fixture(scope="class") + def models(self): + return { + "base_model.sql": base_model_sql, + "schema.yml": model_with_row_filter, + } + + def test_row_filter_lifecycle(self, project): + """Test create, update, and remove row filter lifecycle in V1.""" + self.create_filter_udfs(project) + + # 1. Create with filter + run_dbt(["run"]) + filters = self.get_row_filters(project, "base_model") + assert len(filters) == 1 + assert "region_filter" in filters[0][0].lower() + + # 2. Update filter + write_file(model_updated_filter, "models", "schema.yml") + run_dbt(["run"]) + filters = self.get_row_filters(project, "base_model") + assert len(filters) == 1 + assert "user_filter" in filters[0][0].lower() + + # 3. Remove filter + write_file(model_no_filter, "models", "schema.yml") + run_dbt(["run"]) + filters = self.get_row_filters(project, "base_model") + assert len(filters) == 0 + + +@pytest.mark.skip_profile("databricks_cluster") +class TestIncrementalRowFilterV1(BaseRowFilterMixin, MaterializationV1Mixin): + """Test row filters on incremental models with V1 materialization.""" + + @pytest.fixture(scope="class") + def models(self): + return { + "base_model.sql": base_model_sql.replace("table", "incremental"), + "schema.yml": model_with_row_filter, + } + + def test_incremental_row_filter(self, project): + """Test row filter on incremental model lifecycle in V1.""" + self.create_filter_udfs(project) + + # Initial run + run_dbt(["run"]) + filters = self.get_row_filters(project, "base_model") + assert len(filters) == 1 + + # Incremental run should preserve filter + run_dbt(["run"]) + filters = self.get_row_filters(project, "base_model") + assert len(filters) == 1 diff --git a/tests/functional/adapter/tags/fixtures.py b/tests/functional/adapter/tags/fixtures.py index 67422fd93..3dad09a82 100644 --- a/tests/functional/adapter/tags/fixtures.py +++ b/tests/functional/adapter/tags/fixtures.py @@ -1,7 +1,7 @@ tags_sql = """ {{ config( materialized = 'table', - databricks_tags = {'a': 'b', 'c': 'd'}, + databricks_tags = {'a': 'b', 'c': 'd', 'k': ''}, ) }} select cast(1 as bigint) as id, 'hello' as msg, 'blue' as color @@ -19,7 +19,7 @@ streaming_table_tags_sql = """ {{ config( materialized='streaming_table', - databricks_tags = {'a': 'b', 'c': 'd'}, + databricks_tags = {'a': 'b', 'c': 'd', 'k': ''}, ) }} select * from stream {{ ref('my_seed') }} @@ -54,4 +54,5 @@ def model(dbt, spark): databricks_tags: a: b c: d + k: "" """ diff --git a/tests/functional/adapter/tags/test_databricks_tags.py b/tests/functional/adapter/tags/test_databricks_tags.py index a4eeb7741..e28d11510 100644 --- a/tests/functional/adapter/tags/test_databricks_tags.py +++ b/tests/functional/adapter/tags/test_databricks_tags.py @@ -23,8 +23,8 @@ def test_tags(self, project): " where schema_name = '{schema}' and table_name='tags'", fetch="all", ) - assert len(results) == 2 - expected_tags = {("a", "b"), ("c", "d")} + assert len(results) == 3 + expected_tags = {("a", "b"), ("c", "d"), ("k", "")} actual_tags = set((row[0], row[1]) for row in results) assert actual_tags == expected_tags @@ -56,8 +56,8 @@ def test_updated_tags(self, project): " where schema_name = '{schema}' and table_name='tags'", fetch="all", ) - assert len(results) == 3 - expected_tags = {("a", "b"), ("c", "d"), ("e", "f")} + assert len(results) == 4 + expected_tags = {("a", "b"), ("c", "d"), ("k", ""), ("e", "f")} actual_tags = set((row[0], row[1]) for row in results) assert actual_tags == expected_tags @@ -151,7 +151,7 @@ def test_updated_tags(self, project): " where schema_name = '{schema}' and table_name='tags'", fetch="all", ) - assert len(results) == 3 + assert len(results) == 4 @pytest.mark.python diff --git a/tests/unit/macros/adapters/test_metadata_macros.py b/tests/unit/macros/adapters/test_metadata_macros.py index 23eb0b0e9..643f99148 100644 --- a/tests/unit/macros/adapters/test_metadata_macros.py +++ b/tests/unit/macros/adapters/test_metadata_macros.py @@ -197,6 +197,13 @@ def test_check_schema_exists_sql_with_hyphenated_database(self, template_bundle) expected_sql = "SHOW SCHEMAS IN `data_engineering-uc-dev` LIKE 'my_schema'" self.assert_sql_equal(result, expected_sql) + def test_describe_table_extended_as_json_sql(self, template_bundle, relation): + result = self.run_macro( + template_bundle.template, "describe_table_extended_as_json_sql", relation + ) + expected_sql = "DESCRIBE TABLE EXTENDED `some_database`.`some_schema`.`some_table` AS JSON" + self.assert_sql_equal(result, expected_sql) + def test_case_sensitivity(self, template_bundle): relation = Mock() relation.database = "TEST_DB" diff --git a/tests/unit/macros/materializations/functions/__init__.py b/tests/unit/macros/materializations/functions/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/macros/materializations/functions/test_scalar_functions.py b/tests/unit/macros/materializations/functions/test_scalar_functions.py new file mode 100644 index 000000000..a89a244f7 --- /dev/null +++ b/tests/unit/macros/materializations/functions/test_scalar_functions.py @@ -0,0 +1,231 @@ +from unittest.mock import Mock + +import pytest + +from tests.unit.macros.base import MacroTestBase + + +class ScalarFunctionTestBase(MacroTestBase): + """Base class for scalar function tests with shared fixtures""" + + @pytest.fixture(scope="class") + def template_name(self) -> str: + return "scalar.sql" + + @pytest.fixture(scope="class") + def macro_folders_to_load(self) -> list: + return ["macros/materializations/functions", "macros"] + + @pytest.fixture + def default_context(self) -> dict: + """Override default context to add formatted_scalar_function_args_sql""" + # Get the base context + mock_adapter = Mock() + mock_adapter.quote = lambda identifier: f"`{identifier}`" + + def mock_relation_create(database=None, schema=None, identifier=None, type=None): + mock_relation = Mock() + if database and schema and type == "table": + mock_relation.render = Mock(return_value=f"{schema}.{identifier}") + elif database and schema: + mock_relation.render = Mock(return_value=f"`{database}`.`{schema}`.`{identifier}`") + elif schema: + mock_relation.render = Mock(return_value=f"{schema}.{identifier}") + else: + mock_relation.render = Mock(return_value=identifier) + return mock_relation + + from dbt.adapters.databricks.column import DatabricksColumn + + mock_api = Mock(Column=DatabricksColumn) + mock_api.Relation.create = mock_relation_create + + # Create mock model with arguments for formatted_scalar_function_args_sql + mock_model = Mock() + mock_model.arguments = [] + + def formatted_scalar_function_args_sql(): + """Mock implementation of formatted_scalar_function_args_sql""" + args = [] + for arg in mock_model.arguments: + args.append(f"{arg.name} {arg.data_type}") + return ", ".join(args) + + context = { + "validation": Mock(), + "model": mock_model, + "exceptions": Mock(), + "config": Mock(), + "statement": lambda r, caller: r, + "adapter": mock_adapter, + "var": Mock(), + "log": Mock(return_value=""), + "return": lambda r: r, + "is_incremental": Mock(return_value=False), + "api": mock_api, + "local_md5": lambda s: f"hash({s})", + "formatted_scalar_function_args_sql": formatted_scalar_function_args_sql, + } + return context + + +class TestScalarFunctionSQL(ScalarFunctionTestBase): + """Tests for SQL UDF macro generation""" + + def setup_model_for_sql_udf(self, context): + """Configure model mock for SQL UDF""" + context["model"].language = "sql" + context["model"].compiled_code = "SELECT value * 2" + context["model"].returns = Mock() + context["model"].returns.data_type = "FLOAT" + arg = Mock() + arg.name = "value" + arg.data_type = "FLOAT" + context["model"].arguments = [arg] + + def test_sql_udf_signature(self, template_bundle): + self.setup_model_for_sql_udf(template_bundle.context) + sql = self.run_macro( + template_bundle.template, + "databricks__scalar_function_create_replace_signature_sql", + template_bundle.relation, + ) + assert "language sql" in sql + assert "create or replace function" in sql + + def test_sql_udf_body(self, template_bundle): + self.setup_model_for_sql_udf(template_bundle.context) + sql = self.run_macro(template_bundle.template, "databricks__scalar_function_body_sql") + assert "return" in sql + assert "select value * 2" in sql + + +class TestScalarFunctionPython(ScalarFunctionTestBase): + """Tests for Python UDF macro generation""" + + def setup_model_for_python_udf(self, context, runtime_version="3.11", entry_point="entry"): + """Configure model mock for Python UDF""" + context["model"].language = "python" + context["model"].compiled_code = "return value * 2" + context["model"].returns = Mock() + context["model"].returns.data_type = "FLOAT" + arg = Mock() + arg.name = "value" + arg.data_type = "FLOAT" + context["model"].arguments = [arg] + config_values = {"runtime_version": runtime_version, "entry_point": entry_point} + context["model"].config.get = lambda k, default=None: config_values.get(k, default) + + def test_python_udf_signature(self, template_bundle): + self.setup_model_for_python_udf(template_bundle.context) + sql = self.run_macro( + template_bundle.template, + "databricks__scalar_function_create_replace_signature_python", + template_bundle.relation, + ) + assert "language python" in sql + assert "create or replace function" in sql + # Databricks does not support RUNTIME_VERSION or HANDLER clauses + assert "runtime_version" not in sql + assert "handler" not in sql + + def test_python_udf_body_dollar_quoting(self, template_bundle): + self.setup_model_for_python_udf(template_bundle.context) + sql = self.run_macro_raw( + template_bundle.template, "databricks__scalar_function_body_python" + ) + assert "$$" in sql + assert "return value * 2" in sql + + def test_python_udf_full(self, template_bundle): + self.setup_model_for_python_udf(template_bundle.context) + sql = self.run_macro( + template_bundle.template, + "databricks__scalar_function_python", + template_bundle.relation, + ) + assert "language python" in sql + assert "runtime_version" not in sql + assert "handler" not in sql + assert "$$" in sql + + def test_python_udf_warns_when_runtime_version_set(self, template_bundle): + """Test that warning is emitted when runtime_version is explicitly configured""" + self.setup_model_for_python_udf(template_bundle.context, runtime_version="3.11") + self.run_macro( + template_bundle.template, + "databricks__scalar_function_python", + template_bundle.relation, + ) + # exceptions.warn should have been called with runtime_version message + calls = template_bundle.context["exceptions"].warn.call_args_list + warn_messages = [str(c) for c in calls] + assert any("runtime_version" in msg for msg in warn_messages), ( + f"Expected warning about runtime_version, got: {warn_messages}" + ) + + def test_python_udf_warns_when_entry_point_set(self, template_bundle): + """Test that warning is emitted when entry_point is explicitly configured""" + self.setup_model_for_python_udf(template_bundle.context, entry_point="my_handler") + self.run_macro( + template_bundle.template, + "databricks__scalar_function_python", + template_bundle.relation, + ) + # exceptions.warn should have been called with entry_point message + calls = template_bundle.context["exceptions"].warn.call_args_list + warn_messages = [str(c) for c in calls] + assert any("entry_point" in msg for msg in warn_messages), ( + f"Expected warning about entry_point, got: {warn_messages}" + ) + + def setup_model_for_multi_arg_python_udf(self, context): + """Configure model mock for a multi-argument Python UDF (price FLOAT, quantity INT).""" + context["model"].language = "python" + context["model"].compiled_code = "return price * quantity" + context["model"].returns = Mock() + context["model"].returns.data_type = "FLOAT" + arg_price = Mock() + arg_price.name = "price" + arg_price.data_type = "FLOAT" + arg_quantity = Mock() + arg_quantity.name = "quantity" + arg_quantity.data_type = "INT" + context["model"].arguments = [arg_price, arg_quantity] + config_values = {"runtime_version": None, "entry_point": None} + context["model"].config.get = lambda k, default=None: config_values.get(k, default) + + def test_python_udf_multi_arg_signature(self, template_bundle): + """Test that multi-arg Python UDF generates correct signature with both args.""" + self.setup_model_for_multi_arg_python_udf(template_bundle.context) + sql = self.run_macro( + template_bundle.template, + "databricks__scalar_function_create_replace_signature_python", + template_bundle.relation, + ) + assert "price float, quantity int" in sql + assert "language python" in sql + assert "create or replace function" in sql + + def test_python_udf_multi_arg_body(self, template_bundle): + """Test that multi-arg Python UDF body contains dollar-quoting and code.""" + self.setup_model_for_multi_arg_python_udf(template_bundle.context) + sql = self.run_macro_raw( + template_bundle.template, "databricks__scalar_function_body_python" + ) + assert "$$" in sql + assert "return price * quantity" in sql + + def test_python_udf_multi_arg_full(self, template_bundle): + """Test full multi-arg Python UDF SQL: signature + body, no runtime_version/handler.""" + self.setup_model_for_multi_arg_python_udf(template_bundle.context) + sql = self.run_macro( + template_bundle.template, + "databricks__scalar_function_python", + template_bundle.relation, + ) + assert "price float, quantity int" in sql + assert "language python" in sql + assert "$$" in sql + assert "runtime_version" not in sql + assert "handler" not in sql diff --git a/tests/unit/macros/relations/test_metric_view_create.py b/tests/unit/macros/relations/test_metric_view_create.py new file mode 100644 index 000000000..d8e88ded1 --- /dev/null +++ b/tests/unit/macros/relations/test_metric_view_create.py @@ -0,0 +1,148 @@ +import pytest + +from tests.unit.macros.base import MacroTestBase + + +class TestGetCreateMetricViewAsSQL(MacroTestBase): + @pytest.fixture(scope="class") + def template_name(self) -> str: + return "create.sql" + + @pytest.fixture(scope="class") + def macro_folders_to_load(self) -> list: + return ["macros", "macros/relations/metric_view"] + + def test_basic_metric_view_creation(self, template_bundle): + """Test that get_create_metric_view_as_sql generates correct Databricks SQL""" + yaml_spec = """version: 0.1 +source: orders +dimensions: + - name: order_date + expr: order_date +measures: + - name: order_count + expr: count(1)""" + + result = self.run_macro_raw( + template_bundle.template, + "databricks__get_create_metric_view_as_sql", + template_bundle.relation, + yaml_spec, + ) + + expected = """create or replace view `some_database`.`some_schema`.`some_table` +with metrics +language yaml +as $$ +version: 0.1 +source: orders +dimensions: + - name: order_date + expr: order_date +measures: + - name: order_count + expr: count(1) +$$""" + + # For metric views, we need to preserve YAML formatting exactly + assert result.strip() == expected.strip() + + def test_metric_view_with_filter(self, template_bundle): + """Test metric view generation with filter clause""" + yaml_spec = """version: 0.1 +source: orders +filter: status = 'completed' +dimensions: + - name: order_date + expr: order_date +measures: + - name: revenue + expr: sum(amount)""" + + result = self.run_macro_raw( + template_bundle.template, + "databricks__get_create_metric_view_as_sql", + template_bundle.relation, + yaml_spec, + ) + + expected = """create or replace view `some_database`.`some_schema`.`some_table` +with metrics +language yaml +as $$ +version: 0.1 +source: orders +filter: status = 'completed' +dimensions: + - name: order_date + expr: order_date +measures: + - name: revenue + expr: sum(amount) +$$""" + + assert result.strip() == expected.strip() + + def test_complex_metric_view(self, template_bundle): + """Test metric view with multiple dimensions and measures""" + yaml_spec = """version: 0.1 +source: customer_orders +filter: order_date >= '2024-01-01' +dimensions: + - name: customer_segment + expr: customer_type + - name: order_month + expr: date_trunc('MONTH', order_date) +measures: + - name: total_orders + expr: count(1) + - name: total_revenue + expr: sum(order_total) + - name: avg_order_value + expr: avg(order_total)""" + + result = self.run_macro_raw( + template_bundle.template, + "databricks__get_create_metric_view_as_sql", + template_bundle.relation, + yaml_spec, + ) + + # Check that all key parts are present + assert "create or replace view" in result.lower() + assert "with metrics" in result.lower() + assert "language yaml" in result.lower() + assert "as $$" in result.lower() + assert result.strip().endswith("$$") + assert "version: 0.1" in result + assert "source: customer_orders" in result + assert "filter: order_date >= '2024-01-01'" in result + assert "customer_segment" in result + assert "total_revenue" in result + + def test_generic_macro_dispatcher(self, template_bundle): + """Test that the generic get_create_metric_view_as_sql macro works""" + yaml_spec = """version: 0.1 +source: test_table +measures: + - name: count + expr: count(1)""" + + # Mock the adapter dispatch to return our databricks implementation + template_bundle.context["adapter"].dispatch.return_value = getattr( + template_bundle.template.module, "databricks__get_create_metric_view_as_sql" + ) + + result = self.run_macro_raw( + template_bundle.template, + "get_create_metric_view_as_sql", + template_bundle.relation, + yaml_spec, + ) + + # Should generate the same output as the databricks-specific macro + assert "create or replace view" in result.lower() + assert "with metrics" in result.lower() + assert "language yaml" in result.lower() + assert "version: 0.1" in result + assert "source: test_table" in result diff --git a/tests/unit/macros/relations/test_row_filter_macros.py b/tests/unit/macros/relations/test_row_filter_macros.py new file mode 100644 index 000000000..04a539180 --- /dev/null +++ b/tests/unit/macros/relations/test_row_filter_macros.py @@ -0,0 +1,409 @@ +""" +Unit tests for row filter Jinja macros. + +These tests verify that the Jinja macros in row_filter.sql produce correct SQL +and handle edge cases properly. This complements the Python unit tests in +tests/unit/relation_configs/test_row_filter.py which test the Python-side logic. + +The goal is to ensure parity between: +- Python: RowFilterProcessor._qualify_function_name() +- Jinja: qualify_row_filter_function() +""" + +from unittest.mock import Mock + +import pytest + +from dbt.adapters.databricks.relation import DatabricksRelationType +from tests.unit.macros.base import MacroTestBase + + +class TestQuoteRowFilterColumns(MacroTestBase): + """Tests for the quote_row_filter_columns helper macro.""" + + @pytest.fixture(scope="class") + def template_name(self) -> str: + return "row_filter.sql" + + @pytest.fixture(scope="class") + def macro_folders_to_load(self) -> list: + return ["macros/relations/components", "macros"] + + def test_single_column(self, template_bundle): + sql = self.run_macro(template_bundle.template, "quote_row_filter_columns", ["col1"]) + assert sql == "`col1`" + + def test_multiple_columns(self, template_bundle): + sql = self.run_macro(template_bundle.template, "quote_row_filter_columns", ["a", "b", "c"]) + assert sql == "`a`, `b`, `c`" + + def test_reserved_words(self, template_bundle): + sql = self.run_macro( + template_bundle.template, "quote_row_filter_columns", ["select", "order", "table"] + ) + assert sql == "`select`, `order`, `table`" + + def test_column_with_spaces(self, template_bundle): + sql = self.run_macro( + template_bundle.template, "quote_row_filter_columns", ["my column", "another col"] + ) + assert sql == "`my column`, `another col`" + + +class TestQuoteRowFilterFunction(MacroTestBase): + """Tests for the quote_row_filter_function helper macro. + + This macro adds backticks to a raw function name at SQL generation time. + Function names are stored raw internally (e.g., 'cat.schema.fn') and quoted + only when generating SQL (e.g., '`cat`.`schema`.`fn`'). + """ + + @pytest.fixture(scope="class") + def template_name(self) -> str: + return "row_filter.sql" + + @pytest.fixture(scope="class") + def macro_folders_to_load(self) -> list: + return ["macros/relations/components", "macros"] + + def test_three_part_function_name(self, template_bundle): + """Three-part name should get each part quoted separately.""" + sql = self.run_macro( + template_bundle.template, "quote_row_filter_function", "cat.schema.my_filter" + ) + expected = "`cat`.`schema`.`my_filter`" + assert sql == self.clean_sql(expected) + + def test_function_with_special_chars(self, template_bundle): + """Function names with special characters should be properly quoted.""" + sql = self.run_macro( + template_bundle.template, "quote_row_filter_function", "my_cat.my_schema.my_filter_v2" + ) + expected = "`my_cat`.`my_schema`.`my_filter_v2`" + assert sql == self.clean_sql(expected) + + def test_fallback_for_unexpected_format(self, template_bundle): + """Non-3-part names should fallback to quoting the whole string.""" + sql = self.run_macro( + template_bundle.template, "quote_row_filter_function", "just_a_function" + ) + expected = "`just_a_function`" + assert sql == self.clean_sql(expected) + + +class TestQualifyRowFilterFunction(MacroTestBase): + """Tests for the qualify_row_filter_function macro. + + These tests should mirror the Python tests for _qualify_function_name() + in tests/unit/relation_configs/test_row_filter.py. + """ + + @pytest.fixture(scope="class") + def template_name(self) -> str: + return "row_filter.sql" + + @pytest.fixture(scope="class") + def macro_folders_to_load(self) -> list: + return ["macros/relations/components", "macros"] + + def test_unqualified_function(self, template_bundle): + """1-part name should be qualified with relation's database.schema.""" + sql = self.run_macro( + template_bundle.template, + "qualify_row_filter_function", + "my_filter", + template_bundle.relation, + ) + expected = "some_database.some_schema.my_filter" + assert sql == self.clean_sql(expected) + + def test_fully_qualified_function(self, template_bundle): + """3-part name should be returned raw.""" + sql = self.run_macro( + template_bundle.template, + "qualify_row_filter_function", + "cat.schema.fn", + template_bundle.relation, + ) + expected = "cat.schema.fn" + assert sql == self.clean_sql(expected) + + def test_function_with_existing_backticks(self, template_bundle): + """Function name with backticks should have them stripped and returned raw.""" + sql = self.run_macro( + template_bundle.template, + "qualify_row_filter_function", + "`cat`.`schema`.`fn`", + template_bundle.relation, + ) + expected = "cat.schema.fn" + assert sql == self.clean_sql(expected) + + def test_two_part_function_raises_error(self, context, template_bundle): + """2-part name is ambiguous and should raise an error.""" + context["exceptions"] = Mock() + context["exceptions"].raise_compiler_error = Mock(side_effect=Exception("Test error")) + + with pytest.raises(Exception, match="Test error"): + self.run_macro( + template_bundle.template, + "qualify_row_filter_function", + "schema.fn", + template_bundle.relation, + ) + + # Verify the error message mentions ambiguity + call_args = context["exceptions"].raise_compiler_error.call_args[0][0] + assert "ambiguous" in call_args.lower() + + def test_four_part_function_raises_error(self, context, template_bundle): + """4+ part name should raise an error.""" + context["exceptions"] = Mock() + context["exceptions"].raise_compiler_error = Mock(side_effect=Exception("Test error")) + + with pytest.raises(Exception, match="Test error"): + self.run_macro( + template_bundle.template, + "qualify_row_filter_function", + "a.b.c.d", + template_bundle.relation, + ) + + # Verify the error message mentions too many parts + call_args = context["exceptions"].raise_compiler_error.call_args[0][0] + assert "too many parts" in call_args.lower() + + +class TestAlterSetRowFilter(MacroTestBase): + """Tests for the alter_set_row_filter macro.""" + + @pytest.fixture(scope="class") + def template_name(self) -> str: + return "row_filter.sql" + + @pytest.fixture(scope="class") + def macro_folders_to_load(self) -> list: + return ["macros/relations/components", "macros"] + + def test_basic_alter_set(self, template_bundle): + """Test basic ALTER TABLE SET ROW FILTER.""" + row_filter = Mock() + row_filter.function = "cat.schema.my_filter" + row_filter.columns = ("col1",) + + sql = self.run_macro( + template_bundle.template, "alter_set_row_filter", template_bundle.relation, row_filter + ) + expected = ( + "alter table `some_database`.`some_schema`.`some_table` " + "set row filter `cat`.`schema`.`my_filter` on (`col1`)" + ) + assert sql == self.clean_sql(expected) + + def test_alter_set_multiple_columns(self, template_bundle): + """Test ALTER with multiple columns.""" + row_filter = Mock() + row_filter.function = "cat.schema.my_filter" + row_filter.columns = ("region", "country_code") + + sql = self.run_macro( + template_bundle.template, "alter_set_row_filter", template_bundle.relation, row_filter + ) + expected = ( + "alter table `some_database`.`some_schema`.`some_table` " + "set row filter `cat`.`schema`.`my_filter` on (`region`, `country_code`)" + ) + assert sql == self.clean_sql(expected) + + def test_alter_set_on_materialized_view(self, template_bundle): + """Test ALTER on materialized view.""" + template_bundle.relation.type = DatabricksRelationType.MaterializedView + template_bundle.relation.type.render = Mock(return_value="MATERIALIZED VIEW") + + row_filter = Mock() + row_filter.function = "cat.schema.my_filter" + row_filter.columns = ("col1",) + + sql = self.run_macro( + template_bundle.template, "alter_set_row_filter", template_bundle.relation, row_filter + ) + assert "materialized view" in sql.lower() + + +class TestAlterDropRowFilter(MacroTestBase): + """Tests for the alter_drop_row_filter macro.""" + + @pytest.fixture(scope="class") + def template_name(self) -> str: + return "row_filter.sql" + + @pytest.fixture(scope="class") + def macro_folders_to_load(self) -> list: + return ["macros/relations/components", "macros"] + + def test_drop_row_filter_table(self, template_bundle): + """Test DROP ROW FILTER on table.""" + sql = self.run_macro( + template_bundle.template, "alter_drop_row_filter", template_bundle.relation + ) + expected = "alter table `some_database`.`some_schema`.`some_table` drop row filter" + assert sql == self.clean_sql(expected) + + def test_drop_row_filter_materialized_view(self, template_bundle): + """Test DROP ROW FILTER on materialized view.""" + template_bundle.relation.type = DatabricksRelationType.MaterializedView + template_bundle.relation.type.render = Mock(return_value="MATERIALIZED VIEW") + + sql = self.run_macro( + template_bundle.template, "alter_drop_row_filter", template_bundle.relation + ) + assert "materialized view" in sql.lower() + assert "drop row filter" in sql.lower() + + +class TestGetCreateRowFilterClause(MacroTestBase): + """Tests for the get_create_row_filter_clause macro.""" + + @pytest.fixture(scope="class") + def template_name(self) -> str: + return "row_filter.sql" + + @pytest.fixture(scope="class") + def macro_folders_to_load(self) -> list: + return ["macros/relations/components", "macros"] + + def test_no_config(self, config, template_bundle): + """Test with no row_filter config.""" + # config is empty by default + sql = self.run_macro( + template_bundle.template, "get_create_row_filter_clause", template_bundle.relation + ) + assert sql == "" + + def test_with_valid_config(self, config, template_bundle): + """Test with valid row_filter config.""" + config["row_filter"] = {"function": "my_filter", "columns": ["region"]} + + sql = self.run_macro( + template_bundle.template, "get_create_row_filter_clause", template_bundle.relation + ) + expected = "with row filter `some_database`.`some_schema`.`my_filter` on (`region`)" + assert sql == self.clean_sql(expected) + + def test_with_multiple_columns(self, config, template_bundle): + """Test with multiple columns.""" + config["row_filter"] = { + "function": "multi_col_filter", + "columns": ["region", "country_code"], + } + + sql = self.run_macro( + template_bundle.template, "get_create_row_filter_clause", template_bundle.relation + ) + assert "`region`, `country_code`" in sql.lower() + + def test_with_string_column_normalized(self, config, template_bundle): + """Test that string columns value is normalized to list.""" + config["row_filter"] = {"function": "my_filter", "columns": "region"} + + sql = self.run_macro( + template_bundle.template, "get_create_row_filter_clause", template_bundle.relation + ) + expected = "with row filter `some_database`.`some_schema`.`my_filter` on (`region`)" + assert sql == self.clean_sql(expected) + + def test_with_fully_qualified_function(self, config, template_bundle): + """Test with fully qualified function name.""" + config["row_filter"] = { + "function": "other_cat.other_schema.other_filter", + "columns": ["col1"], + } + + sql = self.run_macro( + template_bundle.template, "get_create_row_filter_clause", template_bundle.relation + ) + assert "`other_cat`.`other_schema`.`other_filter`" in sql.lower() + + def test_empty_columns_raises_error(self, config, context, template_bundle): + """Test that empty columns raises an error.""" + config["row_filter"] = {"function": "my_filter", "columns": []} + context["exceptions"] = Mock() + context["exceptions"].raise_compiler_error = Mock(side_effect=Exception("Test error")) + + with pytest.raises(Exception, match="Test error"): + self.run_macro( + template_bundle.template, "get_create_row_filter_clause", template_bundle.relation + ) + + # Verify the error message mentions non-empty columns + call_args = context["exceptions"].raise_compiler_error.call_args[0][0] + assert "non-empty" in call_args.lower() or "columns" in call_args.lower() + + +class TestFetchRowFiltersSql(MacroTestBase): + """Tests for the fetch_row_filters_sql macro.""" + + @pytest.fixture(scope="class") + def template_name(self) -> str: + return "row_filter.sql" + + @pytest.fixture(scope="class") + def macro_folders_to_load(self) -> list: + return ["macros/relations/components", "macros"] + + def test_fetch_row_filters_sql(self, template_bundle): + """Test the SQL generated for fetching row filters.""" + sql = self.run_macro( + template_bundle.template, "fetch_row_filters_sql", template_bundle.relation + ) + expected = """ + SELECT + table_catalog, + table_schema, + table_name, + filter_name, + target_columns + FROM `some_database`.`information_schema`.`row_filters` + WHERE table_catalog = 'some_database' + AND table_schema = 'some_schema' + AND table_name = 'some_table' + """ + assert sql == self.clean_sql(expected) + + +class TestDropRowFilterIfExists(MacroTestBase): + """Tests for the drop_row_filter_if_exists macro.""" + + @pytest.fixture(scope="class") + def template_name(self) -> str: + return "row_filter.sql" + + @pytest.fixture(scope="class") + def macro_folders_to_load(self) -> list: + return ["macros/relations/components", "macros"] + + def test_skips_hive_metastore(self, template_bundle): + """Test that Hive Metastore relations are skipped entirely.""" + template_bundle.relation.is_hive_metastore = Mock(return_value=True) + + sql = self.run_macro( + template_bundle.template, "drop_row_filter_if_exists", template_bundle.relation + ) + + # Should return empty - no DROP statement generated + assert sql == "" + # Verify is_hive_metastore was called + template_bundle.relation.is_hive_metastore.assert_called() + + def test_generates_correct_drop_sql(self, template_bundle): + """Test that alter_drop_row_filter generates correct SQL. + + This indirectly tests that drop_row_filter_if_exists would call + alter_drop_row_filter with the right arguments when filters exist. + """ + # Verify the underlying SQL generation is correct + sql = self.run_macro( + template_bundle.template, "alter_drop_row_filter", template_bundle.relation + ) + expected = "alter table `some_database`.`some_schema`.`some_table` drop row filter" + assert sql == self.clean_sql(expected) diff --git a/tests/unit/macros/relations/test_table_macros.py b/tests/unit/macros/relations/test_table_macros.py index 9b3730962..344c3a885 100644 --- a/tests/unit/macros/relations/test_table_macros.py +++ b/tests/unit/macros/relations/test_table_macros.py @@ -12,7 +12,12 @@ def template_name(self) -> str: @pytest.fixture(scope="class") def macro_folders_to_load(self) -> list: - return ["macros/relations/table", "macros/relations", "macros"] + return [ + "macros/relations/table", + "macros/relations", + "macros", + "macros/relations/components", + ] @pytest.fixture(scope="class") def databricks_template_names(self) -> list: @@ -21,6 +26,7 @@ def databricks_template_names(self) -> list: "tblproperties.sql", "location.sql", "liquid_clustering.sql", + "row_filter.sql", ] @pytest.fixture diff --git a/tests/unit/macros/relations/test_view_macros.py b/tests/unit/macros/relations/test_view_macros.py index c2c3fd9ae..5a84232ff 100644 --- a/tests/unit/macros/relations/test_view_macros.py +++ b/tests/unit/macros/relations/test_view_macros.py @@ -27,6 +27,7 @@ def test_macros_create_view_as_tblproperties(self, config, template_bundle): template_bundle.context["get_columns_in_query"] = Mock(return_value=[]) template_bundle.context["column_mask_exists"] = Mock(return_value=False) template_bundle.context["column_tags_exist"] = Mock(return_value=False) + template_bundle.context["row_filter_exists"] = Mock(return_value=False) sql = self.render_create_view_as(template_bundle) expected = ( diff --git a/tests/unit/relation_configs/test_column_tags_config.py b/tests/unit/relation_configs/test_column_tags_config.py index 099f9da30..2205138e8 100644 --- a/tests/unit/relation_configs/test_column_tags_config.py +++ b/tests/unit/relation_configs/test_column_tags_config.py @@ -26,7 +26,7 @@ def test_from_relation_results__some(self): "information_schema.column_tags": Table( rows=[ ["col1", "tag_a", "value_a"], - ["col1", "tag_b", "value_b"], + ["col1", "tag_b", ""], # key-only tag ["col2", "tag_c", "value_c"], ], column_names=["column_name", "tag_name", "tag_value"], @@ -35,7 +35,7 @@ def test_from_relation_results__some(self): spec = ColumnTagsProcessor.from_relation_results(results) assert spec == ColumnTagsConfig( set_column_tags={ - "col1": {"tag_a": "value_a", "tag_b": "value_b"}, + "col1": {"tag_a": "value_a", "tag_b": ""}, "col2": {"tag_c": "value_c"}, } ) @@ -54,14 +54,18 @@ def test_from_relation_config__without_column_tags(self): def test_from_relation_config__with_dict(self): model = Mock() model.columns = { - "email": {"_extra": {"databricks_tags": {"pii": "true", "env": "prod"}}}, + "email": { + "_extra": { + "databricks_tags": {"pii": "", "env": "prod", "priority": 0, "enabled": False} + } + }, "id": {"_extra": {}}, "created_at": {}, } spec = ColumnTagsProcessor.from_relation_config(model) assert spec == ColumnTagsConfig( set_column_tags={ - "email": {"pii": "true", "env": "prod"}, + "email": {"pii": "", "env": "prod", "priority": "0", "enabled": "False"}, } ) @@ -71,14 +75,16 @@ def test_from_relation_config__with_column_info(self): "id": ColumnInfo(name="id", _extra={}), "email": ColumnInfo( name="email", - _extra={"databricks_tags": {"pii": "true", "env": "prod"}}, + _extra={ + "databricks_tags": {"pii": "", "env": "prod", "priority": 0, "enabled": False} + }, ), "created_at": ColumnInfo(name="created_at"), } spec = ColumnTagsProcessor.from_relation_config(model) assert spec == ColumnTagsConfig( set_column_tags={ - "email": {"pii": "true", "env": "prod"}, + "email": {"pii": "", "env": "prod", "priority": "0", "enabled": "False"}, } ) diff --git a/tests/unit/relation_configs/test_incremental_config.py b/tests/unit/relation_configs/test_incremental_config.py index 5b99fec74..a543ef455 100644 --- a/tests/unit/relation_configs/test_incremental_config.py +++ b/tests/unit/relation_configs/test_incremental_config.py @@ -13,6 +13,7 @@ ) from dbt.adapters.databricks.relation_configs.incremental import IncrementalTableConfig from dbt.adapters.databricks.relation_configs.liquid_clustering import LiquidClusteringConfig +from dbt.adapters.databricks.relation_configs.row_filter import RowFilterConfig from dbt.adapters.databricks.relation_configs.tags import TagsConfig from dbt.adapters.databricks.relation_configs.tblproperties import TblPropertiesConfig @@ -143,5 +144,6 @@ def test_from_results(self): }, unset_column_masks=[], ), + "row_filter": RowFilterConfig(), } ) diff --git a/tests/unit/relation_configs/test_materialized_view_config.py b/tests/unit/relation_configs/test_materialized_view_config.py index 14177a8e2..44604f849 100644 --- a/tests/unit/relation_configs/test_materialized_view_config.py +++ b/tests/unit/relation_configs/test_materialized_view_config.py @@ -10,6 +10,7 @@ from dbt.adapters.databricks.relation_configs.partitioning import PartitionedByConfig from dbt.adapters.databricks.relation_configs.query import QueryConfig from dbt.adapters.databricks.relation_configs.refresh import RefreshConfig +from dbt.adapters.databricks.relation_configs.row_filter import RowFilterConfig from dbt.adapters.databricks.relation_configs.tags import TagsConfig from dbt.adapters.databricks.relation_configs.tblproperties import TblPropertiesConfig @@ -55,6 +56,7 @@ def test_from_results(self): "refresh": RefreshConfig(), "query": QueryConfig(query="select * from foo"), "tags": TagsConfig(set_tags={"a": "b", "c": "d"}), + "row_filter": RowFilterConfig(), } ) @@ -83,6 +85,7 @@ def test_from_model_node(self): "refresh": RefreshConfig(), "query": QueryConfig(query="select * from foo"), "tags": TagsConfig(set_tags={"a": "b", "c": "d"}), + "row_filter": RowFilterConfig(), } ) @@ -96,6 +99,7 @@ def test_get_changeset__no_changes(self): "refresh": RefreshConfig(), "query": QueryConfig(query="select * from foo"), "tags": TagsConfig(set_tags={"a": "b", "c": "d"}), + "row_filter": RowFilterConfig(), } ) new = MaterializedViewConfig( @@ -107,6 +111,7 @@ def test_get_changeset__no_changes(self): "refresh": RefreshConfig(), "query": QueryConfig(query="select * from foo"), "tags": TagsConfig(set_tags={"a": "b", "c": "d"}), + "row_filter": RowFilterConfig(), } ) @@ -122,6 +127,7 @@ def test_get_changeset__some_changes(self): "refresh": RefreshConfig(), "query": QueryConfig(query="select * from foo"), "tags": TagsConfig(set_tags={}), + "row_filter": RowFilterConfig(), } ) new = MaterializedViewConfig( @@ -133,6 +139,7 @@ def test_get_changeset__some_changes(self): "refresh": RefreshConfig(cron="*/5 * * * *"), "query": QueryConfig(query="select * from foo"), "tags": TagsConfig(set_tags={"a": "b", "c": "d"}), + "row_filter": RowFilterConfig(), } ) diff --git a/tests/unit/relation_configs/test_metric_view_config.py b/tests/unit/relation_configs/test_metric_view_config.py new file mode 100644 index 000000000..4f197fcfa --- /dev/null +++ b/tests/unit/relation_configs/test_metric_view_config.py @@ -0,0 +1,109 @@ +from unittest.mock import Mock + +import pytest +from dbt_common.exceptions import DbtRuntimeError + +from dbt.adapters.databricks.relation_configs.metric_view import ( + MetricViewQueryConfig, + MetricViewQueryProcessor, +) + +yaml_content = """version: 0.1 +source: my_table +dimensions: + - name: dim1 + expr: col1 +measures: + - name: count + expr: count(1)""" + + +class TestMetricViewQueryConfig: + def test_get_diff__same_query(self): + config1 = MetricViewQueryConfig(query=yaml_content) + config2 = MetricViewQueryConfig(query=yaml_content) + assert config1.get_diff(config2) is None + + def test_get_diff__different_query(self): + config1 = MetricViewQueryConfig(query=yaml_content) + config2 = MetricViewQueryConfig(query="version: 0.1\nsource: other_table") + assert config1.get_diff(config2) is config1 + + def test_get_diff__whitespace_normalization(self): + config1 = MetricViewQueryConfig(query="version: 0.1\nsource: my_table") + config2 = MetricViewQueryConfig(query="version: 0.1\n source: my_table ") + assert config1.get_diff(config2) is None + + def test_get_diff__different_whitespace_content(self): + config1 = MetricViewQueryConfig(query="version: 0.1 source: my_table") + config2 = MetricViewQueryConfig(query="version: 0.1 source: other_table") + assert config1.get_diff(config2) is config1 + + +class TestMetricViewQueryProcessor: + def test_from_relation_results__with_dollar_delimiters(self): + describe_extended = [ + ("col_name", "data_type", "comment"), + ("View Text", f"$${yaml_content}$$", None), + ] + results = {"describe_extended": describe_extended} + spec = MetricViewQueryProcessor.from_relation_results(results) + assert spec == MetricViewQueryConfig(query=yaml_content) + + def test_from_relation_results__with_whitespace_around_delimiters(self): + describe_extended = [ + ("col_name", "data_type", "comment"), + ("View Text", f" $$ {yaml_content} $$ ", None), + ] + results = {"describe_extended": describe_extended} + spec = MetricViewQueryProcessor.from_relation_results(results) + assert spec == MetricViewQueryConfig(query=yaml_content) + + def test_from_relation_results__without_delimiters(self): + describe_extended = [ + ("col_name", "data_type", "comment"), + ("View Text", yaml_content, None), + ] + results = {"describe_extended": describe_extended} + spec = MetricViewQueryProcessor.from_relation_results(results) + assert spec == MetricViewQueryConfig(query=yaml_content) + + def test_from_relation_results__missing_describe_extended(self): + results = {} + with pytest.raises(DbtRuntimeError, match="Cannot find metric view description"): + MetricViewQueryProcessor.from_relation_results(results) + + def test_from_relation_results__missing_view_text(self): + describe_extended = [ + ("col_name", "data_type", "comment"), + ("Other Field", "some_value", None), + ] + results = {"describe_extended": describe_extended} + with pytest.raises(DbtRuntimeError, match="no 'View Text' in DESCRIBE EXTENDED"): + MetricViewQueryProcessor.from_relation_results(results) + + def test_from_relation_config__with_query(self): + model = Mock() + model.compiled_code = yaml_content + spec = MetricViewQueryProcessor.from_relation_config(model) + assert spec == MetricViewQueryConfig(query=yaml_content) + + def test_from_relation_config__with_whitespace(self): + model = Mock() + model.compiled_code = f" {yaml_content} " + spec = MetricViewQueryProcessor.from_relation_config(model) + assert spec == MetricViewQueryConfig(query=yaml_content) + + def test_from_relation_config__without_query(self): + model = Mock() + model.compiled_code = None + model.identifier = "test_metric_view" + with pytest.raises(DbtRuntimeError, match="no YAML definition"): + MetricViewQueryProcessor.from_relation_config(model) + + def test_from_relation_config__empty_query(self): + model = Mock() + model.compiled_code = "" + model.identifier = "test_metric_view" + with pytest.raises(DbtRuntimeError, match="no YAML definition"): + MetricViewQueryProcessor.from_relation_config(model) diff --git a/tests/unit/relation_configs/test_row_filter.py b/tests/unit/relation_configs/test_row_filter.py new file mode 100644 index 000000000..419cb7acf --- /dev/null +++ b/tests/unit/relation_configs/test_row_filter.py @@ -0,0 +1,229 @@ +from unittest.mock import Mock + +import pytest +from agate import Table + +from dbt.adapters.databricks.relation_configs.row_filter import ( + RowFilterConfig, + RowFilterProcessor, +) + + +class TestRowFilterConfig: + def test_no_change_when_both_none(self): + desired = RowFilterConfig() + existing = RowFilterConfig() + assert desired.get_diff(existing) is None + + def test_unset_when_removed(self): + """When filter is removed from config, return RowFilterConfig with should_unset=True.""" + desired = RowFilterConfig() + existing = RowFilterConfig(function="cat.schema.fn", columns=("col1",)) + diff = desired.get_diff(existing) + assert diff is not None + assert isinstance(diff, RowFilterConfig) + assert diff.should_unset is True + assert diff.is_change is True # Critical: marks as actual change + assert diff.function is None + + def test_set_when_new(self): + """When filter is added, return RowFilterConfig with the filter to apply.""" + desired = RowFilterConfig(function="cat.schema.fn", columns=("col1",)) + existing = RowFilterConfig() + diff = desired.get_diff(existing) + assert diff is not None + assert isinstance(diff, RowFilterConfig) + assert diff.function == "cat.schema.fn" + assert diff.should_unset is False + assert diff.is_change is True # Critical: marks as actual change + + def test_no_change_when_equal_case_insensitive(self): + desired = RowFilterConfig(function="CAT.SCHEMA.FN", columns=("COL1",)) + existing = RowFilterConfig(function="cat.schema.fn", columns=("col1",)) + assert desired.get_diff(existing) is None + + def test_change_when_different_function(self): + """When filter function changes, return RowFilterConfig with new function.""" + desired = RowFilterConfig(function="cat.schema.fn2", columns=("col1",)) + existing = RowFilterConfig(function="cat.schema.fn1", columns=("col1",)) + diff = desired.get_diff(existing) + assert diff is not None + assert isinstance(diff, RowFilterConfig) + assert diff.function == "cat.schema.fn2" + assert diff.should_unset is False + assert diff.is_change is True # Critical: marks as actual change + + def test_change_when_different_columns(self): + """When filter columns change, return RowFilterConfig with new columns.""" + desired = RowFilterConfig(function="cat.schema.fn", columns=("col1", "col2")) + existing = RowFilterConfig(function="cat.schema.fn", columns=("col1",)) + diff = desired.get_diff(existing) + assert diff is not None + assert isinstance(diff, RowFilterConfig) + assert diff.function == "cat.schema.fn" + assert diff.columns == ("col1", "col2") + assert diff.should_unset is False + assert diff.is_change is True # Critical: marks as actual change + + def test_is_change_false_by_default(self): + """RowFilterConfig.is_change should be False by default (current state objects).""" + config = RowFilterConfig(function="cat.schema.fn", columns=("col1",)) + assert config.is_change is False + assert config.should_unset is False + + +class TestRowFilterProcessor: + def test_parse_target_columns_simple(self): + result = RowFilterProcessor._parse_target_columns("col1, col2") + assert result == ["col1", "col2"] + + def test_parse_target_columns_quoted(self): + result = RowFilterProcessor._parse_target_columns('"col1", "col2"') + assert result == ["col1", "col2"] + + def test_parse_target_columns_empty(self): + result = RowFilterProcessor._parse_target_columns("") + assert result == [] + + def test_parse_target_columns_none(self): + result = RowFilterProcessor._parse_target_columns(None) + assert result == [] + + def test_qualify_function_name_already_qualified(self): + model = Mock() + model.database = "mycat" + model.schema = "myschema" + + result = RowFilterProcessor._qualify_function_name("cat.schema.fn", model) + assert result == "cat.schema.fn" + + def test_qualify_function_name_unqualified(self): + model = Mock() + model.database = "mycat" + model.schema = "myschema" + + result = RowFilterProcessor._qualify_function_name("my_fn", model) + assert result == "mycat.myschema.my_fn" + + def test_qualify_function_name_two_part_raises(self): + model = Mock() + model.database = "mycat" + model.schema = "myschema" + + with pytest.raises(ValueError) as exc_info: + RowFilterProcessor._qualify_function_name("schema.fn", model) + assert "ambiguous" in str(exc_info.value).lower() + + def test_qualify_function_name_four_part_raises(self): + model = Mock() + model.database = "mycat" + model.schema = "myschema" + + with pytest.raises(ValueError) as exc_info: + RowFilterProcessor._qualify_function_name("a.b.c.d", model) + assert "too many parts" in str(exc_info.value).lower() + + def test_from_relation_config_empty_columns_raises(self): + model = Mock() + model.database = "mycat" + model.schema = "myschema" + model.config.extra = {"row_filter": {"function": "fn", "columns": []}} + + with pytest.raises(ValueError) as exc_info: + RowFilterProcessor.from_relation_config(model) + assert "non-empty 'columns' value" in str(exc_info.value) + + def test_from_relation_config_with_valid_filter(self): + model = Mock() + model.database = "mycat" + model.schema = "myschema" + model.config.extra = {"row_filter": {"function": "my_filter", "columns": ["col1"]}} + + spec = RowFilterProcessor.from_relation_config(model) + assert spec.function == "mycat.myschema.my_filter" + assert spec.columns == ("col1",) + + def test_from_relation_config_string_columns_normalized(self): + """String columns should be normalized to list.""" + model = Mock() + model.database = "mycat" + model.schema = "myschema" + model.config.extra = {"row_filter": {"function": "my_filter", "columns": "region"}} + + spec = RowFilterProcessor.from_relation_config(model) + assert spec.function == "mycat.myschema.my_filter" + assert spec.columns == ("region",) + + def test_from_relation_config_empty_string_column_raises(self): + model = Mock() + model.database = "mycat" + model.schema = "myschema" + model.config.extra = {"row_filter": {"function": "fn", "columns": ["col1", ""]}} + + with pytest.raises(ValueError) as exc_info: + RowFilterProcessor.from_relation_config(model) + assert "non-empty string" in str(exc_info.value) + + def test_from_relation_config_whitespace_only_column_raises(self): + model = Mock() + model.database = "mycat" + model.schema = "myschema" + model.config.extra = {"row_filter": {"function": "fn", "columns": ["col1", " "]}} + + with pytest.raises(ValueError) as exc_info: + RowFilterProcessor.from_relation_config(model) + assert "non-empty string" in str(exc_info.value) + + def test_from_relation_results_empty(self): + results = { + "row_filters": Table( + rows=[], + column_names=[ + "table_catalog", + "table_schema", + "table_name", + "filter_name", + "target_columns", + ], + ) + } + spec = RowFilterProcessor.from_relation_results(results) + assert spec == RowFilterConfig() + + def test_from_relation_results_one_row(self): + # filter_name contains fully qualified function name (catalog.schema.function) + results = { + "row_filters": Table( + rows=[["cat", "schema", "my_table", "cat.schema.my_filter", "col1, col2"]], + column_names=[ + "table_catalog", + "table_schema", + "table_name", + "filter_name", + "target_columns", + ], + ) + } + spec = RowFilterProcessor.from_relation_results(results) + assert spec.function == "cat.schema.my_filter" + assert spec.columns == ("col1", "col2") + + def test_from_relation_results_multiple_rows_raises(self): + results = { + "row_filters": Table( + rows=[ + ["cat", "schema", "my_table", "cat.schema.filter1", "col1"], + ["cat", "schema", "my_table", "cat.schema.filter2", "col2"], + ], + column_names=[ + "table_catalog", + "table_schema", + "table_name", + "filter_name", + "target_columns", + ], + ) + } + with pytest.raises(ValueError) as exc_info: + RowFilterProcessor.from_relation_results(results) + assert "Multiple row filters found" in str(exc_info.value) diff --git a/tests/unit/relation_configs/test_streaming_table_config.py b/tests/unit/relation_configs/test_streaming_table_config.py index 41f7b6290..af78f17a8 100644 --- a/tests/unit/relation_configs/test_streaming_table_config.py +++ b/tests/unit/relation_configs/test_streaming_table_config.py @@ -7,6 +7,7 @@ from dbt.adapters.databricks.relation_configs.partitioning import PartitionedByConfig from dbt.adapters.databricks.relation_configs.query import QueryConfig from dbt.adapters.databricks.relation_configs.refresh import RefreshConfig +from dbt.adapters.databricks.relation_configs.row_filter import RowFilterConfig from dbt.adapters.databricks.relation_configs.streaming_table import ( StreamingTableConfig, ) @@ -47,6 +48,7 @@ def test_from_results(self): "refresh": RefreshConfig(), "tags": TagsConfig(set_tags={"a": "b", "c": "d"}), "query": QueryConfig(query="select * from foo"), + "row_filter": RowFilterConfig(), } ) @@ -75,6 +77,7 @@ def test_from_model_node(self): "refresh": RefreshConfig(), "tags": TagsConfig(set_tags={"a": "b", "c": "d"}), "query": QueryConfig(query="select * from foo"), + "row_filter": RowFilterConfig(), } ) @@ -88,6 +91,7 @@ def test_get_changeset__no_changes(self): "refresh": RefreshConfig(), "tags": TagsConfig(set_tags={"a": "b", "c": "d"}), "query": QueryConfig(query="select * from foo"), + "row_filter": RowFilterConfig(), } ) new = StreamingTableConfig( @@ -99,6 +103,7 @@ def test_get_changeset__no_changes(self): "refresh": RefreshConfig(), "tags": TagsConfig(set_tags={"a": "b", "c": "d"}), "query": QueryConfig(query="select * from foo"), + "row_filter": RowFilterConfig(), } ) @@ -116,6 +121,7 @@ def test_get_changeset__some_changes(self): "refresh": RefreshConfig(), "tags": TagsConfig(set_tags={}), "query": QueryConfig(query="select * from foo"), + "row_filter": RowFilterConfig(), } ) new = StreamingTableConfig( @@ -127,6 +133,7 @@ def test_get_changeset__some_changes(self): "refresh": RefreshConfig(cron="*/5 * * * *"), "tags": TagsConfig(set_tags={"a": "b", "c": "d"}), "query": QueryConfig(query="select * from foo"), + "row_filter": RowFilterConfig(), } ) @@ -142,4 +149,5 @@ def test_get_changeset__some_changes(self): "refresh": RefreshConfig(cron="*/5 * * * *"), "tags": TagsConfig(set_tags={"a": "b", "c": "d"}), "query": QueryConfig(query="select * from foo"), + "row_filter": RowFilterConfig(), } diff --git a/tests/unit/relation_configs/test_tags.py b/tests/unit/relation_configs/test_tags.py index e465739b8..a9e4bdc0b 100644 --- a/tests/unit/relation_configs/test_tags.py +++ b/tests/unit/relation_configs/test_tags.py @@ -24,6 +24,15 @@ def test_from_relation_results__some(self): spec = TagsProcessor.from_relation_results(results) assert spec == TagsConfig(set_tags={"a": "valA", "b": "valB"}) + def test_from_relation_results__key_only(self): + results = { + "information_schema.tags": Table( + rows=[["a", ""]], column_names=["tag_name", "tag_value"] + ) + } + spec = TagsProcessor.from_relation_results(results) + assert spec == TagsConfig(set_tags={"a": ""}) + def test_from_relation_config__without_tags(self): model = Mock() model.config.extra = {} @@ -36,6 +45,18 @@ def test_from_relation_config__with_tags(self): spec = TagsProcessor.from_relation_config(model) assert spec == TagsConfig(set_tags={"a": "valA", "b": "1"}) + def test_from_relation_config__with_key_only_tags(self): + model = Mock() + model.config.extra = {"databricks_tags": {"a": "", "b": None}} + spec = TagsProcessor.from_relation_config(model) + assert spec == TagsConfig(set_tags={"a": "", "b": ""}) + + def test_from_relation_config__with_falsy_tags(self): + model = Mock() + model.config.extra = {"databricks_tags": {"priority": 0, "enabled": False}} + spec = TagsProcessor.from_relation_config(model) + assert spec == TagsConfig(set_tags={"priority": "0", "enabled": "False"}) + def test_from_relation_config__with_incorrect_tags(self): model = Mock() model.config.extra = {"databricks_tags": ["a", "b"]} @@ -52,25 +73,25 @@ def test_get_diff__empty_and_some_exist(self): # Tags are "set only" - when config has no tags and relation has tags, # we don't unset the existing tags config = TagsConfig(set_tags={}) - other = TagsConfig(set_tags={"tag": "value"}) - diff = config.get_diff(other) + config_old = TagsConfig(set_tags={"tag": "value"}) + diff = config.get_diff(config_old) assert diff is None # No changes needed since we don't unset tags def test_get_diff__some_new_and_empty_existing(self): config = TagsConfig(set_tags={"tag": "value"}) - other = TagsConfig(set_tags={}) - diff = config.get_diff(other) + config_old = TagsConfig(set_tags={}) + diff = config.get_diff(config_old) assert diff == TagsConfig(set_tags={"tag": "value"}) def test_get_diff__mixed_case(self): # Tags are "set only" - only the new/updated tags are included config = TagsConfig(set_tags={"a": "value", "b": "value"}) - other = TagsConfig(set_tags={"b": "other_value", "c": "value"}) - diff = config.get_diff(other) + config_old = TagsConfig(set_tags={"b": "other_value", "c": "value"}) + diff = config.get_diff(config_old) assert diff == TagsConfig(set_tags={"a": "value", "b": "value"}) def test_get_diff__no_changes(self): config = TagsConfig(set_tags={"tag": "value"}) - other = TagsConfig(set_tags={"tag": "value"}) - diff = config.get_diff(other) + config_old = TagsConfig(set_tags={"tag": "value"}) + diff = config.get_diff(config_old) assert diff is None diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index 23ce774d8..603969063 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -1228,6 +1228,54 @@ def test_get_columns_reraises_other_database_errors( ) +class TestIsDescribeAsJsonSupported(DatabricksAdapterBase): + @pytest.fixture + def adapter(self): + with patch("dbt.adapters.databricks.connections.DatabricksConnectionManager"): + adapter = DatabricksAdapter(self._get_config(), get_context("spawn")) + yield adapter + + def test_supported_for_uc_table_with_capability(self, adapter): + relation = DatabricksRelation.create( + database="catalog", + schema="schema", + identifier="table", + type=DatabricksRelation.Table, + ) + with patch.object(adapter, "has_capability", return_value=True): + assert adapter.is_describe_as_json_supported(relation) is True + + def test_not_supported_without_capability(self, adapter): + relation = DatabricksRelation.create( + database="catalog", + schema="schema", + identifier="table", + type=DatabricksRelation.Table, + ) + with patch.object(adapter, "has_capability", return_value=False): + assert adapter.is_describe_as_json_supported(relation) is False + + def test_not_supported_for_hive_metastore(self, adapter): + relation = DatabricksRelation.create( + database="hive_metastore", + schema="schema", + identifier="table", + type=DatabricksRelation.Table, + ) + with patch.object(adapter, "has_capability", return_value=True): + assert adapter.is_describe_as_json_supported(relation) is False + + def test_not_supported_for_foreign_table(self, adapter): + relation = DatabricksRelation.create( + database="catalog", + schema="schema", + identifier="table", + type=DatabricksRelationType.Foreign, + ) + with patch.object(adapter, "has_capability", return_value=True): + assert adapter.is_describe_as_json_supported(relation) is False + + class TestManagedIcebergBehaviorFlag(DatabricksAdapterBase): @pytest.fixture def adapter(self): diff --git a/tests/unit/test_dbr_capabilities.py b/tests/unit/test_dbr_capabilities.py index 7f02d8dcf..0ba052b34 100644 --- a/tests/unit/test_dbr_capabilities.py +++ b/tests/unit/test_dbr_capabilities.py @@ -14,40 +14,57 @@ class TestDBRCapabilities: def test_capability_enum_values(self): """Test that all capabilities have the expected values.""" - assert DBRCapability.TIMESTAMPDIFF.value == "timestampdiff" - assert DBRCapability.ICEBERG.value == "iceberg" assert DBRCapability.COMMENT_ON_COLUMN.value == "comment_on_column" + assert ( + DBRCapability.DESCRIBE_TABLE_EXTENDED_AS_JSON.value == "describe_table_extended_as_json" + ) + assert DBRCapability.ICEBERG.value == "iceberg" + assert DBRCapability.INSERT_BY_NAME.value == "insert_by_name" assert DBRCapability.JSON_COLUMN_METADATA.value == "json_column_metadata" + assert DBRCapability.REPLACE_ON.value == "replace_on" + assert DBRCapability.STREAMING_TABLE_JSON_METADATA.value == "streaming_table_json_metadata" + assert DBRCapability.TIMESTAMPDIFF.value == "timestampdiff" def test_old_dbr_version(self): """Test capabilities with old DBR version.""" capabilities = DBRCapabilities(dbr_version=(10, 0)) # Should not have newer features - assert not capabilities.has_capability(DBRCapability.TIMESTAMPDIFF) - assert not capabilities.has_capability(DBRCapability.ICEBERG) assert not capabilities.has_capability(DBRCapability.COMMENT_ON_COLUMN) + assert not capabilities.has_capability(DBRCapability.DESCRIBE_TABLE_EXTENDED_AS_JSON) + assert not capabilities.has_capability(DBRCapability.ICEBERG) + assert not capabilities.has_capability(DBRCapability.INSERT_BY_NAME) assert not capabilities.has_capability(DBRCapability.JSON_COLUMN_METADATA) + assert not capabilities.has_capability(DBRCapability.REPLACE_ON) + assert not capabilities.has_capability(DBRCapability.STREAMING_TABLE_JSON_METADATA) + assert not capabilities.has_capability(DBRCapability.TIMESTAMPDIFF) def test_modern_dbr_version(self): """Test capabilities with modern DBR version.""" - capabilities = DBRCapabilities(dbr_version=(16, 2)) + capabilities = DBRCapabilities(dbr_version=(17, 3)) # Should have all features up to 16.2 - assert capabilities.has_capability(DBRCapability.TIMESTAMPDIFF) - assert capabilities.has_capability(DBRCapability.ICEBERG) assert capabilities.has_capability(DBRCapability.COMMENT_ON_COLUMN) + assert capabilities.has_capability(DBRCapability.DESCRIBE_TABLE_EXTENDED_AS_JSON) + assert capabilities.has_capability(DBRCapability.ICEBERG) + assert capabilities.has_capability(DBRCapability.INSERT_BY_NAME) assert capabilities.has_capability(DBRCapability.JSON_COLUMN_METADATA) + assert capabilities.has_capability(DBRCapability.REPLACE_ON) + assert capabilities.has_capability(DBRCapability.STREAMING_TABLE_JSON_METADATA) + assert capabilities.has_capability(DBRCapability.TIMESTAMPDIFF) def test_sql_warehouse(self): """Test that SQL warehouses are assumed to have latest features.""" capabilities = DBRCapabilities(is_sql_warehouse=True) # SQL warehouses should have all supported features - assert capabilities.has_capability(DBRCapability.TIMESTAMPDIFF) - assert capabilities.has_capability(DBRCapability.ICEBERG) assert capabilities.has_capability(DBRCapability.COMMENT_ON_COLUMN) + assert capabilities.has_capability(DBRCapability.DESCRIBE_TABLE_EXTENDED_AS_JSON) + assert capabilities.has_capability(DBRCapability.ICEBERG) + assert capabilities.has_capability(DBRCapability.INSERT_BY_NAME) assert capabilities.has_capability(DBRCapability.JSON_COLUMN_METADATA) + assert capabilities.has_capability(DBRCapability.REPLACE_ON) + assert capabilities.has_capability(DBRCapability.TIMESTAMPDIFF) def test_sql_warehouse_unsupported_features(self): """Test that some features are not supported on SQL warehouses.""" @@ -58,17 +75,36 @@ def test_sql_warehouse_unsupported_features(self): def test_get_required_version(self): """Test getting required version strings.""" - assert DBRCapabilities.get_required_version(DBRCapability.TIMESTAMPDIFF) == "DBR 10.4+" - assert DBRCapabilities.get_required_version(DBRCapability.ICEBERG) == "DBR 14.3+" assert DBRCapabilities.get_required_version(DBRCapability.COMMENT_ON_COLUMN) == "DBR 16.1+" + assert ( + DBRCapabilities.get_required_version(DBRCapability.DESCRIBE_TABLE_EXTENDED_AS_JSON) + == "DBR 17.3+" + ) + assert DBRCapabilities.get_required_version(DBRCapability.ICEBERG) == "DBR 14.3+" + assert DBRCapabilities.get_required_version(DBRCapability.INSERT_BY_NAME) == "DBR 12.2+" + assert ( + DBRCapabilities.get_required_version(DBRCapability.JSON_COLUMN_METADATA) == "DBR 16.2+" + ) + assert DBRCapabilities.get_required_version(DBRCapability.REPLACE_ON) == "DBR 17.1+" + assert ( + DBRCapabilities.get_required_version(DBRCapability.STREAMING_TABLE_JSON_METADATA) + == "DBR 17.1+" + ) + assert DBRCapabilities.get_required_version(DBRCapability.TIMESTAMPDIFF) == "DBR 10.4+" def test_no_connection(self): """Test behavior when not connected (no version info).""" capabilities = DBRCapabilities(dbr_version=None) # Without connection info, assume no capabilities - assert not capabilities.has_capability(DBRCapability.TIMESTAMPDIFF) + assert not capabilities.has_capability(DBRCapability.COMMENT_ON_COLUMN) + assert not capabilities.has_capability(DBRCapability.DESCRIBE_TABLE_EXTENDED_AS_JSON) assert not capabilities.has_capability(DBRCapability.ICEBERG) + assert not capabilities.has_capability(DBRCapability.INSERT_BY_NAME) + assert not capabilities.has_capability(DBRCapability.JSON_COLUMN_METADATA) + assert not capabilities.has_capability(DBRCapability.REPLACE_ON) + assert not capabilities.has_capability(DBRCapability.STREAMING_TABLE_JSON_METADATA) + assert not capabilities.has_capability(DBRCapability.TIMESTAMPDIFF) def test_enabled_capabilities_property(self): """Test the enabled_capabilities method.""" @@ -78,15 +114,18 @@ def test_enabled_capabilities_property(self): # Should include all capabilities supported by DBR 16.2 expected = { - DBRCapability.TIMESTAMPDIFF, - DBRCapability.ICEBERG, DBRCapability.COMMENT_ON_COLUMN, + DBRCapability.ICEBERG, + DBRCapability.INSERT_BY_NAME, DBRCapability.JSON_COLUMN_METADATA, + DBRCapability.TIMESTAMPDIFF, } assert expected.issubset(enabled) # Should not include capabilities requiring newer versions + assert DBRCapability.DESCRIBE_TABLE_EXTENDED_AS_JSON not in enabled + assert DBRCapability.REPLACE_ON not in enabled assert DBRCapability.STREAMING_TABLE_JSON_METADATA not in enabled @@ -137,6 +176,15 @@ def test_version_requirements(self): assert specs[DBRCapability.ICEBERG].min_version == (14, 3) assert specs[DBRCapability.COMMENT_ON_COLUMN].min_version == (16, 1) assert specs[DBRCapability.JSON_COLUMN_METADATA].min_version == (16, 2) + assert specs[DBRCapability.DESCRIBE_TABLE_EXTENDED_AS_JSON].min_version == (17, 3) + + def test_describe_json_boundary(self): + """Test DESCRIBE_TABLE_EXTENDED_AS_JSON is available at 17.3 but not 17.2.""" + unsupported = DBRCapabilities(dbr_version=(17, 2)) + assert not unsupported.has_capability(DBRCapability.DESCRIBE_TABLE_EXTENDED_AS_JSON) + + supported = DBRCapabilities(dbr_version=(17, 3)) + assert supported.has_capability(DBRCapability.DESCRIBE_TABLE_EXTENDED_AS_JSON) def test_sql_warehouse_support_flags(self): """Test that SQL warehouse support is correctly specified.""" diff --git a/tests/unit/test_describe_json_metadata.py b/tests/unit/test_describe_json_metadata.py new file mode 100644 index 000000000..e39d7af38 --- /dev/null +++ b/tests/unit/test_describe_json_metadata.py @@ -0,0 +1,986 @@ +""" +Unit tests for DatabricksDescribeJsonMetadata parser. + +Tests the parsing of DESCRIBE TABLE EXTENDED ... AS JSON responses into +agate Tables that match the format expected by existing processors. +""" + +from dbt_common.contracts.constraints import ConstraintType + +from dbt.adapters.databricks.constraints import ( + ForeignKeyConstraint, + PrimaryKeyConstraint, +) +from dbt.adapters.databricks.impl import DatabricksDescribeJsonMetadata +from dbt.adapters.databricks.relation_configs.column_mask import ( + ColumnMaskConfig, + ColumnMaskProcessor, +) +from dbt.adapters.databricks.relation_configs.constraints import ( + ConstraintsConfig, + ConstraintsProcessor, +) +from dbt.adapters.databricks.relation_configs.query import QueryConfig, QueryProcessor + +# Fixtures: minimal JSON samples with only fields relevant to parsing. + + +EMAIL_ADDRESSES_JSON = { + "columns": [ + {"name": "address_id", "nullable": False}, + {"name": "remote_user_id", "nullable": True}, + {"name": "email_address", "nullable": True}, + ], + "table_constraints": ( + "[(email_ad_pk,PRIMARY KEY (`address_id`))," + " (email_fk,FOREIGN KEY (`remote_user_id`)" + " REFERENCES `main`.`default`.`users` (`user_id`))]" + ), +} + +COLUMN_MASK_JSON = { + "column_masks": [ + { + "column_name": "phone_number", + "function_name": { + "catalog_name": "main", + "schema_name": "db", + "function_name": "mask_phone", + }, + "using_column_names": ["city"], + } + ], +} + + +MATERIALIZED_VIEW_JSON = { + "view_text": "SELECT id, name FROM main.default.source_table", +} + +REGULAR_VIEW_JSON = { + "view_text": "SELECT id, name FROM main.default.other_table", +} + +PLAIN_TABLE_JSON = { + "columns": [ + {"name": "id", "nullable": True}, + {"name": "value", "nullable": True}, + ], +} + + +COMPOSITE_PK_JSON = { + "columns": [ + {"name": "id", "nullable": False}, + {"name": "name", "nullable": False}, + {"name": "value", "nullable": True}, + ], + "table_constraints": "[(id_name_pk,PRIMARY KEY (`id`, `name`))]", +} + +COMPOSITE_FK_JSON = { + "columns": [ + {"name": "id", "nullable": True}, + {"name": "ref_id", "nullable": True}, + {"name": "ref_name", "nullable": True}, + ], + "table_constraints": ( + "[(fk_pk,PRIMARY KEY (`id`))," + " (child_fk,FOREIGN KEY (`ref_id`, `ref_name`)" + " REFERENCES `main`.`default`.`parents` (`id`, `name`))]" + ), +} + +ALL_FIELDS_JSON = { + "columns": [ + {"name": "id", "nullable": False}, + {"name": "secret", "nullable": True}, + ], + "table_constraints": ( + "[(pk1,PRIMARY KEY (`id`))," + " (fk1,FOREIGN KEY (`id`)" + " REFERENCES `main`.`default`.`other` (`other_id`))]" + ), + "column_masks": [ + { + "column_name": "secret", + "function_name": { + "catalog_name": "main", + "schema_name": "db", + "function_name": "mask_secret", + }, + "using_column_names": ["id"], + } + ], + "view_text": "SELECT id, secret FROM main.default.source", +} + +MIXED_PK_FK_JSON = { + "columns": [ + {"name": "id", "nullable": False}, + {"name": "ref_id", "nullable": True}, + ], + "table_constraints": ( + "[(pk1,PRIMARY KEY (`id`))," + " (fk1,FOREIGN KEY (`ref_id`)" + " REFERENCES `main`.`default`.`other` (`other_id`))]" + ), +} + + +class TestParsePrimaryKeyConstraints: + def test_single_primary_key(self): + """Test PRIMARY KEY parsing with a single primary key constraint.""" + json_metadata = {"table_constraints": "[(pk1,PRIMARY KEY (`address_id`))]"} + result = DatabricksDescribeJsonMetadata.parse_primary_key_constraints(json_metadata) + assert len(result.rows) == 1 + assert result.rows[0][0] == "pk1" + assert result.rows[0]["constraint_name"] == "pk1" + assert result.rows[0][1] == "address_id" + assert result.rows[0]["column_name"] == "address_id" + + def test_no_primary_key(self): + """Test PRIMARY KEY parsing with no primary key constraints.""" + json_metadata = { + "table_constraints": ( + "[(fk1,FOREIGN KEY (`ref_id`) REFERENCES `main`.`default`.`t` (`id`))]" + ) + } + result = DatabricksDescribeJsonMetadata.parse_primary_key_constraints(json_metadata) + assert len(result.rows) == 0 + + def test_no_table_constraints_field(self): + """Test PRIMARY KEY parsing with no table_constraints field.""" + result = DatabricksDescribeJsonMetadata.parse_primary_key_constraints({}) + assert len(result.rows) == 0 + + def test_empty_string(self): + """Test PRIMARY KEY parsing with an empty string.""" + result = DatabricksDescribeJsonMetadata.parse_primary_key_constraints( + {"table_constraints": ""} + ) + assert len(result.rows) == 0 + + def test_spaces(self): + """Test PRIMARY KEY parsing is robust to excessive spaces between 'PRIMARY' and 'KEY'.""" + for num_extra_spaces in range(0, 40): + es = " " * num_extra_spaces # extra spaces + constraint_entry = f"{es}({es}pk1{es},{es}PRIMARY {es}KEY{es}({es}`col_1`{es}){es}){es}" + json_metadata = {"table_constraints": f"[{constraint_entry}]"} + result = DatabricksDescribeJsonMetadata.parse_primary_key_constraints(json_metadata) + assert len(result.rows) == 1 + row = result.rows[0] + assert row[0] == "pk1" + assert row["constraint_name"] == "pk1" + assert row[1] == "col_1" + assert row["column_name"] == "col_1" + + def test_many_constraints(self): + """Test PRIMARY KEY constraint parsing with many constraints in one string.""" + constraint_count = 20 + constraint_entries = [ + f"(pk{i},PRIMARY KEY (`col_{i}`))" for i in range(1, constraint_count + 1) + ] + json_metadata = {"table_constraints": f"[{', '.join(constraint_entries)}]"} + result = DatabricksDescribeJsonMetadata.parse_primary_key_constraints(json_metadata) + assert len(result.rows) == constraint_count + for row_index in range(constraint_count): + expected_constraint_name = f"pk{row_index + 1}" + expected_column_name = f"col_{row_index + 1}" + row = result.rows[row_index] + assert row[0] == expected_constraint_name + assert row["constraint_name"] == expected_constraint_name + assert row[1] == expected_column_name + assert row["column_name"] == expected_column_name + + def test_composite_with_many_columns(self): + """Test composite PRIMARY KEY with 1 to 20 columns.""" + for num_cols in range(1, 21): + cols = ", ".join(f"`col_{i}`" for i in range(1, num_cols + 1)) + json_metadata = {"table_constraints": f"[(pk1,PRIMARY KEY ({cols}))]"} + result = DatabricksDescribeJsonMetadata.parse_primary_key_constraints(json_metadata) + assert len(result.rows) == num_cols + for i in range(num_cols): + assert result.rows[i][0] == "pk1" + assert result.rows[i][1] == f"col_{i + 1}" + + def test_underscores_on_names(self): + """ + Test that PRIMARY KEY parsing works when table/column names in constraints + are qualified with varying numbers of leading/trailing underscores. + """ + for i in range(0, 20): + usc = "_" * i # underscores + column_name = f"{usc}id{usc}" + constraint_entry = f"(pk1,PRIMARY KEY (`{column_name}`))" + + json_metadata = {"table_constraints": f"[{constraint_entry}]"} + result = DatabricksDescribeJsonMetadata.parse_primary_key_constraints(json_metadata) + assert len(result.rows) == 1 + row = result.rows[0] + assert row[0] == "pk1" + assert row["constraint_name"] == "pk1" + assert row[1] == column_name + assert row["column_name"] == column_name + + +class TestParseForeignKeyConstraints: + def test_single_column_foreign_key(self): + """Test FOREIGN KEY parsing with a single foreign key constraint.""" + json_metadata = { + "table_constraints": ( + "[(fk1,FOREIGN KEY (`ref_id`) REFERENCES `main`.`default`.`users` (`user_id`))]" + ) + } + result = DatabricksDescribeJsonMetadata.parse_foreign_key_constraints(json_metadata) + assert len(result.rows) == 1 + row = result.rows[0] + assert row[0] == "fk1" + assert row["constraint_name"] == "fk1" + assert row[1] == "ref_id" + assert row["from_column"] == "ref_id" + assert row[2] == "main" + assert row["to_catalog"] == "main" + assert row[3] == "default" + assert row["to_schema"] == "default" + assert row[4] == "users" + assert row["to_table"] == "users" + assert row[5] == "user_id" + assert row["to_column"] == "user_id" + + def test_composite_foreign_key(self): + """Test FOREIGN KEY parsing many columns.""" + for num_cols in range(1, 21): + from_cols = ", ".join(f"`from_{i}`" for i in range(1, num_cols + 1)) + to_cols = ", ".join(f"`to_{i}`" for i in range(1, num_cols + 1)) + json_metadata = { + "table_constraints": ( + f"[(cfk,FOREIGN KEY ({from_cols})" + f" REFERENCES `main`.`default`.`parents` ({to_cols}))]" + ) + } + result = DatabricksDescribeJsonMetadata.parse_foreign_key_constraints(json_metadata) + assert len(result.rows) == num_cols + for i in range(num_cols): + row = result.rows[i] + assert row[0] == "cfk" + assert row["constraint_name"] == "cfk" + assert row[1] == f"from_{i + 1}" + assert row["from_column"] == f"from_{i + 1}" + assert row[2] == "main" + assert row[3] == "default" + assert row[4] == "parents" + assert row[5] == f"to_{i + 1}" + assert row["to_column"] == f"to_{i + 1}" + + def test_schema_with_hyphens(self): + """Test FOREIGN KEY parsing when the referenced schema contains hyphens.""" + json_metadata = { + "table_constraints": ( + "[(fk1,FOREIGN KEY (`ref_id`) REFERENCES `main`.`my-schema`.`users` (`user_id`))]" + ) + } + result = DatabricksDescribeJsonMetadata.parse_foreign_key_constraints(json_metadata) + assert len(result.rows) == 1 + row = result.rows[0] + assert row[3] == "my-schema" + assert row["to_schema"] == "my-schema" + + def test_foreign_key_with_primary_key(self): + """Test FOREIGN KEY parsing with mixed primary and foreign key constraints.""" + result = DatabricksDescribeJsonMetadata.parse_foreign_key_constraints(MIXED_PK_FK_JSON) + assert len(result.rows) == 1 + row = result.rows[0] + assert row[0] == "fk1" + assert row["constraint_name"] == "fk1" + assert row[1] == "ref_id" + assert row["from_column"] == "ref_id" + + def test_no_foreign_key(self): + """Test FOREIGN KEY parsing with no foreign key constraints.""" + json_metadata = {"table_constraints": "[(pk1,PRIMARY KEY (`id`))]"} + result = DatabricksDescribeJsonMetadata.parse_foreign_key_constraints(json_metadata) + assert len(result.rows) == 0 + + def test_no_table_constraints_field(self): + """Test FOREIGN KEY parsing with no table_constraints field.""" + result = DatabricksDescribeJsonMetadata.parse_foreign_key_constraints({}) + assert len(result.rows) == 0 + + def test_empty_string(self): + """Test FOREIGN KEY parsing with an empty string.""" + result = DatabricksDescribeJsonMetadata.parse_foreign_key_constraints( + {"table_constraints": ""} + ) + assert len(result.rows) == 0 + + def test_spaces(self): + """Test FOREIGN KEY parsing is robust to excessive spaces between keywords.""" + for num_extra_spaces in range(0, 40): + es = " " * num_extra_spaces + constraint_entry = ( + f"{es}({es}fk1{es},{es}FOREIGN {es}KEY{es}({es}`ref_id`{es})" + f"{es}REFERENCES{es}`main`{es}.{es}`default`{es}.{es}`users`{es}" + f"({es}`user_id`{es}){es}){es}" + ) + json_metadata = {"table_constraints": f"[{constraint_entry}]"} + result = DatabricksDescribeJsonMetadata.parse_foreign_key_constraints(json_metadata) + assert len(result.rows) == 1 + row = result.rows[0] + assert row[0] == "fk1" + assert row["constraint_name"] == "fk1" + assert row[1] == "ref_id" + assert row["from_column"] == "ref_id" + assert row[2] == "main" + assert row["to_catalog"] == "main" + assert row[3] == "default" + assert row["to_schema"] == "default" + assert row[4] == "users" + assert row["to_table"] == "users" + assert row[5] == "user_id" + assert row["to_column"] == "user_id" + + def test_many_constraints(self): + """Test FOREIGN KEY parsing with many constraints in one string.""" + constraint_count = 20 + constraint_entries = [ + ( + f"(fk{i},FOREIGN KEY (`ref_col_{i}`)" + f" REFERENCES `main`.`default`.`users_{i}` (`user_col_{i}`))" + ) + for i in range(1, constraint_count + 1) + ] + json_metadata = {"table_constraints": f"[{', '.join(constraint_entries)}]"} + result = DatabricksDescribeJsonMetadata.parse_foreign_key_constraints(json_metadata) + assert len(result.rows) == constraint_count + for row_index in range(constraint_count): + expected_constraint_name = f"fk{row_index + 1}" + expected_from_column = f"ref_col_{row_index + 1}" + expected_to_table = f"users_{row_index + 1}" + expected_to_column = f"user_col_{row_index + 1}" + row = result.rows[row_index] + assert row[0] == expected_constraint_name + assert row["constraint_name"] == expected_constraint_name + assert row[1] == expected_from_column + assert row["from_column"] == expected_from_column + assert row[2] == "main" + assert row["to_catalog"] == "main" + assert row[3] == "default" + assert row["to_schema"] == "default" + assert row[4] == expected_to_table + assert row["to_table"] == expected_to_table + assert row[5] == expected_to_column + assert row["to_column"] == expected_to_column + + def test_underscores_on_names(self): + """Test FOREIGN KEY parsing with varying leading and trailing underscores.""" + for i in range(0, 20): + underscores = "_" * i + from_column = f"{underscores}ref_id{underscores}" + to_catalog = f"{underscores}main{underscores}" + to_schema = f"{underscores}default{underscores}" + to_table = f"{underscores}users{underscores}" + to_column = f"{underscores}user_id{underscores}" + constraint_entry = ( + f"(fk1,FOREIGN KEY (`{from_column}`)" + f" REFERENCES `{to_catalog}`.`{to_schema}`.`{to_table}` (`{to_column}`))" + ) + + json_metadata = {"table_constraints": f"[{constraint_entry}]"} + result = DatabricksDescribeJsonMetadata.parse_foreign_key_constraints(json_metadata) + assert len(result.rows) == 1 + row = result.rows[0] + assert row[0] == "fk1" + assert row["constraint_name"] == "fk1" + assert row[1] == from_column + assert row["from_column"] == from_column + assert row[2] == to_catalog + assert row["to_catalog"] == to_catalog + assert row[3] == to_schema + assert row["to_schema"] == to_schema + assert row[4] == to_table + assert row["to_table"] == to_table + assert row[5] == to_column + assert row["to_column"] == to_column + + +class TestParseNonNullConstraints: + def test_mixed_nullable(self): + """Test parsing of non-null constraints when some columns are nullable and some are not.""" + json_metadata = { + "columns": [ + {"name": "id", "nullable": False}, + {"name": "email", "nullable": True}, + ] + } + result = DatabricksDescribeJsonMetadata.parse_non_null_constraints(json_metadata) + assert len(result.rows) == 1 + assert result.rows[0][0] == "id" + assert result.rows[0]["column_name"] == "id" + + def test_all_nullable(self): + """Test parsing of non-null constraints when all columns are nullable.""" + json_metadata = { + "columns": [ + {"name": "a", "nullable": True}, + {"name": "b", "nullable": True}, + ] + } + result = DatabricksDescribeJsonMetadata.parse_non_null_constraints(json_metadata) + assert len(result.rows) == 0 + + def test_multiple_non_null(self): + """Test parsing of non-null constraints when multiple columns are non-nullable.""" + json_metadata = { + "columns": [ + {"name": "id", "nullable": False}, + {"name": "email", "nullable": False}, + {"name": "msg", "nullable": True}, + ] + } + result = DatabricksDescribeJsonMetadata.parse_non_null_constraints(json_metadata) + assert len(result.rows) == 2 + assert result.rows[0][0] == "id" + assert result.rows[0]["column_name"] == "id" + assert result.rows[1][0] == "email" + assert result.rows[1]["column_name"] == "email" + + def test_no_columns_key(self): + """Test parsing of non-null constraints when there is no 'columns' key in the input.""" + result = DatabricksDescribeJsonMetadata.parse_non_null_constraints({}) + assert len(result.rows) == 0 + + +class TestParseColumnMasks: + def test_mask_with_using_columns(self): + result = DatabricksDescribeJsonMetadata.parse_column_masks(COLUMN_MASK_JSON) + assert len(result.rows) == 1 + assert result.rows[0][0] == "phone_number" + assert result.rows[0]["column_name"] == "phone_number" + assert result.rows[0][1] == "main.db.mask_phone" + assert result.rows[0]["mask_name"] == "main.db.mask_phone" + assert result.rows[0][2] == "city" + assert result.rows[0]["using_columns"] == "city" + + def test_mask_without_using_columns(self): + json_metadata = { + "column_masks": [ + { + "column_name": "ssn", + "function_name": { + "catalog_name": "main", + "schema_name": "db", + "function_name": "mask_ssn", + }, + "using_column_names": [], + } + ] + } + result = DatabricksDescribeJsonMetadata.parse_column_masks(json_metadata) + assert len(result.rows) == 1 + assert result.rows[0][0] == "ssn" + assert result.rows[0]["column_name"] == "ssn" + assert result.rows[0][1] == "main.db.mask_ssn" + assert result.rows[0]["mask_name"] == "main.db.mask_ssn" + assert result.rows[0][2] is None + assert result.rows[0]["using_columns"] is None + + def test_multiple_masks(self): + json_metadata = { + "column_masks": [ + { + "column_name": "col_a", + "function_name": { + "catalog_name": "c", + "schema_name": "s", + "function_name": "fn_a", + }, + "using_column_names": ["x"], + }, + { + "column_name": "col_b", + "function_name": { + "catalog_name": "c", + "schema_name": "s", + "function_name": "fn_b", + }, + "using_column_names": [], + }, + ] + } + result = DatabricksDescribeJsonMetadata.parse_column_masks(json_metadata) + assert len(result.rows) == 2 + assert result.rows[0][0] == "col_a" + assert result.rows[0]["column_name"] == "col_a" + assert result.rows[0][1] == "c.s.fn_a" + assert result.rows[0]["mask_name"] == "c.s.fn_a" + assert result.rows[0][2] == "x" + assert result.rows[0]["using_columns"] == "x" + assert result.rows[1][0] == "col_b" + assert result.rows[1]["column_name"] == "col_b" + assert result.rows[1][1] == "c.s.fn_b" + assert result.rows[1]["mask_name"] == "c.s.fn_b" + assert result.rows[1][2] is None + assert result.rows[1]["using_columns"] is None + + def test_no_column_masks_field(self): + result = DatabricksDescribeJsonMetadata.parse_column_masks({}) + assert len(result.rows) == 0 + + def test_empty_column_masks(self): + result = DatabricksDescribeJsonMetadata.parse_column_masks({"column_masks": []}) + assert len(result.rows) == 0 + + def test_mask_with_multiple_using_columns(self): + json_input = { + "column_masks": [ + { + "column_name": "secret", + "function_name": { + "catalog_name": "main", + "schema_name": "db", + "function_name": "mask_fn", + }, + "using_column_names": ["col1", "col2", "col3"], + } + ] + } + result = DatabricksDescribeJsonMetadata.parse_column_masks(json_input) + assert len(result.rows) == 1 + assert result.rows[0][0] == "secret" + assert result.rows[0]["column_name"] == "secret" + assert result.rows[0][1] == "main.db.mask_fn" + assert result.rows[0]["mask_name"] == "main.db.mask_fn" + assert result.rows[0][2] == "col1,col2,col3" + assert result.rows[0]["using_columns"] == "col1,col2,col3" + + def test_mask_missing_using_column_names_key(self): + json_input = { + "column_masks": [ + { + "column_name": "secret", + "function_name": { + "catalog_name": "main", + "schema_name": "db", + "function_name": "mask_fn", + }, + } + ] + } + result = DatabricksDescribeJsonMetadata.parse_column_masks(json_input) + assert len(result.rows) == 1 + assert result.rows[0][0] == "secret" + assert result.rows[0]["column_name"] == "secret" + assert result.rows[0][1] == "main.db.mask_fn" + assert result.rows[0]["mask_name"] == "main.db.mask_fn" + assert result.rows[0][2] is None + assert result.rows[0]["using_columns"] is None + + +class TestParseViewDescription: + def test_with_view_text(self): + json_metadata = {"view_text": "SELECT id, name FROM main.default.source_table"} + result = DatabricksDescribeJsonMetadata.parse_view_description(json_metadata) + assert result["view_definition"] == "SELECT id, name FROM main.default.source_table" + + def test_without_view_text(self): + json_metadata = { + "columns": [ + {"name": "id", "nullable": True}, + {"name": "value", "nullable": True}, + ], + } + result = DatabricksDescribeJsonMetadata.parse_view_description(json_metadata) + assert len(result.values()) == 0 + + def test_null_view_text(self): + result = DatabricksDescribeJsonMetadata.parse_view_description({"view_text": None}) + assert len(result.values()) == 0 + + +class TestFromJsonMetadata: + def test_table_with_column_masks(self): + metadata = DatabricksDescribeJsonMetadata.from_json_metadata(COLUMN_MASK_JSON) + assert len(metadata.column_masks.rows) == 1 + assert metadata.column_masks.rows[0][0] == "phone_number" + assert metadata.column_masks.rows[0]["column_name"] == "phone_number" + assert metadata.column_masks.rows[0][1] == "main.db.mask_phone" + assert metadata.column_masks.rows[0]["mask_name"] == "main.db.mask_phone" + assert metadata.column_masks.rows[0][2] == "city" + assert metadata.column_masks.rows[0]["using_columns"] == "city" + assert len(metadata.primary_key_constraints.rows) == 0 + assert len(metadata.foreign_key_constraints.rows) == 0 + + def test_materialized_view(self): + metadata = DatabricksDescribeJsonMetadata.from_json_metadata(MATERIALIZED_VIEW_JSON) + assert metadata.view_description["view_definition"] == ( + "SELECT id, name FROM main.default.source_table" + ) + assert len(metadata.primary_key_constraints.rows) == 0 + assert len(metadata.column_masks.rows) == 0 + + def test_all_fields_populated(self): + metadata = DatabricksDescribeJsonMetadata.from_json_metadata(ALL_FIELDS_JSON) + # PK + assert len(metadata.primary_key_constraints.rows) == 1 + assert metadata.primary_key_constraints.rows[0]["constraint_name"] == "pk1" + assert metadata.primary_key_constraints.rows[0]["column_name"] == "id" + # FK + assert len(metadata.foreign_key_constraints.rows) == 1 + fk = metadata.foreign_key_constraints.rows[0] + assert fk["constraint_name"] == "fk1" + assert fk["from_column"] == "id" + assert fk["to_catalog"] == "main" + assert fk["to_schema"] == "default" + assert fk["to_table"] == "other" + assert fk["to_column"] == "other_id" + # Non-null + assert len(metadata.non_null_constraints.rows) == 1 + assert metadata.non_null_constraints.rows[0]["column_name"] == "id" + # Column masks + assert len(metadata.column_masks.rows) == 1 + assert metadata.column_masks.rows[0]["column_name"] == "secret" + assert metadata.column_masks.rows[0]["mask_name"] == "main.db.mask_secret" + assert metadata.column_masks.rows[0]["using_columns"] == "id" + # View description + assert metadata.view_description["view_definition"] == ( + "SELECT id, secret FROM main.default.source" + ) + + def test_plain_table(self): + metadata = DatabricksDescribeJsonMetadata.from_json_metadata(PLAIN_TABLE_JSON) + assert len(metadata.primary_key_constraints.rows) == 0 + assert len(metadata.foreign_key_constraints.rows) == 0 + assert len(metadata.non_null_constraints.rows) == 0 + assert len(metadata.column_masks.rows) == 0 + assert len(metadata.view_description.values()) == 0 + + +class TestParserToConstraintsProcessor: + @staticmethod + def _build_results(metadata): + return { + "non_null_constraint_columns": metadata.non_null_constraints, + "primary_key_constraints": metadata.primary_key_constraints, + "foreign_key_constraints": metadata.foreign_key_constraints, + } + + def test_single_pk_roundtrip(self): + json_metadata = { + "columns": [{"name": "id", "nullable": False}], + "table_constraints": "[(pk1,PRIMARY KEY (`id`))]", + } + metadata = DatabricksDescribeJsonMetadata.from_json_metadata(json_metadata) + config = ConstraintsProcessor.from_relation_results(self._build_results(metadata)) + assert config == ConstraintsConfig( + set_non_nulls={"id"}, + set_constraints={ + PrimaryKeyConstraint(type=ConstraintType.primary_key, name="pk1", columns=["id"]), + }, + ) + + def test_composite_pk_roundtrip(self): + metadata = DatabricksDescribeJsonMetadata.from_json_metadata(COMPOSITE_PK_JSON) + config = ConstraintsProcessor.from_relation_results(self._build_results(metadata)) + assert config == ConstraintsConfig( + set_non_nulls={"id", "name"}, + set_constraints={ + PrimaryKeyConstraint( + type=ConstraintType.primary_key, + name="id_name_pk", + columns=["id", "name"], + ) + }, + ) + + def test_single_fk_roundtrip(self): + json_metadata = { + "columns": [{"name": "ref_id", "nullable": True}], + "table_constraints": ( + "[(fk1,FOREIGN KEY (`ref_id`) REFERENCES `main`.`default`.`other` (`other_id`))]" + ), + } + metadata = DatabricksDescribeJsonMetadata.from_json_metadata(json_metadata) + config = ConstraintsProcessor.from_relation_results(self._build_results(metadata)) + assert config == ConstraintsConfig( + set_non_nulls=set(), + set_constraints={ + ForeignKeyConstraint( + type=ConstraintType.foreign_key, + name="fk1", + columns=["ref_id"], + to="`main`.`default`.`other`", + to_columns=["other_id"], + ) + }, + ) + + def test_composite_fk_roundtrip(self): + metadata = DatabricksDescribeJsonMetadata.from_json_metadata(COMPOSITE_FK_JSON) + config = ConstraintsProcessor.from_relation_results(self._build_results(metadata)) + assert config == ConstraintsConfig( + set_non_nulls=set(), + set_constraints={ + PrimaryKeyConstraint( + type=ConstraintType.primary_key, + name="fk_pk", + columns=["id"], + ), + ForeignKeyConstraint( + type=ConstraintType.foreign_key, + name="child_fk", + columns=["ref_id", "ref_name"], + to="`main`.`default`.`parents`", + to_columns=["id", "name"], + ), + }, + ) + + def test_mixed_constraints_roundtrip(self): + metadata = DatabricksDescribeJsonMetadata.from_json_metadata(EMAIL_ADDRESSES_JSON) + config = ConstraintsProcessor.from_relation_results(self._build_results(metadata)) + assert config.set_non_nulls == {"address_id"} + assert any( + isinstance(c, PrimaryKeyConstraint) and c.name == "email_ad_pk" + for c in config.set_constraints + ) + assert any( + isinstance(c, ForeignKeyConstraint) + and c.name == "email_fk" + and c.to == "`main`.`default`.`users`" + for c in config.set_constraints + ) + + def test_no_constraints_roundtrip(self): + metadata = DatabricksDescribeJsonMetadata.from_json_metadata(PLAIN_TABLE_JSON) + config = ConstraintsProcessor.from_relation_results(self._build_results(metadata)) + assert config == ConstraintsConfig(set_non_nulls=set(), set_constraints=set()) + + +class TestParserToColumnMaskProcessor: + def test_mask_roundtrip(self): + metadata = DatabricksDescribeJsonMetadata.from_json_metadata(COLUMN_MASK_JSON) + config = ColumnMaskProcessor.from_relation_results({"column_masks": metadata.column_masks}) + assert config == ColumnMaskConfig( + set_column_masks={ + "phone_number": { + "function": "main.db.mask_phone", + "using_columns": "city", + } + } + ) + + def test_no_masks_roundtrip(self): + metadata = DatabricksDescribeJsonMetadata.from_json_metadata(PLAIN_TABLE_JSON) + config = ColumnMaskProcessor.from_relation_results({"column_masks": metadata.column_masks}) + assert config == ColumnMaskConfig(set_column_masks={}) + + def test_mask_no_false_diff(self): + metadata = DatabricksDescribeJsonMetadata.from_json_metadata(COLUMN_MASK_JSON) + existing = ColumnMaskProcessor.from_relation_results( + {"column_masks": metadata.column_masks} + ) + model = ColumnMaskConfig( + set_column_masks={ + "phone_number": { + "function": "main.db.mask_phone", + "using_columns": "city", + } + } + ) + assert model.get_diff(existing) is None + + def test_mask_diff_change_function(self): + metadata = DatabricksDescribeJsonMetadata.from_json_metadata(COLUMN_MASK_JSON) + existing = ColumnMaskProcessor.from_relation_results( + {"column_masks": metadata.column_masks} + ) + model = ColumnMaskConfig( + set_column_masks={ + "phone_number": { + "function": "main.db.new_mask_fn", + "using_columns": "city", + } + } + ) + diff = model.get_diff(existing) + assert diff is not None + assert diff.set_column_masks == { + "phone_number": { + "function": "main.db.new_mask_fn", + "using_columns": "city", + } + } + + def test_mask_diff_add_new_mask(self): + metadata = DatabricksDescribeJsonMetadata.from_json_metadata(COLUMN_MASK_JSON) + existing = ColumnMaskProcessor.from_relation_results( + {"column_masks": metadata.column_masks} + ) + model = ColumnMaskConfig( + set_column_masks={ + "phone_number": { + "function": "main.db.mask_phone", + "using_columns": "city", + }, + "ssn": { + "function": "main.db.mask_ssn", + }, + } + ) + diff = model.get_diff(existing) + assert diff is not None + assert "ssn" in diff.set_column_masks + assert "phone_number" not in diff.set_column_masks + + +class TestParserToQueryProcessor: + def test_mv_view_text_roundtrip(self): + view_desc = DatabricksDescribeJsonMetadata.parse_view_description(MATERIALIZED_VIEW_JSON) + config = QueryProcessor.from_relation_results({"information_schema.views": view_desc}) + assert config == QueryConfig(query="SELECT id, name FROM main.default.source_table") + + def test_view_text_roundtrip(self): + view_desc = DatabricksDescribeJsonMetadata.parse_view_description(REGULAR_VIEW_JSON) + config = QueryProcessor.from_relation_results({"information_schema.views": view_desc}) + assert config == QueryConfig(query="SELECT id, name FROM main.default.other_table") + + def test_view_text_with_outer_parens(self): + view_desc = DatabricksDescribeJsonMetadata.parse_view_description( + {"view_text": "(SELECT id FROM t)"} + ) + config = QueryProcessor.from_relation_results({"information_schema.views": view_desc}) + assert config == QueryConfig(query="SELECT id FROM t") + + +class TestParserToQueryDiff: + def test_no_false_diff_on_identical_query(self): + view_desc = DatabricksDescribeJsonMetadata.parse_view_description(MATERIALIZED_VIEW_JSON) + existing = QueryProcessor.from_relation_results({"information_schema.views": view_desc}) + model = QueryConfig(query="SELECT id, name FROM main.default.source_table") + assert model.get_diff(existing) is None + + def test_detects_real_query_change(self): + view_desc = DatabricksDescribeJsonMetadata.parse_view_description(MATERIALIZED_VIEW_JSON) + existing = QueryProcessor.from_relation_results({"information_schema.views": view_desc}) + model = QueryConfig(query="SELECT id FROM different_table") + diff = model.get_diff(existing) + assert diff is not None + assert diff.query == "SELECT id FROM different_table" + + +class TestParserToConstraintsDiff: + @staticmethod + def _build_results(metadata): + return { + "non_null_constraint_columns": metadata.non_null_constraints, + "primary_key_constraints": metadata.primary_key_constraints, + "foreign_key_constraints": metadata.foreign_key_constraints, + } + + def test_composite_pk_no_false_diff(self): + metadata = DatabricksDescribeJsonMetadata.from_json_metadata(COMPOSITE_PK_JSON) + existing = ConstraintsProcessor.from_relation_results(self._build_results(metadata)) + model = ConstraintsConfig( + set_non_nulls={"id", "name"}, + set_constraints={ + PrimaryKeyConstraint( + type=ConstraintType.primary_key, + name="id_name_pk", + columns=["id", "name"], + ) + }, + ) + assert model.get_diff(existing) is None + + def test_composite_pk_diff_add_column(self): + """Model adds a column to PK — diff should set new PK, unset old PK, set new non-null.""" + metadata = DatabricksDescribeJsonMetadata.from_json_metadata(COMPOSITE_PK_JSON) + existing = ConstraintsProcessor.from_relation_results(self._build_results(metadata)) + model = ConstraintsConfig( + set_non_nulls={"id", "name", "value"}, + set_constraints={ + PrimaryKeyConstraint( + type=ConstraintType.primary_key, + name="new_pk", + columns=["id", "name", "value"], + ) + }, + ) + diff = model.get_diff(existing) + assert diff is not None + assert diff.set_non_nulls == {"value"} + assert diff.unset_non_nulls == set() + assert len(diff.unset_constraints) == 1 + unset = next(iter(diff.unset_constraints)) + assert isinstance(unset, PrimaryKeyConstraint) + assert unset.name == "id_name_pk" + assert unset.columns == ["id", "name"] + assert len(diff.set_constraints) == 1 + added = next(iter(diff.set_constraints)) + assert isinstance(added, PrimaryKeyConstraint) + assert added.name == "new_pk" + assert added.columns == ["id", "name", "value"] + + def test_composite_fk_diff_change_target(self): + """Model changes FK target — diff should unset old FK, set new FK.""" + metadata = DatabricksDescribeJsonMetadata.from_json_metadata(COMPOSITE_FK_JSON) + existing = ConstraintsProcessor.from_relation_results(self._build_results(metadata)) + model = ConstraintsConfig( + set_non_nulls=set(), + set_constraints={ + PrimaryKeyConstraint( + type=ConstraintType.primary_key, + name="fk_pk", + columns=["id"], + ), + ForeignKeyConstraint( + type=ConstraintType.foreign_key, + name="new_fk", + columns=["ref_id"], + to="`main`.`default`.`other_table`", + to_columns=["other_id"], + ), + }, + ) + diff = model.get_diff(existing) + assert diff is not None + assert diff.set_non_nulls == set() + assert diff.unset_non_nulls == set() + # Old FK unset + unset_fks = {c for c in diff.unset_constraints if isinstance(c, ForeignKeyConstraint)} + assert len(unset_fks) == 1 + unset_fk = next(iter(unset_fks)) + assert unset_fk.name == "child_fk" + assert unset_fk.columns == ["ref_id", "ref_name"] + assert unset_fk.to == "`main`.`default`.`parents`" + # New FK set + set_fks = {c for c in diff.set_constraints if isinstance(c, ForeignKeyConstraint)} + assert len(set_fks) == 1 + set_fk = next(iter(set_fks)) + assert set_fk.name == "new_fk" + assert set_fk.columns == ["ref_id"] + assert set_fk.to == "`main`.`default`.`other_table`" + assert set_fk.to_columns == ["other_id"] + + def test_composite_fk_no_false_diff(self): + metadata = DatabricksDescribeJsonMetadata.from_json_metadata(COMPOSITE_FK_JSON) + existing = ConstraintsProcessor.from_relation_results(self._build_results(metadata)) + model = ConstraintsConfig( + set_non_nulls=set(), + set_constraints={ + PrimaryKeyConstraint( + type=ConstraintType.primary_key, + name="fk_pk", + columns=["id"], + ), + ForeignKeyConstraint( + type=ConstraintType.foreign_key, + name="child_fk", + columns=["ref_id", "ref_name"], + to="`main`.`default`.`parents`", + to_columns=["id", "name"], + ), + }, + ) + assert model.get_diff(existing) is None diff --git a/tests/unit/test_relation.py b/tests/unit/test_relation.py index 9a2ef2af8..b80355147 100644 --- a/tests/unit/test_relation.py +++ b/tests/unit/test_relation.py @@ -1,4 +1,5 @@ import pytest +from dbt.adapters.base.relation import FunctionConfig from dbt_common.contracts.constraints import ConstraintType from dbt_common.exceptions import DbtRuntimeError @@ -392,6 +393,66 @@ def test_none_identifier_is_allowed(self): rel = DatabricksRelation.create(identifier=None, type="table") assert rel.identifier is None +class TestGetFunctionConfig: + @pytest.fixture + def relation(self): + return DatabricksRelation.create() + + def test_python_udf_defaults_injected_when_omitted(self, relation): + """runtime_version and entry_point should be defaulted for Python UDFs.""" + model = { + "resource_type": "function", + "language": "python", + "name": "my_func", + "config": {"type": "scalar"}, + } + result = relation.get_function_config(model) + assert isinstance(result, FunctionConfig) + assert result.runtime_version == "3.11" + assert result.entry_point == "my_func" + assert result.language == "python" + assert result.type == "scalar" + + def test_python_udf_explicit_values_preserved(self, relation): + """Explicitly provided runtime_version and entry_point should be kept.""" + model = { + "resource_type": "function", + "language": "python", + "name": "my_func", + "config": { + "type": "scalar", + "runtime_version": "3.10", + "entry_point": "custom_handler", + }, + } + result = relation.get_function_config(model) + assert isinstance(result, FunctionConfig) + assert result.runtime_version == "3.10" + assert result.entry_point == "custom_handler" + + def test_sql_udf_delegates_to_super(self, relation): + """SQL UDFs should use the base class implementation.""" + model = { + "resource_type": "function", + "language": "sql", + "config": {"type": "scalar"}, + } + result = relation.get_function_config(model) + assert isinstance(result, FunctionConfig) + assert result.language == "sql" + # SQL functions don't need runtime_version/entry_point + assert result.runtime_version is None + assert result.entry_point is None + + def test_non_function_returns_none(self, relation): + """Non-function resource types should return None.""" + model = { + "resource_type": "model", + "language": "sql", + "config": {}, + } + result = relation.get_function_config(model) + assert result is None class TestDatabricksRenderLimited: def test_render_limited_with_empty_no_alias(self):