Skip to content

Commit

Permalink
Add upsert support and allow arbitrary queries in INSERT, UPDATE and …
Browse files Browse the repository at this point in the history
…DELETE

This PR makes several changes to our "manipulation" functions (`insert`, `update`, `delete`).

Firstly, we now support `ON CONFLICT DO UPDATE`, aka "upsert".

Secondly, we now allow the insertion of arbitrary queries (not just static `VALUES`). `values` recovers the old behaviour.

Thirdly, the `WHERE` clauses of `UPDATE` and `DELETE` are now also arbitrary queries (allowing joining against other tables and the use of functions like `absent` and `present`). `where_` recovers the old behaviour.

In terms of generating the SQL to implement these features, it was unfortunately significantly less work to roll our own here than to add this upstream to Opaleye proper, because it would have required more refactoring than I felt comfortable doing.
  • Loading branch information
shane-circuithub committed Jul 7, 2021
1 parent 032242f commit 9747d8b
Show file tree
Hide file tree
Showing 18 changed files with 452 additions and 237 deletions.
10 changes: 5 additions & 5 deletions docs/concepts/insert.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ using ``insert``. ``Insert`` takes:
The rows to insert. These are the same as ``into``, but in the ``Expr``
context. You can construct rows from their individual fields::

rows = [ MyTable { myTableA = lit "A", myTableB = lit 42 }
rows = values [ MyTable { myTableA = lit "A", myTableB = lit 42 }

or you can use ``lit`` on a table value in the ``Result`` context::

rows = [ lit MyTable { myTableA = "A", myTableB = 42 }
rows = values [ lit MyTable { myTableA = "A", myTableB = 42 }

``onConflict``
What should happen if an insert clashes with rows that already exist. This
Expand Down Expand Up @@ -99,7 +99,7 @@ For example, if we are inserting orders, we might want the order ids returned::

insert Insert
{ into = orderSchema
, rows = [ order ]
, rows = values [ order ]
, onConflict = Abort
, returning = Projection orderId
}
Expand All @@ -119,7 +119,7 @@ construct the ``DEFAULT`` expression::

insert Insert
{ into = orderSchema
, rows = [ Order { orderId = unsafeDefault, ... } ]
, rows = values [ Order { orderId = unsafeDefault, ... } ]
, onConflict = Abort
, returning = Projection orderId
}
Expand Down Expand Up @@ -148,7 +148,7 @@ them in Rel8, rather than in your database schema.

insert Insert
{ into = orderSchema
, rows = [ Order { orderId = nextval "order_id_seq", ... } ]
, rows = values [ Order { orderId = nextval "order_id_seq", ... } ]
, onConflict = Abort
, returning = Projection orderId
}
3 changes: 3 additions & 0 deletions rel8.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ library
, contravariant
, hasql ^>= 1.4.5.1
, opaleye ^>= 0.7.3.0
, pretty
, profunctors
, scientific
, semialign
Expand Down Expand Up @@ -140,8 +141,10 @@ library
Rel8.Statement.Insert
Rel8.Statement.Returning
Rel8.Statement.Select
Rel8.Statement.SQL
Rel8.Statement.Update
Rel8.Statement.View
Rel8.Statement.Where

Rel8.Table
Rel8.Table.ADT
Expand Down
8 changes: 8 additions & 0 deletions src/Rel8.hs
Original file line number Diff line number Diff line change
Expand Up @@ -258,16 +258,22 @@ module Rel8
-- ** @INSERT@
, Insert(..)
, OnConflict(..)
, Upsert(..)
, insert
, unsafeDefault
, showInsert

-- ** @DELETE@
, Delete(..)
, delete
, showDelete

-- ** @UPDATE@
, Update(..)
, Set
, Where
, update
, showUpdate

-- ** @.. RETURNING@
, Returning(..)
Expand Down Expand Up @@ -334,8 +340,10 @@ import Rel8.Statement.Delete
import Rel8.Statement.Insert
import Rel8.Statement.Returning
import Rel8.Statement.Select
import Rel8.Statement.SQL
import Rel8.Statement.Update
import Rel8.Statement.View
import Rel8.Statement.Where
import Rel8.Table
import Rel8.Table.ADT
import Rel8.Table.Aggregate
Expand Down
64 changes: 5 additions & 59 deletions src/Rel8/Query/SQL.hs
Original file line number Diff line number Diff line change
@@ -1,75 +1,21 @@
{-# language FlexibleContexts #-}
{-# language TypeFamilies #-}
{-# language ViewPatterns #-}
{-# language MonoLocalBinds #-}

module Rel8.Query.SQL
( showQuery
, sqlForQuery, sqlForQueryWithNames
)
where

-- base
import Data.Foldable ( fold )
import Data.Functor.Const ( Const( Const ), getConst )
import Data.Void ( Void )
import Prelude

-- opaleye
import qualified Opaleye.Internal.HaskellDB.Sql as Opaleye
import qualified Opaleye.Internal.PrimQuery as Opaleye
import qualified Opaleye.Internal.Print as Opaleye
import qualified Opaleye.Internal.Optimize as Opaleye
import qualified Opaleye.Internal.QueryArr as Opaleye hiding ( Select )
import qualified Opaleye.Internal.Sql as Opaleye

-- rel8
import Rel8.Expr ( Expr )
import Rel8.Expr.Opaleye ( toPrimExpr )
import Rel8.Query ( Query )
import Rel8.Query.Opaleye ( toOpaleye )
import Rel8.Schema.Name ( Name( Name ), Selects )
import Rel8.Schema.HTable ( htabulateA, hfield )
import Rel8.Table ( Table, toColumns )
import Rel8.Table.Cols ( toCols )
import Rel8.Table.Name ( namesFromLabels )
import Rel8.Table.Opaleye ( castTable )
import Rel8.Statement.Select ( ppSelect )
import Rel8.Table ( Table )


-- | Convert a query to a 'String' containing the query as a @SELECT@
-- statement.
-- | Convert a 'Query' to a 'String' containing a @SELECT@ statement.
showQuery :: Table Expr a => Query a -> String
showQuery = fold . sqlForQuery


sqlForQuery :: Table Expr a
=> Query a -> Maybe String
sqlForQuery = sqlForQueryWithNames namesFromLabels . fmap toCols


sqlForQueryWithNames :: Selects names exprs
=> names -> Query exprs -> Maybe String
sqlForQueryWithNames names query =
show . Opaleye.ppSql . selectFrom names exprs <$> optimize primQuery
where
(exprs, primQuery, _) =
Opaleye.runSimpleQueryArrStart (toOpaleye query) ()


optimize :: Opaleye.PrimQuery' a -> Maybe (Opaleye.PrimQuery' Void)
optimize = Opaleye.removeEmpty . Opaleye.optimize


selectFrom :: Selects names exprs
=> names -> exprs -> Opaleye.PrimQuery' Void -> Opaleye.Select
selectFrom (toColumns -> names) (toColumns . castTable -> exprs) query =
Opaleye.SelectFrom $ Opaleye.newSelect
{ Opaleye.attrs = Opaleye.SelectAttrs attributes
, Opaleye.tables = Opaleye.oneTable select
}
where
select = Opaleye.foldPrimQuery Opaleye.sqlQueryGenerator query
attributes = getConst $ htabulateA $ \field -> case hfield names field of
Name name -> case hfield exprs field of
expr -> Const (pure (makeAttr name (toPrimExpr expr)))
makeAttr label expr =
(Opaleye.sqlExpr expr, Just (Opaleye.SqlColumn label))
showQuery = foldMap show . ppSelect
12 changes: 12 additions & 0 deletions src/Rel8/Schema/Name.hs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
module Rel8.Schema.Name
( Name(..)
, Selects
, ppColumn
)
where

Expand All @@ -22,6 +23,13 @@ import Data.Kind ( Constraint, Type )
import Data.String ( IsString )
import Prelude

-- opaleye
import qualified Opaleye.Internal.HaskellDB.Sql as Opaleye
import qualified Opaleye.Internal.HaskellDB.Sql.Print as Opaleye

-- pretty
import Text.PrettyPrint ( Doc )

-- rel8
import Rel8.Expr ( Expr )
import qualified Rel8.Schema.Kind as K
Expand Down Expand Up @@ -63,3 +71,7 @@ instance Sql DBType a => Table Name (Name a) where
type Selects :: Type -> Type -> Constraint
class Transposes Name Expr names exprs => Selects names exprs
instance Transposes Name Expr names exprs => Selects names exprs


ppColumn :: String -> Doc
ppColumn = Opaleye.ppSqlExpr . Opaleye.ColumnSqlExpr . Opaleye.SqlColumn
17 changes: 17 additions & 0 deletions src/Rel8/Schema/Table.hs
Original file line number Diff line number Diff line change
@@ -1,14 +1,24 @@
{-# language DeriveFunctor #-}
{-# language DerivingStrategies #-}
{-# language DisambiguateRecordFields #-}
{-# language NamedFieldPuns #-}

module Rel8.Schema.Table
( TableSchema(..)
, ppTable
)
where

-- base
import Prelude

-- opaleye
import qualified Opaleye.Internal.HaskellDB.Sql as Opaleye
import qualified Opaleye.Internal.HaskellDB.Sql.Print as Opaleye

-- pretty
import Text.PrettyPrint ( Doc )


-- | The schema for a table. This is used to specify the name and schema that a
-- table belongs to (the @FROM@ part of a SQL query), along with the schema of
Expand All @@ -27,3 +37,10 @@ data TableSchema names = TableSchema
-- data type here, parameterized by the 'Rel8.ColumnSchema.ColumnSchema' functor.
}
deriving stock Functor


ppTable :: TableSchema a -> Doc
ppTable TableSchema {name, schema} = Opaleye.ppTable Opaleye.SqlTable
{ sqlTableSchemaName = schema
, sqlTableName = name
}
47 changes: 24 additions & 23 deletions src/Rel8/Statement/Delete.hs
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
{-# language DuplicateRecordFields #-}
{-# language GADTs #-}
{-# language NamedFieldPuns #-}
{-# language RecordWildCards #-}
{-# language ScopedTypeVariables #-}
{-# language StandaloneKindSignatures #-}
{-# language TypeApplications #-}

module Rel8.Statement.Delete
( Delete(..)
, delete
, ppDelete
)
where

Expand All @@ -23,17 +25,14 @@ import qualified Hasql.Encoders as Hasql
import qualified Hasql.Session as Hasql
import qualified Hasql.Statement as Hasql

-- opaleye
import qualified Opaleye.Internal.Manipulation as Opaleye
-- pretty
import Text.PrettyPrint ( Doc, (<+>), ($$), text )

-- rel8
import Rel8.Expr ( Expr )
import Rel8.Expr.Opaleye ( toColumn, toPrimExpr )
import Rel8.Schema.Name ( Selects )
import Rel8.Schema.Table ( TableSchema )
import Rel8.Statement.Returning ( Returning( NumberOfRowsAffected, Projection ) )
import Rel8.Table.Cols ( fromCols, toCols )
import Rel8.Table.Opaleye ( castTable, table, unpackspec )
import Rel8.Schema.Table ( TableSchema, ppTable )
import Rel8.Statement.Returning ( Returning(..), ppReturning )
import Rel8.Statement.Where ( Where, ppWhere )
import Rel8.Table.Serialize ( Serializable, parse )

-- text
Expand All @@ -47,45 +46,47 @@ data Delete a where
Delete :: Selects names exprs =>
{ from :: TableSchema names
-- ^ Which table to delete from.
, deleteWhere :: exprs -> Expr Bool
, deleteWhere :: Where exprs
-- ^ Which rows should be selected for deletion.
, returning :: Returning names a
-- ^ What to return from the @DELETE@ statement.
}
-> Delete a


ppDelete :: Delete a -> Maybe Doc
ppDelete Delete {..} = do
condition <- ppWhere from deleteWhere
pure $ text "DELETE FROM" <+> ppTable from
$$ condition
$$ ppReturning from returning


-- | Run a @DELETE@ statement.
delete :: Connection -> Delete a -> IO a
delete c Delete {from, deleteWhere, returning} =
case returning of
NumberOfRowsAffected -> Hasql.run session c >>= either throwIO pure
delete c d@Delete {returning} =
case (show <$> ppDelete d, returning) of
(Nothing, NumberOfRowsAffected) -> pure 0
(Nothing, Projection _) -> pure []
(Just sql, NumberOfRowsAffected) ->
Hasql.run session c >>= either throwIO pure
where
session = Hasql.statement () statement
statement = Hasql.Statement bytes params decode prepare
bytes = encodeUtf8 $ Text.pack sql
params = Hasql.noParams
decode = Hasql.rowsAffected
prepare = False
sql = Opaleye.arrangeDeleteSql from' where'
where
from' = table $ toCols <$> from
where' = toColumn . toPrimExpr . deleteWhere . fromCols

Projection project -> Hasql.run session c >>= either throwIO pure
(Just sql, Projection project) ->
Hasql.run session c >>= either throwIO pure
where
session = Hasql.statement () statement
statement = Hasql.Statement bytes params decode prepare
bytes = encodeUtf8 $ Text.pack sql
params = Hasql.noParams
decode = decoder project
prepare = False
sql =
Opaleye.arrangeDeleteReturningSql unpackspec from' where' project'
where
from' = table $ toCols <$> from
where' = toColumn . toPrimExpr . deleteWhere . fromCols
project' = castTable . toCols . project . fromCols
where
decoder :: forall exprs projection a. Serializable projection a
=> (exprs -> projection) -> Hasql.Result [a]
Expand Down
Loading

0 comments on commit 9747d8b

Please sign in to comment.