Skip to content

Commit

Permalink
refactor: Remove type parameter from decodeTerm
Browse files Browse the repository at this point in the history
  • Loading branch information
Iltotore committed Sep 13, 2024
1 parent 51d8c46 commit 209c864
Showing 1 changed file with 28 additions and 24 deletions.
52 changes: 28 additions & 24 deletions main/src/io/github/iltotore/iron/macros/ReflectUtil.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.github.iltotore.iron.macros

import scala.quoted.*
import io.github.iltotore.iron.compileTime.NumConstant

/**
* Low AST related utils.
Expand All @@ -20,13 +21,17 @@ class ReflectUtil[Q <: Quotes & Singleton](using val _quotes: Q):

import _quotes.reflect.*

type DecodingResult[+T] = Either[DecodingFailure, T]
extension [T](result: DecodingResult[T])
private def as[U]: DecodingResult[U] = result.asInstanceOf[Either[DecodingFailure, U]]

extension [T: Type](expr: Expr[T])
/**
* Decode this expression.
*
* @return the value of this expression found at compile time or a [[DecodingFailure]]
*/
def decode: Either[DecodingFailure, T] = ExprDecoder.decodeTerm(expr.asTerm, Map.empty)
def decode: Either[DecodingFailure, T] = ExprDecoder.decodeTerm(expr.asTerm, Map.empty).as[T]

/**
* A decoding failure.
Expand Down Expand Up @@ -192,10 +197,9 @@ class ReflectUtil[Q <: Quotes & Singleton](using val _quotes: Q):
*
* @param tree the term to decode
* @param definitions the decoded definitions in scope
* @tparam T the expected type of this term used as implicit cast for convenience
* @return the value of the given term found at compile time or a [[DecodingFailure]]
*/
def decodeTerm[T](tree: Term, definitions: Map[String, ?]): Either[DecodingFailure, T] =
def decodeTerm(tree: Term, definitions: Map[String, ?]): Either[DecodingFailure, ?] =
val specializedResult = enhancedDecoders
.collectFirst:
case (k, v) if k =:= tree.tpe => v
Expand All @@ -204,7 +208,7 @@ class ReflectUtil[Q <: Quotes & Singleton](using val _quotes: Q):

specializedResult match
case Left(DecodingFailure.Unknown) => decodeUnspecializedTerm(tree, definitions)
case result => result.asInstanceOf[Either[DecodingFailure, T]]
case result => result

/**
* Decode a term using only unspecialized cases.
Expand All @@ -214,7 +218,7 @@ class ReflectUtil[Q <: Quotes & Singleton](using val _quotes: Q):
* @tparam T the expected type of this term used as implicit cast for convenience
* @return the value of the given term found at compile time or a [[DecodingFailure]]
*/
def decodeUnspecializedTerm[T](tree: Term, definitions: Map[String, ?]): Either[DecodingFailure, T] =
def decodeUnspecializedTerm(tree: Term, definitions: Map[String, ?]): Either[DecodingFailure, ?] =
tree match
case block @ Block(stats, e) => if stats.isEmpty then decodeTerm(e, definitions) else Left(DecodingFailure.HasStatements(block))

Expand All @@ -225,14 +229,14 @@ class ReflectUtil[Q <: Quotes & Singleton](using val _quotes: Q):
case (name, Right(value)) => Right((name, value))
case (name, Left(failure)) => Left((name, failure))

(failures, decodeTerm[T](e, definitions ++ values.toMap)) match
(failures, decodeTerm(e, definitions ++ values.toMap)) match
case (_, Right(value)) =>
Right(value)
case (Nil, Left(failure)) => Left(failure)
case (failures, Left(_)) => Left(DecodingFailure.HasBindings(failures))

case Apply(Select(left, "=="), List(right)) => (decodeTerm[Any](left, definitions), decodeTerm[Any](right, definitions)) match
case (Right(leftValue), Right(rightValue)) => Right((leftValue == rightValue).asInstanceOf[T])
case Apply(Select(left, "=="), List(right)) => (decodeTerm(left, definitions), decodeTerm(right, definitions)) match
case (Right(leftValue), Right(rightValue)) => Right((leftValue == rightValue))
case (leftResult, rightResult) => Left(DecodingFailure.ApplyNotInlined("==", List(leftResult, rightResult)))

case Apply(Select(leftOperand, name), operands) =>
Expand All @@ -255,18 +259,17 @@ class ReflectUtil[Q <: Quotes & Singleton](using val _quotes: Q):
result

if hasFailure then Left(DecodingFailure.VarArgsNotInlined(results))
else Right(results.map(_.getOrElse((???): String)).asInstanceOf[T])
else Right(results.map(_.getOrElse((???): String)))

case Typed(e, _) => decodeTerm(e, definitions)

case _ =>
tree.tpe.widenTermRefByName match
case ConstantType(c) => Right(c.value.asInstanceOf[T])
case ConstantType(c) => Right(c.value)
case _ => tree match
case Ident(name) => definitions
.get(name)
.toRight(DecodingFailure.NotInlined(tree))
.asInstanceOf[Either[DecodingFailure, T]]

case _ => Left(DecodingFailure.NotInlined(tree))

Expand All @@ -278,7 +281,7 @@ class ReflectUtil[Q <: Quotes & Singleton](using val _quotes: Q):
* @tparam T the expected type of this term used as implicit cast for convenience
* @return the value of the given definition found at compile time or a [[DecodingFailure]]
*/
def decodeBinding[T](definition: Definition, definitions: Map[String, ?]): Either[DecodingFailure, T] = definition match
def decodeBinding(definition: Definition, definitions: Map[String, ?]): Either[DecodingFailure, ?] = definition match
case ValDef(name, tpeTree, Some(term)) => decodeTerm(term, definitions)
case DefDef(name, Nil, tpeTree, Some(term)) => decodeTerm(term, definitions)
case _ => Left(DecodingFailure.DefinitionNotInlined(definition.name))
Expand All @@ -290,16 +293,16 @@ class ReflectUtil[Q <: Quotes & Singleton](using val _quotes: Q):
* @param definitions the decoded definitions in scope
* @return the value of the given term found at compile time or a [[DecodingFailure]]
*/
def decodeBoolean(term: Term, definitions: Map[String, ?]): Either[DecodingFailure, Boolean] = term match
def decodeBoolean(term: Term, definitions: Map[String, ?]): Either[DecodingFailure, ?] = term match
case Apply(Select(left, "||"), List(right)) if left.tpe <:< TypeRepr.of[Boolean] && right.tpe <:< TypeRepr.of[Boolean] => // OR
(decodeTerm[Boolean](left, definitions), decodeTerm[Boolean](right, definitions)) match
(decodeTerm(left, definitions).as[Boolean], decodeTerm(right, definitions).as[Boolean]) match
case (Right(true), _) => Right(true)
case (_, Right(true)) => Right(true)
case (Right(leftValue), Right(rightValue)) => Right(leftValue || rightValue)
case (leftResult, rightResult) => Left(DecodingFailure.OrNotInlined(leftResult, rightResult))

case Apply(Select(left, "&&"), List(right)) if left.tpe <:< TypeRepr.of[Boolean] && right.tpe <:< TypeRepr.of[Boolean] => // AND
(decodeTerm[Boolean](left, definitions), decodeTerm[Boolean](right, definitions)) match
(decodeTerm(left, definitions).as[Boolean], decodeTerm(right, definitions).as[Boolean]) match
case (Right(false), _) => Right(false)
case (_, Right(false)) => Right(false)
case (Right(leftValue), Right(rightValue)) => Right(leftValue && rightValue)
Expand All @@ -316,7 +319,7 @@ class ReflectUtil[Q <: Quotes & Singleton](using val _quotes: Q):
*/
def decodeString(term: Term, definitions: Map[String, ?]): Either[DecodingFailure, String] = term match
case Apply(Select(left, "+"), List(right)) if left.tpe <:< TypeRepr.of[String] && right.tpe <:< TypeRepr.of[String] =>
(decodeTerm[String](left, definitions), decodeTerm[String](right, definitions)) match
(decodeTerm(left, definitions).as[String], decodeTerm(right, definitions).as[String]) match
case (Right(leftValue), Right(rightValue)) => Right(leftValue + rightValue)
case (Left(DecodingFailure.StringPartsNotInlined(lparts)), Left(DecodingFailure.StringPartsNotInlined(rparts))) =>
Left(DecodingFailure.StringPartsNotInlined(lparts ++ rparts))
Expand All @@ -338,9 +341,9 @@ class ReflectUtil[Q <: Quotes & Singleton](using val _quotes: Q):
def decodeBigInt(term: Term, definitions: Map[String, ?]): Either[DecodingFailure, BigInt] =
term match
case Apply(Select(Ident("BigInt"), "apply"), List(value)) =>
if value.tpe <:< TypeRepr.of[Int] then decodeTerm[Int](value, definitions).map(BigInt.apply)
else if value.tpe <:< TypeRepr.of[Long] then decodeTerm[Long](value, definitions).map(BigInt.apply)
else Left(DecodingFailure.Unknown)
decodeTerm(value, definitions).as[Int | Long].map:
case x: Int => BigInt(x)
case x: Long => BigInt(x)
case _ => Left(DecodingFailure.Unknown)

/**
Expand All @@ -353,9 +356,10 @@ class ReflectUtil[Q <: Quotes & Singleton](using val _quotes: Q):
def decodeBigDecimal(term: Term, definitions: Map[String, ?]): Either[DecodingFailure, BigDecimal] =
term match
case Apply(Select(Ident("BigDecimal"), "apply"), List(value)) =>
if value.tpe <:< TypeRepr.of[Int] then decodeTerm[Int](value, definitions).map(BigDecimal.apply)
else if value.tpe <:< TypeRepr.of[Long] then decodeTerm[Long](value, definitions).map(BigDecimal.apply)
else if value.tpe <:< TypeRepr.of[Double] then decodeTerm[Double](value, definitions).map(BigDecimal.apply)
else if value.tpe <:< TypeRepr.of[BigInt] then decodeTerm[BigInt](value, definitions).map(BigDecimal.apply)
else Left(DecodingFailure.Unknown)
decodeTerm(value, definitions).as[NumConstant].map:
case x: Int => BigDecimal(x)
case x: Long => BigDecimal(x)
case x: Float => BigDecimal(x)
case x: Double => BigDecimal(x)

case _ => Left(DecodingFailure.Unknown)

0 comments on commit 209c864

Please sign in to comment.