diff --git a/build.sc b/build.sc index b0de745..af8ba82 100644 --- a/build.sc +++ b/build.sc @@ -18,9 +18,15 @@ object squery extends CommonScalaModule with SqueryPublishModule { ivy"ch.qos.logback:logback-classic:1.4.6", ivy"org.scalameta::munit:1.0.0-M7", ivy"com.zaxxer:HikariCP:4.0.3", - ivy"org.postgresql:postgresql:42.5.4", ivy"org.testcontainers:testcontainers:1.17.6", - ivy"org.testcontainers:postgresql:1.17.6" + ivy"org.testcontainers:postgresql:1.17.6", + ivy"org.postgresql:postgresql:42.5.4", + ivy"org.testcontainers:mysql:1.19.3", + ivy"mysql:mysql-connector-java:8.0.33", + ivy"org.testcontainers:mariadb:1.19.3", + ivy"org.mariadb.jdbc:mariadb-java-client:3.3.2" + + ) } } diff --git a/docs/src/files/howtos/HowToPage.scala b/docs/src/files/howtos/HowToPage.scala index d26fa52..f61a3e2 100644 --- a/docs/src/files/howtos/HowToPage.scala +++ b/docs/src/files/howtos/HowToPage.scala @@ -22,7 +22,7 @@ trait HowToPage extends DocPage { InterpolateValues, InterpolateQueries, DynamicQueries, - Transactions, + Transactions ) override def pageCategory = Some("How-Tos") diff --git a/squery/src/ba/sake/squery/mariadb/reads.scala b/squery/src/ba/sake/squery/mariadb/reads.scala new file mode 100644 index 0000000..6cbe865 --- /dev/null +++ b/squery/src/ba/sake/squery/mariadb/reads.scala @@ -0,0 +1,13 @@ +package ba.sake.squery.mariadb + +import java.{sql => jsql} +import java.util.UUID +import ba.sake.squery.read.* + +given SqlRead[UUID] with { + def readByName(jRes: jsql.ResultSet, colName: String): Option[UUID] = + Option(jRes.getString(colName)).map(UUID.fromString) + + def readByIdx(jRes: jsql.ResultSet, colIdx: Int): Option[UUID] = + Option(jRes.getString(colIdx)).map(UUID.fromString) +} diff --git a/squery/src/ba/sake/squery/mariadb/writes.scala b/squery/src/ba/sake/squery/mariadb/writes.scala new file mode 100644 index 0000000..28893f2 --- /dev/null +++ b/squery/src/ba/sake/squery/mariadb/writes.scala @@ -0,0 +1,15 @@ +package ba.sake.squery.mariadb + +import java.{sql => jsql} +import java.util.UUID +import ba.sake.squery.write.* + +given SqlWrite[UUID] with { + def write( + ps: jsql.PreparedStatement, + idx: Int, + valueOpt: Option[UUID] + ): Unit = valueOpt match + case Some(value) => ps.setString(idx, value.toString) + case None => ps.setString(idx, null) +} diff --git a/squery/src/ba/sake/squery/mysql/reads.scala b/squery/src/ba/sake/squery/mysql/reads.scala new file mode 100644 index 0000000..b3f81db --- /dev/null +++ b/squery/src/ba/sake/squery/mysql/reads.scala @@ -0,0 +1,13 @@ +package ba.sake.squery.mysql + +import java.{sql => jsql} +import java.util.UUID +import ba.sake.squery.read.* + +given SqlRead[UUID] with { + def readByName(jRes: jsql.ResultSet, colName: String): Option[UUID] = + Option(jRes.getString(colName)).map(UUID.fromString) + + def readByIdx(jRes: jsql.ResultSet, colIdx: Int): Option[UUID] = + Option(jRes.getString(colIdx)).map(UUID.fromString) +} diff --git a/squery/src/ba/sake/squery/mysql/writes.scala b/squery/src/ba/sake/squery/mysql/writes.scala new file mode 100644 index 0000000..6f7f25a --- /dev/null +++ b/squery/src/ba/sake/squery/mysql/writes.scala @@ -0,0 +1,15 @@ +package ba.sake.squery.mysql + +import java.{sql => jsql} +import java.util.UUID +import ba.sake.squery.write.* + +given SqlWrite[UUID] with { + def write( + ps: jsql.PreparedStatement, + idx: Int, + valueOpt: Option[UUID] + ): Unit = valueOpt match + case Some(value) => ps.setString(idx, value.toString) + case None => ps.setString(idx, null) +} diff --git a/squery/src/ba/sake/squery/postgres/reads.scala b/squery/src/ba/sake/squery/postgres/reads.scala new file mode 100644 index 0000000..3601a1a --- /dev/null +++ b/squery/src/ba/sake/squery/postgres/reads.scala @@ -0,0 +1,13 @@ +package ba.sake.squery.postgres + +import java.{sql => jsql} +import java.util.UUID +import ba.sake.squery.read.* + +given SqlRead[UUID] with { + def readByName(jRes: jsql.ResultSet, colName: String): Option[UUID] = + Option(jRes.getObject(colName, classOf[UUID])) + + def readByIdx(jRes: jsql.ResultSet, colIdx: Int): Option[UUID] = + Option(jRes.getObject(colIdx, classOf[UUID])) +} diff --git a/squery/src/ba/sake/squery/postgres/writes.scala b/squery/src/ba/sake/squery/postgres/writes.scala new file mode 100644 index 0000000..8d2c1c9 --- /dev/null +++ b/squery/src/ba/sake/squery/postgres/writes.scala @@ -0,0 +1,15 @@ +package ba.sake.squery.postgres + +import java.{sql => jsql} +import java.util.UUID +import ba.sake.squery.write.* + +given SqlWrite[UUID] with { + def write( + ps: jsql.PreparedStatement, + idx: Int, + valueOpt: Option[UUID] + ): Unit = valueOpt match + case Some(value) => ps.setObject(idx, value) + case None => ps.setNull(idx, jsql.Types.OTHER) +} diff --git a/squery/src/ba/sake/squery/read/SqlRead.scala b/squery/src/ba/sake/squery/read/SqlRead.scala index 0f79ee3..b89bab2 100644 --- a/squery/src/ba/sake/squery/read/SqlRead.scala +++ b/squery/src/ba/sake/squery/read/SqlRead.scala @@ -82,14 +82,6 @@ object SqlRead { Option(jRes.getTimestamp(colIdx)).map(_.toLocalDateTime()) } - given SqlRead[UUID] with { - def readByName(jRes: jsql.ResultSet, colName: String): Option[UUID] = - Option(jRes.getObject(colName, classOf[UUID])) - - def readByIdx(jRes: jsql.ResultSet, colIdx: Int): Option[UUID] = - Option(jRes.getObject(colIdx, classOf[UUID])) - } - // this "cannot fail" given [T](using sr: SqlRead[T]): SqlRead[Option[T]] with { def readByName(jRes: jsql.ResultSet, colName: String): Option[Option[T]] = diff --git a/squery/src/ba/sake/squery/write/SqlWrite.scala b/squery/src/ba/sake/squery/write/SqlWrite.scala index 850827d..020c109 100644 --- a/squery/src/ba/sake/squery/write/SqlWrite.scala +++ b/squery/src/ba/sake/squery/write/SqlWrite.scala @@ -118,16 +118,6 @@ object SqlWrite { case None => ps.setNull(idx, jsql.Types.TIMESTAMP) } - given SqlWrite[UUID] with { - def write( - ps: jsql.PreparedStatement, - idx: Int, - valueOpt: Option[UUID] - ): Unit = valueOpt match - case Some(value) => ps.setObject(idx, value) - case None => ps.setNull(idx, jsql.Types.OTHER) - } - given [T](using sw: SqlWrite[T]): SqlWrite[Option[T]] with { def write( ps: jsql.PreparedStatement, @@ -175,8 +165,8 @@ object SqlWrite { case Some(value) => val index = $m.ordinal(value) val label = $labels(index) - ps.setObject(idx, label, jsql.Types.OTHER) - case None => ps.setNull(idx, jsql.Types.OTHER) + ps.setString(idx, label) + case None => ps.setString(idx, null) } } diff --git a/squery/test/src/ba/sake/squery/dataTypes.scala b/squery/test/src/ba/sake/squery/dataTypes.scala index bb90201..3beca28 100644 --- a/squery/test/src/ba/sake/squery/dataTypes.scala +++ b/squery/test/src/ba/sake/squery/dataTypes.scala @@ -16,16 +16,4 @@ case class CustomerWithPhoneOpt(c: Customer, p: Option[Phone]) derives SqlReadRo case class Address(id: Int, name: Option[String]) derives SqlReadRow case class CustomerWithAddressOpt(c: Customer, a: Option[Address]) derives SqlReadRow -case class Datatypes( - int: Option[Int], - long: Option[Long], - double: Option[Double], - boolean: Option[Boolean], - string: Option[String], - uuid: Option[UUID], - tstz: Option[Instant], - clr: Option[Color] -) derives SqlReadRow -enum Color derives SqlRead, SqlWrite: - case red, green, blue diff --git a/squery/test/src/ba/sake/squery/mariadb/MariaDbSuite.scala b/squery/test/src/ba/sake/squery/mariadb/MariaDbSuite.scala new file mode 100644 index 0000000..6fc6f9d --- /dev/null +++ b/squery/test/src/ba/sake/squery/mariadb/MariaDbSuite.scala @@ -0,0 +1,366 @@ +package ba.sake.squery +package mariadb + +import java.util.UUID +import java.time.Instant +import java.time.temporal.ChronoUnit +import scala.collection.decorators._ +import org.testcontainers.containers.MariaDBContainer + +// UUID, enum.. MariaDB specific +case class Datatypes( + d_int: Option[Int], + d_long: Option[Long], + d_double: Option[Double], + d_boolean: Option[Boolean], + d_string: Option[String], + d_uuid: Option[UUID], + d_tstz: Option[Instant], + d_clr: Option[Color] +) derives SqlReadRow: + def insertTuple = sql"(${d_int}, ${d_long}, ${d_double}, ${d_boolean}, ${d_string}, ${d_uuid}, ${d_tstz}, ${d_clr})" + +object Datatypes: + inline val * = "d_int, d_long, d_double, d_boolean, d_string, d_uuid, d_tstz, d_clr" + +enum Color derives SqlRead, SqlWrite: + case red, green, blue + +class MariaDbSuite extends munit.FunSuite { + + var customer1 = Customer(1, "a_customer", None) + var customer2 = Customer(1, "b_customer", Some("str1")) + val customers = Seq(customer1, customer2) + + var phone1 = Phone(1, "061 123 456") + var phone2 = Phone(1, "062 225 883") + val phones = Seq(phone1, phone2) + + var address1 = Address(1, Some("a1")) + var address2 = Address(1, None) + val addresses = Seq(address1, address2) + + val initDb = new Fixture[SqueryContext]("database") { + private var ctx: SqueryContext = null + private var container: MariaDBContainer[?] = null + + def apply() = ctx + + override def beforeAll(): Unit = { + container = MariaDBContainer("mariadb:11.2.2") + container.withUrlParam("returnMultiValuesGeneratedIds", "true") // TODO document + container.start() + + val ds = com.zaxxer.hikari.HikariDataSource() + ds.setJdbcUrl(container.getJdbcUrl()) + ds.setUsername(container.getUsername()) + ds.setPassword(container.getPassword()) + + ctx = SqueryContext(ds) + + ctx.run { + sql""" + CREATE TABLE customers( + id SERIAL PRIMARY KEY, + name TEXT NOT NULL, + street VARCHAR(20) + ) + """.update() + + sql""" + CREATE TABLE phones( + id SERIAL PRIMARY KEY, + customer_id BIGINT(20) UNSIGNED REFERENCES customers(id), + number TEXT + ) + """.update() + + sql""" + CREATE TABLE addresses( + id SERIAL PRIMARY KEY, + name TEXT + ) + """.update() + + sql""" + CREATE TABLE customer_address( + customer_id BIGINT(20) UNSIGNED REFERENCES customers(id), + address_id BIGINT(20) UNSIGNED REFERENCES addresses(id), + PRIMARY KEY (customer_id, address_id) + ) + """.update() + + val customerIds = sql""" + INSERT INTO customers(name, street) + VALUES (${customer1.name}, ${customer1.street}), + (${customer2.name}, ${customer2.street}) + """.insertReturningGenKeys[Int]() + + + println(customerIds) + + customer1 = customer1.copy(id = customerIds(0)) + customer2 = customer2.copy(id = customerIds(1)) + + val phoneIds = sql""" + INSERT INTO phones(customer_id, number) VALUES + (${customer1.id}, ${phone1.number}), + (${customer1.id}, ${phone2.number}) + """.insertReturningGenKeys[Int]() + phone1 = phone1.copy(id = phoneIds(0)) + phone2 = phone2.copy(id = phoneIds(1)) + + val addressIds = sql""" + INSERT INTO addresses(name) VALUES + (${address1.name}), + (${address2.name}) + """.insertReturningGenKeys[Int]() + address1 = address1.copy(id = addressIds(0)) + address2 = address2.copy(id = addressIds(1)) + + sql""" + INSERT INTO customer_address(customer_id, address_id) + VALUES + (${customer1.id}, ${address1.id}), + (${customer1.id}, ${address2.id}) + """.insert() + } + } + override def afterAll(): Unit = + if container != null then container.close() + + } + + override def munitFixtures = List(initDb) + + /* TESTS */ + test("SELECT plain values") { + val ctx = initDb() + ctx.run { + assertEquals( + sql"SELECT name FROM customers".readValues[String](), + customers.map(_.name) + ) + + assertEquals( + sql"SELECT number FROM phones WHERE customer_id = ${customer1.id}" + .readValues[String](), + phones.map(_.number) + ) + + val q1 = sql"" + val q2 = sql" ${q1} " + } + } + + test("SELECT rows") { + val ctx = initDb() + ctx.run { + // full join + assertEquals( + sql""" + SELECT c.id, c.name, c.street, + p.id, p.number + FROM customers c + JOIN phones p ON p.customer_id = c.id + WHERE c.id = ${customer1.id} + ORDER BY c.id ASC, p.id ASC + """.readRows[CustomerWithPhone](), + Seq( + CustomerWithPhone(customer1, phone1), + CustomerWithPhone(customer1, phone2) + ) + ) + + // outer/optional join + assertEquals( + sql""" + SELECT c.id, c.name, c.street, + p.id, p.number + FROM customers c + LEFT JOIN phones p ON p.customer_id = c.id + ORDER BY c.id ASC, p.id ASC + """.readRows[CustomerWithPhoneOpt](), + Seq( + CustomerWithPhoneOpt(customer1, Some(phone1)), + CustomerWithPhoneOpt(customer1, Some(phone2)), + CustomerWithPhoneOpt(customer2, None) + ) + ) + + // outer/optional join with many-to-many + assertEquals( + sql""" + SELECT c.id, c.name, c.street, + a.id, a.name + FROM customers c + LEFT JOIN customer_address ca ON ca.customer_id = c.id + LEFT JOIN addresses a ON a.id = ca.address_id + """.readRows[CustomerWithAddressOpt](), + Seq( + CustomerWithAddressOpt(customer1, Some(address1)), + CustomerWithAddressOpt(customer1, Some(address2)), + CustomerWithAddressOpt(customer2, None) + ) + ) + + } + } + + test("BAD SELECT throws") { + val ctx = initDb() + intercept[SqueryException] { + // street is nullable, but CustomerBad says it's mandatory + ctx.run { + sql""" + SELECT id, name, street + FROM customers + """.insertReturningRows[CustomerBad]() + } + } + } + + test("INSERT returning generated keys") { + val ctx = initDb() + ctx.run { + val customerIds = sql""" + INSERT INTO customers(name) + VALUES ('abc'), ('def'), ('ghi') + """.insertReturningGenKeys[Int]() + assertEquals( + customerIds.toSet, + (customer2.id + 1 to customer2.id + 3).toSet + ) + } + } + + test("INSERT returning columns") { + val ctx = initDb() + ctx.run { + val customers = sql""" + INSERT INTO customers(name) + VALUES ('abc'), ('def'), ('ghi') + RETURNING id, name, street + """.insertReturningRows[Customer]() + assertEquals(customers.map(_.name).toSet, Set("abc", "def", "ghi")) + } + } + + test("UPDATE should return number of affected rows") { + val ctx = initDb() + ctx.run { + sql""" + INSERT INTO customers(name) + VALUES ('xyz_1'), ('xyz_2'), ('b_1') + """.insert() + val affected = sql""" + UPDATE customers + SET name = 'whatever' + WHERE name LIKE 'xyz_%' + """.update() + assertEquals(affected, 2) + } + } + + test("Data types") { + val ctx = initDb() + ctx.run { + sql""" + CREATE TABLE datatypes( + d_int INT, + d_long BIGINT, + d_double DOUBLE PRECISION, + d_boolean BOOLEAN, + d_string VARCHAR(255), + d_uuid UUID, + d_tstz TIMESTAMP, + d_clr ENUM('red', 'green', 'blue') + ) + """.update() + val dt1 = Datatypes( + Some(123), + Some(Int.MaxValue + 100), + Some(0.54543), + Some(true), + Some("abc"), + Some(UUID.randomUUID), + Some(Instant.now.truncatedTo(ChronoUnit.SECONDS)), + Some(Color.red) + ) + val dt2 = Datatypes(None, None, None, None, None, None, None, None) + + val values = Seq(dt1, dt2) + .map(_.insertTuple) + .intersperse(sql",") + .reduce(_ ++ _) + sql""" + INSERT INTO datatypes(${Datatypes.*}) + VALUES ${values} + """.insert() + + val storedRows = sql""" + SELECT ${Datatypes.*} + FROM datatypes + """.readRows[Datatypes]() + assertEquals( + storedRows, + Seq(dt1) + ) + } + } + + test("Transaction") { + val ctx = initDb() + // create table normally + ctx.run { + sql""" + CREATE TABLE test_transactions( + name TEXT, + UNIQUE(name) + ) + """.update() + } + // all succeeds, + // or nothing succeeds! + intercept[Exception] { + ctx.runTransaction { + sql""" + INSERT INTO test_transactions(name) + VALUES ('abc') + """.insert() + // fail coz unique name + sql""" + INSERT INTO test_transactions(name) + VALUES ('abc') + """.insert() + } + } + intercept[Exception] { + ctx.runTransactionWithIsolation(TransactionIsolation.Serializable) { + sql""" + INSERT INTO test_transactions(name) + VALUES ('abc') + """.insert() + // fail coz unique name + sql""" + INSERT INTO test_transactions(name) + VALUES ('abc') + """.insert() + } + } + // check there is NO ENTRIES, coz transaction failed + ctx.run { + val values = sql"SELECT name FROM test_transactions".readValues[String]() + assertEquals(values, Seq.empty) + } + } + + test("Log warnings") { + val ctx = initDb() + ctx.run { + // custom squery warnings + sql"UPDATE customers SET name='bla'".update() + } + } + +} diff --git a/squery/test/src/ba/sake/squery/mysql/MySqlSuite.scala b/squery/test/src/ba/sake/squery/mysql/MySqlSuite.scala new file mode 100644 index 0000000..cf28d62 --- /dev/null +++ b/squery/test/src/ba/sake/squery/mysql/MySqlSuite.scala @@ -0,0 +1,352 @@ +package ba.sake.squery +package mysql + +import java.util.UUID +import java.time.Instant +import java.time.temporal.ChronoUnit +import scala.collection.decorators._ +import org.testcontainers.containers.MySQLContainer + +// UUID, enum.. MySql specific +case class Datatypes( + d_int: Option[Int], + d_long: Option[Long], + d_double: Option[Double], + d_boolean: Option[Boolean], + d_string: Option[String], + d_uuid: Option[UUID], + d_tstz: Option[Instant], + d_clr: Option[Color] +) derives SqlReadRow: + def insertTuple = sql"(${d_int}, ${d_long}, ${d_double}, ${d_boolean}, ${d_string}, ${d_uuid}, ${d_tstz}, ${d_clr})" + +object Datatypes: + inline val * = "d_int, d_long, d_double, d_boolean, d_string, d_uuid, d_tstz, d_clr" + +enum Color derives SqlRead, SqlWrite: + case red, green, blue + +class MySqlSuite extends munit.FunSuite { + + var customer1 = Customer(1, "a_customer", None) + var customer2 = Customer(1, "b_customer", Some("str1")) + val customers = Seq(customer1, customer2) + + var phone1 = Phone(1, "061 123 456") + var phone2 = Phone(1, "062 225 883") + val phones = Seq(phone1, phone2) + + var address1 = Address(1, Some("a1")) + var address2 = Address(1, None) + val addresses = Seq(address1, address2) + + val initDb = new Fixture[SqueryContext]("database") { + private var ctx: SqueryContext = null + private var container: MySQLContainer[?] = null + + def apply() = ctx + + override def beforeAll(): Unit = { + container = MySQLContainer("mysql:8.2.0") + container.start() + + val ds = com.zaxxer.hikari.HikariDataSource() + ds.setJdbcUrl(container.getJdbcUrl()) + ds.setUsername(container.getUsername()) + ds.setPassword(container.getPassword()) + + ctx = SqueryContext(ds) + + ctx.run { + sql""" + CREATE TABLE customers( + id SERIAL PRIMARY KEY, + name TEXT NOT NULL, + street VARCHAR(20) + ) + """.update() + + sql""" + CREATE TABLE phones( + id SERIAL PRIMARY KEY, + customer_id INTEGER REFERENCES customers(id), + number TEXT + ) + """.update() + + sql""" + CREATE TABLE addresses( + id SERIAL PRIMARY KEY, + name TEXT + ) + """.update() + + sql""" + CREATE TABLE customer_address( + customer_id INTEGER REFERENCES customers(id), + address_id INTEGER REFERENCES addresses(id), + PRIMARY KEY (customer_id, address_id) + ) + """.update() + + val customerIds = sql""" + INSERT INTO customers(name, street) + VALUES (${customer1.name}, ${customer1.street}), + (${customer2.name}, ${customer2.street}) + """.insertReturningGenKeys[Int]() + customer1 = customer1.copy(id = customerIds(0)) + customer2 = customer2.copy(id = customerIds(1)) + + val phoneIds = sql""" + INSERT INTO phones(customer_id, number) VALUES + (${customer1.id}, ${phone1.number}), + (${customer1.id}, ${phone2.number}) + """.insertReturningGenKeys[Int]() + phone1 = phone1.copy(id = phoneIds(0)) + phone2 = phone2.copy(id = phoneIds(1)) + + val addressIds = sql""" + INSERT INTO addresses(name) VALUES + (${address1.name}), + (${address2.name}) + """.insertReturningGenKeys[Int]() + address1 = address1.copy(id = addressIds(0)) + address2 = address2.copy(id = addressIds(1)) + + sql""" + INSERT INTO customer_address(customer_id, address_id) + VALUES + (${customer1.id}, ${address1.id}), + (${customer1.id}, ${address2.id}) + """.insert() + } + } + override def afterAll(): Unit = + if container != null then container.close() + + } + + override def munitFixtures = List(initDb) + + /* TESTS */ + test("SELECT plain values") { + val ctx = initDb() + ctx.run { + assertEquals( + sql"SELECT name FROM customers".readValues[String](), + customers.map(_.name) + ) + + assertEquals( + sql"SELECT number FROM phones WHERE customer_id = ${customer1.id}" + .readValues[String](), + phones.map(_.number) + ) + + val q1 = sql"" + val q2 = sql" ${q1} " + } + } + + test("SELECT rows") { + val ctx = initDb() + ctx.run { + // full join + assertEquals( + sql""" + SELECT c.id, c.name, c.street, + p.id, p.number + FROM customers c + JOIN phones p ON p.customer_id = c.id + WHERE c.id = ${customer1.id} + ORDER BY c.id ASC, p.id ASC + """.readRows[CustomerWithPhone](), + Seq( + CustomerWithPhone(customer1, phone1), + CustomerWithPhone(customer1, phone2) + ) + ) + + // outer/optional join + assertEquals( + sql""" + SELECT c.id, c.name, c.street, + p.id, p.number + FROM customers c + LEFT JOIN phones p ON p.customer_id = c.id + ORDER BY c.id ASC, p.id ASC + """.readRows[CustomerWithPhoneOpt](), + Seq( + CustomerWithPhoneOpt(customer1, Some(phone1)), + CustomerWithPhoneOpt(customer1, Some(phone2)), + CustomerWithPhoneOpt(customer2, None) + ) + ) + + // outer/optional join with many-to-many + assertEquals( + sql""" + SELECT c.id, c.name, c.street, + a.id, a.name + FROM customers c + LEFT JOIN customer_address ca ON ca.customer_id = c.id + LEFT JOIN addresses a ON a.id = ca.address_id + """.readRows[CustomerWithAddressOpt](), + Seq( + CustomerWithAddressOpt(customer1, Some(address1)), + CustomerWithAddressOpt(customer1, Some(address2)), + CustomerWithAddressOpt(customer2, None) + ) + ) + + } + } + + test("BAD SELECT throws") { + val ctx = initDb() + intercept[SqueryException] { + // street is nullable, but CustomerBad says it's mandatory + ctx.run { + sql""" + SELECT id, name, street + FROM customers + """.insertReturningRows[CustomerBad]() + } + } + } + + test("INSERT returning generated keys") { + val ctx = initDb() + ctx.run { + val customerIds = sql""" + INSERT INTO customers(name) + VALUES ('abc'), ('def'), ('ghi') + """.insertReturningGenKeys[Int]() + assertEquals( + customerIds.toSet, + (customer2.id + 1 to customer2.id + 3).toSet + ) + } + } + + test("UPDATE should return number of affected rows") { + val ctx = initDb() + ctx.run { + sql""" + INSERT INTO customers(name) + VALUES ('xyz_1'), ('xyz_2'), ('b_1') + """.insert() + val affected = sql""" + UPDATE customers + SET name = 'whatever' + WHERE name LIKE 'xyz_%' + """.update() + assertEquals(affected, 2) + } + } + + test("Data types") { + val ctx = initDb() + ctx.run { + sql""" + CREATE TABLE datatypes( + d_int INT, + d_long BIGINT, + d_double DOUBLE PRECISION, + d_boolean BOOLEAN, + d_string VARCHAR(255), + d_uuid VARCHAR(36), + d_tstz TIMESTAMP, + d_clr ENUM('red', 'green', 'blue') + ) + """.update() + val dt1 = Datatypes( + Some(123), + Some(Int.MaxValue + 100), + Some(0.54543), + Some(true), + Some("abc"), + Some(UUID.randomUUID), + Some(Instant.now.truncatedTo(ChronoUnit.SECONDS)), + Some(Color.red) + ) + val dt2 = Datatypes(None, None, None, None, None, None, None, None) + + val values = Seq(dt1, dt2) + .map(_.insertTuple) + .intersperse(sql",") + .reduce(_ ++ _) + sql""" + INSERT INTO datatypes(${Datatypes.*}) + VALUES ${values} + """.insert() + + val storedRows = sql""" + SELECT ${Datatypes.*} + FROM datatypes + """.readRows[Datatypes]() + assertEquals( + storedRows, + Seq(dt1) + ) + } + } + + test("Transaction") { + val ctx = initDb() + // create table normally + ctx.run { + sql""" + CREATE TABLE test_transactions( + name VARCHAR(32), + UNIQUE(name) + ) + """.update() + } + // all succeeds, + // or nothing succeeds! + intercept[Exception] { + ctx.runTransaction { + sql""" + INSERT INTO test_transactions(name) + VALUES ('abc') + """.insert() + // fail coz unique name + sql""" + INSERT INTO test_transactions(name) + VALUES ('abc') + """.insert() + } + } + intercept[Exception] { + ctx.runTransactionWithIsolation(TransactionIsolation.Serializable) { + sql""" + INSERT INTO test_transactions(name) + VALUES ('abc') + """.insert() + // fail coz unique name + sql""" + INSERT INTO test_transactions(name) + VALUES ('abc') + """.insert() + } + } + // check there is NO ENTRIES, coz transaction failed + ctx.run { + val values = sql"SELECT name FROM test_transactions".readValues[String]() + assertEquals(values, Seq.empty) + } + } + + test("Log warnings") { + val ctx = initDb() + ctx.run { + // custom squery warnings + // no WHERE .. conditions + sql"UPDATE customers SET name='bla'".update() + + sql"DELETE FROM customers".update() + } + } + +} diff --git a/squery/test/src/ba/sake/squery/PostgresSuite.scala b/squery/test/src/ba/sake/squery/postgres/PostgresSuite.scala similarity index 88% rename from squery/test/src/ba/sake/squery/PostgresSuite.scala rename to squery/test/src/ba/sake/squery/postgres/PostgresSuite.scala index 8111a32..3fac28d 100644 --- a/squery/test/src/ba/sake/squery/PostgresSuite.scala +++ b/squery/test/src/ba/sake/squery/postgres/PostgresSuite.scala @@ -1,4 +1,5 @@ package ba.sake.squery +package postgres import java.util.UUID import java.time.Instant @@ -6,6 +7,25 @@ import java.time.temporal.ChronoUnit import scala.collection.decorators._ import org.testcontainers.containers.PostgreSQLContainer +// UUID, enum.. Postgres specific +case class Datatypes( + d_int: Option[Int], + d_long: Option[Long], + d_double: Option[Double], + d_boolean: Option[Boolean], + d_string: Option[String], + d_uuid: Option[UUID], + d_tstz: Option[Instant], + d_clr: Option[Color] +) derives SqlReadRow: + def insertTuple = sql"(${d_int}, ${d_long}, ${d_double}, ${d_boolean}, ${d_string}, ${d_uuid}, ${d_tstz}, ${d_clr})" + +object Datatypes: + inline val * = "d_int, d_long, d_double, d_boolean, d_string, d_uuid, d_tstz, d_clr" + +enum Color derives SqlRead, SqlWrite: + case red, green, blue + class PostgresSuite extends munit.FunSuite { var customer1 = Customer(1, "a_customer", None) @@ -28,6 +48,9 @@ class PostgresSuite extends munit.FunSuite { override def beforeAll(): Unit = { container = PostgreSQLContainer("postgres:9.6.12") + // let PG to figure out that a setString is actually an enum + // https://stackoverflow.com/a/43125099/4496364 + container.withUrlParam("stringtype", "unspecified") // TODO document container.start() val ds = com.zaxxer.hikari.HikariDataSource() @@ -246,14 +269,14 @@ class PostgresSuite extends munit.FunSuite { // postgres has MICROseconds precision sql""" CREATE TABLE datatypes( - int INTEGER, - long BIGINT, - double DOUBLE PRECISION, - boolean BOOLEAN, - string VARCHAR(255), - uuid UUID, - tstz TIMESTAMPTZ, - clr color + d_int INTEGER, + d_long BIGINT, + d_double DOUBLE PRECISION, + d_boolean BOOLEAN, + d_string VARCHAR(255), + d_uuid UUID, + d_tstz TIMESTAMPTZ, + d_clr color ) """.update() val dt1 = Datatypes( @@ -269,18 +292,16 @@ class PostgresSuite extends munit.FunSuite { val dt2 = Datatypes(None, None, None, None, None, None, None, None) val values = Seq(dt1, dt2) - .map(dt => - sql"(${dt.int}, ${dt.long}, ${dt.double}, ${dt.boolean}, ${dt.string}, ${dt.uuid}, ${dt.tstz}, ${dt.clr})" - ) + .map(_.insertTuple) .intersperse(sql",") .reduce(_ ++ _) sql""" - INSERT INTO datatypes(int, long, double, boolean, string, uuid, tstz, clr) + INSERT INTO datatypes(${Datatypes.*}) VALUES ${values} """.insert() val storedRows = sql""" - SELECT int, long, double, boolean, string, uuid, tstz, clr + SELECT ${Datatypes.*} FROM datatypes """.readRows[Datatypes]() assertEquals( @@ -340,6 +361,8 @@ class PostgresSuite extends munit.FunSuite { val ctx = initDb() ctx.run { // custom squery warnings + // no WHERE clause + // TODO how to test these..? sql"UPDATE customers SET name='bla'".update() intercept[Exception] {