diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 5ace49d815..1522414e28 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -1540,28 +1540,32 @@ def months_between_sql(self: Generator, expression: exp.MonthsBetween) -> str: def build_formatted_time( - exp_class: Type[E], dialect: str, default: bool | str | None = None -) -> t.Callable[[BuilderArgs], E]: + exp_class: Type[E], dialect_override: str | None = None, default: bool | str | None = None +) -> t.Callable[[BuilderArgs, Dialect], E]: """Helper used for time expressions. Args: exp_class: the expression class to instantiate. - dialect: target sql dialect. + dialect_override: optional sql dialect to override the parser's one. default: the default format, True being time. Returns: A callable that can be used to return the appropriately formatted time expression. """ - def _builder(args: BuilderArgs) -> E: - return exp_class( - this=seq_get(args, 0), - format=Dialect[dialect].format_time( - seq_get(args, 1) - or (Dialect[dialect].TIME_FORMAT if default is True else default or None) - ), + def _builder(args: BuilderArgs, dialect: Dialect) -> E: + target_dialect = ( + t.cast(Dialect, Dialect[dialect_override]) + if isinstance(dialect_override, str) + else dialect ) + fmt = seq_get(args, 1) + if not fmt: + fmt = target_dialect.TIME_FORMAT if default is True else default or None + + return exp_class(this=seq_get(args, 0), format=target_dialect.format_time(fmt)) + return _builder @@ -2465,8 +2469,7 @@ def build_timetostr_or_tochar( annotate_types(this, dialect=dialect) if this.is_type(*exp.DataType.TEMPORAL_TYPES): - dialect_name = dialect.__class__.__name__.lower() - return build_formatted_time(exp.TimeToStr, dialect_name, default=True)(args) + return build_formatted_time(exp.TimeToStr, default=True)(args, t.cast(Dialect, dialect)) return exp.ToChar.from_arg_list(args) diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index d46e1f1fab..44fbc54c3b 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -24,8 +24,12 @@ class Redshift(Postgres): # ref: https://docs.aws.amazon.com/redshift/latest/dg/r_FORMAT_strings.html TIME_FORMAT = "'YYYY-MM-DD HH24:MI:SS'" - TIME_MAPPING = {**Postgres.TIME_MAPPING, "MON": "%b", "MONTH": "%B"} - INVERSE_TIME_MAPPING = {**Postgres.INVERSE_TIME_MAPPING, "%b": "MON", "%B": "MONTH"} + + TIME_MAPPING = { + **Postgres.TIME_MAPPING, + "MON": "%b", + "MONTH": "%B", + } Parser = RedshiftParser diff --git a/sqlglot/parsers/bigquery.py b/sqlglot/parsers/bigquery.py index 860b08301d..7be0ba837d 100644 --- a/sqlglot/parsers/bigquery.py +++ b/sqlglot/parsers/bigquery.py @@ -14,6 +14,7 @@ if t.TYPE_CHECKING: from sqlglot._typing import E + from sqlglot.dialects.dialect import Dialect def _build_contains_substring(args: list) -> exp.Contains: @@ -53,7 +54,7 @@ def _build_datetime(args: list) -> exp.Func: def _build_extract_json_with_default_path( expr_type: type[E], ) -> t.Callable: - def _builder(args: list, dialect: t.Any) -> E: + def _builder(args: list, dialect: Dialect) -> E: if len(args) == 1: args.append(exp.Literal.string("$")) return parser.build_extract_json_with_path(expr_type)(args, dialect) @@ -61,10 +62,10 @@ def _builder(args: list, dialect: t.Any) -> E: return _builder -def _build_format_time(expr_type: type[exp.Expr]) -> t.Callable[[list], exp.TimeToStr]: - def _builder(args: list) -> exp.TimeToStr: - formatted_time = build_formatted_time(exp.TimeToStr, "bigquery")( - [expr_type(this=seq_get(args, 1)), seq_get(args, 0)] +def _build_format_time(expr_type: type[exp.Expr]) -> t.Callable[[list, Dialect], exp.TimeToStr]: + def _builder(args: list, dialect: Dialect) -> exp.TimeToStr: + formatted_time = build_formatted_time(exp.TimeToStr)( + [expr_type(this=seq_get(args, 1)), seq_get(args, 0)], dialect ) formatted_time.set("zone", seq_get(args, 2)) return formatted_time @@ -91,14 +92,14 @@ def _build_levenshtein(args: list) -> exp.Levenshtein: ) -def _build_parse_timestamp(args: list) -> exp.StrToTime: - this = build_formatted_time(exp.StrToTime, "bigquery")([seq_get(args, 1), seq_get(args, 0)]) +def _build_parse_timestamp(args: list, dialect: Dialect) -> exp.StrToTime: + this = build_formatted_time(exp.StrToTime)([seq_get(args, 1), seq_get(args, 0)], dialect) this.set("zone", seq_get(args, 2)) return this def _build_regexp_extract(expr_type: type[E], default_group: exp.Expr | None = None) -> t.Callable: - def _builder(args: list, dialect: t.Any) -> E: + def _builder(args: list, dialect: Dialect) -> E: try: group = re.compile(args[1].name).groups == 1 except re.error: @@ -220,15 +221,15 @@ class BigQueryParser(parser.Parser): ), "OCTET_LENGTH": exp.ByteLength.from_arg_list, "TO_HEX": _build_to_hex, - "PARSE_DATE": lambda args: build_formatted_time(exp.StrToDate, "bigquery")( - [seq_get(args, 1), seq_get(args, 0)] + "PARSE_DATE": lambda args, dialect: build_formatted_time(exp.StrToDate)( + [seq_get(args, 1), seq_get(args, 0)], dialect ), - "PARSE_TIME": lambda args: build_formatted_time(exp.ParseTime, "bigquery")( - [seq_get(args, 1), seq_get(args, 0)] + "PARSE_TIME": lambda args, dialect: build_formatted_time(exp.ParseTime)( + [seq_get(args, 1), seq_get(args, 0)], dialect ), "PARSE_TIMESTAMP": _build_parse_timestamp, - "PARSE_DATETIME": lambda args: build_formatted_time(exp.ParseDatetime, "bigquery")( - [seq_get(args, 1), seq_get(args, 0)] + "PARSE_DATETIME": lambda args, dialect: build_formatted_time(exp.ParseDatetime)( + [seq_get(args, 1), seq_get(args, 0)], dialect ), "REGEXP_CONTAINS": exp.RegexpLike.from_arg_list, "REGEXP_EXTRACT": _build_regexp_extract(exp.RegexpExtract), diff --git a/sqlglot/parsers/clickhouse.py b/sqlglot/parsers/clickhouse.py index b159d98d32..8d72bbfeeb 100644 --- a/sqlglot/parsers/clickhouse.py +++ b/sqlglot/parsers/clickhouse.py @@ -22,9 +22,9 @@ def _build_datetime_format( expr_type: Type[E], -) -> t.Callable[[list], E]: - def _builder(args: list) -> E: - expr = build_formatted_time(expr_type, "clickhouse")(args) +) -> t.Callable: + def _builder(args: list, dialect: t.Any) -> E: + expr = build_formatted_time(expr_type)(args, dialect) timezone = seq_get(args, 2) if timezone: diff --git a/sqlglot/parsers/databricks.py b/sqlglot/parsers/databricks.py index a2d274adc1..c04def15ad 100644 --- a/sqlglot/parsers/databricks.py +++ b/sqlglot/parsers/databricks.py @@ -19,7 +19,7 @@ class DatabricksParser(SparkParser): "DATEDIFF": build_date_delta(exp.DateDiff), "DATE_DIFF": build_date_delta(exp.DateDiff), "NOW": exp.CurrentTimestamp.from_arg_list, - "TO_DATE": build_formatted_time(exp.TsOrDsToDate, "databricks"), + "TO_DATE": build_formatted_time(exp.TsOrDsToDate), "UNIFORM": lambda args: exp.Uniform( this=seq_get(args, 0), expression=seq_get(args, 1), seed=seq_get(args, 2) ), diff --git a/sqlglot/parsers/dremio.py b/sqlglot/parsers/dremio.py index 2d42f8522e..43474a7e39 100644 --- a/sqlglot/parsers/dremio.py +++ b/sqlglot/parsers/dremio.py @@ -101,12 +101,12 @@ class DremioParser(parser.Parser): "BIT_AND": exp.BitwiseAndAgg.from_arg_list, "BIT_OR": exp.BitwiseOrAgg.from_arg_list, "DATE_ADD": build_date_delta_with_cast_interval(exp.DateAdd), - "DATE_FORMAT": build_formatted_time(exp.TimeToStr, "dremio"), + "DATE_FORMAT": build_formatted_time(exp.TimeToStr), "DATE_SUB": build_date_delta_with_cast_interval(exp.DateSub), "REGEXP_MATCHES": exp.RegexpLike.from_arg_list, "REPEATSTR": exp.Repeat.from_arg_list, "TO_CHAR": to_char_is_numeric_handler, - "TO_DATE": build_formatted_time(exp.TsOrDsToDate, "dremio"), + "TO_DATE": build_formatted_time(exp.TsOrDsToDate), "DATE_PART": exp.Extract.from_arg_list, "DATETYPE": datetype_handler, } diff --git a/sqlglot/parsers/drill.py b/sqlglot/parsers/drill.py index 14647c5761..5cc3059697 100644 --- a/sqlglot/parsers/drill.py +++ b/sqlglot/parsers/drill.py @@ -17,7 +17,7 @@ class DrillParser(parser.Parser): **parser.Parser.FUNCTIONS, "REPEATED_COUNT": exp.ArraySize.from_arg_list, "TO_TIMESTAMP": exp.TimeStrToTime.from_arg_list, - "TO_CHAR": build_formatted_time(exp.TimeToStr, "drill"), + "TO_CHAR": build_formatted_time(exp.TimeToStr), "LEVENSHTEIN_DISTANCE": exp.Levenshtein.from_arg_list, } diff --git a/sqlglot/parsers/duckdb.py b/sqlglot/parsers/duckdb.py index cf8ee5e74f..40d144d26d 100644 --- a/sqlglot/parsers/duckdb.py +++ b/sqlglot/parsers/duckdb.py @@ -172,11 +172,11 @@ class DuckDBParser(parser.Parser): single_replace=True, ), "SHA256": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(256)), - "STRFTIME": build_formatted_time(exp.TimeToStr, "duckdb"), + "STRFTIME": build_formatted_time(exp.TimeToStr), "STRING_SPLIT": exp.Split.from_arg_list, "STRING_SPLIT_REGEX": exp.RegexpSplit.from_arg_list, "STRING_TO_ARRAY": exp.Split.from_arg_list, - "STRPTIME": build_formatted_time(exp.StrToTime, "duckdb"), + "STRPTIME": build_formatted_time(exp.StrToTime), "STRUCT_PACK": exp.Struct.from_arg_list, "STR_SPLIT": exp.Split.from_arg_list, "STR_SPLIT_REGEX": exp.RegexpSplit.from_arg_list, diff --git a/sqlglot/parsers/exasol.py b/sqlglot/parsers/exasol.py index 884087e692..0d21560352 100644 --- a/sqlglot/parsers/exasol.py +++ b/sqlglot/parsers/exasol.py @@ -86,7 +86,7 @@ class ExasolParser(parser.Parser): "TRUNC": build_trunc, "TRUNCATE": build_trunc, "TO_CHAR": build_timetostr_or_tochar, - "TO_DATE": build_formatted_time(exp.TsOrDsToDate, "exasol"), + "TO_DATE": build_formatted_time(exp.TsOrDsToDate), "USER": exp.CurrentUser.from_arg_list, "VAR_POP": exp.VariancePop.from_arg_list, "ZEROIFNULL": _build_zeroifnull, diff --git a/sqlglot/parsers/hive.py b/sqlglot/parsers/hive.py index 2349c3c2fa..9b3e8f85f1 100644 --- a/sqlglot/parsers/hive.py +++ b/sqlglot/parsers/hive.py @@ -9,6 +9,7 @@ if t.TYPE_CHECKING: from sqlglot._typing import F + from sqlglot.dialects.dialect import Dialect def build_with_ignore_nulls( @@ -23,8 +24,8 @@ def _parse(args: list[exp.Expr]) -> exp.Expr: return _parse -def _build_to_date(args: list) -> exp.TsOrDsToDate: - expr = build_formatted_time(exp.TsOrDsToDate, "hive")(args) +def _build_to_date(args: list, dialect: Dialect) -> exp.TsOrDsToDate: + expr = build_formatted_time(exp.TsOrDsToDate)(args, dialect) expr.set("safe", True) return expr @@ -74,11 +75,12 @@ class HiveParser(parser.Parser): "DATE_ADD": lambda args: exp.TsOrDsAdd( this=seq_get(args, 0), expression=seq_get(args, 1), unit=exp.Literal.string("DAY") ), - "DATE_FORMAT": lambda args: build_formatted_time(exp.TimeToStr, "hive")( + "DATE_FORMAT": lambda args, dialect: build_formatted_time(exp.TimeToStr)( [ exp.TimeStrToTime(this=seq_get(args, 0)), seq_get(args, 1), - ] + ], + dialect, ), "DATE_SUB": _build_date_add, "DATEDIFF": lambda args: exp.DateDiff( @@ -88,7 +90,7 @@ class HiveParser(parser.Parser): "DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))), "FIRST": build_with_ignore_nulls(exp.First), "FIRST_VALUE": build_with_ignore_nulls(exp.FirstValue), - "FROM_UNIXTIME": build_formatted_time(exp.UnixToStr, "hive", True), + "FROM_UNIXTIME": build_formatted_time(exp.UnixToStr, default=True), "GET_JSON_OBJECT": lambda args, dialect: exp.JSONExtractScalar( this=seq_get(args, 0), expression=dialect.to_json_path(seq_get(args, 1)) ), @@ -111,8 +113,8 @@ class HiveParser(parser.Parser): "TO_JSON": exp.JSONFormat.from_arg_list, "TRUNC": exp.TimestampTrunc.from_arg_list, "UNBASE64": exp.FromBase64.from_arg_list, - "UNIX_TIMESTAMP": lambda args: build_formatted_time(exp.StrToUnix, "hive", True)( - args or [exp.CurrentTimestamp()] + "UNIX_TIMESTAMP": lambda args, dialect: build_formatted_time(exp.StrToUnix, default=True)( + args or [exp.CurrentTimestamp()], dialect ), "YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate.from_arg_list(args)), } diff --git a/sqlglot/parsers/mysql.py b/sqlglot/parsers/mysql.py index 183ded1e63..9e12db7478 100644 --- a/sqlglot/parsers/mysql.py +++ b/sqlglot/parsers/mysql.py @@ -122,7 +122,7 @@ class MySQLParser(parser.Parser): "DAYOFWEEK": lambda args: exp.DayOfWeek(this=exp.TsOrDsToDate(this=seq_get(args, 0))), "DAYOFYEAR": lambda args: exp.DayOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))), "FORMAT": exp.NumberToStr.from_arg_list, - "FROM_UNIXTIME": build_formatted_time(exp.UnixToTime, "mysql"), + "FROM_UNIXTIME": build_formatted_time(exp.UnixToTime), "ISNULL": isnull_to_is_null, "LENGTH": lambda args: exp.Length(this=seq_get(args, 0), binary=True), "MAKETIME": exp.TimeFromParts.from_arg_list, diff --git a/sqlglot/parsers/oracle.py b/sqlglot/parsers/oracle.py index 0aede77c30..f869866ebc 100644 --- a/sqlglot/parsers/oracle.py +++ b/sqlglot/parsers/oracle.py @@ -10,13 +10,14 @@ if t.TYPE_CHECKING: from sqlglot._typing import E + from sqlglot.dialects.dialect import Dialect -def _build_to_timestamp(args: list) -> exp.StrToTime | exp.Anonymous: +def _build_to_timestamp(args: list, dialect: Dialect) -> exp.StrToTime | exp.Anonymous: if len(args) == 1: return exp.Anonymous(this="TO_TIMESTAMP", expressions=args) - return build_formatted_time(exp.StrToTime, "oracle")(args) + return build_formatted_time(exp.StrToTime)(args, dialect) class OracleParser(parser.Parser): @@ -31,7 +32,7 @@ class OracleParser(parser.Parser): "SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)), "TO_CHAR": build_timetostr_or_tochar, "TO_TIMESTAMP": _build_to_timestamp, - "TO_DATE": build_formatted_time(exp.StrToDate, "oracle"), + "TO_DATE": build_formatted_time(exp.StrToDate), "TRUNC": lambda args, dialect: build_trunc( args, dialect, date_trunc_unabbreviate=False, default_date_trunc_unit="DD" ), @@ -82,7 +83,7 @@ class OracleParser(parser.Parser): exp.DType.DATE: lambda self, this, _: self.expression(exp.DateStrToDate(this=this)), # https://docs.oracle.com/en/database/oracle/oracle-database/19/refrn/NLS_TIMESTAMP_FORMAT.html exp.DType.TIMESTAMP: lambda self, this, _: _build_to_timestamp( - [this, '"%Y-%m-%d %H:%M:%S.%f"'] + [this, '"%Y-%m-%d %H:%M:%S.%f"'], self.dialect ), } diff --git a/sqlglot/parsers/postgres.py b/sqlglot/parsers/postgres.py index 2fac622421..81562be9c6 100644 --- a/sqlglot/parsers/postgres.py +++ b/sqlglot/parsers/postgres.py @@ -14,6 +14,9 @@ from sqlglot.parser import binary_range_parser from sqlglot.tokens import TokenType +if t.TYPE_CHECKING: + from sqlglot.dialects.dialect import Dialect + def _build_generate_series(args: list) -> exp.ExplodingGenerateSeries: # The goal is to convert step values like '1 day' or INTERVAL '1 day' into INTERVAL '1' day @@ -28,14 +31,14 @@ def _build_generate_series(args: list) -> exp.ExplodingGenerateSeries: return exp.ExplodingGenerateSeries.from_arg_list(args) -def _build_to_timestamp(args: list) -> exp.UnixToTime | exp.StrToTime: +def _build_to_timestamp(args: list, dialect: Dialect) -> exp.UnixToTime | exp.StrToTime: # TO_TIMESTAMP accepts either a single double argument or (text, text) if len(args) == 1: # https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TABLE return exp.UnixToTime.from_arg_list(args) # https://www.postgresql.org/docs/current/functions-formatting.html - return build_formatted_time(exp.StrToTime, "postgres")(args) + return build_formatted_time(exp.StrToTime)(args, dialect) def _build_regexp_replace(args: list, dialect: DialectType = None) -> exp.RegexpReplace: @@ -116,8 +119,8 @@ class PostgresParser(parser.Parser): "MAKE_TIMESTAMP": exp.TimestampFromParts.from_arg_list, "NOW": exp.CurrentTimestamp.from_arg_list, "REGEXP_REPLACE": _build_regexp_replace, - "TO_CHAR": build_formatted_time(exp.TimeToStr, "postgres"), - "TO_DATE": build_formatted_time(exp.StrToDate, "postgres"), + "TO_CHAR": build_formatted_time(exp.TimeToStr), + "TO_DATE": build_formatted_time(exp.StrToDate), "TO_TIMESTAMP": _build_to_timestamp, "UNNEST": exp.Explode.from_arg_list, "SHA256": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(256)), diff --git a/sqlglot/parsers/presto.py b/sqlglot/parsers/presto.py index f27e54f579..b655e0e8d4 100644 --- a/sqlglot/parsers/presto.py +++ b/sqlglot/parsers/presto.py @@ -1,5 +1,6 @@ from __future__ import annotations +import typing as t from sqlglot import exp, parser from sqlglot.dialects.dialect import ( @@ -12,6 +13,9 @@ from sqlglot.helper import seq_get from sqlglot.tokens import TokenType +if t.TYPE_CHECKING: + from sqlglot.dialects.dialect import Dialect + def _build_approx_percentile(args: list) -> exp.Expr: if len(args) == 4: @@ -41,7 +45,7 @@ def _build_from_unixtime(args: list) -> exp.Expr: return exp.UnixToTime.from_arg_list(args) -def _build_to_char(args: list) -> exp.TimeToStr: +def _build_to_char(args: list, dialect: Dialect) -> exp.TimeToStr: fmt = seq_get(args, 1) if isinstance(fmt, exp.Literal): # We uppercase this to match Teradata's format mapping keys @@ -49,7 +53,7 @@ def _build_to_char(args: list) -> exp.TimeToStr: # We use "teradata" on purpose here, because the time formats are different in Presto. # See https://prestodb.io/docs/current/functions/teradata.html?highlight=to_char#to_char - return build_formatted_time(exp.TimeToStr, "teradata")(args) + return build_formatted_time(exp.TimeToStr, "teradata")(args, dialect) class PrestoParser(parser.Parser): @@ -84,8 +88,8 @@ class PrestoParser(parser.Parser): "DATE_DIFF": lambda args: exp.DateDiff( this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0) ), - "DATE_FORMAT": build_formatted_time(exp.TimeToStr, "presto"), - "DATE_PARSE": build_formatted_time(exp.StrToTime, "presto"), + "DATE_FORMAT": build_formatted_time(exp.TimeToStr), + "DATE_PARSE": build_formatted_time(exp.StrToTime), "DATE_TRUNC": date_trunc_to_time, "DAY_OF_WEEK": exp.DayOfWeekIso.from_arg_list, "DOW": exp.DayOfWeekIso.from_arg_list, diff --git a/sqlglot/parsers/redshift.py b/sqlglot/parsers/redshift.py index 5722123799..699ba673dc 100644 --- a/sqlglot/parsers/redshift.py +++ b/sqlglot/parsers/redshift.py @@ -7,7 +7,7 @@ from sqlglot.parsers.postgres import PostgresParser from sqlglot.parser import build_convert_timezone from sqlglot.tokens import TokenType -from sqlglot.dialects.dialect import build_formatted_time, map_date_part +from sqlglot.dialects.dialect import map_date_part from builtins import type as Type if t.TYPE_CHECKING: @@ -63,13 +63,6 @@ class RedshiftParser(PostgresParser): ), "STRTOL": exp.FromBase.from_arg_list, "TEXTLEN": exp.Length.from_arg_list, - "TO_CHAR": build_formatted_time(exp.TimeToStr, "redshift"), - "TO_DATE": build_formatted_time(exp.StrToDate, "redshift"), - "TO_TIMESTAMP": lambda args: ( - exp.UnixToTime.from_arg_list(args) - if len(args) == 1 - else build_formatted_time(exp.StrToTime, "redshift")(args) - ), } NO_PAREN_FUNCTIONS = { diff --git a/sqlglot/parsers/singlestore.py b/sqlglot/parsers/singlestore.py index 203eba1f2b..8d3815ef1e 100644 --- a/sqlglot/parsers/singlestore.py +++ b/sqlglot/parsers/singlestore.py @@ -23,9 +23,9 @@ def cast_to_time6(expression: exp.Expr | None, time_type: exp.DType = exp.DType. class SingleStoreParser(MySQLParser): FUNCTIONS = { **MySQLParser.FUNCTIONS, - "TO_DATE": build_formatted_time(exp.TsOrDsToDate, "singlestore"), - "TO_TIMESTAMP": build_formatted_time(exp.StrToTime, "singlestore"), - "TO_CHAR": build_formatted_time(exp.ToChar, "singlestore"), + "TO_DATE": build_formatted_time(exp.TsOrDsToDate), + "TO_TIMESTAMP": build_formatted_time(exp.StrToTime), + "TO_CHAR": build_formatted_time(exp.ToChar), "STR_TO_DATE": build_formatted_time(exp.StrToDate, "mysql"), "DATE_FORMAT": build_formatted_time(exp.TimeToStr, "mysql"), # The first argument of following functions is converted to TIME(6) diff --git a/sqlglot/parsers/snowflake.py b/sqlglot/parsers/snowflake.py index b8445bad76..5f7897269a 100644 --- a/sqlglot/parsers/snowflake.py +++ b/sqlglot/parsers/snowflake.py @@ -20,8 +20,9 @@ from sqlglot.tokens import TokenType if t.TYPE_CHECKING: - from sqlglot._typing import B, E from collections.abc import Collection + from sqlglot._typing import B, E + from sqlglot.dialects.dialect import Dialect def _build_approx_top_k(args: list) -> exp.ApproxTopK: @@ -77,8 +78,8 @@ def _build_date_from_parts(args: list) -> exp.DateFromParts: } -def _build_datetime(name: str, kind: exp.DType, safe: bool = False) -> t.Callable[[list], exp.Func]: - def _builder(args: list) -> exp.Func: +def _build_datetime(name: str, kind: exp.DType, safe: bool = False) -> t.Callable: + def _builder(args: list, dialect: Dialect) -> exp.Func: value = seq_get(args, 0) scale_or_fmt = seq_get(args, 1) @@ -107,7 +108,7 @@ def _builder(args: list) -> exp.Func: return unix_expr if scale_or_fmt and not int_scale_or_fmt: # Format string provided (e.g., 'YYYY-MM-DD'), use StrToTime - strtotime_expr = build_formatted_time(exp.StrToTime, "snowflake")(args) + strtotime_expr = build_formatted_time(exp.StrToTime)(args, dialect) strtotime_expr.set("safe", safe) strtotime_expr.set("target_type", kind.into_expr()) return strtotime_expr @@ -116,7 +117,7 @@ def _builder(args: list) -> exp.Func: has_format_string = scale_or_fmt and not int_scale_or_fmt if kind in (exp.DType.DATE, exp.DType.TIME) and (not int_value or has_format_string): klass = exp.TsOrDsToDate if kind == exp.DType.DATE else exp.TsOrDsToTime - formatted_exp = build_formatted_time(klass, "snowflake")(args) + formatted_exp = build_formatted_time(klass)(args, dialect) formatted_exp.set("safe", safe) return formatted_exp diff --git a/sqlglot/parsers/spark2.py b/sqlglot/parsers/spark2.py index 5bb3e625de..c9354eb30a 100644 --- a/sqlglot/parsers/spark2.py +++ b/sqlglot/parsers/spark2.py @@ -59,10 +59,10 @@ class Spark2Parser(HiveParser): "STRING": build_as_cast("string"), "SLICE": exp.ArraySlice.from_arg_list, "TIMESTAMP": build_as_cast("timestamp"), - "TO_TIMESTAMP": lambda args: ( + "TO_TIMESTAMP": lambda args, dialect: ( build_as_cast("timestamp")(args) if len(args) == 1 - else build_formatted_time(exp.StrToTime, "spark")(args) + else build_formatted_time(exp.StrToTime)(args, dialect) ), "TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list, "TO_UTC_TIMESTAMP": lambda args, dialect: exp.FromTimeZone(