From 209c864676bc7293057b77996d63c7484e88e5c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Fromentin?= Date: Fri, 13 Sep 2024 15:22:59 +0200 Subject: [PATCH] refactor: Remove type parameter from decodeTerm --- .../iltotore/iron/macros/ReflectUtil.scala | 52 ++++++++++--------- 1 file changed, 28 insertions(+), 24 deletions(-) diff --git a/main/src/io/github/iltotore/iron/macros/ReflectUtil.scala b/main/src/io/github/iltotore/iron/macros/ReflectUtil.scala index 8975e17..d50db97 100644 --- a/main/src/io/github/iltotore/iron/macros/ReflectUtil.scala +++ b/main/src/io/github/iltotore/iron/macros/ReflectUtil.scala @@ -1,6 +1,7 @@ package io.github.iltotore.iron.macros import scala.quoted.* +import io.github.iltotore.iron.compileTime.NumConstant /** * Low AST related utils. @@ -20,13 +21,17 @@ class ReflectUtil[Q <: Quotes & Singleton](using val _quotes: Q): import _quotes.reflect.* + type DecodingResult[+T] = Either[DecodingFailure, T] + extension [T](result: DecodingResult[T]) + private def as[U]: DecodingResult[U] = result.asInstanceOf[Either[DecodingFailure, U]] + extension [T: Type](expr: Expr[T]) /** * Decode this expression. * * @return the value of this expression found at compile time or a [[DecodingFailure]] */ - def decode: Either[DecodingFailure, T] = ExprDecoder.decodeTerm(expr.asTerm, Map.empty) + def decode: Either[DecodingFailure, T] = ExprDecoder.decodeTerm(expr.asTerm, Map.empty).as[T] /** * A decoding failure. @@ -192,10 +197,9 @@ class ReflectUtil[Q <: Quotes & Singleton](using val _quotes: Q): * * @param tree the term to decode * @param definitions the decoded definitions in scope - * @tparam T the expected type of this term used as implicit cast for convenience * @return the value of the given term found at compile time or a [[DecodingFailure]] */ - def decodeTerm[T](tree: Term, definitions: Map[String, ?]): Either[DecodingFailure, T] = + def decodeTerm(tree: Term, definitions: Map[String, ?]): Either[DecodingFailure, ?] = val specializedResult = enhancedDecoders .collectFirst: case (k, v) if k =:= tree.tpe => v @@ -204,7 +208,7 @@ class ReflectUtil[Q <: Quotes & Singleton](using val _quotes: Q): specializedResult match case Left(DecodingFailure.Unknown) => decodeUnspecializedTerm(tree, definitions) - case result => result.asInstanceOf[Either[DecodingFailure, T]] + case result => result /** * Decode a term using only unspecialized cases. @@ -214,7 +218,7 @@ class ReflectUtil[Q <: Quotes & Singleton](using val _quotes: Q): * @tparam T the expected type of this term used as implicit cast for convenience * @return the value of the given term found at compile time or a [[DecodingFailure]] */ - def decodeUnspecializedTerm[T](tree: Term, definitions: Map[String, ?]): Either[DecodingFailure, T] = + def decodeUnspecializedTerm(tree: Term, definitions: Map[String, ?]): Either[DecodingFailure, ?] = tree match case block @ Block(stats, e) => if stats.isEmpty then decodeTerm(e, definitions) else Left(DecodingFailure.HasStatements(block)) @@ -225,14 +229,14 @@ class ReflectUtil[Q <: Quotes & Singleton](using val _quotes: Q): case (name, Right(value)) => Right((name, value)) case (name, Left(failure)) => Left((name, failure)) - (failures, decodeTerm[T](e, definitions ++ values.toMap)) match + (failures, decodeTerm(e, definitions ++ values.toMap)) match case (_, Right(value)) => Right(value) case (Nil, Left(failure)) => Left(failure) case (failures, Left(_)) => Left(DecodingFailure.HasBindings(failures)) - case Apply(Select(left, "=="), List(right)) => (decodeTerm[Any](left, definitions), decodeTerm[Any](right, definitions)) match - case (Right(leftValue), Right(rightValue)) => Right((leftValue == rightValue).asInstanceOf[T]) + case Apply(Select(left, "=="), List(right)) => (decodeTerm(left, definitions), decodeTerm(right, definitions)) match + case (Right(leftValue), Right(rightValue)) => Right((leftValue == rightValue)) case (leftResult, rightResult) => Left(DecodingFailure.ApplyNotInlined("==", List(leftResult, rightResult))) case Apply(Select(leftOperand, name), operands) => @@ -255,18 +259,17 @@ class ReflectUtil[Q <: Quotes & Singleton](using val _quotes: Q): result if hasFailure then Left(DecodingFailure.VarArgsNotInlined(results)) - else Right(results.map(_.getOrElse((???): String)).asInstanceOf[T]) + else Right(results.map(_.getOrElse((???): String))) case Typed(e, _) => decodeTerm(e, definitions) case _ => tree.tpe.widenTermRefByName match - case ConstantType(c) => Right(c.value.asInstanceOf[T]) + case ConstantType(c) => Right(c.value) case _ => tree match case Ident(name) => definitions .get(name) .toRight(DecodingFailure.NotInlined(tree)) - .asInstanceOf[Either[DecodingFailure, T]] case _ => Left(DecodingFailure.NotInlined(tree)) @@ -278,7 +281,7 @@ class ReflectUtil[Q <: Quotes & Singleton](using val _quotes: Q): * @tparam T the expected type of this term used as implicit cast for convenience * @return the value of the given definition found at compile time or a [[DecodingFailure]] */ - def decodeBinding[T](definition: Definition, definitions: Map[String, ?]): Either[DecodingFailure, T] = definition match + def decodeBinding(definition: Definition, definitions: Map[String, ?]): Either[DecodingFailure, ?] = definition match case ValDef(name, tpeTree, Some(term)) => decodeTerm(term, definitions) case DefDef(name, Nil, tpeTree, Some(term)) => decodeTerm(term, definitions) case _ => Left(DecodingFailure.DefinitionNotInlined(definition.name)) @@ -290,16 +293,16 @@ class ReflectUtil[Q <: Quotes & Singleton](using val _quotes: Q): * @param definitions the decoded definitions in scope * @return the value of the given term found at compile time or a [[DecodingFailure]] */ - def decodeBoolean(term: Term, definitions: Map[String, ?]): Either[DecodingFailure, Boolean] = term match + def decodeBoolean(term: Term, definitions: Map[String, ?]): Either[DecodingFailure, ?] = term match case Apply(Select(left, "||"), List(right)) if left.tpe <:< TypeRepr.of[Boolean] && right.tpe <:< TypeRepr.of[Boolean] => // OR - (decodeTerm[Boolean](left, definitions), decodeTerm[Boolean](right, definitions)) match + (decodeTerm(left, definitions).as[Boolean], decodeTerm(right, definitions).as[Boolean]) match case (Right(true), _) => Right(true) case (_, Right(true)) => Right(true) case (Right(leftValue), Right(rightValue)) => Right(leftValue || rightValue) case (leftResult, rightResult) => Left(DecodingFailure.OrNotInlined(leftResult, rightResult)) case Apply(Select(left, "&&"), List(right)) if left.tpe <:< TypeRepr.of[Boolean] && right.tpe <:< TypeRepr.of[Boolean] => // AND - (decodeTerm[Boolean](left, definitions), decodeTerm[Boolean](right, definitions)) match + (decodeTerm(left, definitions).as[Boolean], decodeTerm(right, definitions).as[Boolean]) match case (Right(false), _) => Right(false) case (_, Right(false)) => Right(false) case (Right(leftValue), Right(rightValue)) => Right(leftValue && rightValue) @@ -316,7 +319,7 @@ class ReflectUtil[Q <: Quotes & Singleton](using val _quotes: Q): */ def decodeString(term: Term, definitions: Map[String, ?]): Either[DecodingFailure, String] = term match case Apply(Select(left, "+"), List(right)) if left.tpe <:< TypeRepr.of[String] && right.tpe <:< TypeRepr.of[String] => - (decodeTerm[String](left, definitions), decodeTerm[String](right, definitions)) match + (decodeTerm(left, definitions).as[String], decodeTerm(right, definitions).as[String]) match case (Right(leftValue), Right(rightValue)) => Right(leftValue + rightValue) case (Left(DecodingFailure.StringPartsNotInlined(lparts)), Left(DecodingFailure.StringPartsNotInlined(rparts))) => Left(DecodingFailure.StringPartsNotInlined(lparts ++ rparts)) @@ -338,9 +341,9 @@ class ReflectUtil[Q <: Quotes & Singleton](using val _quotes: Q): def decodeBigInt(term: Term, definitions: Map[String, ?]): Either[DecodingFailure, BigInt] = term match case Apply(Select(Ident("BigInt"), "apply"), List(value)) => - if value.tpe <:< TypeRepr.of[Int] then decodeTerm[Int](value, definitions).map(BigInt.apply) - else if value.tpe <:< TypeRepr.of[Long] then decodeTerm[Long](value, definitions).map(BigInt.apply) - else Left(DecodingFailure.Unknown) + decodeTerm(value, definitions).as[Int | Long].map: + case x: Int => BigInt(x) + case x: Long => BigInt(x) case _ => Left(DecodingFailure.Unknown) /** @@ -353,9 +356,10 @@ class ReflectUtil[Q <: Quotes & Singleton](using val _quotes: Q): def decodeBigDecimal(term: Term, definitions: Map[String, ?]): Either[DecodingFailure, BigDecimal] = term match case Apply(Select(Ident("BigDecimal"), "apply"), List(value)) => - if value.tpe <:< TypeRepr.of[Int] then decodeTerm[Int](value, definitions).map(BigDecimal.apply) - else if value.tpe <:< TypeRepr.of[Long] then decodeTerm[Long](value, definitions).map(BigDecimal.apply) - else if value.tpe <:< TypeRepr.of[Double] then decodeTerm[Double](value, definitions).map(BigDecimal.apply) - else if value.tpe <:< TypeRepr.of[BigInt] then decodeTerm[BigInt](value, definitions).map(BigDecimal.apply) - else Left(DecodingFailure.Unknown) + decodeTerm(value, definitions).as[NumConstant].map: + case x: Int => BigDecimal(x) + case x: Long => BigDecimal(x) + case x: Float => BigDecimal(x) + case x: Double => BigDecimal(x) + case _ => Left(DecodingFailure.Unknown)