From 9e927c19daf0d6db6456cad6cc91e92824103a1e Mon Sep 17 00:00:00 2001 From: Shane O'Brien Date: Fri, 25 Jun 2021 10:30:07 +0100 Subject: [PATCH] Add upsert support and allow arbitrary queries in INSERT, UPDATE and DELETE This PR makes several changes to our "manipulation" functions (`insert`, `update`, `delete`). Firstly, we now support `ON CONFLICT DO UPDATE`, aka "upsert". Secondly, we now allow the insertion of arbitrary queries (not just static `VALUES`). `values` recovers the old behaviour. Thirdly, our `Update` and `Delete` statements now support `FROM` and `USING` clauses respectively, allowing joining against other tables. Fourthly, `Returning` is now an `Applicative`, which means you can say `returning = pure ()` if you don't care about the number of rows affected. In terms of generating the SQL to implement these features, it was unfortunately significantly less work to roll our own here than to add this upstream to Opaleye proper, because it would have required more refactoring than I felt comfortable doing. --- docs/concepts/insert.rst | 45 ++++++++--- rel8.cabal | 6 ++ src/Rel8.hs | 12 ++- src/Rel8/Query/SQL.hs | 64 ++------------- src/Rel8/Schema/Name.hs | 12 +++ src/Rel8/Schema/Table.hs | 17 ++++ src/Rel8/Statement/Delete.hs | 75 ++++++++---------- src/Rel8/Statement/Insert.hs | 111 +++++++++++--------------- src/Rel8/Statement/OnConflict.hs | 105 ++++++++++++++++++++++++ src/Rel8/Statement/Returning.hs | 132 ++++++++++++++++++++++++++++--- src/Rel8/Statement/SQL.hs | 29 +++++++ src/Rel8/Statement/Select.hs | 97 ++++++++++++++++++----- src/Rel8/Statement/Set.hs | 45 +++++++++++ src/Rel8/Statement/Update.hs | 87 +++++++++----------- src/Rel8/Statement/Using.hs | 36 +++++++++ src/Rel8/Statement/View.hs | 40 +++++----- src/Rel8/Statement/Where.hs | 37 +++++++++ src/Rel8/Table/Name.hs | 20 +---- src/Rel8/Table/Opaleye.hs | 40 +++++++++- src/Rel8/Table/Projection.hs | 8 +- tests/Main.hs | 39 ++++----- 21 files changed, 738 insertions(+), 319 deletions(-) create mode 100644 src/Rel8/Statement/OnConflict.hs create mode 100644 src/Rel8/Statement/SQL.hs create mode 100644 src/Rel8/Statement/Set.hs create mode 100644 src/Rel8/Statement/Using.hs create mode 100644 src/Rel8/Statement/Where.hs diff --git a/docs/concepts/insert.rst b/docs/concepts/insert.rst index dbba1c62..b1b6808e 100644 --- a/docs/concepts/insert.rst +++ b/docs/concepts/insert.rst @@ -21,9 +21,15 @@ using ``delete``. ``Delete`` takes: ``from`` The ``TableSchema`` for the table to delete rows from. +``using`` + This is a simple ``Query`` that forms the ``USING`` clause of the ``UPDATE`` + statement. This can be used to join against other tables, and the results + can be referenced in the ``deleteWhere`` parameters. + ``deleteWhere`` The ``WHERE`` clause of the ``DELETE`` statement. This is a function that - takes a single ``Expr`` table as input. + takes two inputs: the result of the ``using`` query, and the current value + of the row. ``returning`` What to return - see :ref:`returning`. @@ -37,16 +43,22 @@ using ``update``. ``Update`` takes: ``target`` The ``TableSchema`` for the table to update rows in. -``updateWhere`` - The ``WHERE`` clause of the ``UPDATE`` statement. This is a function that - takes a single ``Expr`` table as input. +``from`` + This is a simple ``Query`` that forms the ``FROM`` clause of the ``UPDATE`` + statement. This can be used to join against other tables, and the results + can be referenced in the ``set`` and ``updateWhere`` parameters. ``set`` A row to row transformation function, indicating how to update selected rows. This function takes rows of the same shape as ``target`` but in the ``Expr`` context. One way to write this function is to use record update syntax:: - set = \row -> row { rowName = "new name" } + set = \from row -> row { rowName = "new name" } + +``updateWhere`` + The ``WHERE`` clause of the ``UPDATE`` statement. This is a function that + takes two inputs: the result of the ``from`` query, and the current value of + the row. ``returning`` What to return - see :ref:`returning`. @@ -64,11 +76,11 @@ using ``insert``. ``Insert`` takes: The rows to insert. These are the same as ``into``, but in the ``Expr`` context. You can construct rows from their individual fields:: - rows = [ MyTable { myTableA = lit "A", myTableB = lit 42 } + rows = values [ MyTable { myTableA = lit "A", myTableB = lit 42 } or you can use ``lit`` on a table value in the ``Result`` context:: - rows = [ lit MyTable { myTableA = "A", myTableB = 42 } + rows = values [ lit MyTable { myTableA = "A", myTableB = 42 } ``onConflict`` What should happen if an insert clashes with rows that already exist. This @@ -80,6 +92,10 @@ using ``insert``. ``Insert`` takes: ``DoNothing`` PostgreSQL should not insert the duplicate rows. + ``DoUpdate`` + PostgreSQL should instead try to update any existing rows that conflict + with rows proposed for insertion. + ``returning`` What to return - see :ref:`returning`. @@ -99,11 +115,20 @@ For example, if we are inserting orders, we might want the order ids returned:: insert Insert { into = orderSchema - , rows = [ order ] + , rows = values [ order ] , onConflict = Abort , returning = Projection orderId } +If we don't want to return anything, we can use ``pure ()``:: + + insert Insert + { into = orderSchema + , rows = values [ order ] + , onConflict = Abort + , returning = pure () + } + Default values -------------- @@ -119,7 +144,7 @@ construct the ``DEFAULT`` expression:: insert Insert { into = orderSchema - , rows = [ Order { orderId = unsafeDefault, ... } ] + , rows = values [ Order { orderId = unsafeDefault, ... } ] , onConflict = Abort , returning = Projection orderId } @@ -148,7 +173,7 @@ them in Rel8, rather than in your database schema. insert Insert { into = orderSchema - , rows = [ Order { orderId = nextval "order_id_seq", ... } ] + , rows = values [ Order { orderId = nextval "order_id_seq", ... } ] , onConflict = Abort , returning = Projection orderId } diff --git a/rel8.cabal b/rel8.cabal index 6b5d743a..2a35497c 100644 --- a/rel8.cabal +++ b/rel8.cabal @@ -27,6 +27,7 @@ library , contravariant , hasql ^>= 1.4.5.1 , opaleye ^>= 0.7.3.0 + , pretty , profunctors , scientific , semialign @@ -138,10 +139,15 @@ library Rel8.Statement.Delete Rel8.Statement.Insert + Rel8.Statement.OnConflict Rel8.Statement.Returning Rel8.Statement.Select + Rel8.Statement.Set + Rel8.Statement.SQL Rel8.Statement.Update + Rel8.Statement.Using Rel8.Statement.View + Rel8.Statement.Where Rel8.Table Rel8.Table.ADT diff --git a/src/Rel8.hs b/src/Rel8.hs index 8f1eec08..832e209a 100644 --- a/src/Rel8.hs +++ b/src/Rel8.hs @@ -258,19 +258,25 @@ module Rel8 -- ** @INSERT@ , Insert(..) , OnConflict(..) + , Upsert(..) , insert , unsafeDefault + , showInsert -- ** @DELETE@ , Delete(..) , delete + , showDelete -- ** @UPDATE@ , Update(..) + , Set + , Where , update + , showUpdate -- ** @.. RETURNING@ - , Returning(..) + , Returning( NumberOfRowsAffected, Projection ) -- ** @CREATE VIEW@ , createView @@ -332,10 +338,14 @@ import Rel8.Schema.Result ( Result ) import Rel8.Schema.Table import Rel8.Statement.Delete import Rel8.Statement.Insert +import Rel8.Statement.OnConflict import Rel8.Statement.Returning import Rel8.Statement.Select +import Rel8.Statement.Set +import Rel8.Statement.SQL import Rel8.Statement.Update import Rel8.Statement.View +import Rel8.Statement.Where import Rel8.Table import Rel8.Table.ADT import Rel8.Table.Aggregate diff --git a/src/Rel8/Query/SQL.hs b/src/Rel8/Query/SQL.hs index 68eb04c4..9f168b7a 100644 --- a/src/Rel8/Query/SQL.hs +++ b/src/Rel8/Query/SQL.hs @@ -1,75 +1,21 @@ {-# language FlexibleContexts #-} -{-# language TypeFamilies #-} -{-# language ViewPatterns #-} +{-# language MonoLocalBinds #-} module Rel8.Query.SQL ( showQuery - , sqlForQuery, sqlForQueryWithNames ) where -- base -import Data.Foldable ( fold ) -import Data.Functor.Const ( Const( Const ), getConst ) -import Data.Void ( Void ) import Prelude --- opaleye -import qualified Opaleye.Internal.HaskellDB.Sql as Opaleye -import qualified Opaleye.Internal.PrimQuery as Opaleye -import qualified Opaleye.Internal.Print as Opaleye -import qualified Opaleye.Internal.Optimize as Opaleye -import qualified Opaleye.Internal.QueryArr as Opaleye hiding ( Select ) -import qualified Opaleye.Internal.Sql as Opaleye - -- rel8 import Rel8.Expr ( Expr ) -import Rel8.Expr.Opaleye ( toPrimExpr ) import Rel8.Query ( Query ) -import Rel8.Query.Opaleye ( toOpaleye ) -import Rel8.Schema.Name ( Name( Name ), Selects ) -import Rel8.Schema.HTable ( htabulateA, hfield ) -import Rel8.Table ( Table, toColumns ) -import Rel8.Table.Cols ( toCols ) -import Rel8.Table.Name ( namesFromLabels ) -import Rel8.Table.Opaleye ( castTable ) +import Rel8.Statement.Select ( ppSelect ) +import Rel8.Table ( Table ) --- | Convert a query to a 'String' containing the query as a @SELECT@ --- statement. +-- | Convert a 'Query' to a 'String' containing a @SELECT@ statement. showQuery :: Table Expr a => Query a -> String -showQuery = fold . sqlForQuery - - -sqlForQuery :: Table Expr a - => Query a -> Maybe String -sqlForQuery = sqlForQueryWithNames namesFromLabels . fmap toCols - - -sqlForQueryWithNames :: Selects names exprs - => names -> Query exprs -> Maybe String -sqlForQueryWithNames names query = - show . Opaleye.ppSql . selectFrom names exprs <$> optimize primQuery - where - (exprs, primQuery, _) = - Opaleye.runSimpleQueryArrStart (toOpaleye query) () - - -optimize :: Opaleye.PrimQuery' a -> Maybe (Opaleye.PrimQuery' Void) -optimize = Opaleye.removeEmpty . Opaleye.optimize - - -selectFrom :: Selects names exprs - => names -> exprs -> Opaleye.PrimQuery' Void -> Opaleye.Select -selectFrom (toColumns -> names) (toColumns . castTable -> exprs) query = - Opaleye.SelectFrom $ Opaleye.newSelect - { Opaleye.attrs = Opaleye.SelectAttrs attributes - , Opaleye.tables = Opaleye.oneTable select - } - where - select = Opaleye.foldPrimQuery Opaleye.sqlQueryGenerator query - attributes = getConst $ htabulateA $ \field -> case hfield names field of - Name name -> case hfield exprs field of - expr -> Const (pure (makeAttr name (toPrimExpr expr))) - makeAttr label expr = - (Opaleye.sqlExpr expr, Just (Opaleye.SqlColumn label)) +showQuery = foldMap show . ppSelect diff --git a/src/Rel8/Schema/Name.hs b/src/Rel8/Schema/Name.hs index 0488d3ff..1ceac682 100644 --- a/src/Rel8/Schema/Name.hs +++ b/src/Rel8/Schema/Name.hs @@ -13,6 +13,7 @@ module Rel8.Schema.Name ( Name(..) , Selects + , ppColumn ) where @@ -22,6 +23,13 @@ import Data.Kind ( Constraint, Type ) import Data.String ( IsString ) import Prelude +-- opaleye +import qualified Opaleye.Internal.HaskellDB.Sql as Opaleye +import qualified Opaleye.Internal.HaskellDB.Sql.Print as Opaleye + +-- pretty +import Text.PrettyPrint ( Doc ) + -- rel8 import Rel8.Expr ( Expr ) import qualified Rel8.Schema.Kind as K @@ -63,3 +71,7 @@ instance Sql DBType a => Table Name (Name a) where type Selects :: Type -> Type -> Constraint class Transposes Name Expr names exprs => Selects names exprs instance Transposes Name Expr names exprs => Selects names exprs + + +ppColumn :: String -> Doc +ppColumn = Opaleye.ppSqlExpr . Opaleye.ColumnSqlExpr . Opaleye.SqlColumn diff --git a/src/Rel8/Schema/Table.hs b/src/Rel8/Schema/Table.hs index 22515812..cb32b69d 100644 --- a/src/Rel8/Schema/Table.hs +++ b/src/Rel8/Schema/Table.hs @@ -1,14 +1,24 @@ {-# language DeriveFunctor #-} {-# language DerivingStrategies #-} +{-# language DisambiguateRecordFields #-} +{-# language NamedFieldPuns #-} module Rel8.Schema.Table ( TableSchema(..) + , ppTable ) where -- base import Prelude +-- opaleye +import qualified Opaleye.Internal.HaskellDB.Sql as Opaleye +import qualified Opaleye.Internal.HaskellDB.Sql.Print as Opaleye + +-- pretty +import Text.PrettyPrint ( Doc ) + -- | The schema for a table. This is used to specify the name and schema that a -- table belongs to (the @FROM@ part of a SQL query), along with the schema of @@ -27,3 +37,10 @@ data TableSchema names = TableSchema -- data type here, parameterized by the 'Rel8.ColumnSchema.ColumnSchema' functor. } deriving stock Functor + + +ppTable :: TableSchema a -> Doc +ppTable TableSchema {name, schema} = Opaleye.ppTable Opaleye.SqlTable + { sqlTableSchemaName = schema + , sqlTableName = name + } diff --git a/src/Rel8/Statement/Delete.hs b/src/Rel8/Statement/Delete.hs index e1b9e390..1efe01fb 100644 --- a/src/Rel8/Statement/Delete.hs +++ b/src/Rel8/Statement/Delete.hs @@ -1,13 +1,15 @@ {-# language DuplicateRecordFields #-} {-# language GADTs #-} {-# language NamedFieldPuns #-} -{-# language ScopedTypeVariables #-} +{-# language RankNTypes #-} +{-# language RecordWildCards #-} {-# language StandaloneKindSignatures #-} -{-# language TypeApplications #-} +{-# language StrictData #-} module Rel8.Statement.Delete ( Delete(..) , delete + , ppDelete ) where @@ -18,23 +20,23 @@ import Prelude -- hasql import Hasql.Connection ( Connection ) -import qualified Hasql.Decoders as Hasql import qualified Hasql.Encoders as Hasql import qualified Hasql.Session as Hasql import qualified Hasql.Statement as Hasql --- opaleye -import qualified Opaleye.Internal.Manipulation as Opaleye +-- pretty +import Text.PrettyPrint ( Doc, (<+>), ($$), text ) -- rel8 -import Rel8.Expr ( Expr ) -import Rel8.Expr.Opaleye ( toColumn, toPrimExpr ) +import Rel8.Query ( Query ) import Rel8.Schema.Name ( Selects ) -import Rel8.Schema.Table ( TableSchema ) -import Rel8.Statement.Returning ( Returning( NumberOfRowsAffected, Projection ) ) -import Rel8.Table.Cols ( fromCols, toCols ) -import Rel8.Table.Opaleye ( castTable, table, unpackspec ) -import Rel8.Table.Serialize ( Serializable, parse ) +import Rel8.Schema.Table ( TableSchema, ppTable ) +import Rel8.Statement.Returning + ( Returning + , decodeReturning, emptyReturning, ppReturning + ) +import Rel8.Statement.Using ( ppUsing ) +import Rel8.Statement.Where ( Where, ppWhere ) -- text import qualified Data.Text as Text @@ -47,7 +49,10 @@ data Delete a where Delete :: Selects names exprs => { from :: TableSchema names -- ^ Which table to delete from. - , deleteWhere :: exprs -> Expr Bool + , using :: Query using + -- ^ @USING@ clause — this can be used to join against other tables, + -- and its results can be referenced in the @WHERE@ clause + , deleteWhere :: using -> Where exprs -- ^ Which rows should be selected for deletion. , returning :: Returning names a -- ^ What to return from the @DELETE@ statement. @@ -55,38 +60,26 @@ data Delete a where -> Delete a --- | Run a @DELETE@ statement. -delete :: Connection -> Delete a -> IO a -delete c Delete {from, deleteWhere, returning} = - case returning of - NumberOfRowsAffected -> Hasql.run session c >>= either throwIO pure - where - session = Hasql.statement () statement - statement = Hasql.Statement bytes params decode prepare - bytes = encodeUtf8 $ Text.pack sql - params = Hasql.noParams - decode = Hasql.rowsAffected - prepare = False - sql = Opaleye.arrangeDeleteSql from' where' - where - from' = table $ toCols <$> from - where' = toColumn . toPrimExpr . deleteWhere . fromCols +ppDelete :: Delete a -> Maybe Doc +ppDelete Delete {..} = do + (usingDoc, i) <- ppUsing using + pure $ text "DELETE FROM" <+> ppTable from + $$ usingDoc + $$ ppWhere from (deleteWhere i) + $$ ppReturning from returning - Projection project -> Hasql.run session c >>= either throwIO pure + +-- | Run a 'Delete' statement. +delete :: Connection -> Delete a -> IO a +delete connection d@Delete {returning} = + case show <$> ppDelete d of + Nothing -> pure (emptyReturning returning) + Just sql -> + Hasql.run session connection >>= either throwIO pure where session = Hasql.statement () statement statement = Hasql.Statement bytes params decode prepare bytes = encodeUtf8 $ Text.pack sql params = Hasql.noParams - decode = decoder project + decode = decodeReturning returning prepare = False - sql = - Opaleye.arrangeDeleteReturningSql unpackspec from' where' project' - where - from' = table $ toCols <$> from - where' = toColumn . toPrimExpr . deleteWhere . fromCols - project' = castTable . toCols . project . fromCols - where - decoder :: forall exprs projection a. Serializable projection a - => (exprs -> projection) -> Hasql.Result [a] - decoder _ = Hasql.rowList (parse @projection @a) diff --git a/src/Rel8/Statement/Insert.hs b/src/Rel8/Statement/Insert.hs index a26e88d3..d3eacd15 100644 --- a/src/Rel8/Statement/Insert.hs +++ b/src/Rel8/Statement/Insert.hs @@ -1,64 +1,65 @@ {-# language DuplicateRecordFields #-} +{-# language FlexibleContexts #-} {-# language GADTs #-} {-# language NamedFieldPuns #-} -{-# language ScopedTypeVariables #-} +{-# language RecordWildCards #-} {-# language StandaloneKindSignatures #-} -{-# language TypeApplications #-} +{-# language StrictData #-} module Rel8.Statement.Insert ( Insert(..) - , OnConflict(..) , insert + , ppInsert + , ppInto ) where -- base import Control.Exception ( throwIO ) -import Data.List.NonEmpty ( NonEmpty( (:|) ) ) +import Data.Foldable ( toList ) import Data.Kind ( Type ) import Prelude -- hasql import Hasql.Connection ( Connection ) -import qualified Hasql.Decoders as Hasql import qualified Hasql.Encoders as Hasql import qualified Hasql.Session as Hasql import qualified Hasql.Statement as Hasql -- opaleye -import qualified Opaleye.Internal.Manipulation as Opaleye -import qualified Opaleye.Manipulation as Opaleye +import qualified Opaleye.Internal.HaskellDB.Sql.Print as Opaleye + +-- pretty +import Text.PrettyPrint ( Doc, (<+>), ($$), parens, text ) -- rel8 -import Rel8.Schema.Name ( Selects ) -import Rel8.Schema.Table ( TableSchema ) -import Rel8.Statement.Returning ( Returning( Projection, NumberOfRowsAffected ) ) -import Rel8.Table.Cols ( fromCols, toCols ) -import Rel8.Table.Opaleye ( castTable, table, unpackspec ) -import Rel8.Table.Serialize ( Serializable, parse ) +import Rel8.Query ( Query ) +import Rel8.Schema.Name ( Name, Selects, ppColumn ) +import Rel8.Schema.Table ( TableSchema(..), ppTable ) +import Rel8.Statement.OnConflict ( OnConflict, ppOnConflict ) +import Rel8.Statement.Returning + ( Returning + , decodeReturning, emptyReturning, ppReturning + ) +import Rel8.Statement.Select ( ppSelect ) +import Rel8.Table ( Table ) +import Rel8.Table.Name ( showNames ) -- text import qualified Data.Text as Text ( pack ) import Data.Text.Encoding ( encodeUtf8 ) --- | @OnConflict@ allows you to add an @ON CONFLICT@ clause to an @INSERT@ --- statement. -type OnConflict :: Type -data OnConflict - = Abort -- ^ @ON CONFLICT ABORT@ - | DoNothing -- ^ @ON CONFLICT DO NOTHING@ - - -- | The constituent parts of a SQL @INSERT@ statement. type Insert :: Type -> Type data Insert a where Insert :: Selects names exprs => { into :: TableSchema names -- ^ Which table to insert into. - , rows :: [exprs] - -- ^ The rows to insert. - , onConflict :: OnConflict + , rows :: Query exprs + -- ^ The rows to insert. This can be an arbitrary query — use + -- 'Rel8.values' insert a static list of rows. + , onConflict :: OnConflict names -- ^ What to do if the inserted rows conflict with data already in the -- table. , returning :: Returning names a @@ -67,52 +68,32 @@ data Insert a where -> Insert a --- | Run an @INSERT@ statement -insert :: Connection -> Insert a -> IO a -insert c Insert {into, rows, onConflict, returning} = - case (rows, returning) of - ([], NumberOfRowsAffected) -> pure 0 - ([], Projection _) -> pure [] +ppInsert :: Insert a -> Maybe Doc +ppInsert Insert {..} = do + rows' <- ppSelect rows + pure $ text "INSERT INTO" <+> ppInto into + $$ rows' + $$ ppOnConflict into onConflict + $$ ppReturning into returning - (x:xs, NumberOfRowsAffected) -> Hasql.run session c >>= either throwIO pure - where - session = Hasql.statement () statement - statement = Hasql.Statement bytes params decode prepare - bytes = encodeUtf8 $ Text.pack sql - params = Hasql.noParams - decode = Hasql.rowsAffected - prepare = False - sql = Opaleye.arrangeInsertManySql into' rows' onConflict' - where - into' = table $ toCols <$> into - rows' = toCols <$> x :| xs - (x:xs, Projection project) -> Hasql.run session c >>= either throwIO pure +ppInto :: Table Name a => TableSchema a -> Doc +ppInto table@TableSchema {columns} = + ppTable table <+> + parens (Opaleye.commaV ppColumn (toList (showNames columns))) + + +-- | Run an 'Insert' statement. +insert :: Connection -> Insert a -> IO a +insert connection i@Insert {returning} = + case show <$> ppInsert i of + Nothing -> pure (emptyReturning returning) + Just sql -> + Hasql.run session connection >>= either throwIO pure where session = Hasql.statement () statement statement = Hasql.Statement bytes params decode prepare bytes = encodeUtf8 $ Text.pack sql params = Hasql.noParams - decode = decoder project + decode = decodeReturning returning prepare = False - sql = - Opaleye.arrangeInsertManyReturningSql - unpackspec - into' - rows' - project' - onConflict' - where - into' = table $ toCols <$> into - rows' = toCols <$> x :| xs - project' = castTable . toCols . project . fromCols - - where - onConflict' = - case onConflict of - DoNothing -> Just Opaleye.DoNothing - Abort -> Nothing - - decoder :: forall exprs projection a. Serializable projection a - => (exprs -> projection) -> Hasql.Result [a] - decoder _ = Hasql.rowList (parse @projection @a) diff --git a/src/Rel8/Statement/OnConflict.hs b/src/Rel8/Statement/OnConflict.hs new file mode 100644 index 00000000..22cd1e3d --- /dev/null +++ b/src/Rel8/Statement/OnConflict.hs @@ -0,0 +1,105 @@ +{-# language DuplicateRecordFields #-} +{-# language FlexibleContexts #-} +{-# language GADTs #-} +{-# language LambdaCase #-} +{-# language NamedFieldPuns #-} +{-# language RecordWildCards #-} +{-# language StandaloneKindSignatures #-} +{-# language StrictData #-} + +module Rel8.Statement.OnConflict + ( OnConflict(..) + , Upsert(..) + , ppOnConflict + ) +where + +-- base +import Data.Foldable ( toList ) +import Data.Kind ( Type ) +import Prelude + +-- opaleye +import qualified Opaleye.Internal.HaskellDB.Sql.Print as Opaleye + +-- pretty +import Text.PrettyPrint ( Doc, (<+>), ($$), text ) + +-- rel8 +import Rel8.Schema.Name ( Name, Selects, ppColumn ) +import Rel8.Schema.Table ( TableSchema(..) ) +import Rel8.Statement.Set ( Set, ppSet ) +import Rel8.Statement.Where ( Where, ppWhere ) +import Rel8.Table ( Table, toColumns ) +import Rel8.Table.Cols ( Cols( Cols ) ) +import Rel8.Table.Name ( showNames ) +import Rel8.Table.Opaleye ( attributes ) +import Rel8.Table.Projection ( Projecting, Projection, apply ) + + +-- | 'OnConflict' represents the @ON CONFLICT@ clause of an @INSERT@ +-- statement. This specifies what ought to happen when one or more of the +-- rows proposed for insertion conflict with an existing row in the table. +type OnConflict :: Type -> Type +data OnConflict names + = Abort + -- ^ Abort the transaction if there are conflicting rows (Postgres' default) + | DoNothing + -- ^ @ON CONFLICT DO NOTHING@ + | DoUpdate (Upsert names) + -- ^ @ON CONFLICT DO UPDATE@ + + +-- | The @ON CONFLICT (...) DO UPDATE@ clause of an @INSERT@ statement, also +-- known as \"upsert\". +-- +-- When an existing row conflicts with a row proposed for insertion, +-- @ON CONFLICT DO UPDATE@ allows you to instead update this existing row. The +-- conflicting row proposed for insertion is then \"excluded\", but its values +-- can still be referenced from the @SET@ and @WHERE@ clauses of the @UPDATE@ +-- statement. +-- +-- Upsert in Postgres requires an explicit set of \"conflict targets\" — the +-- set of columns comprising the @UNIQUE@ index from conflicts with which we +-- would like to recover. +type Upsert :: Type -> Type +data Upsert names where + Upsert :: (Selects names exprs, Projecting names index, excluded ~ exprs) => + { index :: Projection names index + -- ^ The set of conflict targets, projected from the set of columns for + -- the whole table + , set :: Set excluded exprs + -- ^ How to update each selected row. + , updateWhere :: excluded -> Where exprs + -- ^ Which rows to select for update. + } + -> Upsert names + + +ppOnConflict :: TableSchema names -> OnConflict names -> Doc +ppOnConflict schema = \case + Abort -> mempty + DoNothing -> text "ON CONFLICT DO NOTHING" + DoUpdate upsert -> ppUpsert schema upsert + + +ppUpsert :: TableSchema names -> Upsert names -> Doc +ppUpsert schema@TableSchema {columns} Upsert {..} = + text "ON CONFLICT" <+> + ppIndex schema index <+> + text "DO UPDATE" $$ + ppSet schema excluded set $$ + ppWhere schema (updateWhere excluded) + where + excluded = attributes TableSchema + { schema = Nothing + , name = "excluded" + , columns + } + + +ppIndex :: (Table Name names, Projecting names index) + => TableSchema names -> Projection names index -> Doc +ppIndex TableSchema {columns} index = + Opaleye.commaV ppColumn $ toList $ + showNames $ Cols $ apply index $ toColumns columns diff --git a/src/Rel8/Statement/Returning.hs b/src/Rel8/Statement/Returning.hs index 15865a93..4c1fb2b8 100644 --- a/src/Rel8/Statement/Returning.hs +++ b/src/Rel8/Statement/Returning.hs @@ -1,29 +1,141 @@ +{-# language DerivingStrategies #-} {-# language GADTs #-} +{-# language GeneralizedNewtypeDeriving #-} +{-# language LambdaCase #-} +{-# language NamedFieldPuns #-} +{-# language PatternSynonyms #-} +{-# language RankNTypes #-} +{-# language ScopedTypeVariables #-} {-# language StandaloneKindSignatures #-} +{-# language StrictData #-} +{-# language TypeApplications #-} module Rel8.Statement.Returning - ( Returning(..) + ( Returning( NumberOfRowsAffected, Projection ) + + , decodeReturning + , emptyReturning + , ppReturning ) where -- base +import Control.Applicative ( liftA2 ) +import Data.Foldable ( toList ) import Data.Int ( Int64 ) import Data.Kind ( Type ) -import Prelude () +import Data.List.NonEmpty ( NonEmpty ) +import Prelude + +-- hasql +import qualified Hasql.Decoders as Hasql + +-- opaleye +import qualified Opaleye.Internal.HaskellDB.PrimQuery as Opaleye +import qualified Opaleye.Internal.HaskellDB.Sql.Print as Opaleye +import qualified Opaleye.Internal.Sql as Opaleye + +-- pretty +import Text.PrettyPrint ( Doc, (<+>), text ) -- rel8 import Rel8.Schema.Name ( Selects ) -import Rel8.Table.Serialize ( Serializable ) +import Rel8.Schema.Table ( TableSchema(..) ) +import Rel8.Table.Opaleye ( castTable, exprs, view ) +import Rel8.Table.Serialize ( Serializable, parse ) + +-- semigropuoids +import Data.Functor.Apply ( Apply, (<.>) ) --- | @INSERT@, @UPDATE@ and @DELETE@ all support returning either the number of --- rows affected, or the actual rows modified. 'Projection' allows you to --- project out of these returned rows, which can be useful if you want to log --- exactly which rows were deleted, or to view a generated id (for example, if --- using a column with an autoincrementing counter as a default value). type Returning :: Type -> Type -> Type data Returning names a where + Pure :: a -> Returning names a + Ap :: Returning names (a -> b) -> Returning names a -> Returning names b + + -- | 'projection' allows you to project out of the affected rows, which can + -- be useful if you want to log exactly which rows were deleted, or to view + -- a generated id (for example, if using a column with an autoincrementing + -- counter via 'Rel8.nextval'). NumberOfRowsAffected :: Returning names Int64 - Projection :: (Selects names exprs, Serializable projection a) - => (exprs -> projection) + + -- | Return the number of rows affected. + Projection :: (Selects names exprs, Serializable returning a) + => (exprs -> returning) -> Returning names [a] + + +instance Functor (Returning names) where + fmap f = \case + Pure a -> Pure (f a) + Ap g a -> Ap (fmap (f .) g) a + m -> Ap (Pure f) m + + +instance Apply (Returning names) where + (<.>) = Ap + + +instance Applicative (Returning names) where + pure = Pure + (<*>) = Ap + + +projections :: () + => TableSchema names -> Returning names a -> Maybe (NonEmpty Opaleye.PrimExpr) +projections schema@TableSchema {columns} = \case + Pure _ -> Nothing + Ap f a -> projections schema f <> projections schema a + NumberOfRowsAffected -> Nothing + Projection f -> Just (exprs (castTable (f (view columns)))) + + +runReturning :: () + => ((Int64 -> a) -> r) + -> (forall x. Hasql.Row x -> ([x] -> a) -> r) + -> Returning names a + -> r +runReturning rowCount rowList = \case + Pure a -> rowCount (const a) + Ap fs as -> + runReturning + (\withCount -> + runReturning + (\withCount' -> rowCount (withCount <*> withCount')) + (\decoder -> rowList decoder . liftA2 withCount length64) + as) + (\decoder withRows -> + runReturning + (\withCount -> rowList decoder $ withRows <*> withCount . length64) + (\decoder' withRows' -> + rowList (liftA2 (,) decoder decoder') $ + withRows <$> fmap fst <*> withRows' . fmap snd) + as) + fs + NumberOfRowsAffected -> rowCount id + Projection (_ :: exprs -> returning) -> rowList decoder' id + where + decoder' = parse @returning + where + length64 :: Foldable f => f x -> Int64 + length64 = fromIntegral . length + + +decodeReturning :: Returning names a -> Hasql.Result a +decodeReturning = runReturning + (<$> Hasql.rowsAffected) + (\decoder withRows -> withRows <$> Hasql.rowList decoder) + + +emptyReturning :: Returning names a -> a +emptyReturning = + runReturning (\withCount -> withCount 0) (\_ withRows -> withRows []) + + +ppReturning :: TableSchema names -> Returning names a -> Doc +ppReturning schema returning = case projections schema returning of + Nothing -> mempty + Just columns -> + text "RETURNING" <+> Opaleye.commaV Opaleye.ppSqlExpr (toList sqlExprs) + where + sqlExprs = Opaleye.sqlExpr <$> columns diff --git a/src/Rel8/Statement/SQL.hs b/src/Rel8/Statement/SQL.hs new file mode 100644 index 00000000..627a0c75 --- /dev/null +++ b/src/Rel8/Statement/SQL.hs @@ -0,0 +1,29 @@ +module Rel8.Statement.SQL + ( showDelete + , showInsert + , showUpdate + ) +where + +-- base +import Prelude + +-- rel8 +import Rel8.Statement.Delete ( Delete, ppDelete ) +import Rel8.Statement.Insert ( Insert, ppInsert ) +import Rel8.Statement.Update ( Update, ppUpdate ) + + +-- | Convert a 'Delete' to a 'String' containing a @DELETE@ statement. +showDelete :: Delete a -> String +showDelete = foldMap show . ppDelete + + +-- | Convert an 'Insert' to a 'String' containing an @INSERT@ statement. +showInsert :: Insert a -> String +showInsert = foldMap show . ppInsert + + +-- | Convert an 'Update' to a 'String' containing an @UPDATE@ statement. +showUpdate :: Update a -> String +showUpdate = foldMap show . ppUpdate diff --git a/src/Rel8/Statement/Select.hs b/src/Rel8/Statement/Select.hs index fac26018..73bf41a8 100644 --- a/src/Rel8/Statement/Select.hs +++ b/src/Rel8/Statement/Select.hs @@ -1,15 +1,22 @@ +{-# language DeriveTraversable #-} +{-# language DerivingStrategies #-} +{-# language FlexibleContexts #-} {-# language MonoLocalBinds #-} {-# language ScopedTypeVariables #-} {-# language TypeApplications #-} module Rel8.Statement.Select ( select - , selectWithNames + , ppSelect + + , Optimized(..) + , ppPrimSelect ) where -- base import Control.Exception ( throwIO ) +import Data.Void ( Void ) import Prelude -- hasql @@ -19,10 +26,26 @@ import qualified Hasql.Encoders as Hasql import qualified Hasql.Session as Hasql import qualified Hasql.Statement as Hasql +-- opaleye +import qualified Opaleye.Internal.HaskellDB.Sql as Opaleye +import qualified Opaleye.Internal.PrimQuery as Opaleye +import qualified Opaleye.Internal.Print as Opaleye +import qualified Opaleye.Internal.Optimize as Opaleye +import qualified Opaleye.Internal.QueryArr as Opaleye hiding ( Select ) +import qualified Opaleye.Internal.Sql as Opaleye + +-- pretty +import Text.PrettyPrint ( Doc ) + -- rel8 +import Rel8.Expr ( Expr ) import Rel8.Query ( Query ) -import Rel8.Query.SQL ( sqlForQuery, sqlForQueryWithNames ) +import Rel8.Query.Opaleye ( toOpaleye ) import Rel8.Schema.Name ( Selects ) +import Rel8.Table ( Table ) +import Rel8.Table.Cols ( toCols ) +import Rel8.Table.Name ( namesFromLabels ) +import Rel8.Table.Opaleye ( castTable, exprsWithNames ) import Rel8.Table.Serialize ( Serializable, parse ) -- text @@ -30,12 +53,12 @@ import qualified Data.Text as Text import Data.Text.Encoding ( encodeUtf8 ) --- | Run a @SELECT@ query, returning all rows. +-- | Run a @SELECT@ statement, returning all rows. select :: forall exprs a. Serializable exprs a => Connection -> Query exprs -> IO [a] -select c query = case sqlForQuery query of +select c query = case ppSelect query of Nothing -> pure [] - Just sql -> Hasql.run session c >>= either throwIO pure + Just doc -> Hasql.run session c >>= either throwIO pure where session = Hasql.statement () statement statement = Hasql.Statement bytes params decode prepare @@ -43,20 +66,54 @@ select c query = case sqlForQuery query of params = Hasql.noParams decode = Hasql.rowList (parse @exprs @a) prepare = False + sql = show doc -selectWithNames :: forall exprs a names. - ( Selects names exprs - , Serializable exprs a - ) - => Connection -> names -> Query exprs -> IO [a] -selectWithNames c names query = case sqlForQueryWithNames names query of - Nothing -> pure [] - Just sql -> Hasql.run session c >>= either throwIO pure - where - session = Hasql.statement () statement - statement = Hasql.Statement bytes params decode prepare - bytes = encodeUtf8 (Text.pack sql) - params = Hasql.noParams - decode = Hasql.rowList (parse @exprs @a) - prepare = False +ppSelect :: Table Expr a => Query a -> Maybe Doc +ppSelect query = do + primQuery' <- case optimize primQuery of + Empty -> Nothing + Unit -> Just Opaleye.Unit + Optimized primQuery' -> Just primQuery' + pure $ Opaleye.ppSql $ primSelectWith names (toCols exprs) primQuery' + where + names = namesFromLabels + (exprs, primQuery, _) = + Opaleye.runSimpleQueryArrStart (toOpaleye query) () + + +ppPrimSelect :: Query a -> (Optimized Doc, a) +ppPrimSelect query = + (Opaleye.ppSql . primSelect <$> optimize primQuery, a) + where + (a, primQuery, _) = + Opaleye.runSimpleQueryArrStart (toOpaleye query) () + + +data Optimized a = Empty | Unit | Optimized a + deriving stock (Functor, Foldable, Traversable) + + +optimize :: Opaleye.PrimQuery' a -> Optimized (Opaleye.PrimQuery' Void) +optimize query = case Opaleye.removeEmpty (Opaleye.optimize query) of + Nothing -> Empty + Just Opaleye.Unit -> Unit + Just query' -> Optimized query' + + +primSelect :: Opaleye.PrimQuery' Void -> Opaleye.Select +primSelect = Opaleye.foldPrimQuery Opaleye.sqlQueryGenerator + + +primSelectWith :: Selects names exprs + => names -> exprs -> Opaleye.PrimQuery' Void -> Opaleye.Select +primSelectWith names exprs query = + Opaleye.SelectFrom $ Opaleye.newSelect + { Opaleye.attrs = Opaleye.SelectAttrs attrs + , Opaleye.tables = Opaleye.oneTable (primSelect query) + } + where + attrs = makeAttr <$> exprsWithNames names (castTable exprs) + where + makeAttr (label, expr) = + (Opaleye.sqlExpr expr, Just (Opaleye.SqlColumn label)) diff --git a/src/Rel8/Statement/Set.hs b/src/Rel8/Statement/Set.hs new file mode 100644 index 00000000..85223d09 --- /dev/null +++ b/src/Rel8/Statement/Set.hs @@ -0,0 +1,45 @@ +{-# language MonoLocalBinds #-} +{-# language NamedFieldPuns #-} + +module Rel8.Statement.Set + ( Set + , ppSet + ) +where + +-- base +import Data.Foldable ( toList ) +import Prelude () + +-- opaleye +import qualified Opaleye.Internal.HaskellDB.Sql.Print as Opaleye +import qualified Opaleye.Internal.Sql as Opaleye + +-- pretty +import Text.PrettyPrint ( Doc, (<+>), equals, text ) + +-- rel8 +import Rel8.Schema.Name ( Selects, ppColumn ) +import Rel8.Schema.Table ( TableSchema(..) ) +import Rel8.Table.Opaleye ( attributes, exprsWithNames ) + + +-- | The @SET@ part of an @UPDATE@ (or @ON CONFLICT DO UPDATE@) statement. +-- +-- The @expr -> expr@ function takes the current value of the existing row and +-- returns the updated values for the row. +-- +-- The additional parameter @from@ is either the result of the query executed +-- in the @FROM@ of an @UPDATE@ stateent, or the @excluded@ row that couldn't +-- be inserted in an @ON CONFLICT DO UPDATE@ statement. +type Set from expr = from -> expr -> expr + + +ppSet :: Selects names exprs + => TableSchema names -> from -> Set from exprs -> Doc +ppSet schema@TableSchema {columns} from f = + text "SET" <+> Opaleye.commaV ppAssign (toList assigns) + where + assigns = exprsWithNames columns (f from (attributes schema)) + ppAssign (column, expr) = + ppColumn column <+> equals <+> Opaleye.ppSqlExpr (Opaleye.sqlExpr expr) diff --git a/src/Rel8/Statement/Update.hs b/src/Rel8/Statement/Update.hs index 12c615eb..863b5519 100644 --- a/src/Rel8/Statement/Update.hs +++ b/src/Rel8/Statement/Update.hs @@ -1,12 +1,14 @@ +{-# language DuplicateRecordFields #-} {-# language GADTs #-} {-# language NamedFieldPuns #-} -{-# language ScopedTypeVariables #-} +{-# language RecordWildCards #-} {-# language StandaloneKindSignatures #-} -{-# language TypeApplications #-} +{-# language StrictData #-} module Rel8.Statement.Update ( Update(..) , update + , ppUpdate ) where @@ -17,23 +19,24 @@ import Prelude -- hasql import Hasql.Connection ( Connection ) -import qualified Hasql.Decoders as Hasql import qualified Hasql.Encoders as Hasql import qualified Hasql.Session as Hasql import qualified Hasql.Statement as Hasql --- opaleye -import qualified Opaleye.Internal.Manipulation as Opaleye +-- pretty +import Text.PrettyPrint ( Doc, (<+>), ($$), text ) -- rel8 -import Rel8.Expr ( Expr ) -import Rel8.Expr.Opaleye ( toColumn, toPrimExpr ) +import Rel8.Query ( Query ) import Rel8.Schema.Name ( Selects ) -import Rel8.Schema.Table ( TableSchema ) -import Rel8.Statement.Returning ( Returning( Projection, NumberOfRowsAffected ) ) -import Rel8.Table.Cols ( fromCols, toCols ) -import Rel8.Table.Opaleye ( castTable, table, unpackspec ) -import Rel8.Table.Serialize ( Serializable, parse ) +import Rel8.Schema.Table ( TableSchema(..), ppTable ) +import Rel8.Statement.Returning + ( Returning + , decodeReturning, emptyReturning, ppReturning + ) +import Rel8.Statement.Set ( Set, ppSet ) +import Rel8.Statement.Using ( ppFrom ) +import Rel8.Statement.Where ( Where, ppWhere ) -- text import qualified Data.Text as Text @@ -46,9 +49,12 @@ data Update a where Update :: Selects names exprs => { target :: TableSchema names -- ^ Which table to update. - , set :: exprs -> exprs + , from :: Query from + -- ^ @FROM@ clause — this can be used to join against other tables, + -- and its results can be referenced in the @SET@ and @WHERE@ clauses. + , set :: Set from exprs -- ^ How to update each selected row. - , updateWhere :: exprs -> Expr Bool + , updateWhere :: from -> Where exprs -- ^ Which rows to select for update. , returning :: Returning names a -- ^ What to return from the @UPDATE@ statement. @@ -56,46 +62,29 @@ data Update a where -> Update a +ppUpdate :: Update a -> Maybe Doc +ppUpdate Update {..} = do + (fromDoc, i) <- ppFrom from + pure $ + text "UPDATE" <+> + ppTable target $$ + ppSet target i set $$ + fromDoc $$ + ppWhere target (updateWhere i) $$ + ppReturning target returning + + -- | Run an @UPDATE@ statement. update :: Connection -> Update a -> IO a -update c Update {target, set, updateWhere, returning} = - case returning of - NumberOfRowsAffected -> Hasql.run session c >>= either throwIO pure +update connection u@Update {returning} = + case show <$> ppUpdate u of + Nothing -> pure (emptyReturning returning) + Just sql -> + Hasql.run session connection >>= either throwIO pure where session = Hasql.statement () statement statement = Hasql.Statement bytes params decode prepare bytes = encodeUtf8 $ Text.pack sql params = Hasql.noParams - decode = Hasql.rowsAffected + decode = decodeReturning returning prepare = False - sql = Opaleye.arrangeUpdateSql target' set' where' - where - target' = table $ toCols <$> target - set' = toCols . set . fromCols - where' = toColumn . toPrimExpr . updateWhere . fromCols - - Projection project -> Hasql.run session c >>= either throwIO pure - where - session = Hasql.statement () statement - statement = Hasql.Statement bytes params decode prepare - bytes = encodeUtf8 $ Text.pack sql - params = Hasql.noParams - decode = decoder project - prepare = False - sql = - Opaleye.arrangeUpdateReturningSql - unpackspec - target' - set' - where' - project' - where - target' = table $ toCols <$> target - set' = toCols . set . fromCols - where' = toColumn . toPrimExpr . updateWhere . fromCols - project' = castTable . toCols . project . fromCols - - where - decoder :: forall exprs projection a. Serializable projection a - => (exprs -> projection) -> Hasql.Result [a] - decoder _ = Hasql.rowList (parse @projection @a) diff --git a/src/Rel8/Statement/Using.hs b/src/Rel8/Statement/Using.hs new file mode 100644 index 00000000..c8dc00cd --- /dev/null +++ b/src/Rel8/Statement/Using.hs @@ -0,0 +1,36 @@ +module Rel8.Statement.Using + ( ppFrom + , ppUsing + ) +where + +-- base +import Prelude + +-- pretty +import Text.PrettyPrint ( Doc, (<+>), parens, text ) + +-- rel8 +import Rel8.Query ( Query ) +import Rel8.Schema.Table ( TableSchema(..), ppTable ) +import Rel8.Statement.Select ( Optimized(..), ppPrimSelect ) + + +ppFrom :: Query a -> Maybe (Doc, a) +ppFrom = ppJoin "FROM" + + +ppUsing :: Query a -> Maybe (Doc, a) +ppUsing = ppJoin "USING" + + +ppJoin :: String -> Query a -> Maybe (Doc, a) +ppJoin clause join = do + doc <- case ofrom of + Empty -> Nothing + Unit -> Just mempty + Optimized doc -> Just $ text clause <+> parens doc <+> ppTable alias + pure (doc, a) + where + alias = TableSchema {name = "T1", schema = Nothing, columns = ()} + (ofrom, a) = ppPrimSelect join diff --git a/src/Rel8/Statement/View.hs b/src/Rel8/Statement/View.hs index 85bea13a..afe802f5 100644 --- a/src/Rel8/Statement/View.hs +++ b/src/Rel8/Statement/View.hs @@ -8,7 +8,6 @@ where -- base import Control.Exception ( throwIO ) -import Control.Monad ( (>=>) ) import Data.Foldable ( fold ) import Data.Maybe ( fromMaybe ) import Prelude @@ -22,11 +21,15 @@ import qualified Hasql.Statement as Hasql -- rel8 import Rel8.Query ( Query ) -import Rel8.Query.SQL ( sqlForQueryWithNames ) import Rel8.Schema.Name ( Selects ) -import Rel8.Schema.Table ( TableSchema( TableSchema ) ) +import Rel8.Schema.Table ( TableSchema ) +import Rel8.Statement.Insert ( ppInto ) +import Rel8.Statement.Select ( ppSelect ) import Rel8.Table.Alternative ( emptyTable ) +-- pretty +import Text.PrettyPrint ( Doc, (<+>), ($$), text ) + -- text import qualified Data.Text as Text import Data.Text.Encoding ( encodeUtf8 ) @@ -36,9 +39,9 @@ import Data.Text.Encoding ( encodeUtf8 ) -- statement that will save the given query as a view. This can be useful if -- you want to share Rel8 queries with other applications. createView :: Selects names exprs - => TableSchema names -> Query exprs -> Connection -> IO () -createView (TableSchema name mschema names) query = - Hasql.run session >=> either throwIO pure + => Connection -> TableSchema names -> Query exprs -> IO () +createView connection schema query = + Hasql.run session connection >>= either throwIO pure where session = Hasql.statement () statement statement = Hasql.Statement bytes params decode prepare @@ -46,18 +49,15 @@ createView (TableSchema name mschema names) query = params = Hasql.noParams decode = Hasql.noResult prepare = False - sql = "CREATE VIEW " <> title <> " AS " <> select - where - title = case mschema of - Nothing -> quote name - Just schema -> quote schema <> "." <> quote name - select = fromMaybe fallback $ sqlForQueryWithNames names query - where - fallback = fold $ sqlForQueryWithNames names emptyTable - - -quote :: String -> String -quote string = "\"" <> concatMap go string <> "\"" + sql = show (ppCreateView schema query) + + +ppCreateView :: Selects names exprs + => TableSchema names -> Query exprs -> Doc +ppCreateView schema query = + text "CREATE VIEW" <+> + ppInto schema $$ + text "AS" <+> + fromMaybe fallback (ppSelect query) where - go '"' = "\"\"" - go c = [c] + fallback = fold (ppSelect (emptyTable `asTypeOf` query)) diff --git a/src/Rel8/Statement/Where.hs b/src/Rel8/Statement/Where.hs new file mode 100644 index 00000000..77460c08 --- /dev/null +++ b/src/Rel8/Statement/Where.hs @@ -0,0 +1,37 @@ +{-# language MonoLocalBinds #-} + +module Rel8.Statement.Where + ( Where + , ppWhere + ) +where + +-- base +import Prelude + +-- opaleye +import qualified Opaleye.Internal.HaskellDB.Sql.Print as Opaleye +import qualified Opaleye.Internal.Sql as Opaleye + +-- pretty +import Text.PrettyPrint ( Doc, (<+>), text ) + +-- rel8 +import Rel8.Expr ( Expr ) +import Rel8.Expr.Opaleye ( toPrimExpr ) +import Rel8.Schema.Name ( Selects ) +import Rel8.Schema.Table ( TableSchema ) +import Rel8.Table.Opaleye ( attributes ) + + +-- | The @WHERE@ condition in a @DELETE@ or @UPDATE@ (or @ON CONFLICT DO +-- UPDATE@) statement. This takes the value of the existing row and decides +-- whether or not it should be modified. +type Where expr = expr -> Expr Bool + + +ppWhere :: Selects names exprs => TableSchema names -> Where exprs -> Doc +ppWhere schema where_ = text "WHERE" <+> ppExpr condition + where + ppExpr = Opaleye.ppSqlExpr . Opaleye.sqlExpr . toPrimExpr + condition = where_ (attributes schema) diff --git a/src/Rel8/Table/Name.hs b/src/Rel8/Table/Name.hs index 65410c53..51c35382 100644 --- a/src/Rel8/Table/Name.hs +++ b/src/Rel8/Table/Name.hs @@ -12,7 +12,6 @@ module Rel8.Table.Name ( namesFromLabels , namesFromLabelsWith - , showExprs , showLabels , showNames ) @@ -25,17 +24,11 @@ import Data.List.NonEmpty ( NonEmpty, intersperse, nonEmpty ) import Data.Maybe ( fromMaybe ) import Prelude --- opaleye -import qualified Opaleye.Internal.HaskellDB.PrimQuery as Opaleye - -- rel8 -import Rel8.Expr ( Expr ) -import Rel8.Expr.Opaleye ( toPrimExpr ) import Rel8.Schema.HTable ( htabulate, htabulateA, hfield, hspecs ) import Rel8.Schema.Name ( Name( Name ) ) import Rel8.Schema.Spec ( Spec(..) ) import Rel8.Table ( Table(..) ) -import Rel8.Table.Cols ( Cols( Cols ) ) -- | Construct a table in the 'Name' context containing the names of all @@ -70,23 +63,16 @@ namesFromLabelsWith f = fromColumns $ htabulate $ \field -> Spec {labels} -> Name (f (renderLabels labels)) -showExprs :: Table Expr a => a -> [(String, Opaleye.PrimExpr)] -showExprs as = case (namesFromLabels, toColumns as) of - (Cols names, exprs) -> getConst $ htabulateA $ \field -> - case (hfield names field, hfield exprs field) of - (Name name, expr) -> Const [(name, toPrimExpr expr)] - - showLabels :: forall a. Table (Context a) a => a -> [NonEmpty String] showLabels _ = getConst $ htabulateA @(Columns a) $ \field -> case hfield hspecs field of - Spec {labels} -> Const [renderLabels labels] + Spec {labels} -> Const (pure (renderLabels labels)) -showNames :: forall a. Table Name a => a -> [String] +showNames :: forall a. Table Name a => a -> NonEmpty String showNames (toColumns -> names) = getConst $ htabulateA @(Columns a) $ \field -> case hfield names field of - Name name -> Const [name] + Name name -> Const (pure name) renderLabels :: [String] -> NonEmpty String diff --git a/src/Rel8/Table/Opaleye.hs b/src/Rel8/Table/Opaleye.hs index b21e90d6..55bfb9b0 100644 --- a/src/Rel8/Table/Opaleye.hs +++ b/src/Rel8/Table/Opaleye.hs @@ -2,7 +2,6 @@ {-# language DataKinds #-} {-# language DisambiguateRecordFields #-} {-# language FlexibleContexts #-} -{-# language LambdaCase #-} {-# language NamedFieldPuns #-} {-# language RankNTypes #-} {-# language TypeFamilies #-} @@ -10,17 +9,23 @@ module Rel8.Table.Opaleye ( aggregator + , attributes , binaryspec , distinctspec + , exprs + , exprsWithNames , table , tableFields , unpackspec , valuesspec + , view , castTable ) where -- base +import Data.Functor.Const ( Const( Const ), getConst ) +import Data.List.NonEmpty ( NonEmpty ) import Prelude hiding ( undefined ) -- opaleye @@ -46,9 +51,9 @@ import Rel8.Expr.Opaleye , scastExpr ) import Rel8.Schema.HTable ( htabulateA, hfield, htraverse, hspecs, htabulate ) -import Rel8.Schema.Name ( Name( Name ), Selects ) +import Rel8.Schema.Name ( Name( Name ), Selects, ppColumn ) import Rel8.Schema.Spec ( Spec(..) ) -import Rel8.Schema.Table ( TableSchema(..) ) +import Rel8.Schema.Table ( TableSchema(..), ppTable ) import Rel8.Table ( Table, fromColumns, toColumns ) import Rel8.Table.Undefined ( undefined ) @@ -64,6 +69,14 @@ aggregator = Opaleye.Aggregator $ Opaleye.PackMap $ \f aggregates -> inner f () +attributes :: Selects names exprs => TableSchema names -> exprs +attributes schema@TableSchema {columns} = fromColumns $ htabulate $ \field -> + case hfield (toColumns columns) field of + Name column -> fromPrimExpr $ Opaleye.ConstExpr $ + Opaleye.OtherLit $ + show (ppTable schema) <> "." <> show (ppColumn column) + + binaryspec :: Table Expr a => Opaleye.Binaryspec a a binaryspec = Opaleye.Binaryspec $ Opaleye.PackMap $ \f (as, bs) -> fmap fromColumns $ unwrapApplicative $ htabulateA $ \field -> @@ -82,7 +95,20 @@ distinctspec = toColumns -table ::Selects names exprs => TableSchema names -> Opaleye.Table exprs exprs +exprs :: Table Expr a => a -> NonEmpty Opaleye.PrimExpr +exprs (toColumns -> as) = getConst $ htabulateA $ \field -> + case hfield as field of + expr -> Const (pure (toPrimExpr expr)) + + +exprsWithNames :: Selects names exprs + => names -> exprs -> NonEmpty (String, Opaleye.PrimExpr) +exprsWithNames names as = getConst $ htabulateA $ \field -> + case (hfield (toColumns names) field, hfield (toColumns as) field) of + (Name name, expr) -> Const (pure (name, toPrimExpr expr)) + + +table :: Selects names exprs => TableSchema names -> Opaleye.Table exprs exprs table (TableSchema name schema columns) = case schema of Nothing -> Opaleye.Table name (tableFields columns) @@ -115,6 +141,12 @@ valuesspec :: Table Expr a => Opaleye.ValuesspecSafe a a valuesspec = Opaleye.ValuesspecSafe (toPackMap undefined) unpackspec +view :: Selects names exprs => names -> exprs +view columns = fromColumns $ htabulate $ \field -> + case hfield (toColumns columns) field of + Name column -> fromPrimExpr $ Opaleye.BaseTableAttrExpr column + + toPackMap :: Table Expr a => a -> Opaleye.PackMap Opaleye.PrimExpr Opaleye.PrimExpr () a toPackMap as = Opaleye.PackMap $ \f () -> diff --git a/src/Rel8/Table/Projection.hs b/src/Rel8/Table/Projection.hs index d9da446c..344c1b9a 100644 --- a/src/Rel8/Table/Projection.hs +++ b/src/Rel8/Table/Projection.hs @@ -29,13 +29,13 @@ import Rel8.Table.Transpose ( Transposes ) -- usable 'Projection'. type Projecting :: Type -> Type -> Constraint class - ( Transposes (Field a) (Context a) a (Transpose (Field a) a) - , Transposes (Field a) (Context a) b (Transpose (Field a) b) + ( Transposes (Context a) (Field a) a (Transpose (Field a) a) + , Transposes (Context a) (Field a) b (Transpose (Field a) b) ) => Projecting a b instance - ( Transposes (Field a) (Context a) a (Transpose (Field a) a) - , Transposes (Field a) (Context a) b (Transpose (Field b) b) + ( Transposes (Context a) (Field a) a (Transpose (Field a) a) + , Transposes (Context a) (Field a) b (Transpose (Field b) b) ) => Projecting a b diff --git a/tests/Main.hs b/tests/Main.hs index dd714db7..e2cf86fc 100644 --- a/tests/Main.hs +++ b/tests/Main.hs @@ -185,14 +185,13 @@ testSelectTestTable = databasePropertyTest "Can SELECT TestTable" \transaction - rows <- forAll $ Gen.list (Range.linear 0 10) genTestTable transaction \connection -> do - void do - liftIO $ Rel8.insert connection - Rel8.Insert - { into = testTableSchema - , rows = map Rel8.lit rows - , onConflict = Rel8.DoNothing - , returning = Rel8.NumberOfRowsAffected - } + liftIO $ Rel8.insert connection + Rel8.Insert + { into = testTableSchema + , rows = Rel8.values $ map Rel8.lit rows + , onConflict = Rel8.DoNothing + , returning = pure () + } selected <- liftIO $ Rel8.select connection do Rel8.each testTableSchema @@ -599,18 +598,19 @@ testUpdate = databasePropertyTest "Can UPDATE TestTable" \transaction -> do rows <- forAll $ Gen.map (Range.linear 0 5) $ liftA2 (,) genTestTable genTestTable transaction \connection -> do - void $ liftIO $ Rel8.insert connection + liftIO $ Rel8.insert connection Rel8.Insert { into = testTableSchema - , rows = map Rel8.lit $ Map.keys rows + , rows = Rel8.values $ map Rel8.lit $ Map.keys rows , onConflict = Rel8.DoNothing - , returning = Rel8.NumberOfRowsAffected + , returning = pure () } - void $ liftIO $ Rel8.update connection + liftIO $ Rel8.update connection Rel8.Update { target = testTableSchema - , set = \r -> + , from = pure () + , set = \_ r -> let updates = map (bimap Rel8.lit Rel8.lit) $ Map.toList rows in foldl @@ -624,8 +624,8 @@ testUpdate = databasePropertyTest "Can UPDATE TestTable" \transaction -> do ) r updates - , updateWhere = \_ -> Rel8.lit True - , returning = Rel8.NumberOfRowsAffected + , updateWhere = \_ _ -> Rel8.lit True + , returning = pure () } selected <- liftIO $ Rel8.select connection do @@ -643,19 +643,20 @@ testDelete = databasePropertyTest "Can DELETE TestTable" \transaction -> do rows <- forAll $ Gen.list (Range.linear 0 5) genTestTable transaction \connection -> do - void $ liftIO $ Rel8.insert connection + liftIO $ Rel8.insert connection Rel8.Insert { into = testTableSchema - , rows = map Rel8.lit rows + , rows = Rel8.values $ map Rel8.lit rows , onConflict = Rel8.DoNothing - , returning = Rel8.NumberOfRowsAffected + , returning = pure () } deleted <- liftIO $ Rel8.delete connection Rel8.Delete { from = testTableSchema - , deleteWhere = testTableColumn2 + , using = pure () + , deleteWhere = const testTableColumn2 , returning = Rel8.Projection id }