Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 15 additions & 12 deletions sqlglot/dialects/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -1540,28 +1540,32 @@ def months_between_sql(self: Generator, expression: exp.MonthsBetween) -> str:


def build_formatted_time(
exp_class: Type[E], dialect: str, default: bool | str | None = None
) -> t.Callable[[BuilderArgs], E]:
exp_class: Type[E], dialect_override: str | None = None, default: bool | str | None = None
) -> t.Callable[[BuilderArgs, Dialect], E]:
"""Helper used for time expressions.

Args:
exp_class: the expression class to instantiate.
dialect: target sql dialect.
dialect_override: optional sql dialect to override the parser's one.
default: the default format, True being time.

Returns:
A callable that can be used to return the appropriately formatted time expression.
"""

def _builder(args: BuilderArgs) -> E:
return exp_class(
this=seq_get(args, 0),
format=Dialect[dialect].format_time(
seq_get(args, 1)
or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
),
def _builder(args: BuilderArgs, dialect: Dialect) -> E:
target_dialect = (
t.cast(Dialect, Dialect[dialect_override])
if isinstance(dialect_override, str)
else dialect
)

fmt = seq_get(args, 1)
if not fmt:
fmt = target_dialect.TIME_FORMAT if default is True else default or None

return exp_class(this=seq_get(args, 0), format=target_dialect.format_time(fmt))

return _builder


Expand Down Expand Up @@ -2465,8 +2469,7 @@ def build_timetostr_or_tochar(
annotate_types(this, dialect=dialect)

if this.is_type(*exp.DataType.TEMPORAL_TYPES):
dialect_name = dialect.__class__.__name__.lower()
return build_formatted_time(exp.TimeToStr, dialect_name, default=True)(args)
return build_formatted_time(exp.TimeToStr, default=True)(args, t.cast(Dialect, dialect))

return exp.ToChar.from_arg_list(args)

Expand Down
8 changes: 6 additions & 2 deletions sqlglot/dialects/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,12 @@ class Redshift(Postgres):

# ref: https://docs.aws.amazon.com/redshift/latest/dg/r_FORMAT_strings.html
TIME_FORMAT = "'YYYY-MM-DD HH24:MI:SS'"
TIME_MAPPING = {**Postgres.TIME_MAPPING, "MON": "%b", "MONTH": "%B"}
INVERSE_TIME_MAPPING = {**Postgres.INVERSE_TIME_MAPPING, "%b": "MON", "%B": "MONTH"}

TIME_MAPPING = {
**Postgres.TIME_MAPPING,
"MON": "%b",
"MONTH": "%B",
}

Parser = RedshiftParser

Expand Down
29 changes: 15 additions & 14 deletions sqlglot/parsers/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

if t.TYPE_CHECKING:
from sqlglot._typing import E
from sqlglot.dialects.dialect import Dialect


def _build_contains_substring(args: list) -> exp.Contains:
Expand Down Expand Up @@ -53,18 +54,18 @@ def _build_datetime(args: list) -> exp.Func:
def _build_extract_json_with_default_path(
expr_type: type[E],
) -> t.Callable:
def _builder(args: list, dialect: t.Any) -> E:
def _builder(args: list, dialect: Dialect) -> E:
if len(args) == 1:
args.append(exp.Literal.string("$"))
return parser.build_extract_json_with_path(expr_type)(args, dialect)

return _builder


def _build_format_time(expr_type: type[exp.Expr]) -> t.Callable[[list], exp.TimeToStr]:
def _builder(args: list) -> exp.TimeToStr:
formatted_time = build_formatted_time(exp.TimeToStr, "bigquery")(
[expr_type(this=seq_get(args, 1)), seq_get(args, 0)]
def _build_format_time(expr_type: type[exp.Expr]) -> t.Callable[[list, Dialect], exp.TimeToStr]:
def _builder(args: list, dialect: Dialect) -> exp.TimeToStr:
formatted_time = build_formatted_time(exp.TimeToStr)(
[expr_type(this=seq_get(args, 1)), seq_get(args, 0)], dialect
)
formatted_time.set("zone", seq_get(args, 2))
return formatted_time
Expand All @@ -91,14 +92,14 @@ def _build_levenshtein(args: list) -> exp.Levenshtein:
)


def _build_parse_timestamp(args: list) -> exp.StrToTime:
this = build_formatted_time(exp.StrToTime, "bigquery")([seq_get(args, 1), seq_get(args, 0)])
def _build_parse_timestamp(args: list, dialect: Dialect) -> exp.StrToTime:
this = build_formatted_time(exp.StrToTime)([seq_get(args, 1), seq_get(args, 0)], dialect)
this.set("zone", seq_get(args, 2))
return this


def _build_regexp_extract(expr_type: type[E], default_group: exp.Expr | None = None) -> t.Callable:
def _builder(args: list, dialect: t.Any) -> E:
def _builder(args: list, dialect: Dialect) -> E:
try:
group = re.compile(args[1].name).groups == 1
except re.error:
Expand Down Expand Up @@ -220,15 +221,15 @@ class BigQueryParser(parser.Parser):
),
"OCTET_LENGTH": exp.ByteLength.from_arg_list,
"TO_HEX": _build_to_hex,
"PARSE_DATE": lambda args: build_formatted_time(exp.StrToDate, "bigquery")(
[seq_get(args, 1), seq_get(args, 0)]
"PARSE_DATE": lambda args, dialect: build_formatted_time(exp.StrToDate)(
[seq_get(args, 1), seq_get(args, 0)], dialect
),
"PARSE_TIME": lambda args: build_formatted_time(exp.ParseTime, "bigquery")(
[seq_get(args, 1), seq_get(args, 0)]
"PARSE_TIME": lambda args, dialect: build_formatted_time(exp.ParseTime)(
[seq_get(args, 1), seq_get(args, 0)], dialect
),
"PARSE_TIMESTAMP": _build_parse_timestamp,
"PARSE_DATETIME": lambda args: build_formatted_time(exp.ParseDatetime, "bigquery")(
[seq_get(args, 1), seq_get(args, 0)]
"PARSE_DATETIME": lambda args, dialect: build_formatted_time(exp.ParseDatetime)(
[seq_get(args, 1), seq_get(args, 0)], dialect
),
"REGEXP_CONTAINS": exp.RegexpLike.from_arg_list,
"REGEXP_EXTRACT": _build_regexp_extract(exp.RegexpExtract),
Expand Down
6 changes: 3 additions & 3 deletions sqlglot/parsers/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@

def _build_datetime_format(
expr_type: Type[E],
) -> t.Callable[[list], E]:
def _builder(args: list) -> E:
expr = build_formatted_time(expr_type, "clickhouse")(args)
) -> t.Callable:
def _builder(args: list, dialect: t.Any) -> E:
expr = build_formatted_time(expr_type)(args, dialect)

timezone = seq_get(args, 2)
if timezone:
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/parsers/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class DatabricksParser(SparkParser):
"DATEDIFF": build_date_delta(exp.DateDiff),
"DATE_DIFF": build_date_delta(exp.DateDiff),
"NOW": exp.CurrentTimestamp.from_arg_list,
"TO_DATE": build_formatted_time(exp.TsOrDsToDate, "databricks"),
"TO_DATE": build_formatted_time(exp.TsOrDsToDate),
"UNIFORM": lambda args: exp.Uniform(
this=seq_get(args, 0), expression=seq_get(args, 1), seed=seq_get(args, 2)
),
Expand Down
4 changes: 2 additions & 2 deletions sqlglot/parsers/dremio.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,12 @@ class DremioParser(parser.Parser):
"BIT_AND": exp.BitwiseAndAgg.from_arg_list,
"BIT_OR": exp.BitwiseOrAgg.from_arg_list,
"DATE_ADD": build_date_delta_with_cast_interval(exp.DateAdd),
"DATE_FORMAT": build_formatted_time(exp.TimeToStr, "dremio"),
"DATE_FORMAT": build_formatted_time(exp.TimeToStr),
"DATE_SUB": build_date_delta_with_cast_interval(exp.DateSub),
"REGEXP_MATCHES": exp.RegexpLike.from_arg_list,
"REPEATSTR": exp.Repeat.from_arg_list,
"TO_CHAR": to_char_is_numeric_handler,
"TO_DATE": build_formatted_time(exp.TsOrDsToDate, "dremio"),
"TO_DATE": build_formatted_time(exp.TsOrDsToDate),
"DATE_PART": exp.Extract.from_arg_list,
"DATETYPE": datetype_handler,
}
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/parsers/drill.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class DrillParser(parser.Parser):
**parser.Parser.FUNCTIONS,
"REPEATED_COUNT": exp.ArraySize.from_arg_list,
"TO_TIMESTAMP": exp.TimeStrToTime.from_arg_list,
"TO_CHAR": build_formatted_time(exp.TimeToStr, "drill"),
"TO_CHAR": build_formatted_time(exp.TimeToStr),
"LEVENSHTEIN_DISTANCE": exp.Levenshtein.from_arg_list,
}

Expand Down
4 changes: 2 additions & 2 deletions sqlglot/parsers/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,11 +172,11 @@ class DuckDBParser(parser.Parser):
single_replace=True,
),
"SHA256": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(256)),
"STRFTIME": build_formatted_time(exp.TimeToStr, "duckdb"),
"STRFTIME": build_formatted_time(exp.TimeToStr),
"STRING_SPLIT": exp.Split.from_arg_list,
"STRING_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
"STRING_TO_ARRAY": exp.Split.from_arg_list,
"STRPTIME": build_formatted_time(exp.StrToTime, "duckdb"),
"STRPTIME": build_formatted_time(exp.StrToTime),
"STRUCT_PACK": exp.Struct.from_arg_list,
"STR_SPLIT": exp.Split.from_arg_list,
"STR_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/parsers/exasol.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ class ExasolParser(parser.Parser):
"TRUNC": build_trunc,
"TRUNCATE": build_trunc,
"TO_CHAR": build_timetostr_or_tochar,
"TO_DATE": build_formatted_time(exp.TsOrDsToDate, "exasol"),
"TO_DATE": build_formatted_time(exp.TsOrDsToDate),
"USER": exp.CurrentUser.from_arg_list,
"VAR_POP": exp.VariancePop.from_arg_list,
"ZEROIFNULL": _build_zeroifnull,
Expand Down
16 changes: 9 additions & 7 deletions sqlglot/parsers/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

if t.TYPE_CHECKING:
from sqlglot._typing import F
from sqlglot.dialects.dialect import Dialect


def build_with_ignore_nulls(
Expand All @@ -23,8 +24,8 @@ def _parse(args: list[exp.Expr]) -> exp.Expr:
return _parse


def _build_to_date(args: list) -> exp.TsOrDsToDate:
expr = build_formatted_time(exp.TsOrDsToDate, "hive")(args)
def _build_to_date(args: list, dialect: Dialect) -> exp.TsOrDsToDate:
expr = build_formatted_time(exp.TsOrDsToDate)(args, dialect)
expr.set("safe", True)
return expr

Expand Down Expand Up @@ -74,11 +75,12 @@ class HiveParser(parser.Parser):
"DATE_ADD": lambda args: exp.TsOrDsAdd(
this=seq_get(args, 0), expression=seq_get(args, 1), unit=exp.Literal.string("DAY")
),
"DATE_FORMAT": lambda args: build_formatted_time(exp.TimeToStr, "hive")(
"DATE_FORMAT": lambda args, dialect: build_formatted_time(exp.TimeToStr)(
[
exp.TimeStrToTime(this=seq_get(args, 0)),
seq_get(args, 1),
]
],
dialect,
),
"DATE_SUB": _build_date_add,
"DATEDIFF": lambda args: exp.DateDiff(
Expand All @@ -88,7 +90,7 @@ class HiveParser(parser.Parser):
"DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"FIRST": build_with_ignore_nulls(exp.First),
"FIRST_VALUE": build_with_ignore_nulls(exp.FirstValue),
"FROM_UNIXTIME": build_formatted_time(exp.UnixToStr, "hive", True),
"FROM_UNIXTIME": build_formatted_time(exp.UnixToStr, default=True),
"GET_JSON_OBJECT": lambda args, dialect: exp.JSONExtractScalar(
this=seq_get(args, 0), expression=dialect.to_json_path(seq_get(args, 1))
),
Expand All @@ -111,8 +113,8 @@ class HiveParser(parser.Parser):
"TO_JSON": exp.JSONFormat.from_arg_list,
"TRUNC": exp.TimestampTrunc.from_arg_list,
"UNBASE64": exp.FromBase64.from_arg_list,
"UNIX_TIMESTAMP": lambda args: build_formatted_time(exp.StrToUnix, "hive", True)(
args or [exp.CurrentTimestamp()]
"UNIX_TIMESTAMP": lambda args, dialect: build_formatted_time(exp.StrToUnix, default=True)(
args or [exp.CurrentTimestamp()], dialect
),
"YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate.from_arg_list(args)),
}
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/parsers/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class MySQLParser(parser.Parser):
"DAYOFWEEK": lambda args: exp.DayOfWeek(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"DAYOFYEAR": lambda args: exp.DayOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"FORMAT": exp.NumberToStr.from_arg_list,
"FROM_UNIXTIME": build_formatted_time(exp.UnixToTime, "mysql"),
"FROM_UNIXTIME": build_formatted_time(exp.UnixToTime),
"ISNULL": isnull_to_is_null,
"LENGTH": lambda args: exp.Length(this=seq_get(args, 0), binary=True),
"MAKETIME": exp.TimeFromParts.from_arg_list,
Expand Down
9 changes: 5 additions & 4 deletions sqlglot/parsers/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@

if t.TYPE_CHECKING:
from sqlglot._typing import E
from sqlglot.dialects.dialect import Dialect


def _build_to_timestamp(args: list) -> exp.StrToTime | exp.Anonymous:
def _build_to_timestamp(args: list, dialect: Dialect) -> exp.StrToTime | exp.Anonymous:
if len(args) == 1:
return exp.Anonymous(this="TO_TIMESTAMP", expressions=args)

return build_formatted_time(exp.StrToTime, "oracle")(args)
return build_formatted_time(exp.StrToTime)(args, dialect)


class OracleParser(parser.Parser):
Expand All @@ -31,7 +32,7 @@ class OracleParser(parser.Parser):
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
"TO_CHAR": build_timetostr_or_tochar,
"TO_TIMESTAMP": _build_to_timestamp,
"TO_DATE": build_formatted_time(exp.StrToDate, "oracle"),
"TO_DATE": build_formatted_time(exp.StrToDate),
"TRUNC": lambda args, dialect: build_trunc(
args, dialect, date_trunc_unabbreviate=False, default_date_trunc_unit="DD"
),
Expand Down Expand Up @@ -82,7 +83,7 @@ class OracleParser(parser.Parser):
exp.DType.DATE: lambda self, this, _: self.expression(exp.DateStrToDate(this=this)),
# https://docs.oracle.com/en/database/oracle/oracle-database/19/refrn/NLS_TIMESTAMP_FORMAT.html
exp.DType.TIMESTAMP: lambda self, this, _: _build_to_timestamp(
[this, '"%Y-%m-%d %H:%M:%S.%f"']
[this, '"%Y-%m-%d %H:%M:%S.%f"'], self.dialect
),
}

Expand Down
11 changes: 7 additions & 4 deletions sqlglot/parsers/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from sqlglot.parser import binary_range_parser
from sqlglot.tokens import TokenType

if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import Dialect


def _build_generate_series(args: list) -> exp.ExplodingGenerateSeries:
# The goal is to convert step values like '1 day' or INTERVAL '1 day' into INTERVAL '1' day
Expand All @@ -28,14 +31,14 @@ def _build_generate_series(args: list) -> exp.ExplodingGenerateSeries:
return exp.ExplodingGenerateSeries.from_arg_list(args)


def _build_to_timestamp(args: list) -> exp.UnixToTime | exp.StrToTime:
def _build_to_timestamp(args: list, dialect: Dialect) -> exp.UnixToTime | exp.StrToTime:
# TO_TIMESTAMP accepts either a single double argument or (text, text)
if len(args) == 1:
# https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TABLE
return exp.UnixToTime.from_arg_list(args)

# https://www.postgresql.org/docs/current/functions-formatting.html
return build_formatted_time(exp.StrToTime, "postgres")(args)
return build_formatted_time(exp.StrToTime)(args, dialect)


def _build_regexp_replace(args: list, dialect: DialectType = None) -> exp.RegexpReplace:
Expand Down Expand Up @@ -116,8 +119,8 @@ class PostgresParser(parser.Parser):
"MAKE_TIMESTAMP": exp.TimestampFromParts.from_arg_list,
"NOW": exp.CurrentTimestamp.from_arg_list,
"REGEXP_REPLACE": _build_regexp_replace,
"TO_CHAR": build_formatted_time(exp.TimeToStr, "postgres"),
"TO_DATE": build_formatted_time(exp.StrToDate, "postgres"),
"TO_CHAR": build_formatted_time(exp.TimeToStr),
"TO_DATE": build_formatted_time(exp.StrToDate),
"TO_TIMESTAMP": _build_to_timestamp,
"UNNEST": exp.Explode.from_arg_list,
"SHA256": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(256)),
Expand Down
12 changes: 8 additions & 4 deletions sqlglot/parsers/presto.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import typing as t

from sqlglot import exp, parser
from sqlglot.dialects.dialect import (
Expand All @@ -12,6 +13,9 @@
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType

if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import Dialect


def _build_approx_percentile(args: list) -> exp.Expr:
if len(args) == 4:
Expand Down Expand Up @@ -41,15 +45,15 @@ def _build_from_unixtime(args: list) -> exp.Expr:
return exp.UnixToTime.from_arg_list(args)


def _build_to_char(args: list) -> exp.TimeToStr:
def _build_to_char(args: list, dialect: Dialect) -> exp.TimeToStr:
fmt = seq_get(args, 1)
if isinstance(fmt, exp.Literal):
# We uppercase this to match Teradata's format mapping keys
fmt.set("this", fmt.this.upper())

# We use "teradata" on purpose here, because the time formats are different in Presto.
# See https://prestodb.io/docs/current/functions/teradata.html?highlight=to_char#to_char
return build_formatted_time(exp.TimeToStr, "teradata")(args)
return build_formatted_time(exp.TimeToStr, "teradata")(args, dialect)


class PrestoParser(parser.Parser):
Expand Down Expand Up @@ -84,8 +88,8 @@ class PrestoParser(parser.Parser):
"DATE_DIFF": lambda args: exp.DateDiff(
this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)
),
"DATE_FORMAT": build_formatted_time(exp.TimeToStr, "presto"),
"DATE_PARSE": build_formatted_time(exp.StrToTime, "presto"),
"DATE_FORMAT": build_formatted_time(exp.TimeToStr),
"DATE_PARSE": build_formatted_time(exp.StrToTime),
"DATE_TRUNC": date_trunc_to_time,
"DAY_OF_WEEK": exp.DayOfWeekIso.from_arg_list,
"DOW": exp.DayOfWeekIso.from_arg_list,
Expand Down
Loading
Loading