diff --git a/src/postgres/def/column.rs b/src/postgres/def/column.rs index a17a6b8..4f80f2c 100644 --- a/src/postgres/def/column.rs +++ b/src/postgres/def/column.rs @@ -11,9 +11,9 @@ pub struct ColumnInfo { /// The type of the column with any additional definitions such as the precision or the character /// set pub col_type: ColumnType, - /// The default value experssion for this column, if any - pub default: Option, - /// The generation expression for this column, if it is a generated colum + /// The default value for this column, if any + pub default: Option, + /// The generation expression for this column, if it is a generated column pub generated: Option, pub not_null: Option, pub is_identity: bool, @@ -34,6 +34,20 @@ pub struct ColumnInfo { pub type ColumnType = Type; +#[derive(Debug, Clone, PartialEq)] +#[cfg_attr(feature = "with-serde", derive(Serialize, Deserialize))] +pub enum ColumnDefault { + Int(i64), + Real(f64), + String(String), + Bool(bool), + CurrentTimestamp, + /// A sequence default, e.g. `nextval('table_id_seq'::regclass)` + AutoIncrement(String), + /// Any other expression not covered by the above variants + Expression(String), +} + #[derive(Debug, Clone, PartialEq)] #[cfg_attr(feature = "with-serde", derive(Serialize, Deserialize))] pub struct ColumnExpression(pub String); diff --git a/src/postgres/parser/column.rs b/src/postgres/parser/column.rs index 0202dec..8d917ee 100644 --- a/src/postgres/parser/column.rs +++ b/src/postgres/parser/column.rs @@ -13,13 +13,49 @@ pub fn parse_column_query_result(result: ColumnQueryResult, enums: &EnumVariantM ColumnInfo { name: result.column_name.clone(), col_type: parse_column_type(&result, enums), - default: ColumnExpression::from_option_string(result.column_default), + default: parse_column_default(result.column_default), generated: ColumnExpression::from_option_string(result.column_generated), not_null: NotNull::from_bool(!yes_or_no_to_bool(&result.is_nullable)), is_identity: yes_or_no_to_bool(&result.is_identity), } } +pub fn parse_column_default(default: Option) -> Option { + let default = default?; + if default.is_empty() { + return None; + } + // Trim may be redundant + let def_trim = default.trim(); + + Some(if def_trim.starts_with("nextval") { + ColumnDefault::AutoIncrement(default) + } else if def_trim == "now()" || def_trim == "CURRENT_TIMESTAMP" { + ColumnDefault::CurrentTimestamp + } else if def_trim == "true" { + ColumnDefault::Bool(true) + } else if def_trim == "false" { + ColumnDefault::Bool(false) + } else if let Ok(int) = def_trim.parse::() { + ColumnDefault::Int(int) + } else if let Ok(real) = def_trim.parse::() { + ColumnDefault::Real(real) + } else { + // Check for quoted string literals like 'value'::type or plain 'value' + if let Some(inner) = def_trim.strip_prefix('\'') { + // Find the closing quote — handles 'value'::type_cast + if let Some(end) = inner.find('\'') { + let string_value = inner[..end].to_owned(); + let suffix = &inner[end + 1..]; + if suffix.is_empty() { + return Some(ColumnDefault::String(string_value)); + } + } + } + ColumnDefault::Expression(default) + }) +} + pub fn parse_column_type(result: &ColumnQueryResult, enums: &EnumVariantMap) -> ColumnType { let is_enum = result .udt_name diff --git a/src/postgres/writer/column.rs b/src/postgres/writer/column.rs index 212675a..5aca1dd 100644 --- a/src/postgres/writer/column.rs +++ b/src/postgres/writer/column.rs @@ -1,19 +1,15 @@ -use crate::postgres::def::{ColumnInfo, Type}; -use sea_query::{Alias, ColumnDef, ColumnType, DynIden, IntoIden, PgInterval, RcOrArc, StringLen}; -use std::{convert::TryFrom, fmt::Write}; +use crate::postgres::def::{ColumnDefault, ColumnInfo, Type}; +use sea_query::{ + Alias, ColumnDef, ColumnType, DynIden, Expr, IntoIden, Keyword, PgInterval, RcOrArc, + SimpleExpr, StringLen, +}; +use std::convert::TryFrom; impl ColumnInfo { pub fn write(&self) -> ColumnDef { let mut col_info = self.clone(); - let mut extras: Vec = Vec::new(); - if let Some(default) = self.default.as_ref() { - if default.0.starts_with("nextval") { - col_info = Self::convert_to_serial(col_info); - } else { - let mut string = "".to_owned(); - write!(&mut string, "DEFAULT {}", default.0).unwrap(); - extras.push(string); - } + if let Some(ColumnDefault::AutoIncrement(_)) = &self.default { + col_info = Self::convert_to_serial(col_info); } let col_type = col_info.write_col_type(); let mut col_def = ColumnDef::new_with_type(Alias::new(self.name.as_str()), col_type); @@ -29,8 +25,10 @@ impl ColumnInfo { if self.not_null.is_some() { col_def.not_null(); } - if !extras.is_empty() { - col_def.extra(extras.join(" ")); + if let Some(default) = &self.default { + if let Some(default_expr) = default.write() { + col_def.default(default_expr); + } } col_def } @@ -144,3 +142,20 @@ impl ColumnInfo { write_type(&self.col_type) } } + +impl ColumnDefault { + /// Convert to a [SimpleExpr] for use with `col_def.default()`. + /// Returns `None` for [ColumnDefault::AutoIncrement] since those are handled + /// via SERIAL type conversion instead. + pub fn write(&self) -> Option { + match self { + ColumnDefault::Int(int) => Some((*int).into()), + ColumnDefault::Real(real) => Some((*real).into()), + ColumnDefault::String(string) => Some(string.into()), + ColumnDefault::Bool(val) => Some(Expr::val(*val)), + ColumnDefault::CurrentTimestamp => Some(Keyword::CurrentTimestamp.into()), + ColumnDefault::AutoIncrement(_) => None, + ColumnDefault::Expression(expr) => Some(Expr::cust(expr.to_owned())), + } + } +} diff --git a/tests/discovery/postgres/schema.rs b/tests/discovery/postgres/schema.rs index 18dc5fd..8352ff0 100644 --- a/tests/discovery/postgres/schema.rs +++ b/tests/discovery/postgres/schema.rs @@ -11,7 +11,7 @@ Schema { name: "actor_id", col_type: Integer, default: Some( - ColumnExpression( + AutoIncrement( "nextval(\'actor_actor_id_seq\'::regclass)", ), ), @@ -60,9 +60,7 @@ Schema { }, ), default: Some( - ColumnExpression( - "now()", - ), + CurrentTimestamp, ), generated: None, not_null: Some( @@ -115,7 +113,7 @@ Schema { name: "film_id", col_type: Integer, default: Some( - ColumnExpression( + AutoIncrement( "nextval(\'film_film_id_seq\'::regclass)", ), ), @@ -173,8 +171,8 @@ Schema { name: "rental_duration", col_type: SmallInt, default: Some( - ColumnExpression( - "3", + Int( + 3, ), ), generated: None, @@ -195,8 +193,8 @@ Schema { }, ), default: Some( - ColumnExpression( - "4.99", + Real( + 4.99, ), ), generated: None, @@ -224,8 +222,8 @@ Schema { }, ), default: Some( - ColumnExpression( - "19.99", + Real( + 19.99, ), ), generated: None, @@ -239,7 +237,7 @@ Schema { "USER-DEFINED", ), default: Some( - ColumnExpression( + Expression( "\'G\'::mpaa_rating", ), ), @@ -256,9 +254,7 @@ Schema { }, ), default: Some( - ColumnExpression( - "now()", - ), + CurrentTimestamp, ), generated: None, not_null: Some( @@ -380,7 +376,7 @@ Schema { name: "payment_id", col_type: Integer, default: Some( - ColumnExpression( + AutoIncrement( "nextval(\'payment_payment_id_seq\'::regclass)", ), ), @@ -552,7 +548,7 @@ Schema { name: "payment_id", col_type: Integer, default: Some( - ColumnExpression( + AutoIncrement( "nextval(\'payment_payment_id_seq\'::regclass)", ), ), @@ -724,7 +720,7 @@ Schema { name: "payment_id", col_type: Integer, default: Some( - ColumnExpression( + AutoIncrement( "nextval(\'payment_payment_id_seq\'::regclass)", ), ), @@ -896,7 +892,7 @@ Schema { name: "payment_id", col_type: Integer, default: Some( - ColumnExpression( + AutoIncrement( "nextval(\'payment_payment_id_seq\'::regclass)", ), ), @@ -1068,7 +1064,7 @@ Schema { name: "payment_id", col_type: Integer, default: Some( - ColumnExpression( + AutoIncrement( "nextval(\'payment_payment_id_seq\'::regclass)", ), ), @@ -1240,7 +1236,7 @@ Schema { name: "payment_id", col_type: Integer, default: Some( - ColumnExpression( + AutoIncrement( "nextval(\'payment_payment_id_seq\'::regclass)", ), ), @@ -1412,7 +1408,7 @@ Schema { name: "address_id", col_type: Integer, default: Some( - ColumnExpression( + AutoIncrement( "nextval(\'address_address_id_seq\'::regclass)", ), ), @@ -1511,9 +1507,7 @@ Schema { }, ), default: Some( - ColumnExpression( - "now()", - ), + CurrentTimestamp, ), generated: None, not_null: Some( @@ -1593,7 +1587,7 @@ Schema { name: "category_id", col_type: Integer, default: Some( - ColumnExpression( + AutoIncrement( "nextval(\'category_category_id_seq\'::regclass)", ), ), @@ -1627,9 +1621,7 @@ Schema { }, ), default: Some( - ColumnExpression( - "now()", - ), + CurrentTimestamp, ), generated: None, not_null: Some( @@ -1677,7 +1669,7 @@ Schema { name: "city_id", col_type: Integer, default: Some( - ColumnExpression( + AutoIncrement( "nextval(\'city_city_id_seq\'::regclass)", ), ), @@ -1720,9 +1712,7 @@ Schema { }, ), default: Some( - ColumnExpression( - "now()", - ), + CurrentTimestamp, ), generated: None, not_null: Some( @@ -1792,7 +1782,7 @@ Schema { name: "country_id", col_type: Integer, default: Some( - ColumnExpression( + AutoIncrement( "nextval(\'country_country_id_seq\'::regclass)", ), ), @@ -1826,9 +1816,7 @@ Schema { }, ), default: Some( - ColumnExpression( - "now()", - ), + CurrentTimestamp, ), generated: None, not_null: Some( @@ -1876,7 +1864,7 @@ Schema { name: "customer_id", col_type: Integer, default: Some( - ColumnExpression( + AutoIncrement( "nextval(\'customer_customer_id_seq\'::regclass)", ), ), @@ -1950,8 +1938,8 @@ Schema { name: "activebool", col_type: Boolean, default: Some( - ColumnExpression( - "true", + Bool( + true, ), ), generated: None, @@ -1963,7 +1951,7 @@ Schema { name: "create_date", col_type: Date, default: Some( - ColumnExpression( + Expression( "(\'now\'::text)::date", ), ), @@ -1982,9 +1970,7 @@ Schema { }, ), default: Some( - ColumnExpression( - "now()", - ), + CurrentTimestamp, ), generated: None, not_null: None, @@ -2114,9 +2100,7 @@ Schema { }, ), default: Some( - ColumnExpression( - "now()", - ), + CurrentTimestamp, ), generated: None, not_null: Some( @@ -2222,9 +2206,7 @@ Schema { }, ), default: Some( - ColumnExpression( - "now()", - ), + CurrentTimestamp, ), generated: None, not_null: Some( @@ -2306,7 +2288,7 @@ Schema { name: "inventory_id", col_type: Integer, default: Some( - ColumnExpression( + AutoIncrement( "nextval(\'inventory_inventory_id_seq\'::regclass)", ), ), @@ -2343,9 +2325,7 @@ Schema { }, ), default: Some( - ColumnExpression( - "now()", - ), + CurrentTimestamp, ), generated: None, not_null: Some( @@ -2431,7 +2411,7 @@ Schema { name: "language_id", col_type: Integer, default: Some( - ColumnExpression( + AutoIncrement( "nextval(\'language_language_id_seq\'::regclass)", ), ), @@ -2465,9 +2445,7 @@ Schema { }, ), default: Some( - ColumnExpression( - "now()", - ), + CurrentTimestamp, ), generated: None, not_null: Some( @@ -2515,7 +2493,7 @@ Schema { name: "rental_id", col_type: Integer, default: Some( - ColumnExpression( + AutoIncrement( "nextval(\'rental_rental_id_seq\'::regclass)", ), ), @@ -2589,9 +2567,7 @@ Schema { }, ), default: Some( - ColumnExpression( - "now()", - ), + CurrentTimestamp, ), generated: None, not_null: Some( @@ -2703,7 +2679,7 @@ Schema { name: "staff_id", col_type: Integer, default: Some( - ColumnExpression( + AutoIncrement( "nextval(\'staff_staff_id_seq\'::regclass)", ), ), @@ -2777,8 +2753,8 @@ Schema { name: "active", col_type: Boolean, default: Some( - ColumnExpression( - "true", + Bool( + true, ), ), generated: None, @@ -2824,9 +2800,7 @@ Schema { }, ), default: Some( - ColumnExpression( - "now()", - ), + CurrentTimestamp, ), generated: None, not_null: Some( @@ -2939,7 +2913,7 @@ Schema { name: "store_id", col_type: Integer, default: Some( - ColumnExpression( + AutoIncrement( "nextval(\'store_store_id_seq\'::regclass)", ), ), @@ -2976,9 +2950,7 @@ Schema { }, ), default: Some( - ColumnExpression( - "now()", - ), + CurrentTimestamp, ), generated: None, not_null: Some( @@ -3064,7 +3036,7 @@ Schema { name: "payment_id", col_type: Integer, default: Some( - ColumnExpression( + AutoIncrement( "nextval(\'payment_payment_id_seq\'::regclass)", ), ), diff --git a/tests/live/postgres/src/main.rs b/tests/live/postgres/src/main.rs index 18b1baf..7c2c6d8 100644 --- a/tests/live/postgres/src/main.rs +++ b/tests/live/postgres/src/main.rs @@ -219,7 +219,9 @@ fn create_order_table() -> TableCreateStatement { ColumnDef::new("updated") .date_time() .not_null() - .extra("DEFAULT '2023-06-07 16:24:00'::timestamp without time zone"), + .default(Expr::cust( + "'2023-06-07 16:24:00'::timestamp without time zone", + )), ) .col( ColumnDef::new("net_weight")