diff --git a/frontends/benchmarks/dotty-specific/valid/ParametricExtensionMetho.scala b/frontends/benchmarks/dotty-specific/valid/ParametricExtensionMetho.scala new file mode 100644 index 000000000..c802167f5 --- /dev/null +++ b/frontends/benchmarks/dotty-specific/valid/ParametricExtensionMetho.scala @@ -0,0 +1,16 @@ +object ParametricExtensionMetho { + sealed trait Opt[+T] + case object Non extends Opt[Nothing] // Where is Oui? + final case class Som[+T](content: T) extends Opt[T] + + extension[T](m: Opt[T]) + def flatMap[U](f: T => Opt[U]): Opt[U] = + m match + case Non => Non + case Som(t) => f(t) + + def test[A, B](a: A, f: A => B): Unit = + assert(Som(a).flatMap(a => Som(f(a))) == Som(f(a))) + assert(Som(a).flatMap(_ => Non) == Non) + assert((Non : Opt[A]).flatMap(a => Som(f(a))) == Non) +} \ No newline at end of file diff --git a/frontends/dotty/src/main/scala/stainless/frontends/dotc/ASTExtractors.scala b/frontends/dotty/src/main/scala/stainless/frontends/dotc/ASTExtractors.scala index e7068314e..a0b5872b5 100644 --- a/frontends/dotty/src/main/scala/stainless/frontends/dotc/ASTExtractors.scala +++ b/frontends/dotty/src/main/scala/stainless/frontends/dotc/ASTExtractors.scala @@ -801,10 +801,34 @@ trait ASTExtractors { canExtractSynthetic(dd.symbol) && !(getAnnotations(tpt.symbol) exists (_._1 == "ignore")) )) => - Some((dd.symbol, dd.leadingTypeParams, dd.termParamss.flatten, tpt.tpe, dd.rhs)) + Some((dd.symbol, allTypeParams(dd), dd.termParamss.flatten, tpt.tpe, dd.rhs)) case _ => None } + + // Get all type parameters of a DefDef. Note that dd.leadingTypeParams will only retrieve the leading ones + // which is insufficient for parametric extension methods since these have type parameters in the "middle" of `paramss`. + // For instance, for the following extension method: + // extension[T](m: Option[T]) + // def map[U](f: U => T): Option[U] = ... + // `paramss` will be as follows: + // List( + // List(TypeDef(T)), + // List(ValDef(m)), + // List(TypeDef(U)), + // List(ValDef(f)), + // ) + // and `d.leadingTypeParams` will only get `T` and miss `U`. + private def allTypeParams(dd: tpd.DefDef): Seq[tpd.TypeDef] = { + def go(paramss: List[tpd.ParamClause], acc: List[tpd.TypeDef]): List[tpd.TypeDef] = { + paramss match { + case Nil => acc + case (tparams@(tparam: tpd.TypeDef) :: _) :: rest => go(rest, acc ++ tparams.asInstanceOf[List[tpd.TypeDef]]) + case _ :: rest => go(rest, acc) + } + } + go(dd.paramss, Nil) + } } /**