Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sea-orm-sync/src/entity/prelude.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
pub use crate::{
ActiveEnum, ActiveModelBehavior, ActiveModelTrait, ColumnDef, ColumnTrait, ColumnType,
ColumnTypeTrait, ConnectionTrait, CursorTrait, DatabaseConnection, DbConn, EntityName,
ColumnTypeTrait, ConnectionTrait, CountTrait, CursorTrait, DatabaseConnection, DbConn, EntityName,
EntityTrait, EnumIter, ForeignKeyAction, Iden, IdenStatic, Linked, LoaderTrait, ModelTrait,
PaginatorTrait, PrimaryKeyArity, PrimaryKeyToColumn, PrimaryKeyTrait, QueryFilter, QueryResult,
Related, RelatedSelfVia, RelationDef, RelationTrait, Select, SelectExt, Value,
Expand Down
14 changes: 1 addition & 13 deletions sea-orm-sync/src/executor/paginator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,7 @@ where
Some(res) => res,
None => return Ok(0),
};
#[allow(clippy::match_single_binding)]
let num_items = match self.db.get_database_backend() {
_ => result.try_get::<i64>("", "num_items")? as u64,
};
Ok(num_items)
Ok(result.try_get::<i64>("", "num_items")? as u64)
}

/// Get the total number of pages
Expand Down Expand Up @@ -271,14 +267,6 @@ where

/// Paginate the result of a select operation.
fn paginate(self, db: &'db C, page_size: u64) -> Paginator<'db, C, Self::Selector>;

/// Perform a count on the paginated results
fn count(self, db: &'db C) -> Result<u64, DbErr>
where
Self: Sized,
{
self.paginate(db, 1).num_items()
}
}

impl<'db, C, S> PaginatorTrait<'db, C> for Selector<S>
Expand Down
69 changes: 67 additions & 2 deletions sea-orm-sync/src/executor/select_ext.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
use crate::{
ConnectionTrait, DbErr, EntityTrait, Select, SelectFive, SelectFour, SelectSix, SelectThree,
SelectTwo, Selector, SelectorRaw, SelectorTrait, Topology,
QueryTrait, SelectTwo, Selector, SelectorRaw, SelectorTrait, Statement, Topology,
};
use sea_query::{Expr, SelectStatement};

// TODO: Move count here
/// Helper trait for selectors with convenient methods
pub trait SelectExt {
/// This method is unstable and is only used for internal testing.
Expand All @@ -22,6 +21,14 @@ pub trait SelectExt {
}
}

/// Helper trait for counting rows selected by a query.
pub trait CountTrait {
/// Count the number of rows selected by this query.
fn count(self, db: &impl ConnectionTrait) -> Result<u64, DbErr>
where
Self: Sized;
}

fn into_exists_query(mut stmt: SelectStatement) -> SelectStatement {
stmt.clear_selects();
// Expr::Custom has fewer branches, but this may not have any significant impact on performance.
Expand All @@ -32,6 +39,37 @@ fn into_exists_query(mut stmt: SelectStatement) -> SelectStatement {
stmt
}

fn build_count_query(stmt: SelectStatement) -> SelectStatement {
SelectStatement::new()
.expr(Expr::cust("COUNT(*) AS count"))
.from_subquery(stmt, "sub_query")
.to_owned()
}

fn build_count_query_raw(stmt: Statement) -> SelectStatement {
let sub_query_sql = stmt.sql.trim().trim_end_matches(';').trim();
let count_sql = format!("COUNT(*) AS count FROM ({sub_query_sql}) AS sub_query");

let mut query = SelectStatement::new();
query.expr(if let Some(values) = stmt.values {
Expr::cust_with_values(count_sql, values.0)
} else {
Expr::cust(count_sql)
});
query
}

fn exec_count<C>(db: &C, stmt: SelectStatement) -> Result<u64, DbErr>
where
C: ConnectionTrait,
{
let result = match db.query_one(&stmt)? {
Some(res) => res,
None => return Ok(0),
};
Ok(result.try_get::<i64>("", "count")? as u64)
}

impl<S> SelectExt for Selector<S>
where
S: SelectorTrait,
Expand All @@ -41,6 +79,15 @@ where
}
}

impl<S> CountTrait for Selector<S>
where
S: SelectorTrait,
{
fn count(self, db: &impl ConnectionTrait) -> Result<u64, DbErr> {
exec_count(db, build_count_query(self.query))
}
}

impl<S> SelectExt for SelectorRaw<S>
where
S: SelectorTrait,
Expand All @@ -60,6 +107,15 @@ where
}
}

impl<S> CountTrait for SelectorRaw<S>
where
S: SelectorTrait,
{
fn count(self, db: &impl ConnectionTrait) -> Result<u64, DbErr> {
exec_count(db, build_count_query_raw(self.stmt))
}
}

impl<E> SelectExt for Select<E>
where
E: EntityTrait,
Expand Down Expand Up @@ -133,6 +189,15 @@ where
}
}

impl<T> CountTrait for T
where
T: QueryTrait<QueryStatement = SelectStatement>,
{
fn count(self, db: &impl ConnectionTrait) -> Result<u64, DbErr> {
exec_count(db, build_count_query(self.into_query()))
}
}

#[cfg(test)]
mod tests {
use super::SelectExt;
Expand Down
2 changes: 1 addition & 1 deletion sea-orm-sync/src/query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ pub use update::*;
pub(crate) use util::*;

pub use crate::{
ConnectionTrait, CursorTrait, InsertResult, PaginatorTrait, SelectExt, Statement,
ConnectionTrait, CountTrait, CursorTrait, InsertResult, PaginatorTrait, SelectExt, Statement,
TransactionTrait, UpdateResult, Value, Values,
};
pub use sea_query::ExprTrait;
Expand Down
8 changes: 4 additions & 4 deletions src/entity/prelude.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
pub use crate::{
ActiveEnum, ActiveModelBehavior, ActiveModelTrait, ColumnDef, ColumnTrait, ColumnType,
ColumnTypeTrait, ConnectionTrait, CursorTrait, DatabaseConnection, DbConn, EntityName,
EntityTrait, EnumIter, ForeignKeyAction, Iden, IdenStatic, Linked, LoaderTrait, ModelTrait,
PaginatorTrait, PrimaryKeyArity, PrimaryKeyToColumn, PrimaryKeyTrait, QueryFilter, QueryResult,
Related, RelatedSelfVia, RelationDef, RelationTrait, Select, SelectExt, Value,
ColumnTypeTrait, ConnectionTrait, CountTrait, CursorTrait, DatabaseConnection, DbConn,
EntityName, EntityTrait, EnumIter, ForeignKeyAction, Iden, IdenStatic, Linked, LoaderTrait,
ModelTrait, PaginatorTrait, PrimaryKeyArity, PrimaryKeyToColumn, PrimaryKeyTrait, QueryFilter,
QueryResult, Related, RelatedSelfVia, RelationDef, RelationTrait, Select, SelectExt, Value,
error::*,
sea_query::{DynIden, Expr, RcOrArc, SeaRc, StringLen},
};
Expand Down
14 changes: 1 addition & 13 deletions src/executor/paginator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,7 @@ where
Some(res) => res,
None => return Ok(0),
};
#[allow(clippy::match_single_binding)]
let num_items = match self.db.get_database_backend() {
_ => result.try_get::<i64>("", "num_items")? as u64,
};
Ok(num_items)
Ok(result.try_get::<i64>("", "num_items")? as u64)
}

/// Get the total number of pages
Expand Down Expand Up @@ -278,14 +274,6 @@ where

/// Paginate the result of a select operation.
fn paginate(self, db: &'db C, page_size: u64) -> Paginator<'db, C, Self::Selector>;

/// Perform a count on the paginated results
async fn count(self, db: &'db C) -> Result<u64, DbErr>
where
Self: Send + Sized,
{
self.paginate(db, 1).num_items().await
}
}

impl<'db, C, S> PaginatorTrait<'db, C> for Selector<S>
Expand Down
84 changes: 81 additions & 3 deletions src/executor/select_ext.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
use crate::{
ConnectionTrait, DbErr, EntityTrait, Select, SelectFive, SelectFour, SelectSix, SelectThree,
SelectTwo, Selector, SelectorRaw, SelectorTrait, Topology,
ConnectionTrait, DbErr, EntityTrait, QueryTrait, Select, SelectFive, SelectFour, SelectSix,
SelectThree, SelectTwo, Selector, SelectorRaw, SelectorTrait, Statement, Topology,
};
use sea_query::{Expr, SelectStatement};

// TODO: Move count here
#[async_trait::async_trait]
/// Helper trait for selectors with convenient methods
pub trait SelectExt {
Expand All @@ -23,6 +22,15 @@ pub trait SelectExt {
}
}

#[async_trait::async_trait]
/// Helper trait for counting rows selected by a query.
pub trait CountTrait {
/// Count the number of rows selected by this query.
async fn count(self, db: &impl ConnectionTrait) -> Result<u64, DbErr>
where
Self: Send + Sized;
}

fn into_exists_query(mut stmt: SelectStatement) -> SelectStatement {
stmt.clear_selects();
// Expr::Custom has fewer branches, but this may not have any significant impact on performance.
Expand All @@ -33,6 +41,37 @@ fn into_exists_query(mut stmt: SelectStatement) -> SelectStatement {
stmt
}

fn build_count_query(stmt: SelectStatement) -> SelectStatement {
SelectStatement::new()
.expr(Expr::cust("COUNT(*) AS count"))
.from_subquery(stmt, "sub_query")
.to_owned()
Comment thread
Huliiiiii marked this conversation as resolved.
}

fn build_count_query_raw(stmt: Statement) -> SelectStatement {
let sub_query_sql = stmt.sql.trim().trim_end_matches(';').trim();
let count_sql = format!("COUNT(*) AS count FROM ({sub_query_sql}) AS sub_query");

let mut query = SelectStatement::new();
query.expr(if let Some(values) = stmt.values {
Expr::cust_with_values(count_sql, values.0)
} else {
Expr::cust(count_sql)
});
query
}

async fn exec_count<C>(db: &C, stmt: SelectStatement) -> Result<u64, DbErr>
where
C: ConnectionTrait,
{
let result = match db.query_one(&stmt).await? {
Some(res) => res,
None => return Ok(0),
};
Ok(result.try_get::<i64>("", "count")? as u64)
}

impl<S> SelectExt for Selector<S>
where
S: SelectorTrait,
Expand All @@ -42,6 +81,19 @@ where
}
}

#[async_trait::async_trait]
impl<S> CountTrait for Selector<S>
where
S: SelectorTrait,
{
async fn count(self, db: &impl ConnectionTrait) -> Result<u64, DbErr>
where
Self: Send + Sized,
{
exec_count(db, build_count_query(self.query)).await
}
}

#[async_trait::async_trait]
impl<S> SelectExt for SelectorRaw<S>
where
Expand All @@ -62,6 +114,19 @@ where
}
}

#[async_trait::async_trait]
impl<S> CountTrait for SelectorRaw<S>
where
S: SelectorTrait,
{
async fn count(self, db: &impl ConnectionTrait) -> Result<u64, DbErr>
where
Self: Send + Sized,
{
exec_count(db, build_count_query_raw(self.stmt)).await
}
}

impl<E> SelectExt for Select<E>
where
E: EntityTrait,
Expand Down Expand Up @@ -135,6 +200,19 @@ where
}
}

#[async_trait::async_trait]
impl<T> CountTrait for T
where
T: QueryTrait<QueryStatement = SelectStatement> + Send,
{
async fn count(self, db: &impl ConnectionTrait) -> Result<u64, DbErr>
where
Self: Send + Sized,
{
exec_count(db, build_count_query(self.into_query())).await
}
}

#[cfg(test)]
mod tests {
use super::SelectExt;
Expand Down
2 changes: 1 addition & 1 deletion src/query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ pub use update::*;
pub(crate) use util::*;

pub use crate::{
ConnectionTrait, CursorTrait, InsertResult, PaginatorTrait, SelectExt, Statement,
ConnectionTrait, CountTrait, CursorTrait, InsertResult, PaginatorTrait, SelectExt, Statement,
TransactionTrait, UpdateResult, Value, Values,
};
pub use sea_query::ExprTrait;
Expand Down
Loading