diff --git a/benchmark/src/main/scala/benchmark/Model.scala b/benchmark/src/main/scala/benchmark/Model.scala index d41dd7971..e648be2e7 100644 --- a/benchmark/src/main/scala/benchmark/Model.scala +++ b/benchmark/src/main/scala/benchmark/Model.scala @@ -17,8 +17,7 @@ case class Model1( ) derives Table object Model1: - given Encoder[Model1] = Encoder[Int].to[Model1] - given Decoder[Model1] = Decoder[Int].to[Model1] + given Codec[Model1] = Codec[Int].to[Model1] case class Model5( c1: Int, @@ -65,17 +64,11 @@ case class Model20( ) derives Table object Model20: - given Encoder[Model20] = ( - Encoder[Int] *: Encoder[Int] *: Encoder[Int] *: Encoder[Int] *: Encoder[Int] *: - Encoder[Int] *: Encoder[Int] *: Encoder[Int] *: Encoder[Int] *: Encoder[Int] *: - Encoder[Int] *: Encoder[Int] *: Encoder[Int] *: Encoder[Int] *: Encoder[Int] *: - Encoder[Int] *: Encoder[Int] *: Encoder[Int] *: Encoder[Int] *: Encoder[Int] - ).to[Model20] - given Decoder[Model20] = ( - Decoder[Int] *: Decoder[Int] *: Decoder[Int] *: Decoder[Int] *: Decoder[Int] *: - Decoder[Int] *: Decoder[Int] *: Decoder[Int] *: Decoder[Int] *: Decoder[Int] *: - Decoder[Int] *: Decoder[Int] *: Decoder[Int] *: Decoder[Int] *: Decoder[Int] *: - Decoder[Int] *: Decoder[Int] *: Decoder[Int] *: Decoder[Int] *: Decoder[Int] + given Codec[Model20] = ( + Codec[Int] *: Codec[Int] *: Codec[Int] *: Codec[Int] *: Codec[Int] *: + Codec[Int] *: Codec[Int] *: Codec[Int] *: Codec[Int] *: Codec[Int] *: + Codec[Int] *: Codec[Int] *: Codec[Int] *: Codec[Int] *: Codec[Int] *: + Codec[Int] *: Codec[Int] *: Codec[Int] *: Codec[Int] *: Codec[Int] ).to[Model20] case class Model25( @@ -107,19 +100,12 @@ case class Model25( ) derives Table object Model25: - given Encoder[Model25] = ( - Encoder[Int] *: Encoder[Int] *: Encoder[Int] *: Encoder[Int] *: Encoder[Int] *: - Encoder[Int] *: Encoder[Int] *: Encoder[Int] *: Encoder[Int] *: Encoder[Int] *: - Encoder[Int] *: Encoder[Int] *: Encoder[Int] *: Encoder[Int] *: Encoder[Int] *: - Encoder[Int] *: Encoder[Int] *: Encoder[Int] *: Encoder[Int] *: Encoder[Int] *: - Encoder[Int] *: Encoder[Int] *: Encoder[Int] *: Encoder[Int] *: Encoder[Int] - ).to[Model25] - given Decoder[Model25] = ( - Decoder[Int] *: Decoder[Int] *: Decoder[Int] *: Decoder[Int] *: Decoder[Int] *: - Decoder[Int] *: Decoder[Int] *: Decoder[Int] *: Decoder[Int] *: Decoder[Int] *: - Decoder[Int] *: Decoder[Int] *: Decoder[Int] *: Decoder[Int] *: Decoder[Int] *: - Decoder[Int] *: Decoder[Int] *: Decoder[Int] *: Decoder[Int] *: Decoder[Int] *: - Decoder[Int] *: Decoder[Int] *: Decoder[Int] *: Decoder[Int] *: Decoder[Int] + given Codec[Model25] = ( + Codec[Int] *: Codec[Int] *: Codec[Int] *: Codec[Int] *: Codec[Int] *: + Codec[Int] *: Codec[Int] *: Codec[Int] *: Codec[Int] *: Codec[Int] *: + Codec[Int] *: Codec[Int] *: Codec[Int] *: Codec[Int] *: Codec[Int] *: + Codec[Int] *: Codec[Int] *: Codec[Int] *: Codec[Int] *: Codec[Int] *: + Codec[Int] *: Codec[Int] *: Codec[Int] *: Codec[Int] *: Codec[Int] ).to[Model25] case class City( diff --git a/module/ldbc-codegen/shared/src/main/scala/ldbc/codegen/TableModelGenerator.scala b/module/ldbc-codegen/shared/src/main/scala/ldbc/codegen/TableModelGenerator.scala index 9774000e1..55318a7ce 100644 --- a/module/ldbc-codegen/shared/src/main/scala/ldbc/codegen/TableModelGenerator.scala +++ b/module/ldbc-codegen/shared/src/main/scala/ldbc/codegen/TableModelGenerator.scala @@ -148,7 +148,6 @@ private[ldbc] object TableModelGenerator: Some(s"""enum $enumName extends model.Enum: | case ${ types.mkString(", ") } | object $enumName extends model.EnumDataType[$enumName]: - | given ldbc.dsl.codec.Decoder[$enumName] = ldbc.dsl.codec.Decoder[Int].map($enumName.fromOrdinal) - | given ldbc.dsl.codec.Encoder[$enumName] = ldbc.dsl.codec.Encoder[Int].contramap(_.ordinal) + | given ldbc.dsl.codec.Codec[$enumName] = ldbc.dsl.codec.Codec[Int].imap($enumName.fromOrdinal)(_.ordinal) |""".stripMargin) case _ => None diff --git a/module/ldbc-dsl/src/main/scala/ldbc/dsl/Parameter.scala b/module/ldbc-dsl/src/main/scala/ldbc/dsl/Parameter.scala index d115f3c10..7f695c595 100644 --- a/module/ldbc-dsl/src/main/scala/ldbc/dsl/Parameter.scala +++ b/module/ldbc-dsl/src/main/scala/ldbc/dsl/Parameter.scala @@ -8,7 +8,7 @@ package ldbc.dsl import cats.syntax.all.* -import ldbc.dsl.codec.Encoder +import ldbc.dsl.codec.* /** * Trait for setting Scala and Java values to PreparedStatement. @@ -35,10 +35,18 @@ object Parameter: case Encoder.Encoded.Success(list) => list.map(value => Success(value)) case Encoder.Encoded.Failure(errors) => List(Failure(errors.toList)) - given [A](using encoder: Encoder[A]): Conversion[A, Dynamic] with + given convFromEncoder[A](using encoder: Encoder[A]): Conversion[A, Dynamic] with override def apply(value: A): Dynamic = encoder.encode(value) match case Encoder.Encoded.Success(list) => list match case head :: Nil => Dynamic.Success(head) case _ => Dynamic.Failure(List("Multiple values are not allowed")) case Encoder.Encoded.Failure(errors) => Dynamic.Failure(errors.toList) + + given convFromCodec[A](using codec: Codec[A]): Conversion[A, Dynamic] with + override def apply(value: A): Dynamic = codec.encode(value) match + case Encoder.Encoded.Success(list) => + list match + case head :: Nil => Dynamic.Success(head) + case _ => Dynamic.Failure(List("Multiple values are not allowed")) + case Encoder.Encoded.Failure(errors) => Dynamic.Failure(errors.toList) diff --git a/module/ldbc-dsl/src/main/scala/ldbc/dsl/codec/Codec.scala b/module/ldbc-dsl/src/main/scala/ldbc/dsl/codec/Codec.scala new file mode 100644 index 000000000..fde32bc5d --- /dev/null +++ b/module/ldbc-dsl/src/main/scala/ldbc/dsl/codec/Codec.scala @@ -0,0 +1,153 @@ +/** + * Copyright (c) 2023-2024 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.dsl.codec + +import java.time.* + +import scala.deriving.Mirror + +import cats.InvariantSemigroupal + +import org.typelevel.twiddles.TwiddleSyntax + +import ldbc.sql.ResultSet + +/** + * Symmetric encoder and decoder of MySQL data to and from Scala types. + * + * @tparam A + * Types handled in Scala + */ +trait Codec[A] extends Encoder[A], Decoder[A]: + self => + + /** Forget this value is a `Codec` and treat it as an `Encoder`. */ + def asEncoder: Encoder[A] = this + + /** Forget this value is a `Codec` and treat it as a `Decoder`. */ + def asDecoder: Decoder[A] = this + + /** `Codec` is semigroupal: a pair of codecs make a codec for a pair. */ + def product[B](fb: Codec[B]): Codec[(A, B)] = new Codec[(A, B)]: + private val pe = self.asEncoder product fb.asEncoder + private val pd = self.asDecoder product fb.asDecoder + + override def offset: Int = self.offset + fb.offset + override def encode(value: (A, B)): Encoder.Encoded = pe.encode(value) + override def decode(resultSet: ResultSet, index: Int): (A, B) = pd.decode(resultSet, index) + + /** Contramap inputs from, and map outputs to, a new type `B`, yielding a `Codec[B]`. */ + def imap[B](f: A => B)(g: B => A): Codec[B] = new Codec[B]: + override def offset: Int = self.offset + override def encode(value: B): Encoder.Encoded = self.encode(g(value)) + override def decode(resultSet: ResultSet, index: Int): B = f(self.decode(resultSet, index)) + + /** Lift this `Codec` into `Option`, where `None` is mapped to and from a vector of `NULL`. */ + override def opt: Codec[Option[A]] = new Codec[Option[A]]: + override def offset: Int = self.offset + override def encode(value: Option[A]): Encoder.Encoded = + value.fold(Encoder.Encoded.success(List(None)))(self.encode) + override def decode(resultSet: ResultSet, index: Int): Option[A] = + val value = self.decode(resultSet, index) + if resultSet.wasNull() then None else Some(value) + +object Codec extends TwiddleSyntax[Codec]: + + def apply[A](using codec: Codec[A]): Codec[A] = codec + + given InvariantSemigroupal[Codec] with + override def imap[A, B](fa: Codec[A])(f: A => B)(g: B => A): Codec[B] = fa.imap(f)(g) + override def product[A, B](fa: Codec[A], fb: Codec[B]): Codec[(A, B)] = fa product fb + + given Codec[Boolean] with + override def offset: Int = 1 + override def encode(value: Boolean): Encoder.Encoded = Encoder.Encoded.success(List(value)) + override def decode(resultSet: ResultSet, index: Int): Boolean = resultSet.getBoolean(index) + + given Codec[Byte] with + override def offset: Int = 1 + override def encode(value: Byte): Encoder.Encoded = Encoder.Encoded.success(List(value)) + override def decode(resultSet: ResultSet, index: Int): Byte = resultSet.getByte(index) + + given Codec[Short] with + override def offset: Int = 1 + override def encode(value: Short): Encoder.Encoded = Encoder.Encoded.success(List(value)) + override def decode(resultSet: ResultSet, index: Int): Short = resultSet.getShort(index) + + given Codec[Int] with + override def offset: Int = 1 + override def encode(value: Int): Encoder.Encoded = Encoder.Encoded.success(List(value)) + override def decode(resultSet: ResultSet, index: Int): Int = resultSet.getInt(index) + + given Codec[Long] with + override def offset: Int = 1 + override def encode(value: Long): Encoder.Encoded = Encoder.Encoded.success(List(value)) + override def decode(resultSet: ResultSet, index: Int): Long = resultSet.getLong(index) + + given Codec[Float] with + override def offset: Int = 1 + override def encode(value: Float): Encoder.Encoded = Encoder.Encoded.success(List(value)) + override def decode(resultSet: ResultSet, index: Int): Float = resultSet.getFloat(index) + + given Codec[Double] with + override def offset: Int = 1 + override def encode(value: Double): Encoder.Encoded = Encoder.Encoded.success(List(value)) + override def decode(resultSet: ResultSet, index: Int): Double = resultSet.getDouble(index) + + given Codec[BigDecimal] with + override def offset: Int = 1 + override def encode(value: BigDecimal): Encoder.Encoded = Encoder.Encoded.success(List(value)) + override def decode(resultSet: ResultSet, index: Int): BigDecimal = resultSet.getBigDecimal(index) + + given Codec[String] with + override def offset: Int = 1 + override def encode(value: String): Encoder.Encoded = Encoder.Encoded.success(List(value)) + override def decode(resultSet: ResultSet, index: Int): String = resultSet.getString(index) + + given Codec[Array[Byte]] with + override def offset: Int = 1 + override def encode(value: Array[Byte]): Encoder.Encoded = Encoder.Encoded.success(List(value)) + override def decode(resultSet: ResultSet, index: Int): Array[Byte] = resultSet.getBytes(index) + + given Codec[LocalTime] with + override def offset: Int = 1 + override def encode(value: LocalTime): Encoder.Encoded = Encoder.Encoded.success(List(value)) + override def decode(resultSet: ResultSet, index: Int): LocalTime = resultSet.getTime(index) + + given Codec[LocalDate] with + override def offset: Int = 1 + override def encode(value: LocalDate): Encoder.Encoded = Encoder.Encoded.success(List(value)) + override def decode(resultSet: ResultSet, index: Int): LocalDate = resultSet.getDate(index) + + given Codec[LocalDateTime] with + override def offset: Int = 1 + override def encode(value: LocalDateTime): Encoder.Encoded = Encoder.Encoded.success(List(value)) + override def decode(resultSet: ResultSet, index: Int): LocalDateTime = resultSet.getTimestamp(index) + + given [A](using codec: Codec[Int]): Codec[Year] = + codec.imap(Year.of)(_.getValue) + + given [A](using codec: Codec[String]): Codec[YearMonth] = + codec.imap(YearMonth.parse)(_.toString) + + given [A](using codec: Codec[String]): Codec[BigInt] = + codec.imap(str => if str == null then null else BigInt(str))(_.toString) + + given Codec[None.type] with + override def offset: Int = 1 + override def encode(value: None.type): Encoder.Encoded = Encoder.Encoded.success(List(None)) + override def decode(resultSet: ResultSet, index: Int): None.type = None + + given [A](using codec: Codec[A]): Codec[Option[A]] = codec.opt + + given [A, B](using ca: Codec[A], cb: Codec[B]): Codec[(A, B)] = ca product cb + + given [H, T <: Tuple](using dh: Codec[H], dt: Codec[T]): Codec[H *: T] = + dh.product(dt).imap { case (h, t) => h *: t }(tuple => (tuple.head, tuple.tail)) + + given [P <: Product](using mirror: Mirror.ProductOf[P], codec: Codec[mirror.MirroredElemTypes]): Codec[P] = + codec.to[P] diff --git a/module/ldbc-dsl/src/main/scala/ldbc/dsl/codec/Decoder.scala b/module/ldbc-dsl/src/main/scala/ldbc/dsl/codec/Decoder.scala index ee2559877..de617741b 100644 --- a/module/ldbc-dsl/src/main/scala/ldbc/dsl/codec/Decoder.scala +++ b/module/ldbc-dsl/src/main/scala/ldbc/dsl/codec/Decoder.scala @@ -6,8 +6,6 @@ package ldbc.dsl.codec -import java.time.* - import scala.deriving.Mirror import cats.Applicative @@ -17,7 +15,7 @@ import org.typelevel.twiddles.TwiddleSyntax import ldbc.sql.ResultSet /** - * Class to get the DataType that matches the Scala type information from the ResultSet. + * Trait to get the DataType that matches the Scala type information from the ResultSet. * * @tparam A * Scala types that match SQL DataType @@ -71,28 +69,7 @@ object Decoder extends TwiddleSyntax[Decoder]: override def offset: Int = 0 override def decode(resultSet: ResultSet, index: Int): A = x - given Decoder[String] = (resultSet: ResultSet, index: Int) => resultSet.getString(index) - given Decoder[Boolean] = (resultSet: ResultSet, index: Int) => resultSet.getBoolean(index) - given Decoder[Byte] = (resultSet: ResultSet, index: Int) => resultSet.getByte(index) - given Decoder[Array[Byte]] = (resultSet: ResultSet, index: Int) => resultSet.getBytes(index) - given Decoder[Short] = (resultSet: ResultSet, index: Int) => resultSet.getShort(index) - given Decoder[Int] = (resultSet: ResultSet, index: Int) => resultSet.getInt(index) - given Decoder[Long] = (resultSet: ResultSet, index: Int) => resultSet.getLong(index) - given Decoder[Float] = (resultSet: ResultSet, index: Int) => resultSet.getFloat(index) - given Decoder[Double] = (resultSet: ResultSet, index: Int) => resultSet.getDouble(index) - given Decoder[LocalDate] = (resultSet: ResultSet, index: Int) => resultSet.getDate(index) - given Decoder[LocalTime] = (resultSet: ResultSet, index: Int) => resultSet.getTime(index) - given Decoder[LocalDateTime] = (resultSet: ResultSet, index: Int) => resultSet.getTimestamp(index) - given Decoder[BigDecimal] = (resultSet: ResultSet, index: Int) => resultSet.getBigDecimal(index) - - given (using decoder: Decoder[String]): Decoder[BigInt] = - decoder.map(str => if str == null then null else BigInt(str)) - - given (using decoder: Decoder[Int]): Decoder[Year] = - decoder.map(int => Year.of(int)) - - given (using decoder: Decoder[String]): Decoder[YearMonth] = - decoder.map(str => YearMonth.parse(str)) + given [A](using codec: Codec[A]): Decoder[A] = codec.asDecoder given [A](using decoder: Decoder[A]): Decoder[Option[A]] = decoder.opt diff --git a/module/ldbc-dsl/src/main/scala/ldbc/dsl/codec/Encoder.scala b/module/ldbc-dsl/src/main/scala/ldbc/dsl/codec/Encoder.scala index d607e3e2e..4d6a87718 100644 --- a/module/ldbc-dsl/src/main/scala/ldbc/dsl/codec/Encoder.scala +++ b/module/ldbc-dsl/src/main/scala/ldbc/dsl/codec/Encoder.scala @@ -58,73 +58,7 @@ object Encoder extends TwiddleSyntax[Encoder]: override def contramap[A, B](fa: Encoder[A])(f: B => A): Encoder[B] = fa.contramap(f) override def product[A, B](fa: Encoder[A], fb: Encoder[B]): Encoder[(A, B)] = fa.product(fb) - given Encoder[Boolean] with - override def encode(value: Boolean): Encoded = - Encoded.success(List(value)) - - given Encoder[Byte] with - override def encode(value: Byte): Encoded = - Encoded.success(List(value)) - - given Encoder[Short] with - override def encode(value: Short): Encoded = - Encoded.success(List(value)) - - given Encoder[Int] with - override def encode(value: Int): Encoded = - Encoded.success(List(value)) - - given Encoder[Long] with - override def encode(value: Long): Encoded = - Encoded.success(List(value)) - - given Encoder[Float] with - override def encode(value: Float): Encoded = - Encoded.success(List(value)) - - given Encoder[Double] with - override def encode(value: Double): Encoded = - Encoded.success(List(value)) - - given Encoder[BigDecimal] with - override def encode(value: BigDecimal): Encoded = - Encoded.success(List(value)) - - given Encoder[String] with - override def encode(value: String): Encoded = - Encoded.success(List(value)) - - given Encoder[Array[Byte]] with - override def encode(value: Array[Byte]): Encoded = - Encoded.success(List(value)) - - given Encoder[LocalTime] with - override def encode(value: LocalTime): Encoded = - Encoded.success(List(value)) - - given Encoder[LocalDate] with - override def encode(value: LocalDate): Encoded = - Encoded.success(List(value)) - - given Encoder[LocalDateTime] with - override def encode(value: LocalDateTime): Encoded = - Encoded.success(List(value)) - - given (using encoder: Encoder[String]): Encoder[Year] = encoder.contramap(_.toString) - - given (using encoder: Encoder[String]): Encoder[YearMonth] = encoder.contramap(_.toString) - - given (using encoder: Encoder[String]): Encoder[BigInt] = encoder.contramap(_.toString) - - given Encoder[None.type] with - override def encode(value: None.type): Encoded = - Encoded.success(List(None)) - - given [A](using encoder: Encoder[A]): Encoder[Option[A]] with - override def encode(value: Option[A]): Encoded = - value match - case Some(value) => encoder.encode(value) - case None => Encoded.success(List(None)) + given [A](using codec: Codec[A]): Encoder[A] = codec.asEncoder given [A, B](using ea: Encoder[A], eb: Encoder[B]): Encoder[(A, B)] = ea.product(eb) diff --git a/module/ldbc-query-builder/src/main/scala/ldbc/query/builder/Table.scala b/module/ldbc-query-builder/src/main/scala/ldbc/query/builder/Table.scala index 33761c890..69fed3b92 100644 --- a/module/ldbc-query-builder/src/main/scala/ldbc/query/builder/Table.scala +++ b/module/ldbc-query-builder/src/main/scala/ldbc/query/builder/Table.scala @@ -53,7 +53,7 @@ object Table: private[ldbc] case class Impl[P <: Product]( $name: String, columns: List[Column[?]] - )(using decoder: Decoder[P], encoder: Encoder[P]) + )(using codec: Codec[P]) extends Table[P]: override def statement: String = $name @@ -63,8 +63,8 @@ object Table: Column.Impl[P]( columns.map(_.name).mkString(", "), if alias.isEmpty then None else Some(alias), - decoder, - encoder, + codec.asDecoder, + codec.asEncoder, Some(columns.length), Some(columns.map(column => s"${ column.name } = ?").mkString(", ")) ) @@ -87,12 +87,8 @@ object Table: case Some(naming) => naming case None => '{ Naming.SNAKE } - val decoder = Expr.summon[Decoder[P]].getOrElse { - report.errorAndAbort(s"Decoder for type $tpe not found") - } - - val encoder = Expr.summon[Encoder[P]].getOrElse { - report.errorAndAbort(s"Encoder for type $tpe not found") + val codec = Expr.summon[Codec[P]].getOrElse { + report.errorAndAbort(s"Codec for type $tpe not found") } val labels = symbol.primaryConstructor.paramSymss.flatten @@ -112,13 +108,9 @@ object Table: case ValDef(name, tpt, _) => tpt.tpe.asType match case '[tpe] => - val decoder = Expr.summon[Decoder[tpe]].getOrElse { - report.errorAndAbort(s"Decoder for type $tpe not found") - } - val encoder = Expr.summon[Encoder[tpe]].getOrElse { - report.errorAndAbort(s"Encoder for type $tpe not found") + Expr.summon[Codec[tpe]].getOrElse { + report.errorAndAbort(s"Codec for type $tpe not found") } - '{ ($decoder, $encoder) } case _ => report.errorAndAbort(s"Type $tpt is not a type") } @@ -130,8 +122,8 @@ object Table: ${ Expr.ofSeq(labels) } .zip($codecs) .map { - case (label: String, codec: (Decoder[t], Encoder[?])) => - Column[t](label, $naming.format($name))(using codec._1, codec._2.asInstanceOf[Encoder[t]]) + case (label: String, codec: Codec[t]) => + Column[t](label, $naming.format($name))(using codec.asDecoder, codec.asEncoder) } .toList } @@ -140,7 +132,7 @@ object Table: Impl[P]( $naming.format($name), $columns - )(using $decoder, $encoder) + )(using $codec) } private def derivedWithNameImpl[P <: Product](name: Expr[String])(using @@ -158,12 +150,8 @@ object Table: case Some(naming) => naming case None => '{ Naming.SNAKE } - val decoder = Expr.summon[Decoder[P]].getOrElse { - report.errorAndAbort(s"Decoder for type $tpe not found") - } - - val encoder = Expr.summon[Encoder[P]].getOrElse { - report.errorAndAbort(s"Encoder for type $tpe not found") + val codec = Expr.summon[Codec[P]].getOrElse { + report.errorAndAbort(s"Codec for type $tpe not found") } val labels = symbol.primaryConstructor.paramSymss.flatten @@ -183,13 +171,9 @@ object Table: case ValDef(name, tpt, _) => tpt.tpe.asType match case '[tpe] => - val decoder = Expr.summon[Decoder[tpe]].getOrElse { - report.errorAndAbort(s"Decoder for type $tpe not found") - } - val encoder = Expr.summon[Encoder[tpe]].getOrElse { - report.errorAndAbort(s"Encoder for type $tpe not found") + Expr.summon[Codec[tpe]].getOrElse { + report.errorAndAbort(s"Codec for type $tpe not found") } - '{ ($decoder, $encoder) } case _ => report.errorAndAbort(s"Type $tpt is not a type") } @@ -199,8 +183,8 @@ object Table: ${ Expr.ofSeq(labels) } .zip($codecs) .map { - case (label: String, codec: (Decoder[t], Encoder[?])) => - Column[t](label, $name)(using codec._1, codec._2.asInstanceOf[Encoder[t]]) + case (label: String, codec: Codec[t]) => + Column[t](label, $name)(using codec.asDecoder, codec.asEncoder) } .toList } @@ -209,7 +193,7 @@ object Table: Impl[P]( $name, $columns - )(using $decoder, $encoder) + )(using $codec) } trait Opt[P] extends SharedTable, AbstractTable.Opt[P]: diff --git a/module/ldbc-schema/src/main/scala/ldbc/schema/Table.scala b/module/ldbc-schema/src/main/scala/ldbc/schema/Table.scala index d64855ae1..f16f2c32e 100644 --- a/module/ldbc-schema/src/main/scala/ldbc/schema/Table.scala +++ b/module/ldbc-schema/src/main/scala/ldbc/schema/Table.scala @@ -9,7 +9,7 @@ package ldbc.schema import scala.language.dynamics import scala.deriving.Mirror -import ldbc.dsl.codec.{ Decoder, Encoder } +import ldbc.dsl.codec.Codec import ldbc.statement.{ AbstractTable, Column } import ldbc.schema.interpreter.* import ldbc.schema.attribute.Attribute @@ -18,20 +18,16 @@ trait Table[T](val $name: String) extends AbstractTable[T]: type Column[A] = ldbc.statement.Column[A] - protected final def column[A](name: String)(using decoder: Decoder[A], encoder: Encoder[A]): Column[A] = - ColumnImpl[A](name, Some(s"${ $name }.$name"), decoder, encoder, None, List.empty) + protected final def column[A](name: String)(using codec: Codec[A]): Column[A] = + ColumnImpl[A](name, Some(s"${ $name }.$name"), codec.asDecoder, codec.asEncoder, None, List.empty) - protected final def column[A](name: String, dataType: DataType[A])(using - decoder: Decoder[A], - encoder: Encoder[A] - ): Column[A] = - ColumnImpl[A](name, Some(s"${ $name }.$name"), decoder, encoder, Some(dataType), List.empty) + protected final def column[A](name: String, dataType: DataType[A])(using codec: Codec[A]): Column[A] = + ColumnImpl[A](name, Some(s"${ $name }.$name"), codec.asDecoder, codec.asEncoder, Some(dataType), List.empty) protected final def column[A](name: String, dataType: DataType[A], attributes: Attribute[A]*)(using - decoder: Decoder[A], - encoder: Encoder[A] + codec: Codec[A] ): Column[A] = - ColumnImpl[A](name, Some(s"${ $name }.$name"), decoder, encoder, Some(dataType), attributes.toList) + ColumnImpl[A](name, Some(s"${ $name }.$name"), codec.asDecoder, codec.asEncoder, Some(dataType), attributes.toList) /** * Methods for setting key information for tables. diff --git a/module/ldbc-statement/src/main/scala/ldbc/statement/Column.scala b/module/ldbc-statement/src/main/scala/ldbc/statement/Column.scala index fae44db95..e70b0a84b 100644 --- a/module/ldbc-statement/src/main/scala/ldbc/statement/Column.scala +++ b/module/ldbc-statement/src/main/scala/ldbc/statement/Column.scala @@ -14,7 +14,7 @@ import org.typelevel.twiddles.TwiddleSyntax import ldbc.sql.ResultSet import ldbc.dsl.* -import ldbc.dsl.codec.{ Encoder, Decoder } +import ldbc.dsl.codec.* import ldbc.statement.interpreter.Extract import ldbc.statement.Expression.* @@ -59,7 +59,7 @@ trait Column[A]: def opt: Column[Option[A]] = Column.Opt[A](name, alias, decoder, encoder) - def count(using Decoder[Int]): Column.Count = Column.Count(name, alias) + def count(using Decoder[Int], Encoder[Int]): Column.Count = Column.Count(name, alias) def asc: OrderBy.Order[A] = OrderBy.Order.asc(this) def desc: OrderBy.Order[A] = OrderBy.Order.desc(this) @@ -503,7 +503,7 @@ trait Column[A]: @targetName("_bitFlip") def ~(value: Extract[A])(using Encoder[Extract[A]]): BitFlip[A] = bitFlip(value) - def combine(other: Column[A])(using Decoder[A], Encoder[A]): Column.MultiColumn[A] = + def combine(other: Column[A])(using Codec[A]): Column.MultiColumn[A] = Column.MultiColumn[A]("+", this, other) /** @@ -520,9 +520,9 @@ trait Column[A]: * A query to combine columns in a SELECT statement */ @targetName("_combine") - def ++(other: Column[A])(using Decoder[A], Encoder[A]): Column.MultiColumn[A] = combine(other) + def ++(other: Column[A])(using Codec[A]): Column.MultiColumn[A] = combine(other) - def deduct(other: Column[A])(using Decoder[A], Encoder[A]): Column.MultiColumn[A] = + def deduct(other: Column[A])(using Codec[A]): Column.MultiColumn[A] = Column.MultiColumn[A]("-", this, other) /** @@ -539,9 +539,9 @@ trait Column[A]: * A query to subtract columns in a SELECT statement */ @targetName("_deduct") - def --(other: Column[A])(using Decoder[A], Encoder[A]): Column.MultiColumn[A] = deduct(other) + def --(other: Column[A])(using Codec[A]): Column.MultiColumn[A] = deduct(other) - def multiply(other: Column[A])(using Decoder[A], Encoder[A]): Column.MultiColumn[A] = + def multiply(other: Column[A])(using Codec[A]): Column.MultiColumn[A] = Column.MultiColumn[A]("*", this, other) /** @@ -558,9 +558,9 @@ trait Column[A]: * A query to multiply columns in a SELECT statement */ @targetName("_multiply") - def *(other: Column[A])(using Decoder[A], Encoder[A]): Column.MultiColumn[A] = multiply(other) + def *(other: Column[A])(using Codec[A]): Column.MultiColumn[A] = multiply(other) - def smash(other: Column[A])(using Decoder[A], Encoder[A]): Column.MultiColumn[A] = + def smash(other: Column[A])(using Codec[A]): Column.MultiColumn[A] = Column.MultiColumn[A]("/", this, other) /** @@ -577,7 +577,7 @@ trait Column[A]: * A query to divide columns in a SELECT statement */ @targetName("_smash") - def /(other: Column[A])(using Decoder[A], Encoder[A]): Column.MultiColumn[A] = smash(other) + def /(other: Column[A])(using Codec[A]): Column.MultiColumn[A] = smash(other) /** List of sub query methods */ def _equals(value: SQL): SubQuery[A] = @@ -740,28 +740,25 @@ object Column extends TwiddleSyntax[Column]: flag: String, left: Column[A], right: Column[A] - )(using _decoder: Decoder[A], _encoder: Encoder[A]) + )(using codec: Codec[A]) extends Column[A]: override def name: String = s"${ left.noBagQuotLabel } $flag ${ right.noBagQuotLabel }" override def alias: Option[String] = Some( s"${ left.alias.getOrElse(left.name) } $flag ${ right.alias.getOrElse(right.name) }" ) override def as(name: String): Column[A] = this - override def decoder: Decoder[A] = _decoder - override def encoder: Encoder[A] = _encoder + override def decoder: Decoder[A] = codec.asDecoder + override def encoder: Encoder[A] = codec.asEncoder override def insertStatement: String = "" override def updateStatement: String = "" override def duplicateKeyUpdateStatement: String = "" - private[ldbc] case class Count(_name: String, _alias: Option[String])(using - _decoder: Decoder[Int], - _encoder: Encoder[Int] - ) extends Column[Int]: + private[ldbc] case class Count(_name: String, _alias: Option[String])(using codec: Codec[Int]) extends Column[Int]: override def name: String = s"COUNT($_name)" override def alias: Option[String] = _alias.map(a => s"COUNT($a)") override def as(name: String): Column[Int] = this.copy(s"$name.${ _name }") - override def decoder: Decoder[Int] = _decoder - override def encoder: Encoder[Int] = _encoder + override def decoder: Decoder[Int] = codec.asDecoder + override def encoder: Encoder[Int] = codec.asEncoder override def toString: String = name override def insertStatement: String = "" override def updateStatement: String = "" diff --git a/tests/src/main/scala/ldbc/tests/model/Country.scala b/tests/src/main/scala/ldbc/tests/model/Country.scala index 3a5cecbe8..c3c20298b 100644 --- a/tests/src/main/scala/ldbc/tests/model/Country.scala +++ b/tests/src/main/scala/ldbc/tests/model/Country.scala @@ -7,7 +7,7 @@ package ldbc.tests.model import ldbc.dsl.* -import ldbc.dsl.codec.{ Encoder, Decoder } +import ldbc.dsl.codec.Codec import ldbc.query.builder.Table import ldbc.schema.Table as SchemaTable @@ -42,21 +42,13 @@ object Country: override def toString: String = value - given Encoder[Continent] = Encoder[String].contramap(_.value) + given Codec[Continent] = Codec[String].imap(str => Continent.valueOf(str.replace(" ", "_")))(_.value) - given Decoder[Continent] = Decoder[String].map(str => Continent.valueOf(str.replace(" ", "_"))) - - given Encoder[Country] = ( - Encoder[String] *: Encoder[String] *: Encoder[Continent] *: Encoder[String] *: Encoder[BigDecimal] *: - Encoder[Option[Short]] *: Encoder[Int] *: Encoder[Option[BigDecimal]] *: Encoder[Option[BigDecimal]] *: - Encoder[Option[BigDecimal]] *: Encoder[String] *: Encoder[String] *: Encoder[Option[String]] *: - Encoder[Option[Int]] *: Encoder[String] - ).to[Country] - given Decoder[Country] = ( - Decoder[String] *: Decoder[String] *: Decoder[Continent] *: Decoder[String] *: Decoder[BigDecimal] *: - Decoder[Option[Short]] *: Decoder[Int] *: Decoder[Option[BigDecimal]] *: Decoder[Option[BigDecimal]] *: - Decoder[Option[BigDecimal]] *: Decoder[String] *: Decoder[String] *: Decoder[Option[String]] *: - Decoder[Option[Int]] *: Decoder[String] + given Codec[Country] = ( + Codec[String] *: Codec[String] *: Codec[Continent] *: Codec[String] *: Codec[BigDecimal] *: + Codec[Option[Short]] *: Codec[Int] *: Codec[Option[BigDecimal]] *: Codec[Option[BigDecimal]] *: + Codec[Option[BigDecimal]] *: Codec[String] *: Codec[String] *: Codec[Option[String]] *: + Codec[Option[Int]] *: Codec[String] ).to[Country] given Table[Country] = Table.derived[Country]("country") diff --git a/tests/src/main/scala/ldbc/tests/model/CountryLanguage.scala b/tests/src/main/scala/ldbc/tests/model/CountryLanguage.scala index d1f56a6f9..d6e1611e2 100644 --- a/tests/src/main/scala/ldbc/tests/model/CountryLanguage.scala +++ b/tests/src/main/scala/ldbc/tests/model/CountryLanguage.scala @@ -7,7 +7,7 @@ package ldbc.tests.model import ldbc.dsl.* -import ldbc.dsl.codec.{ Encoder, Decoder } +import ldbc.dsl.codec.Codec import ldbc.query.builder.Table import ldbc.schema.Table as SchemaTable @@ -25,9 +25,7 @@ object CountryLanguage: object IsOfficial - given Encoder[IsOfficial] = Encoder[String].contramap(_.toString) - - given Decoder[IsOfficial] = Decoder[String].map(IsOfficial.valueOf) + given Codec[IsOfficial] = Codec[String].imap(IsOfficial.valueOf)(_.toString) given Table[CountryLanguage] = Table.derived[CountryLanguage]("countrylanguage")