diff --git a/src/ast/function.rs b/src/ast/function.rs index 64198655c..06e6efac7 100644 --- a/src/ast/function.rs +++ b/src/ast/function.rs @@ -1,10 +1,14 @@ mod aggregate_to_string; +mod average; mod count; mod row_number; +mod sum; pub use aggregate_to_string::*; +pub use average::*; pub use count::*; pub use row_number::*; +pub use sum::*; use super::DatabaseValue; use std::borrow::Cow; @@ -22,6 +26,8 @@ pub(crate) enum FunctionType<'a> { RowNumber(RowNumber<'a>), Count(Count<'a>), AggregateToString(AggregateToString<'a>), + Average(Average<'a>), + Sum(Sum<'a>), } impl<'a> Function<'a> { @@ -58,4 +64,4 @@ macro_rules! function { ); } -function!(RowNumber, Count, AggregateToString); +function!(RowNumber, Count, AggregateToString, Average, Sum); diff --git a/src/ast/function/average.rs b/src/ast/function/average.rs new file mode 100644 index 000000000..d8e1f0c93 --- /dev/null +++ b/src/ast/function/average.rs @@ -0,0 +1,22 @@ +use crate::ast::Column; + +#[derive(Debug, Clone, PartialEq)] +pub struct Average<'a> { + pub(crate) column: Column<'a>, +} + +/// Calculates the average value of a numeric column. +/// +/// ```rust +/// # use quaint::{ast::*, visitor::{Visitor, Sqlite}}; +/// let query = Select::from_table("users").value(avg("age")); +/// let (sql, _) = Sqlite::build(query); +/// assert_eq!("SELECT AVG(`age`) FROM `users`", sql); +/// ``` +#[inline] +pub fn avg<'a, C>(col: C) -> Average<'a> +where + C: Into>, +{ + Average { column: col.into() } +} diff --git a/src/ast/function/sum.rs b/src/ast/function/sum.rs new file mode 100644 index 000000000..3b13eecdd --- /dev/null +++ b/src/ast/function/sum.rs @@ -0,0 +1,22 @@ +use crate::ast::Column; + +#[derive(Debug, Clone, PartialEq)] +pub struct Sum<'a> { + pub(crate) column: Column<'a>, +} + +/// Calculates the sum value of a numeric column. +/// +/// ```rust +/// # use quaint::{ast::*, visitor::{Visitor, Sqlite}}; +/// let query = Select::from_table("users").value(sum("age")); +/// let (sql, _) = Sqlite::build(query); +/// assert_eq!("SELECT SUM(`age`) FROM `users`", sql); +/// ``` +#[inline] +pub fn sum<'a, C>(col: C) -> Sum<'a> +where + C: Into>, +{ + Sum { column: col.into() } +} diff --git a/src/visitor.rs b/src/visitor.rs index 3ed529c85..439dc9b95 100644 --- a/src/visitor.rs +++ b/src/visitor.rs @@ -673,6 +673,14 @@ pub trait Visitor<'a> { FunctionType::AggregateToString(agg) => { self.visit_aggregate_to_string(agg.value.as_ref().clone())?; } + FunctionType::Average(avg) => { + self.write("AVG")?; + self.surround_with("(", ")", |ref mut s| s.visit_column(avg.column))?; + } + FunctionType::Sum(sum) => { + self.write("SUM")?; + self.surround_with("(", ")", |ref mut s| s.visit_column(sum.column))?; + } }; if let Some(alias) = fun.alias {