From ffeaa5f584a40a3ca0089f56780f334b5f54f15d Mon Sep 17 00:00:00 2001 From: Som Snytt Date: Sat, 20 Mar 2021 15:50:13 -0700 Subject: [PATCH 01/10] Port Scala 2 f-interpolator --- .../dotty/tools/dotc/parsing/Scanners.scala | 2 +- .../tools/dotc/printing/PlainPrinter.scala | 2 +- .../transform/localopt/FormatChecker.scala | 247 ++++++++++++++++++ .../FormatInterpolatorTransform.scala | 195 ++++++++++++++ .../localopt/StringContextChecker.scala | 26 +- .../localopt/StringInterpolatorOpt.scala | 12 +- .../runtime/impl/printers/SourceCode.scala | 2 +- .../dotc/transform/FormatCheckerTest.scala | 82 ++++++ tests/neg/f-interpolator-tests.scala | 11 + tests/neg/fEscapes.check | 4 + tests/neg/fEscapes.scala | 5 + tests/run-macros/f-interpolator-tests.scala | 205 +++++++++++++++ tests/run/t6476.check | 13 +- tests/run/t6476.scala | 10 +- 14 files changed, 786 insertions(+), 30 deletions(-) create mode 100644 compiler/src/dotty/tools/dotc/transform/localopt/FormatChecker.scala create mode 100644 compiler/src/dotty/tools/dotc/transform/localopt/FormatInterpolatorTransform.scala create mode 100644 compiler/test/dotty/tools/dotc/transform/FormatCheckerTest.scala create mode 100644 tests/neg/f-interpolator-tests.scala create mode 100644 tests/neg/fEscapes.check create mode 100644 tests/neg/fEscapes.scala diff --git a/compiler/src/dotty/tools/dotc/parsing/Scanners.scala b/compiler/src/dotty/tools/dotc/parsing/Scanners.scala index 1a2f3cd3d86a..c2c13e899ef4 100644 --- a/compiler/src/dotty/tools/dotc/parsing/Scanners.scala +++ b/compiler/src/dotty/tools/dotc/parsing/Scanners.scala @@ -1251,7 +1251,7 @@ object Scanners { nextChar() } } - val alt = if oct == LF then raw"\n" else f"\u$oct%04x" + val alt = if oct == LF then raw"\n" else f"${"\\"}u$oct%04x" error(s"octal escape literals are unsupported: use $alt instead", start) putChar(oct.toChar) } diff --git a/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala b/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala index d2efbeff2901..197a2e6ded9c 100644 --- a/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala +++ b/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala @@ -542,7 +542,7 @@ class PlainPrinter(_ctx: Context) extends Printer { case '"' => "\\\"" case '\'' => "\\\'" case '\\' => "\\\\" - case _ => if (ch.isControl) f"\u${ch.toInt}%04x" else String.valueOf(ch) + case _ => if (ch.isControl) f"${"\\"}u${ch.toInt}%04x" else String.valueOf(ch) } def toText(const: Constant): Text = const.tag match { diff --git a/compiler/src/dotty/tools/dotc/transform/localopt/FormatChecker.scala b/compiler/src/dotty/tools/dotc/transform/localopt/FormatChecker.scala new file mode 100644 index 000000000000..1cbd73db2ba0 --- /dev/null +++ b/compiler/src/dotty/tools/dotc/transform/localopt/FormatChecker.scala @@ -0,0 +1,247 @@ +package dotty.tools.dotc +package transform.localopt + +import scala.annotation.tailrec +import scala.collection.mutable.{ListBuffer, Stack} +import scala.reflect.{ClassTag, classTag} +import scala.util.chaining.* +import scala.util.matching.Regex.Match + +import java.util.{Calendar, Date, Formattable} + +import StringContextChecker.InterpolationReporter + +/** Formatter string checker. */ +abstract class FormatChecker(using reporter: InterpolationReporter): + + // Pick the first runtime type which the i'th arg can satisfy. + // If conversion is required, implementation must emit it. + def argType(argi: Int, types: ClassTag[?]*): ClassTag[?] + + // count of args, for checking indexes + def argc: Int + + val allFlags = "-#+ 0,(<" + val formatPattern = """%(?:(\d+)\$)?([-#+ 0,(<]+)?(\d+)?(\.\d+)?([tT]?[%a-zA-Z])?""".r + + // ordinal is the regex group index in the format pattern + enum SpecGroup: + case Spec, Index, Flags, Width, Precision, CC + import SpecGroup.* + + /** For N part strings and N-1 args to interpolate, normalize parts and check arg types. + * + * Returns parts, possibly updated with explicit leading "%s", + * and conversions for each arg. + * + * Implementation must emit conversions required by invocations of `argType`. + */ + def checked(parts0: List[String]): (List[String], List[Conversion]) = + val amended = ListBuffer.empty[String] + val convert = ListBuffer.empty[Conversion] + + @tailrec + def loop(parts: List[String], n: Int): Unit = + parts match + case part0 :: more => + def badPart(t: Throwable): String = "".tap(_ => reporter.partError(t.getMessage, index = n, offset = 0)) + val part = try StringContext.processEscapes(part0) catch badPart + val matches = formatPattern.findAllMatchIn(part) + + def insertStringConversion(): Unit = + amended += "%s" + part + convert += Conversion(formatPattern.findAllMatchIn("%s").next(), n) // improve + argType(n-1, classTag[Any]) + def errorLeading(op: Conversion) = op.errorAt(Spec)(s"conversions must follow a splice; ${Conversion.literalHelp}") + def accept(op: Conversion): Unit = + if !op.isLeading then errorLeading(op) + op.accepts(argType(n-1, op.acceptableVariants*)) + amended += part + convert += op + + // after the first part, a leading specifier is required for the interpolated arg; %s is supplied if needed + if n == 0 then amended += part + else if !matches.hasNext then insertStringConversion() + else + val cv = Conversion(matches.next(), n) + if cv.isLiteral then insertStringConversion() + else if cv.isIndexed then + if cv.index.getOrElse(-1) == n then accept(cv) + else + // either some other arg num, or '<' + //c.warning(op.groupPos(Index), "Index is not this arg") + insertStringConversion() + else if !cv.isError then accept(cv) + + // any remaining conversions in this part must be either literals or indexed + while matches.hasNext do + val cv = Conversion(matches.next(), n) + if n == 0 && cv.hasFlag('<') then cv.badFlag('<', "No last arg") + else if !cv.isLiteral && !cv.isIndexed then errorLeading(cv) + + loop(more, n + 1) + case Nil => () + end loop + + loop(parts0, n = 0) + (amended.toList, convert.toList) + end checked + + extension (descriptor: Match) + def at(g: SpecGroup): Int = descriptor.start(g.ordinal) + def offset(g: SpecGroup, i: Int = 0): Int = at(g) + i + def group(g: SpecGroup): Option[String] = Option(descriptor.group(g.ordinal)) + def stringOf(g: SpecGroup): String = group(g).getOrElse("") + def intOf(g: SpecGroup): Option[Int] = group(g).map(_.toInt) + + extension (inline value: Boolean) + inline def or(inline body: => Unit): Boolean = value || { body ; false } + inline def orElse(inline body: => Unit): Boolean = value || { body ; true } + inline def but(inline body: => Unit): Boolean = value && { body ; false } + inline def and(inline body: => Unit): Boolean = value && { body ; true } + + /** A conversion specifier matched in the argi'th string part, + * with `argc` arguments to interpolate. + */ + sealed abstract class Conversion: + // the match for this descriptor + def descriptor: Match + // the part number for reporting errors + def argi: Int + + // the descriptor fields + val index: Option[Int] = descriptor.intOf(Index) + val flags: String = descriptor.stringOf(Flags) + val width: Option[Int] = descriptor.intOf(Width) + val precision: Option[Int] = descriptor.group(Precision).map(_.drop(1).toInt) + val op: String = descriptor.stringOf(CC) + + // the conversion char is the head of the op string (but see DateTimeXn) + val cc: Char = if isError then '?' else op(0) + + def isError: Boolean = false + def isIndexed: Boolean = index.nonEmpty || hasFlag('<') + def isLiteral: Boolean = false + + // descriptor is at index 0 of the part string + def isLeading: Boolean = descriptor.at(Spec) == 0 + + // true if passes. Default checks flags and index + def verify: Boolean = goodFlags && goodIndex + + // is the specifier OK with the given arg + def accepts(arg: ClassTag[?]): Boolean = true + + // what arg type if any does the conversion accept + def acceptableVariants: List[ClassTag[?]] + + // what flags does the conversion accept? defaults to all + protected def okFlags: String = allFlags + + def hasFlag(f: Char) = flags.contains(f) + def hasAnyFlag(fs: String) = fs.exists(hasFlag) + + def badFlag(f: Char, msg: String) = + val i = flags.indexOf(f) match { case -1 => 0 case j => j } + errorAt(Flags, i)(msg) + + def errorAt(g: SpecGroup, i: Int = 0)(msg: String) = reporter.partError(msg, argi, descriptor.offset(g, i)) + def warningAt(g: SpecGroup, i: Int = 0)(msg: String) = reporter.partWarning(msg, argi, descriptor.offset(g, i)) + + def noFlags = flags.isEmpty or errorAt(Flags)("flags not allowed") + def noWidth = width.isEmpty or errorAt(Width)("width not allowed") + def noPrecision = precision.isEmpty or errorAt(Precision)("precision not allowed") + def only_-(msg: String) = + val badFlags = flags.filterNot { case '-' | '<' => true case _ => false } + badFlags.isEmpty or badFlag(badFlags(0), s"Only '-' allowed for $msg") + def goodFlags = + val badFlags = flags.filterNot(okFlags.contains) + for f <- badFlags do badFlag(f, s"Illegal flag '$f'") + badFlags.isEmpty + def goodIndex = + if index.nonEmpty && hasFlag('<') then warningAt(Index)("Argument index ignored if '<' flag is present") + val okRange = index.map(i => i > 0 && i <= argc).getOrElse(true) + okRange || hasFlag('<') or errorAt(Index)("Argument index out of range") + object Conversion: + def apply(m: Match, i: Int): Conversion = + def badCC(msg: String) = ErrorXn(m, i).tap(error => error.errorAt(if (error.op.isEmpty) Spec else CC)(msg)) + def cv(cc: Char) = cc match + case 's' | 'S' => StringXn(m, i) + case 'h' | 'H' => HashXn(m, i) + case 'b' | 'B' => BooleanXn(m, i) + case 'c' | 'C' => CharacterXn(m, i) + case 'd' | 'o' | + 'x' | 'X' => IntegralXn(m, i) + case 'e' | 'E' | + 'f' | + 'g' | 'G' | + 'a' | 'A' => FloatingPointXn(m, i) + case 't' | 'T' => DateTimeXn(m, i) + case '%' | 'n' => LiteralXn(m, i) + case _ => badCC(s"illegal conversion character '$cc'") + end cv + m.group(CC) match + case Some(cc) => cv(cc(0)).tap(_.verify) + case None => badCC(s"Missing conversion operator in '${m.matched}'; $literalHelp") + end apply + val literalHelp = "use %% for literal %, %n for newline" + end Conversion + abstract class GeneralXn extends Conversion + // s | S + class StringXn(val descriptor: Match, val argi: Int) extends GeneralXn: + val acceptableVariants = + if hasFlag('#') then classTag[Formattable] :: Nil + else classTag[Any] :: Nil + override protected def okFlags = "-#<" + // b | B + class BooleanXn(val descriptor: Match, val argi: Int) extends GeneralXn: + val FakeNullTag: ClassTag[?] = null + val acceptableVariants = classTag[Boolean] :: FakeNullTag :: Nil + override def accepts(arg: ClassTag[?]): Boolean = + arg == classTag[Boolean] orElse warningAt(CC)("Boolean format is null test for non-Boolean") + override protected def okFlags = "-<" + // h | H + class HashXn(val descriptor: Match, val argi: Int) extends GeneralXn: + val acceptableVariants = classTag[Any] :: Nil + override protected def okFlags = "-<" + // %% | %n + class LiteralXn(val descriptor: Match, val argi: Int) extends Conversion: + override def isLiteral = true + override def verify = op match + case "%" => super.verify && noPrecision and width.foreach(_ => warningAt(Width)("width ignored on literal")) + case "n" => noFlags && noWidth && noPrecision + override protected val okFlags = "-" + override def acceptableVariants = Nil + class CharacterXn(val descriptor: Match, val argi: Int) extends Conversion: + override def verify = super.verify && noPrecision && only_-("c conversion") + val acceptableVariants = classTag[Char] :: classTag[Byte] :: classTag[Short] :: classTag[Int] :: Nil + class IntegralXn(val descriptor: Match, val argi: Int) extends Conversion: + override def verify = + def d_# = cc == 'd' && hasFlag('#') and badFlag('#', "# not allowed for d conversion") + def x_comma = cc != 'd' && hasFlag(',') and badFlag(',', "',' only allowed for d conversion of integral types") + super.verify && noPrecision && !d_# && !x_comma + val acceptableVariants = classTag[Int] :: classTag[Long] :: classTag[Byte] :: classTag[Short] :: classTag[BigInt] :: Nil + override def accepts(arg: ClassTag[?]): Boolean = + arg == classTag[BigInt] || { + cc match + case 'o' | 'x' | 'X' if hasAnyFlag("+ (") => "+ (".filter(hasFlag).foreach(bad => badFlag(bad, s"only use '$bad' for BigInt conversions to o, x, X")) ; false + case _ => true + } + class FloatingPointXn(val descriptor: Match, val argi: Int) extends Conversion: + override def verify = super.verify && (cc match { + case 'a' | 'A' => + val badFlags = ",(".filter(hasFlag) + noPrecision && badFlags.isEmpty or badFlags.foreach(badf => badFlag(badf, s"'$badf' not allowed for a, A")) + case _ => true + }) + val acceptableVariants = classTag[Double] :: classTag[Float] :: classTag[BigDecimal] :: Nil + class DateTimeXn(val descriptor: Match, val argi: Int) extends Conversion: + override val cc: Char = if op.length > 1 then op(1) else '?' + def hasCC = op.length == 2 or errorAt(CC)("Date/time conversion must have two characters") + def goodCC = "HIklMSLNpzZsQBbhAaCYyjmdeRTrDFc".contains(cc) or errorAt(CC, 1)(s"'$cc' doesn't seem to be a date or time conversion") + override def verify = super.verify && hasCC && goodCC && noPrecision && only_-("date/time conversions") + val acceptableVariants = classTag[Long] :: classTag[Calendar] :: classTag[Date] :: Nil + class ErrorXn(val descriptor: Match, val argi: Int) extends Conversion: + override def isError = true + override def verify = false + override def acceptableVariants = Nil diff --git a/compiler/src/dotty/tools/dotc/transform/localopt/FormatInterpolatorTransform.scala b/compiler/src/dotty/tools/dotc/transform/localopt/FormatInterpolatorTransform.scala new file mode 100644 index 000000000000..820b42d2c225 --- /dev/null +++ b/compiler/src/dotty/tools/dotc/transform/localopt/FormatInterpolatorTransform.scala @@ -0,0 +1,195 @@ +package dotty.tools.dotc +package transform.localopt + +import dotty.tools.dotc.ast.Trees._ +import dotty.tools.dotc.ast.tpd +import dotty.tools.dotc.core.Decorators._ +import dotty.tools.dotc.core.Constants.Constant +import dotty.tools.dotc.core.Contexts._ +import dotty.tools.dotc.core.StdNames._ +import dotty.tools.dotc.core.NameKinds._ +import dotty.tools.dotc.core.Symbols._ +import dotty.tools.dotc.core.Types._ +import dotty.tools.dotc.core.Phases.typerPhase +import dotty.tools.dotc.typer.ProtoTypes._ + +import scala.StringContext.processEscapes +import scala.annotation.tailrec +import scala.collection.mutable.{ListBuffer, Stack} +import scala.reflect.{ClassTag, classTag} +import scala.util.chaining._ +import scala.util.matching.Regex.Match + +object FormatInterpolatorTransform: + import tpd._ + import StringContextChecker.InterpolationReporter + + /* + /** This trait defines a tool to report errors/warnings that do not depend on Position. */ + trait InterpolationReporter: + + /** Reports error/warning of size 1 linked with a part of the StringContext. + * + * @param message the message to report as error/warning + * @param index the index of the part inside the list of parts of the StringContext + * @param offset the index in the part String where the error is + * @return an error/warning depending on the function + */ + def partError(message: String, index: Int, offset: Int): Unit + def partWarning(message: String, index: Int, offset: Int): Unit + + /** Reports error linked with an argument to format. + * + * @param message the message to report as error/warning + * @param index the index of the argument inside the list of arguments of the format function + * @return an error depending on the function + */ + def argError(message: String, index: Int): Unit + + /** Reports error linked with the list of arguments or the StringContext. + * + * @param message the message to report in the error + * @return an error + */ + def strCtxError(message: String): Unit + def argsError(message: String): Unit + + /** Claims whether an error or a warning has been reported + * + * @return true if an error/warning has been reported, false + */ + def hasReported: Boolean + + /** Stores the old value of the reported and reset it to false */ + def resetReported(): Unit + + /** Restores the value of the reported boolean that has been reset */ + def restoreReported(): Unit + end InterpolationReporter + */ + class PartsReporter(fun: Tree, args0: Tree, parts: List[Tree], args: List[Tree])(using Context) extends InterpolationReporter: + private var reported = false + private var oldReported = false + private def partPosAt(index: Int, offset: Int) = + val pos = parts(index).sourcePos + pos.withSpan(pos.span.shift(offset)) + def partError(message: String, index: Int, offset: Int): Unit = + reported = true + report.error(message, partPosAt(index, offset)) + def partWarning(message: String, index: Int, offset: Int): Unit = + reported = true + report.warning(message, partPosAt(index, offset)) + def argError(message: String, index: Int): Unit = + reported = true + report.error(message, args(index).srcPos) + def strCtxError(message: String): Unit = + reported = true + report.error(message, fun.srcPos) + def argsError(message: String): Unit = + reported = true + report.error(message, args0.srcPos) + def hasReported: Boolean = reported + def resetReported(): Unit = + oldReported = reported + reported = false + def restoreReported(): Unit = reported = oldReported + end PartsReporter + object tags: + import java.util.{Calendar, Date, Formattable} + val StringTag = classTag[String] + val FormattableTag = classTag[Formattable] + val BigIntTag = classTag[BigInt] + val BigDecimalTag = classTag[BigDecimal] + val CalendarTag = classTag[Calendar] + val DateTag = classTag[Date] + class FormattableTypes(using Context): + val FormattableType = requiredClassRef("java.util.Formattable") + val BigIntType = requiredClassRef("scala.math.BigInt") + val BigDecimalType = requiredClassRef("scala.math.BigDecimal") + val CalendarType = requiredClassRef("java.util.Calendar") + val DateType = requiredClassRef("java.util.Date") + class TypedFormatChecker(val args: List[Tree])(using Context, InterpolationReporter) extends FormatChecker: + val reporter = summon[InterpolationReporter] + val argTypes = args.map(_.tpe) + val actuals = ListBuffer.empty[Tree] + val argc = argTypes.length + def argType(argi: Int, types: Seq[ClassTag[?]]) = + require(argi < argc, s"$argi out of range picking from $types") + val tpe = argTypes(argi) + types.find(t => argConformsTo(argi, tpe, argTypeOf(t))) + .orElse(types.find(t => argConvertsTo(argi, tpe, argTypeOf(t)))) + .getOrElse { + reporter.argError(s"Found: ${tpe.show}, Required: ${types.mkString(", ")}", argi) + actuals += args(argi) + types.head + } + final lazy val fmtTypes = FormattableTypes() + import tags.*, fmtTypes.* + def argConformsTo(argi: Int, arg: Type, target: Type): Boolean = + (arg <:< target).tap(if _ then actuals += args(argi)) + def argConvertsTo(argi: Int, arg: Type, target: Type): Boolean = + import typer.Implicits.SearchSuccess + atPhase(typerPhase) { + ctx.typer.inferView(args(argi), target) match + case SearchSuccess(view, ref, _, _) => actuals += view ; true + case _ => false + } + def argTypeOf(tag: ClassTag[?]): Type = tag match + case StringTag => defn.StringType + case ClassTag.Boolean => defn.BooleanType + case ClassTag.Byte => defn.ByteType + case ClassTag.Char => defn.CharType + case ClassTag.Short => defn.ShortType + case ClassTag.Int => defn.IntType + case ClassTag.Long => defn.LongType + case ClassTag.Float => defn.FloatType + case ClassTag.Double => defn.DoubleType + case ClassTag.Any => defn.AnyType + case ClassTag.AnyRef => defn.AnyRefType + case FormattableTag => FormattableType + case BigIntTag => BigIntType + case BigDecimalTag => BigDecimalType + case CalendarTag => CalendarType + case DateTag => DateType + case null => defn.NullType + case _ => reporter.strCtxError(s"Unknown type for format $tag") + defn.AnyType + end TypedFormatChecker + + /** For f"${arg}%xpart", check format conversions and return (format, args) + * suitable for String.format(format, args). + */ + def checked(fun: Tree, args0: Tree)(using Context): (Tree, Tree) = + val (partsExpr, parts) = fun match + case TypeApply(Select(Apply(_, (parts: SeqLiteral) :: Nil), _), _) => + (parts.elems, parts.elems.map { case Literal(Constant(s: String)) => s }) + case _ => + report.error("Expected statically known StringContext", fun.srcPos) + (Nil, Nil) + val (args, elemtpt) = args0 match + case seqlit: SeqLiteral => (seqlit.elems, seqlit.elemtpt) + case _ => + report.error("Expected statically known argument list", args0.srcPos) + (Nil, EmptyTree) + given reporter: InterpolationReporter = PartsReporter(fun, args0, partsExpr, args) + + def literally(s: String) = Literal(Constant(s)) + inline val skip = false + if parts.lengthIs != args.length + 1 then + reporter.strCtxError { + if parts.isEmpty then "there are no parts" + else s"too ${if parts.lengthIs > args.length + 1 then "few" else "many"} arguments for interpolated string" + } + (literally(""), args0) + else if skip then + val checked = parts.head :: parts.tail.map(p => if p.startsWith("%") then p else "%s" + p) + (literally(checked.mkString), args0) + else + val checker = TypedFormatChecker(args) + val (checked, cvs) = checker.checked(parts) + if reporter.hasReported then (literally(parts.mkString), args0) + else + assert(checker.argc == checker.actuals.size, s"Expected ${checker.argc}, actuals size is ${checker.actuals.size} for [${parts.mkString(", ")}]") + (literally(checked.mkString), tpd.SeqLiteral(checker.actuals.toList, elemtpt)) + end checked +end FormatInterpolatorTransform diff --git a/compiler/src/dotty/tools/dotc/transform/localopt/StringContextChecker.scala b/compiler/src/dotty/tools/dotc/transform/localopt/StringContextChecker.scala index fbd09f43b853..18f296af8c4e 100644 --- a/compiler/src/dotty/tools/dotc/transform/localopt/StringContextChecker.scala +++ b/compiler/src/dotty/tools/dotc/transform/localopt/StringContextChecker.scala @@ -49,7 +49,7 @@ object StringContextChecker { * * @return true if an error/warning has been reported, false */ - def hasReported() : Boolean + def hasReported : Boolean /** Stores the old value of the reported and reset it to false */ def resetReported() : Unit @@ -106,7 +106,7 @@ object StringContextChecker { report.error(message, args0.srcPos) } - def hasReported() : Boolean = { + def hasReported : Boolean = { reported } @@ -160,7 +160,7 @@ object StringContextChecker { case p :: parts1 => p :: parts1.map((part : String) => { if (!part.startsWith("%")) { val index = part.indexOf('%') - if (!reporter.hasReported() && index != -1) { + if (!reporter.hasReported && index != -1) { reporter.partError("conversions must follow a splice; use %% for literal %, %n for newline", parts.indexOf(part), index) "%s" + part } else "%s" + part @@ -296,7 +296,7 @@ object StringContextChecker { } //conversion - if((conversion >= l || (!part.charAt(conversion).isLetter && part.charAt(conversion) != '%')) && !reporter.hasReported()) + if((conversion >= l || (!part.charAt(conversion).isLetter && part.charAt(conversion) != '%')) && !reporter.hasReported) reporter.partError("Missing conversion operator in '" + part.substring(pos, conversion) + "'; use %% for literal %, %n for newline", partIndex, pos) val hasWidth = (hasWidth1 && !hasArgumentIndex) || hasWidth2 @@ -337,7 +337,7 @@ object StringContextChecker { if (argumentIndex > maxArgumentIndex || argumentIndex <= 0) reporter.partError("Argument index out of range", partIndex, offset) - if (expected && expectedArgumentIndex != argumentIndex && !reporter.hasReported()) + if (expected && expectedArgumentIndex != argumentIndex && !reporter.hasReported) reporter.partWarning("Index is not this arg", partIndex, offset) } @@ -376,10 +376,10 @@ object StringContextChecker { def checkUniqueFlags(partIndex : Int, flags : List[(Char, Int)], notAllowedFlagOnCondition : List[(Char, Boolean, String)]) = { reporter.resetReported() for {flag <- flags ; (nonAllowedFlag, condition, message) <- notAllowedFlagOnCondition ; if (flag._1 == nonAllowedFlag && condition)} { - if (!reporter.hasReported()) + if (!reporter.hasReported) reporter.partError(message, partIndex, flag._2) } - if (!reporter.hasReported()) + if (!reporter.hasReported) reporter.restoreReported() } @@ -656,7 +656,7 @@ object StringContextChecker { argument match { case Some(argIndex, arg) => { val (hasArgumentIndex, argumentIndex, flags, hasWidth, width, hasPrecision, precision, hasRelative, relativeIndex, conversion) = getFormatSpecifiers(part, argIndex, argIndex + 1, false, formattingStart) - if (!reporter.hasReported()){ + if (!reporter.hasReported){ val conversionWithType = checkFormatSpecifiers(argIndex + 1, hasArgumentIndex, argumentIndex, Some(argIndex + 1), start == 0, maxArgumentIndex, hasRelative, hasWidth, width, hasPrecision, precision, flags, conversion, Some(arg.tpe), part) nextStart = conversion + 1 conversionWithType :: checkPart(part, nextStart, argument, maxArgumentIndex) @@ -668,10 +668,10 @@ object StringContextChecker { reporter.partError("Argument index out of range", 0, argumentIndex) if (hasRelative) reporter.partError("No last arg", 0, relativeIndex) - if (!reporter.hasReported()){ + if (!reporter.hasReported){ val conversionWithType = checkFormatSpecifiers(0, hasArgumentIndex, argumentIndex, None, start == 0, maxArgumentIndex, hasRelative, hasWidth, width, hasPrecision, precision, flags, conversion, None, part) nextStart = conversion + 1 - if (!reporter.hasReported() && part.charAt(conversion) != '%' && part.charAt(conversion) != 'n' && !hasArgumentIndex && !hasRelative) + if (!reporter.hasReported && part.charAt(conversion) != '%' && part.charAt(conversion) != 'n' && !hasArgumentIndex && !hasRelative) reporter.partError("conversions must follow a splice; use %% for literal %, %n for newline", 0, part.indexOf('%')) conversionWithType :: checkPart(part, nextStart, argument, maxArgumentIndex) } else checkPart(part, conversion + 1, argument, maxArgumentIndex) @@ -691,10 +691,10 @@ object StringContextChecker { // add default format val parts = addDefaultFormat(parts0) - if (!parts.isEmpty && !reporter.hasReported()) { + if (!parts.isEmpty && !reporter.hasReported) { if (parts.size == 1 && args.size == 0 && parts.head.size != 0){ val argTypeWithConversion = checkPart(parts.head, 0, None, None) - if (!reporter.hasReported()) + if (!reporter.hasReported) for ((argument, conversionChar, flags) <- argTypeWithConversion) checkArgTypeWithConversion(0, conversionChar, argument, flags, parts.head.indexOf('%')) } else { @@ -702,7 +702,7 @@ object StringContextChecker { for (i <- (0 until args.size)){ val (part, arg) = partWithArgs(i) val argTypeWithConversion = checkPart(part, 0, Some((i, arg)), Some(args.size)) - if (!reporter.hasReported()) + if (!reporter.hasReported) for ((argument, conversionChar, flags) <- argTypeWithConversion) checkArgTypeWithConversion(i + 1, conversionChar, argument, flags, parts(i).indexOf('%')) } diff --git a/compiler/src/dotty/tools/dotc/transform/localopt/StringInterpolatorOpt.scala b/compiler/src/dotty/tools/dotc/transform/localopt/StringInterpolatorOpt.scala index 741803d41a66..e42a98080242 100644 --- a/compiler/src/dotty/tools/dotc/transform/localopt/StringInterpolatorOpt.scala +++ b/compiler/src/dotty/tools/dotc/transform/localopt/StringInterpolatorOpt.scala @@ -32,7 +32,7 @@ class StringInterpolatorOpt extends MiniPhase { tree match { case tree: RefTree => val sym = tree.symbol - assert(sym != defn.StringContext_raw && sym != defn.StringContext_s, + assert(sym != defn.StringContext_raw && sym != defn.StringContext_s && sym != defn.StringContext_f, i"$tree in ${ctx.owner.showLocated} should have been rewritten by phase $phaseName") case _ => } @@ -122,6 +122,11 @@ class StringInterpolatorOpt extends MiniPhase { (sym.name == nme.raw_ && sym.eq(defn.StringContext_raw)) || (sym.name == nme.f && sym.eq(defn.StringContext_f)) || (sym.name == nme.s && sym.eq(defn.StringContext_s)) + def transformF(fun: Tree, args: Tree): Tree = + val (parts1, args1) = FormatInterpolatorTransform.checked(fun, args) + resolveConstructor(defn.StringOps.typeRef, List(parts1)) + .select(nme.format) + .appliedTo(args1) if (isInterpolatedMethod) (tree: @unchecked) match { case StringContextIntrinsic(strs: List[Literal], elems: List[Tree]) => @@ -138,10 +143,7 @@ class StringInterpolatorOpt extends MiniPhase { } result case Apply(intp, args :: Nil) if sym.eq(defn.StringContext_f) => - val partsStr = StringContextChecker.checkedParts(intp, args).mkString - resolveConstructor(defn.StringOps.typeRef, List(Literal(Constant(partsStr)))) - .select(nme.format) - .appliedTo(args) + transformF(intp, args) // Starting with Scala 2.13, s and raw are macros in the standard // library, so we need to expand them manually. // sc.s(args) --> standardInterpolator(processEscapes, args, sc.parts) diff --git a/compiler/src/scala/quoted/runtime/impl/printers/SourceCode.scala b/compiler/src/scala/quoted/runtime/impl/printers/SourceCode.scala index d92c16cfe54d..b259b3f21b86 100644 --- a/compiler/src/scala/quoted/runtime/impl/printers/SourceCode.scala +++ b/compiler/src/scala/quoted/runtime/impl/printers/SourceCode.scala @@ -1423,7 +1423,7 @@ object SourceCode { case '"' => "\\\"" case '\'' => "\\\'" case '\\' => "\\\\" - case _ => if (ch.isControl) f"\u${ch.toInt}%04x" else String.valueOf(ch) + case _ => if (ch.isControl) f"${"\\"}u${ch.toInt}%04x" else String.valueOf(ch) } private def escapedString(str: String): String = str flatMap escapedChar diff --git a/compiler/test/dotty/tools/dotc/transform/FormatCheckerTest.scala b/compiler/test/dotty/tools/dotc/transform/FormatCheckerTest.scala new file mode 100644 index 000000000000..56bc47d172b2 --- /dev/null +++ b/compiler/test/dotty/tools/dotc/transform/FormatCheckerTest.scala @@ -0,0 +1,82 @@ +package dotty.tools +package dotc.transform + +import org.junit.{Test, Assert}, Assert.{assertEquals, assertFalse, assertTrue} + +import scala.collection.mutable.ListBuffer +import scala.language.implicitConversions +import scala.reflect.{ClassTag, classTag} +import scala.util.chaining._ + +import java.util.{Calendar, Date, Formattable} + +import localopt.{FormatChecker, StringContextChecker} + +// TDD for just the Checker +class FormatCheckerTest: + class TestReporter extends StringContextChecker.InterpolationReporter: + private var reported = false + private var oldReported = false + val reports = ListBuffer.empty[(String, Int, Int)] + + def partError(message: String, index: Int, offset: Int): Unit = + reported = true + reports += ((message, index, offset)) + def partWarning(message: String, index: Int, offset: Int): Unit = + reported = true + reports += ((message, index, offset)) + def argError(message: String, index: Int): Unit = + reported = true + reports += ((message, index, 0)) + def strCtxError(message: String): Unit = + reported = true + reports += ((message, 0, 0)) + def argsError(message: String): Unit = + reports += ((message, 0, 0)) + + def hasReported: Boolean = reported + + def resetReported(): Unit = + oldReported = reported + reported = false + + def restoreReported(): Unit = + reported = oldReported + end TestReporter + given TestReporter = TestReporter() + + /* + enum ArgTypeTag: + case BooleanArg, ByteArg, CharArg, ShortArg, IntArg, LongArg, FloatArg, DoubleArg, AnyArg, + StringArg, FormattableArg, BigIntArg, BigDecimalArg, CalendarArg, DateArg + given Conversion[ArgTypeTag, Int] = _.ordinal + def argTypeString(tag: Int) = + if tag < 0 then "Null" + else if tag >= ArgTypeTag.values.length then throw RuntimeException(s"Bad tag $tag") + else ArgTypeTag.values(tag) + */ + + class TestChecker(args: ClassTag[?]*)(using val reporter: TestReporter) extends FormatChecker: + def argType(argi: Int, types: ClassTag[?]*): ClassTag[?] = types.find(_ == args(argi)).getOrElse(types.head) + val argc = args.length + + def checked(parts: String*)(args: ClassTag[?]*): String = + val checker = TestChecker(args*) + val (amended, _) = checker.checked(parts.toList) + assertFalse(checker.reporter.hasReported) + amended.mkString + def assertChecked(parts: String*)(args: ClassTag[?]*)(p: TestReporter => Boolean = _ => true): Unit = + val checker = TestChecker(args*) + checker.checked(parts.toList) + assertTrue(p(checker.reporter)) + def errorIs(msg: String): (TestReporter => Boolean) = _.reports.head._1.contains(msg) + + @Test def `simple string` = assertEquals("xyz", checked("xyz")()) + @Test def `one string` = assertEquals("xyz%s123", checked("xyz", "123")(classTag[String])) + @Test def `in first part` = assertEquals("x%ny%%z%s123", checked("x%ny%%z", "123")(classTag[String])) + @Test def `one int` = assertEquals("xyz%d123", checked("xyz", "%d123")(classTag[Int])) + //@Test def `one bad int`: Unit = assertChecked("xyz", "%d123")(classTag[String])(errorIs("Type error")) + @Test def `extra descriptor` = assertChecked("xyz", "%s12%d3")(classTag[String])(errorIs("conversions must follow")) + @Test def `bad leader`: Unit = assertChecked("%dxyz")()(_.reports.head._1.contains("conversions must follow")) + @Test def `in second part`: Unit = assertEquals("xyz%s1%n2%%3", checked("xyz", "1%n2%%3")(classTag[String])) + @Test def `something weird`: Unit = assertEquals("xyz%tH123", checked("xyz", "%tH123")(classTag[Calendar])) diff --git a/tests/neg/f-interpolator-tests.scala b/tests/neg/f-interpolator-tests.scala new file mode 100644 index 000000000000..ca4b4ad11e39 --- /dev/null +++ b/tests/neg/f-interpolator-tests.scala @@ -0,0 +1,11 @@ + +trait T { + val s = "um" + def `um uh` = f"$s%d" // error + + // "% y" looks like a format conversion because ' ' is a legal flag + def i11256 = + val x = 42 + val y = 27 + f"x % y = ${x % y}%d" // error: illegal conversion character 'y' +} diff --git a/tests/neg/fEscapes.check b/tests/neg/fEscapes.check new file mode 100644 index 000000000000..29fe1412e075 --- /dev/null +++ b/tests/neg/fEscapes.check @@ -0,0 +1,4 @@ +-- Error: tests/neg/fEscapes.scala:5:18 -------------------------------------------------------------------------------- +5 | val fEscape = f"\u$octal%04x" // error + | ^^ + | invalid unicode escape at index 1 of \u diff --git a/tests/neg/fEscapes.scala b/tests/neg/fEscapes.scala new file mode 100644 index 000000000000..c4a9a6ffb200 --- /dev/null +++ b/tests/neg/fEscapes.scala @@ -0,0 +1,5 @@ + +// f-interpolator wasn't doing any escape processing +class C: + val octal = 8 + val fEscape = f"\u$octal%04x" // error diff --git a/tests/run-macros/f-interpolator-tests.scala b/tests/run-macros/f-interpolator-tests.scala index a50d22a4c022..73ad9b8852fa 100755 --- a/tests/run-macros/f-interpolator-tests.scala +++ b/tests/run-macros/f-interpolator-tests.scala @@ -23,6 +23,7 @@ object Test { dateArgsTests specificLiteralsTests argumentsTests + unitTests } def multilineTests = { @@ -199,5 +200,209 @@ object Test { def argumentsTests = { println(f"${"a"}%s ${"b"}%s % "false", + f"${b_true}%b" -> "true", + + f"${null}%b" -> "false", + f"${false}%b" -> "false", + f"${true}%b" -> "true", + f"${true && false}%b" -> "false", + f"${java.lang.Boolean.valueOf(false)}%b" -> "false", + f"${java.lang.Boolean.valueOf(true)}%b" -> "true", + + f"${null}%B" -> "FALSE", + f"${false}%B" -> "FALSE", + f"${true}%B" -> "TRUE", + f"${java.lang.Boolean.valueOf(false)}%B" -> "FALSE", + f"${java.lang.Boolean.valueOf(true)}%B" -> "TRUE", + + f"${"true"}%b" -> "true", + f"${"false"}%b"-> "false", + + // 'h' | 'H' (category: general) + // ----------------------------- + f"${null}%h" -> "null", + f"${f_zero}%h" -> "0", + f"${f_zero_-}%h" -> "80000000", + f"${s}%h" -> "4c01926", + + f"${null}%H" -> "NULL", + f"${s}%H" -> "4C01926", + + // 's' | 'S' (category: general) + // ----------------------------- + f"${null}%s" -> "null", + f"${null}%S" -> "NULL", + f"${s}%s" -> "Scala", + f"${s}%S" -> "SCALA", + f"${5}" -> "5", + f"${i}" -> "42", + f"${Symbol("foo")}" -> "Symbol(foo)", + + f"${Thread.State.NEW}" -> "NEW", + + // 'c' | 'C' (category: character) + // ------------------------------- + f"${120:Char}%c" -> "x", + f"${120:Byte}%c" -> "x", + f"${120:Short}%c" -> "x", + f"${120:Int}%c" -> "x", + f"${java.lang.Character.valueOf('x')}%c" -> "x", + f"${java.lang.Byte.valueOf(120:Byte)}%c" -> "x", + f"${java.lang.Short.valueOf(120:Short)}%c" -> "x", + f"${java.lang.Integer.valueOf(120)}%c" -> "x", + + f"${'x' : java.lang.Character}%c" -> "x", + f"${(120:Byte) : java.lang.Byte}%c" -> "x", + f"${(120:Short) : java.lang.Short}%c" -> "x", + f"${120 : java.lang.Integer}%c" -> "x", + + f"${"Scala"}%c" -> "S", + + // 'd' | 'o' | 'x' | 'X' (category: integral) + // ------------------------------------------ + f"${120:Byte}%d" -> "120", + f"${120:Short}%d" -> "120", + f"${120:Int}%d" -> "120", + f"${120:Long}%d" -> "120", + f"${60 * 2}%d" -> "120", + f"${java.lang.Byte.valueOf(120:Byte)}%d" -> "120", + f"${java.lang.Short.valueOf(120:Short)}%d" -> "120", + f"${java.lang.Integer.valueOf(120)}%d" -> "120", + f"${java.lang.Long.valueOf(120)}%d" -> "120", + f"${120 : java.lang.Integer}%d" -> "120", + f"${120 : java.lang.Long}%d" -> "120", + f"${BigInt(120)}%d" -> "120", + + f"${new java.math.BigInteger("120")}%d" -> "120", + + f"${4}%#10X" -> " 0X4", + + f"She is ${fff}%#s feet tall." -> "She is 4 feet tall.", + + f"Just want to say ${"hello, world"}%#s..." -> "Just want to say hello, world...", + + //{ implicit val strToShort: Conversion[String, Short] = java.lang.Short.parseShort ; f"${"120"}%d" } -> "120", + //{ implicit val strToInt = (s: String) => 42 ; f"${"120"}%d" } -> "42", + + // 'e' | 'E' | 'g' | 'G' | 'f' | 'a' | 'A' (category: floating point) + // ------------------------------------------------------------------ + f"${3.4f}%e" -> locally"3.400000e+00", + f"${3.4}%e" -> locally"3.400000e+00", + f"${3.4f : java.lang.Float}%e" -> locally"3.400000e+00", + f"${3.4 : java.lang.Double}%e" -> locally"3.400000e+00", + + f"${BigDecimal(3.4)}%e" -> locally"3.400000e+00", + + f"${new java.math.BigDecimal(3.4)}%e" -> locally"3.400000e+00", + + f"${3}%e" -> locally"3.000000e+00", + f"${3L}%e" -> locally"3.000000e+00", + + // 't' | 'T' (category: date/time) + // ------------------------------- + f"${cal}%TD" -> "05/26/12", + f"${cal.getTime}%TD" -> "05/26/12", + f"${cal.getTime.getTime}%TD" -> "05/26/12", + f"""${"1234"}%TD""" -> "05/26/12", + + // literals and arg indexes + f"%%" -> "%", + f" mind%n------%nmatter" -> + """| mind + |------ + |matter""".stripMargin.linesIterator.mkString(System.lineSeparator), + f"${i}%d % "42 42 9", + f"${7}%d % "7 7 9", + f"${7}%d %2$$d ${9}%d" -> "7 9 9", + + f"${null}%d % "null FALSE", + + f"${5: Any}" -> "5", + f"${5}%s% "55", + f"${3.14}%s,% locally"3.14,${"3.140000"}", + + f"z" -> "z" + ) + + for ((f, s) <- ss) assertEquals(s, f) + end `f interpolator baseline` + + def fIf = + val res = f"${if true then 2.5 else 2.5}%.2f" + val expected = locally"2.50" + assertEquals(expected, res) + + def fIfNot = + val res = f"${if false then 2.5 else 3.5}%.2f" + val expected = locally"3.50" + assertEquals(expected, res) + + // in Scala 2, [A >: Any] forced not to convert 3 to 3.0; Scala 3 harmonics should also respect lower bound. + def fHeteroArgs() = + val res = f"${3.14}%.2f rounds to ${3}%d" + val expected = locally"${"3.14"} rounds to 3" + assertEquals(expected, res) } +object StringContextTestUtils: + private val decimalSeparator: Char = new DecimalFormat().getDecimalFormatSymbols().getDecimalSeparator() + private val numberPattern = """(\d+)\.(\d+.*)""".r + private def applyProperLocale(number: String): String = + val numberPattern(intPart, fractionalPartAndSuffix) = number + s"$intPart$decimalSeparator$fractionalPartAndSuffix" + + extension (sc: StringContext) + // Use this String interpolator to avoid problems with a locale-dependent decimal mark. + def locally(numbers: String*): String = + val numbersWithCorrectLocale = numbers.map(applyProperLocale) + sc.s(numbersWithCorrectLocale: _*) + + // Handles cases like locally"3.14" - it's prettier than locally"${"3.14"}". + def locally(): String = sc.parts.map(applyProperLocale).mkString diff --git a/tests/run/t6476.check b/tests/run/t6476.check index 69bf68978177..e2a080bcf6dc 100644 --- a/tests/run/t6476.check +++ b/tests/run/t6476.check @@ -1,13 +1,18 @@ "Hello", Alice "Hello", Alice + +"Hello", Alice +"Hello", Alice + \"Hello\", Alice \"Hello\", Alice -\"Hello\", Alice -\"Hello\", Alice + \TILT\ -\\TILT\\ -\\TILT\\ \TILT\ \\TILT\\ + +\TILT\ +\TILT\ \\TILT\\ + \TILT\ diff --git a/tests/run/t6476.scala b/tests/run/t6476.scala index a04645065a2a..25a1d5f03ec1 100644 --- a/tests/run/t6476.scala +++ b/tests/run/t6476.scala @@ -3,21 +3,21 @@ object Test { val person = "Alice" println(s"\"Hello\", $person") println(s"""\"Hello\", $person""") - + println() println(f"\"Hello\", $person") println(f"""\"Hello\", $person""") - + println() println(raw"\"Hello\", $person") println(raw"""\"Hello\", $person""") - + println() println(s"\\TILT\\") println(f"\\TILT\\") println(raw"\\TILT\\") - + println() println(s"""\\TILT\\""") println(f"""\\TILT\\""") println(raw"""\\TILT\\""") - + println() println(raw"""\TILT\""") } } From fa9f19e29ccf89aff5b4544afdc8a299371be9d9 Mon Sep 17 00:00:00 2001 From: Som Snytt Date: Thu, 26 Aug 2021 16:26:44 -0700 Subject: [PATCH 02/10] Cleanup dispatch to interpolations --- .../localopt/StringInterpolatorOpt.scala | 84 ++++++++++--------- 1 file changed, 43 insertions(+), 41 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/transform/localopt/StringInterpolatorOpt.scala b/compiler/src/dotty/tools/dotc/transform/localopt/StringInterpolatorOpt.scala index e42a98080242..1fc357ab2bfd 100644 --- a/compiler/src/dotty/tools/dotc/transform/localopt/StringInterpolatorOpt.scala +++ b/compiler/src/dotty/tools/dotc/transform/localopt/StringInterpolatorOpt.scala @@ -117,55 +117,57 @@ class StringInterpolatorOpt extends MiniPhase { } override def transformApply(tree: Apply)(using Context): Tree = { + def mkConcat(strs: List[Literal], elems: List[Tree]): Tree = + val stri = strs.iterator + val elemi = elems.iterator + var result: Tree = stri.next + def concat(tree: Tree): Unit = + result = result.select(defn.String_+).appliedTo(tree).withSpan(tree.span) + while elemi.hasNext + do + concat(elemi.next) + val str = stri.next + if !str.const.stringValue.isEmpty then concat(str) + result + end mkConcat val sym = tree.symbol - val isInterpolatedMethod = // Test names first to avoid loading scala.StringContext if not used - (sym.name == nme.raw_ && sym.eq(defn.StringContext_raw)) || - (sym.name == nme.f && sym.eq(defn.StringContext_f)) || - (sym.name == nme.s && sym.eq(defn.StringContext_s)) + // Test names first to avoid loading scala.StringContext if not used, and common names first + val isInterpolatedMethod = + sym.name match + case nme.s => sym eq defn.StringContext_s + case nme.raw_ => sym eq defn.StringContext_raw + case nme.f => sym eq defn.StringContext_f + case _ => false def transformF(fun: Tree, args: Tree): Tree = val (parts1, args1) = FormatInterpolatorTransform.checked(fun, args) resolveConstructor(defn.StringOps.typeRef, List(parts1)) .select(nme.format) .appliedTo(args1) - if (isInterpolatedMethod) - (tree: @unchecked) match { + // Starting with Scala 2.13, s and raw are macros in the standard + // library, so we need to expand them manually. + // sc.s(args) --> standardInterpolator(processEscapes, args, sc.parts) + // sc.raw(args) --> standardInterpolator(x => x, args, sc.parts) + def transformS(fun: Tree, args: Tree, isRaw: Boolean): Tree = + val pre = fun match + case Select(pre, _) => pre + case intp: Ident => tpd.desugarIdentPrefix(intp) + val stringToString = defn.StringContextModule_processEscapes.info.asInstanceOf[MethodType] + val process = tpd.Lambda(stringToString, args => + if isRaw then args.head else ref(defn.StringContextModule_processEscapes).appliedToTermArgs(args) + ) + evalOnce(pre) { sc => + val parts = sc.select(defn.StringContext_parts) + ref(defn.StringContextModule_standardInterpolator) + .appliedToTermArgs(List(process, args, parts)) + } + end transformS + if isInterpolatedMethod then + (tree: @unchecked) match case StringContextIntrinsic(strs: List[Literal], elems: List[Tree]) => - val stri = strs.iterator - val elemi = elems.iterator - var result: Tree = stri.next - def concat(tree: Tree): Unit = { - result = result.select(defn.String_+).appliedTo(tree).withSpan(tree.span) - } - while (elemi.hasNext) { - concat(elemi.next) - val str = stri.next - if (!str.const.stringValue.isEmpty) concat(str) - } - result - case Apply(intp, args :: Nil) if sym.eq(defn.StringContext_f) => - transformF(intp, args) - // Starting with Scala 2.13, s and raw are macros in the standard - // library, so we need to expand them manually. - // sc.s(args) --> standardInterpolator(processEscapes, args, sc.parts) - // sc.raw(args) --> standardInterpolator(x => x, args, sc.parts) + mkConcat(strs, elems) case Apply(intp, args :: Nil) => - val pre = intp match { - case Select(pre, _) => pre - case intp: Ident => tpd.desugarIdentPrefix(intp) - } - val isRaw = sym eq defn.StringContext_raw - val stringToString = defn.StringContextModule_processEscapes.info.asInstanceOf[MethodType] - - val process = tpd.Lambda(stringToString, args => - if (isRaw) args.head else ref(defn.StringContextModule_processEscapes).appliedToTermArgs(args)) - - evalOnce(pre) { sc => - val parts = sc.select(defn.StringContext_parts) - - ref(defn.StringContextModule_standardInterpolator) - .appliedToTermArgs(List(process, args, parts)) - } - } + if sym eq defn.StringContext_f then transformF(intp, args) + else transformS(intp, args, isRaw = sym eq defn.StringContext_raw) else tree.tpe match case _: ConstantType => tree From 67cddfa21b04e0aeac3243847fe1b5a3b784c41d Mon Sep 17 00:00:00 2001 From: Som Snytt Date: Mon, 22 Nov 2021 20:53:53 -0800 Subject: [PATCH 03/10] Brace reduction and remove dead code --- compiler/src/dotty/tools/dotc/Compiler.scala | 2 +- .../transform/localopt/FormatChecker.scala | 2 - .../FormatInterpolatorTransform.scala | 86 ++- .../localopt/StringContextChecker.scala | 714 ------------------ .../localopt/StringInterpolatorOpt.scala | 118 ++- .../dotc/transform/FormatCheckerTest.scala | 16 +- tests/run-macros/f-interpolator-tests.scala | 13 +- 7 files changed, 96 insertions(+), 855 deletions(-) delete mode 100644 compiler/src/dotty/tools/dotc/transform/localopt/StringContextChecker.scala diff --git a/compiler/src/dotty/tools/dotc/Compiler.scala b/compiler/src/dotty/tools/dotc/Compiler.scala index cbcc62b7fb6b..1ddc626d2646 100644 --- a/compiler/src/dotty/tools/dotc/Compiler.scala +++ b/compiler/src/dotty/tools/dotc/Compiler.scala @@ -85,7 +85,7 @@ class Compiler { new sjs.ExplicitJSClasses, // Make all JS classes explicit (Scala.js only) new ExplicitOuter, // Add accessors to outer classes from nested ones. new ExplicitSelf, // Make references to non-trivial self types explicit as casts - new StringInterpolatorOpt) :: // Optimizes raw and s string interpolators by rewriting them to string concatenations + new StringInterpolatorOpt) :: // Optimizes raw and s and f string interpolators by rewriting them to string concatenations or formats List(new PruneErasedDefs, // Drop erased definitions from scopes and simplify erased expressions new UninitializedDefs, // Replaces `compiletime.uninitialized` by `_` new InlinePatterns, // Remove placeholders of inlined patterns diff --git a/compiler/src/dotty/tools/dotc/transform/localopt/FormatChecker.scala b/compiler/src/dotty/tools/dotc/transform/localopt/FormatChecker.scala index 1cbd73db2ba0..54493c552473 100644 --- a/compiler/src/dotty/tools/dotc/transform/localopt/FormatChecker.scala +++ b/compiler/src/dotty/tools/dotc/transform/localopt/FormatChecker.scala @@ -9,8 +9,6 @@ import scala.util.matching.Regex.Match import java.util.{Calendar, Date, Formattable} -import StringContextChecker.InterpolationReporter - /** Formatter string checker. */ abstract class FormatChecker(using reporter: InterpolationReporter): diff --git a/compiler/src/dotty/tools/dotc/transform/localopt/FormatInterpolatorTransform.scala b/compiler/src/dotty/tools/dotc/transform/localopt/FormatInterpolatorTransform.scala index 820b42d2c225..436b56710370 100644 --- a/compiler/src/dotty/tools/dotc/transform/localopt/FormatInterpolatorTransform.scala +++ b/compiler/src/dotty/tools/dotc/transform/localopt/FormatInterpolatorTransform.scala @@ -22,51 +22,7 @@ import scala.util.matching.Regex.Match object FormatInterpolatorTransform: import tpd._ - import StringContextChecker.InterpolationReporter - /* - /** This trait defines a tool to report errors/warnings that do not depend on Position. */ - trait InterpolationReporter: - - /** Reports error/warning of size 1 linked with a part of the StringContext. - * - * @param message the message to report as error/warning - * @param index the index of the part inside the list of parts of the StringContext - * @param offset the index in the part String where the error is - * @return an error/warning depending on the function - */ - def partError(message: String, index: Int, offset: Int): Unit - def partWarning(message: String, index: Int, offset: Int): Unit - - /** Reports error linked with an argument to format. - * - * @param message the message to report as error/warning - * @param index the index of the argument inside the list of arguments of the format function - * @return an error depending on the function - */ - def argError(message: String, index: Int): Unit - - /** Reports error linked with the list of arguments or the StringContext. - * - * @param message the message to report in the error - * @return an error - */ - def strCtxError(message: String): Unit - def argsError(message: String): Unit - - /** Claims whether an error or a warning has been reported - * - * @return true if an error/warning has been reported, false - */ - def hasReported: Boolean - - /** Stores the old value of the reported and reset it to false */ - def resetReported(): Unit - - /** Restores the value of the reported boolean that has been reset */ - def restoreReported(): Unit - end InterpolationReporter - */ class PartsReporter(fun: Tree, args0: Tree, parts: List[Tree], args: List[Tree])(using Context) extends InterpolationReporter: private var reported = false private var oldReported = false @@ -193,3 +149,45 @@ object FormatInterpolatorTransform: (literally(checked.mkString), tpd.SeqLiteral(checker.actuals.toList, elemtpt)) end checked end FormatInterpolatorTransform + +/** This trait defines a tool to report errors/warnings that do not depend on Position. */ +trait InterpolationReporter: + + /** Reports error/warning of size 1 linked with a part of the StringContext. + * + * @param message the message to report as error/warning + * @param index the index of the part inside the list of parts of the StringContext + * @param offset the index in the part String where the error is + * @return an error/warning depending on the function + */ + def partError(message: String, index: Int, offset: Int): Unit + def partWarning(message: String, index: Int, offset: Int): Unit + + /** Reports error linked with an argument to format. + * + * @param message the message to report as error/warning + * @param index the index of the argument inside the list of arguments of the format function + * @return an error depending on the function + */ + def argError(message: String, index: Int): Unit + + /** Reports error linked with the list of arguments or the StringContext. + * + * @param message the message to report in the error + * @return an error + */ + def strCtxError(message: String): Unit + def argsError(message: String): Unit + + /** Claims whether an error or a warning has been reported + * + * @return true if an error/warning has been reported, false + */ + def hasReported: Boolean + + /** Stores the old value of the reported and reset it to false */ + def resetReported(): Unit + + /** Restores the value of the reported boolean that has been reset */ + def restoreReported(): Unit +end InterpolationReporter diff --git a/compiler/src/dotty/tools/dotc/transform/localopt/StringContextChecker.scala b/compiler/src/dotty/tools/dotc/transform/localopt/StringContextChecker.scala deleted file mode 100644 index 18f296af8c4e..000000000000 --- a/compiler/src/dotty/tools/dotc/transform/localopt/StringContextChecker.scala +++ /dev/null @@ -1,714 +0,0 @@ -package dotty.tools.dotc -package transform.localopt - -import dotty.tools.dotc.ast.Trees._ -import dotty.tools.dotc.ast.tpd -import dotty.tools.dotc.core.Decorators._ -import dotty.tools.dotc.core.Constants.Constant -import dotty.tools.dotc.core.Contexts._ -import dotty.tools.dotc.core.StdNames._ -import dotty.tools.dotc.core.NameKinds._ -import dotty.tools.dotc.core.Symbols._ -import dotty.tools.dotc.core.Types._ - -// Ported from old dotty.internal.StringContextMacro -// TODO: port Scala 2 logic? (see https://github.com/scala/scala/blob/2.13.x/src/compiler/scala/tools/reflect/FormatInterpolator.scala#L74) -object StringContextChecker { - import tpd._ - - /** This trait defines a tool to report errors/warnings that do not depend on Position. */ - trait InterpolationReporter { - - /** Reports error/warning of size 1 linked with a part of the StringContext. - * - * @param message the message to report as error/warning - * @param index the index of the part inside the list of parts of the StringContext - * @param offset the index in the part String where the error is - * @return an error/warning depending on the function - */ - def partError(message : String, index : Int, offset : Int) : Unit - def partWarning(message : String, index : Int, offset : Int) : Unit - - /** Reports error linked with an argument to format. - * - * @param message the message to report as error/warning - * @param index the index of the argument inside the list of arguments of the format function - * @return an error depending on the function - */ - def argError(message : String, index : Int) : Unit - - /** Reports error linked with the list of arguments or the StringContext. - * - * @param message the message to report in the error - * @return an error - */ - def strCtxError(message : String) : Unit - def argsError(message : String) : Unit - - /** Claims whether an error or a warning has been reported - * - * @return true if an error/warning has been reported, false - */ - def hasReported : Boolean - - /** Stores the old value of the reported and reset it to false */ - def resetReported() : Unit - - /** Restores the value of the reported boolean that has been reset */ - def restoreReported() : Unit - } - - /** Check the format of the parts of the f".." arguments and returns the string parts of the StringContext */ - def checkedParts(strContext_f: Tree, args0: Tree)(using Context): String = { - - val (partsExpr, parts) = strContext_f match { - case TypeApply(Select(Apply(_, (parts: SeqLiteral) :: Nil), _), _) => - (parts.elems, parts.elems.map { case Literal(Constant(str: String)) => str } ) - case _ => - report.error("Expected statically known String Context", strContext_f.srcPos) - return "" - } - - val args = args0 match { - case args: SeqLiteral => args.elems - case _ => - report.error("Expected statically known argument list", args0.srcPos) - return "" - } - - val reporter = new InterpolationReporter{ - private[this] var reported = false - private[this] var oldReported = false - def partError(message : String, index : Int, offset : Int) : Unit = { - reported = true - val pos = partsExpr(index).sourcePos - val posOffset = pos.withSpan(pos.span.shift(offset)) - report.error(message, posOffset) - } - def partWarning(message : String, index : Int, offset : Int) : Unit = { - reported = true - val pos = partsExpr(index).sourcePos - val posOffset = pos.withSpan(pos.span.shift(offset)) - report.warning(message, posOffset) - } - - def argError(message : String, index : Int) : Unit = { - reported = true - report.error(message, args(index).srcPos) - } - - def strCtxError(message : String) : Unit = { - reported = true - report.error(message, strContext_f.srcPos) - } - def argsError(message : String) : Unit = { - reported = true - report.error(message, args0.srcPos) - } - - def hasReported : Boolean = { - reported - } - - def resetReported() : Unit = { - oldReported = reported - reported = false - } - - def restoreReported() : Unit = { - reported = oldReported - } - } - - checked(parts, args, reporter) - } - - def checked(parts0: List[String], args: List[Tree], reporter: InterpolationReporter)(using Context): String = { - - - /** Checks if the number of arguments are the same as the number of formatting strings - * - * @param format the number of formatting parts in the StringContext - * @param argument the number of arguments to interpolate in the string - * @return reports an error if the number of arguments does not match with the number of formatting strings, - * nothing otherwise - */ - def checkSizes(format : Int, argument : Int) : Unit = { - if (format > argument && !(format == -1 && argument == 0)) - if (argument == 0) - reporter.argsError("too few arguments for interpolated string") - else - reporter.argError("too few arguments for interpolated string", argument - 1) - if (format < argument && !(format == -1 && argument == 0)) - if (argument == 0) - reporter.argsError("too many arguments for interpolated string") - else - reporter.argError("too many arguments for interpolated string", format) - if (format == -1) - reporter.strCtxError("there are no parts") - } - - /** Adds the default "%s" to the Strings that do not have any given format - * - * @param parts the list of parts contained in the StringContext - * @return a new list of string with all a defined formatting or reports an error if the '%' and - * formatting parameter are too far away from the argument that they refer to - * For example : f2"${d}random-leading-junk%d" will lead to an error - */ - def addDefaultFormat(parts : List[String]) : List[String] = parts match { - case Nil => Nil - case p :: parts1 => p :: parts1.map((part : String) => { - if (!part.startsWith("%")) { - val index = part.indexOf('%') - if (!reporter.hasReported && index != -1) { - reporter.partError("conversions must follow a splice; use %% for literal %, %n for newline", parts.indexOf(part), index) - "%s" + part - } else "%s" + part - } else part - }) - } - - /** Checks whether a part contains a formatting substring - * - * @param part the part to check - * @param l the length of the given part - * @param index the index where to start to look for a potential new formatting string - * @return an Option containing the index in the part where a new formatting String starts, None otherwise - */ - def getFormattingSubstring(part : String, l : Int, index : Int) : Option[Int] = { - var i = index - var result : Option[Int] = None - while (i < l){ - if (part.charAt(i) == '%' && result.isEmpty) - result = Some(i) - i += 1 - } - result - } - - /** Finds all the flags that are inside a formatting String from a given index - * - * @param i the index in the String s where to start to check - * @param l the length of s - * @param s the String to check - * @return a list containing all the flags that are inside the formatting String, - * and their index in the String - */ - def getFlags(i : Int, l : Int, s : String) : List[(Char, Int)] = { - def isFlag(c : Char) : Boolean = c match { - case '-' | '#' | '+' | ' ' | '0' | ',' | '(' => true - case _ => false - } - if (i < l && isFlag(s.charAt(i))) (s.charAt(i), i) :: getFlags(i + 1, l, s) - else Nil - } - - /** Skips the Characters that are width or argumentIndex parameters - * - * @param i the index where to start checking in the given String - * @param s the String to check - * @param l the length of s - * @return a tuple containing the index in the String after skipping - * the parameters, true if it has a width parameter and its value, false otherwise - */ - def skipWidth(i : Int, s : String, l : Int) = { - var j = i - var width = (false, 0) - while (j < l && Character.isDigit(s.charAt(j))){ - width = (true, j) - j += 1 - } - (j, width._1, width._2) - } - - /** Retrieves all the formatting parameters from a part and their index in it - * - * @param part the String containing the formatting parameters - * @param argIndex the index of the current argument inside the list of arguments to interpolate - * @param partIndex the index of the current part inside the list of parts in the StringContext - * @param noArg true if there is no arg, i.e. "%%" or "%n" - * @param pos the initial index where to start checking the part - * @return reports an error if any of the size of the arguments and the parts do not match or if a conversion - * parameter is missing. Otherwise, - * the index where the format specifier substring is, - * hasArgumentIndex (true and the index of its corresponding argumentIndex if there is an argument index, false and 0 otherwise) and - * flags that contains the list of flags (empty if there is none), - * hasWidth (true and the index of the width parameter if there is a width, false and 0 otherwise), - * hasPrecision (true and the index of the precision if there is a precision, false and 0 otherwise), - * hasRelative (true if the specifiers use relative indexing, false otherwise) and - * conversion character index - */ - def getFormatSpecifiers(part : String, argIndex : Int, partIndex : Int, noArg : Boolean, pos : Int) : (Boolean, Int, List[(Char, Int)], Boolean, Int, Boolean, Int, Boolean, Int, Int) = { - var conversion = pos - var hasArgumentIndex = false - var argumentIndex = pos - var hasPrecision = false - var precision = pos - val l = part.length - - if (l >= 1 && part.charAt(conversion) == '%') - conversion += 1 - else if (!noArg) - reporter.argError("too many arguments for interpolated string", argIndex) - - //argument index or width - val (i, hasWidth1, width1) = skipWidth(conversion, part, l) - conversion = i - - //argument index - if (conversion < l && part.charAt(conversion) == '$'){ - if (hasWidth1){ - hasArgumentIndex = true - argumentIndex = width1 - conversion += 1 - } else { - reporter.partError("Missing conversion operator in '" + part.substring(0, conversion) + "'; use %% for literal %, %n for newline", partIndex, 0) - } - } - - //relative indexing - val hasRelative = conversion < l && part.charAt(conversion) == '<' - val relativeIndex = conversion - if (hasRelative) - conversion += 1 - - //flags - val flags = getFlags(conversion, l, part) - conversion += flags.size - - //width - val (j, hasWidth2, width2) = skipWidth(conversion, part, l) - conversion = j - - //precision - if (conversion < l && part.charAt(conversion) == '.') { - precision = conversion - conversion += 1 - hasPrecision = true - val oldConversion = conversion - while (conversion < l && Character.isDigit(part.charAt(conversion))) { - conversion += 1 - } - if (oldConversion == conversion) { - reporter.partError("Missing conversion operator in '" + part.substring(pos, oldConversion - 1) + "'; use %% for literal %, %n for newline", partIndex, pos) - hasPrecision = false - } - } - - //conversion - if((conversion >= l || (!part.charAt(conversion).isLetter && part.charAt(conversion) != '%')) && !reporter.hasReported) - reporter.partError("Missing conversion operator in '" + part.substring(pos, conversion) + "'; use %% for literal %, %n for newline", partIndex, pos) - - val hasWidth = (hasWidth1 && !hasArgumentIndex) || hasWidth2 - val width = if (hasWidth1 && !hasArgumentIndex) width1 else width2 - (hasArgumentIndex, argumentIndex, flags, hasWidth, width, hasPrecision, precision, hasRelative, relativeIndex, conversion) - } - - /** Checks if a given type is a subtype of any of the possibilities - * - * @param actualType the given type - * @param expectedType the type we are expecting - * @param argIndex the index of the argument that should type check - * @param possibilities all the types within which we want to find a super type of the actualType - * @return reports a type mismatch error if the actual type is not a subtype of any of the possibilities, - * nothing otherwise - */ - def checkSubtype(actualType: Type, expectedType: String, argIndex: Int, possibilities: List[Type]) = { - if !possibilities.exists(actualType <:< _) then - reporter.argError("type mismatch;\n found : " + actualType.widen.show.stripPrefix("scala.Predef.").stripPrefix("java.lang.").stripPrefix("scala.") + "\n required: " + expectedType, argIndex) - } - - /** Checks whether a given argument index, relative or not, is in the correct bounds - * - * @param partIndex the index of the part we are checking - * @param offset the index in the part where there might be an error - * @param relative true if relative indexing is used, false otherwise - * @param argumentIndex the argument index parameter in the formatting String - * @param expected true if we have an expectedArgumentIndex, false otherwise - * @param expectedArgumentIndex the expected argument index parameter - * @param maxArgumentIndex the maximum argument index parameter that can be used - * @return reports a warning if relative indexing is used but an argument is still given, - * an error is the argument index is not in the bounds [1, number of arguments] - */ - def checkArgumentIndex(partIndex : Int, offset : Int, relative : Boolean, argumentIndex : Int, expected : Boolean, expectedArgumentIndex : Int, maxArgumentIndex : Int) = { - if (relative) - reporter.partWarning("Argument index ignored if '<' flag is present", partIndex, offset) - - if (argumentIndex > maxArgumentIndex || argumentIndex <= 0) - reporter.partError("Argument index out of range", partIndex, offset) - - if (expected && expectedArgumentIndex != argumentIndex && !reporter.hasReported) - reporter.partWarning("Index is not this arg", partIndex, offset) - } - - /** Checks if a parameter is specified whereas it is not allowed - * - * @param hasParameter true if parameter is specified, false otherwise - * @param partIndex the index of the part inside the parts - * @param offset the index in the part where to report an error - * @param parameter the parameter that is not allowed - * @return reports an error if hasParameter is true, nothing otherwise - */ - def checkNotAllowedParameter(hasParameter : Boolean, partIndex : Int, offset : Int, parameter : String) = { - if (hasParameter) - reporter.partError(parameter + " not allowed", partIndex, offset) - } - - /** Checks if the flags are allowed for the conversion - * - * @param partIndex the index of the part in the String Context - * @param flags the specified flags to check - * @param notAllowedFlagsOnCondition a list that maps which flags are allowed depending on the conversion Char - * @return reports an error if the flag is not allowed, nothing otherwise - */ - def checkFlags(partIndex : Int, flags : List[(Char, Int)], notAllowedFlagOnCondition: List[(Char, Boolean, String)]) = { - for {flag <- flags ; (nonAllowedFlag, condition, message) <- notAllowedFlagOnCondition ; if (flag._1 == nonAllowedFlag && condition)} - reporter.partError(message, partIndex, flag._2) - } - - /** Checks if the flags are allowed for the conversion - * - * @param partIndex the index of the part in the String Context - * @param flags the specified flags to check - * @param notAllowedFlagsOnCondition a list that maps which flags are allowed depending on the conversion Char - * @return reports an error only once if at least one of the flags is not allowed, nothing otherwise - */ - def checkUniqueFlags(partIndex : Int, flags : List[(Char, Int)], notAllowedFlagOnCondition : List[(Char, Boolean, String)]) = { - reporter.resetReported() - for {flag <- flags ; (nonAllowedFlag, condition, message) <- notAllowedFlagOnCondition ; if (flag._1 == nonAllowedFlag && condition)} { - if (!reporter.hasReported) - reporter.partError(message, partIndex, flag._2) - } - if (!reporter.hasReported) - reporter.restoreReported() - } - - /** Checks all the formatting parameters for a Character conversion - * - * @param partIndex the index of the part, that we are checking, inside the parts - * @param flags the flags parameters inside the formatting part - * @param hasPrecision true if precision parameter is specified, false otherwise - * @param precision the index of the precision parameter inside the part - * @return reports an error - * if precision is specified or if the used flags are different from '-' - */ - def checkCharacterConversion(partIndex : Int, flags : List[(Char, Int)], hasPrecision : Boolean, precisionIndex : Int) = { - val notAllowedFlagOnCondition = for (flag <- List('#', '+', ' ', '0', ',', '(')) yield (flag, true, "Only '-' allowed for c conversion") - checkUniqueFlags(partIndex, flags, notAllowedFlagOnCondition) - checkNotAllowedParameter(hasPrecision, partIndex, precisionIndex, "precision") - } - - /** Checks all the formatting parameters for an Integral conversion - * - * @param partIndex the index of the part, that we are checking, inside the parts - * @param argType the type of the argument matching with the given part - * @param conversionChar the Char used for the formatting conversion - * @param flags the flags parameters inside the formatting part - * @param hasPrecision true if precision parameter is specified, false otherwise - * @param precision the index of the precision parameter inside the part - * @return reports an error - * if precision is specified or if the used flags are not allowed : - * ’d’: only ’#’ is allowed, - * ’o’, ’x’, ’X’: ’-’, ’#’, ’0’ are always allowed, depending on the type, this will be checked in the type check step - */ - def checkIntegralConversion(partIndex : Int, argType : Option[Type], conversionChar : Char, flags : List[(Char, Int)], hasPrecision : Boolean, precision : Int) = { - if (conversionChar == 'd') - checkFlags(partIndex, flags, List(('#', true, "# not allowed for d conversion"))) - - checkNotAllowedParameter(hasPrecision, partIndex, precision, "precision") - } - - /** Checks all the formatting parameters for a Floating Point conversion - * - * @param partIndex the index of the part, that we are checking, inside the parts - * @param conversionChar the Char used for the formatting conversion - * @param flags the flags parameters inside the formatting part - * @param hasPrecision true if precision parameter is specified, false otherwise - * @param precision the index of the precision parameter inside the part - * @return reports an error - * if precision is specified for 'a', 'A' conversion or if the used flags are '(' and ',' for 'a', 'A' - */ - def checkFloatingPointConversion(partIndex: Int, conversionChar : Char, flags : List[(Char, Int)], hasPrecision : Boolean, precision : Int) = { - if(conversionChar == 'a' || conversionChar == 'A'){ - for {flag <- flags ; if (flag._1 == ',' || flag._1 == '(')} - reporter.partError("'" + flag._1 + "' not allowed for a, A", partIndex, flag._2) - checkNotAllowedParameter(hasPrecision, partIndex, precision, "precision") - } - } - - /** Checks all the formatting parameters for a Time conversion - * - * @param partIndex the index of the part, that we are checking, inside the parts - * @param part the part that we are checking - * @param conversionIndex the index of the conversion Char used in the part - * @param flags the flags parameters inside the formatting part - * @param hasPrecision true if precision parameter is specified, false otherwise - * @param precision the index of the precision parameter inside the part - * @return reports an error - * if precision is specified, if the time suffix is not given/incorrect or if the used flags are - * different from '-' - */ - def checkTimeConversion(partIndex : Int, part : String, conversionIndex : Int, flags : List[(Char, Int)], hasPrecision : Boolean, precision : Int) = { - /** Checks whether a time suffix is given and whether it is allowed - * - * @param part the part that we are checking - * @param partIndex the index of the part inside of the parts of the StringContext - * @param conversionIndex the index of the conversion Char inside the part - * @param return reports an error if no suffix is specified or if the given suffix is not - * part of the allowed ones - */ - def checkTime(part : String, partIndex : Int, conversionIndex : Int) : Unit = { - if (conversionIndex + 1 >= part.size) - reporter.partError("Date/time conversion must have two characters", partIndex, conversionIndex) - else { - part.charAt(conversionIndex + 1) match { - case 'H' | 'I' | 'k' | 'l' | 'M' | 'S' | 'L' | 'N' | 'p' | 'z' | 'Z' | 's' | 'Q' => //times - case 'B' | 'b' | 'h' | 'A' | 'a' | 'C' | 'Y' | 'y' | 'j' | 'm' | 'd' | 'e' => //dates - case 'R' | 'T' | 'r' | 'D' | 'F' | 'c' => //dates and times - case c => reporter.partError("'" + c + "' doesn't seem to be a date or time conversion", partIndex, conversionIndex + 1) - } - } - } - - val notAllowedFlagOnCondition = for (flag <- List('#', '+', ' ', '0', ',', '(')) yield (flag, true, "Only '-' allowed for date/time conversions") - checkUniqueFlags(partIndex, flags, notAllowedFlagOnCondition) - checkNotAllowedParameter(hasPrecision, partIndex, precision, "precision") - checkTime(part, partIndex, conversionIndex) - } - - /** Checks all the formatting parameters for a General conversion - * - * @param partIndex the index of the part, that we are checking, inside the parts - * @param argType the type of the argument matching with the given part - * @param conversionChar the Char used for the formatting conversion - * @param flags the flags parameters inside the formatting part - * @return reports an error - * if '#' flag is used or if any other flag is used - */ - def checkGeneralConversion(partIndex : Int, argType : Option[Type], conversionChar : Char, flags : List[(Char, Int)]) = { - for {flag <- flags ; if (flag._1 != '-' && flag._1 != '#')} - reporter.partError("Illegal flag '" + flag._1 + "'", partIndex, flag._2) - } - - /** Checks all the formatting parameters for a special Char such as '%' and end of line - * - * @param partIndex the index of the part, that we are checking, inside the parts - * @param conversionChar the Char used for the formatting conversion - * @param hasPrecision true if precision parameter is specified, false otherwise - * @param precision the index of the precision parameter inside the part - * @param hasWidth true if width parameter is specified, false otherwise - * @param width the index of the width parameter inside the part - * @return reports an error if precision or width is specified for '%' or - * if precision is specified for end of line - */ - def checkSpecials(partIndex : Int, conversionChar : Char, hasPrecision : Boolean, precision : Int, hasWidth : Boolean, width : Int, flags : List[(Char, Int)]) = conversionChar match { - case 'n' => { - checkNotAllowedParameter(hasPrecision, partIndex, precision, "precision") - checkNotAllowedParameter(hasWidth, partIndex, width, "width") - val notAllowedFlagOnCondition = for (flag <- List('-', '#', '+', ' ', '0', ',', '(')) yield (flag, true, "flags not allowed") - checkUniqueFlags(partIndex, flags, notAllowedFlagOnCondition) - } - case '%' => { - checkNotAllowedParameter(hasPrecision, partIndex, precision, "precision") - val notAllowedFlagOnCondition = for (flag <- List('#', '+', ' ', '0', ',', '(')) yield (flag, true, "Illegal flag '" + flag + "'") - checkFlags(partIndex, flags, notAllowedFlagOnCondition) - } - case _ => // OK - } - - /** Checks whether the format specifiers are correct depending on the conversion parameter - * - * @param partIndex the index of the part, that we are checking, inside the parts - * @param part the part to check - * The rest of the inputs correspond to the output of the function getFormatSpecifiers - * @param hasArgumentIndex - * @param actualArgumentIndex - * @param expectedArgumentIndex - * @param firstFormattingSubstring true if it is the first in the list, i.e. not an indexed argument - * @param maxArgumentIndex - * @param hasRelative - * @param hasWidth - * @param hasPrecision - * @param precision - * @param flags - * @param conversion - * @param argType - * @return the argument index and its type if there is an argument, the flags and the conversion parameter - * reports an error/warning if the formatting parameters are not allowed/wrong, nothing otherwise - */ - def checkFormatSpecifiers(partIndex : Int, hasArgumentIndex : Boolean, actualArgumentIndex : Int, expectedArgumentIndex : Option[Int], firstFormattingSubstring : Boolean, maxArgumentIndex : Option[Int], - hasRelative : Boolean, hasWidth : Boolean, width : Int, hasPrecision : Boolean, precision : Int, flags : List[(Char, Int)], conversion : Int, argType : Option[Type], part : String) : (Option[(Type, Int)], Char, List[(Char, Int)])= { - val conversionChar = part.charAt(conversion) - - if (hasArgumentIndex && expectedArgumentIndex.nonEmpty && maxArgumentIndex.nonEmpty && firstFormattingSubstring) - checkArgumentIndex(partIndex, actualArgumentIndex, hasRelative, part.charAt(actualArgumentIndex).asDigit, true, expectedArgumentIndex.get, maxArgumentIndex.get) - else if(hasArgumentIndex && maxArgumentIndex.nonEmpty && !firstFormattingSubstring) - checkArgumentIndex(partIndex, actualArgumentIndex, hasRelative, part.charAt(actualArgumentIndex).asDigit, false, 0, maxArgumentIndex.get) - - conversionChar match { - case 'c' | 'C' => checkCharacterConversion(partIndex, flags, hasPrecision, precision) - case 'd' | 'o' | 'x' | 'X' => checkIntegralConversion(partIndex, argType, conversionChar, flags, hasPrecision, precision) - case 'e' | 'E' |'f' | 'g' | 'G' | 'a' | 'A' => checkFloatingPointConversion(partIndex, conversionChar, flags, hasPrecision, precision) - case 't' | 'T' => checkTimeConversion(partIndex, part, conversion, flags, hasPrecision, precision) - case 'b' | 'B' | 'h' | 'H' | 'S' | 's' => checkGeneralConversion(partIndex, argType, conversionChar, flags) - case 'n' | '%' => checkSpecials(partIndex, conversionChar, hasPrecision, precision, hasWidth, width, flags) - case illegal => reporter.partError("illegal conversion character '" + illegal + "'", partIndex, conversion) - } - - (if (argType.isEmpty) None else Some(argType.get, (partIndex - 1)), conversionChar, flags) - } - - /** Checks whether the argument type, if there is one, type checks with the formatting parameters - * - * @param partIndex the index of the part, that we are checking, inside the parts - * @param conversionChar the character used for the conversion - * @param argument an option containing the type and index of the argument, None if there is no argument - * @param flags the flags used for the formatting - * @param formattingStart the index in the part where the formatting substring starts, i.e. where the '%' is - * @return reports an error/warning if the formatting parameters are not allowed/wrong depending on the type, nothing otherwise - */ - def checkArgTypeWithConversion(partIndex : Int, conversionChar : Char, argument : Option[(Type, Int)], flags : List[(Char, Int)], formattingStart : Int) = { - if (argument.nonEmpty) - checkTypeWithArgs(argument.get, conversionChar, partIndex, flags) - else - checkTypeWithoutArgs(conversionChar, partIndex, flags, formattingStart) - } - - /** Checks whether the argument type checks with the formatting parameters - * - * @param argument the given argument to check - * @param conversionChar the conversion parameter inside the formatting String - * @param partIndex index of the part inside the String Context - * @param flags the list of flags, and their index, used inside the formatting String - * @return reports an error if the argument type does not correspond with the conversion character, - * nothing otherwise - */ - def checkTypeWithArgs(argument : (Type, Int), conversionChar : Char, partIndex : Int, flags : List[(Char, Int)]) = { - def booleans = List(defn.BooleanType, defn.NullType) - def dates = List(defn.LongType, defn.JavaCalendarClass.typeRef, defn.JavaDateClass.typeRef) - def floatingPoints = List(defn.DoubleType, defn.FloatType, defn.JavaBigDecimalClass.typeRef) - def integral = List(defn.IntType, defn.LongType, defn.ShortType, defn.ByteType, defn.JavaBigIntegerClass.typeRef) - def character = List(defn.CharType, defn.ByteType, defn.ShortType, defn.IntType) - - val (argType, argIndex) = argument - conversionChar match { - case 'c' | 'C' => checkSubtype(argType, "Char", argIndex, character) - case 'd' | 'o' | 'x' | 'X' => { - checkSubtype(argType, "Int", argIndex, integral) - if (conversionChar != 'd') { - val notAllowedFlagOnCondition = List(('+', !(argType <:< defn.JavaBigIntegerClass.typeRef), "only use '+' for BigInt conversions to o, x, X"), - (' ', !(argType <:< defn.JavaBigIntegerClass.typeRef), "only use ' ' for BigInt conversions to o, x, X"), - ('(', !(argType <:< defn.JavaBigIntegerClass.typeRef), "only use '(' for BigInt conversions to o, x, X"), - (',', true, "',' only allowed for d conversion of integral types")) - checkFlags(partIndex, flags, notAllowedFlagOnCondition) - } - } - case 'e' | 'E' |'f' | 'g' | 'G' | 'a' | 'A' => checkSubtype(argType, "Double", argIndex, floatingPoints) - case 't' | 'T' => checkSubtype(argType, "Date", argIndex, dates) - case 'b' | 'B' => checkSubtype(argType, "Boolean", argIndex, booleans) - case 'h' | 'H' | 'S' | 's' => - if !(argType <:< defn.JavaFormattableClass.typeRef) then - for flag <- flags; if flag._1 == '#' do - reporter.argError("type mismatch;\n found : " + argType.widen.show.stripPrefix("scala.Predef.").stripPrefix("java.lang.").stripPrefix("scala.") + "\n required: java.util.Formattable", argIndex) - case 'n' | '%' => - case illegal => - } - } - - /** Reports error when the formatting parameter require a specific type but no argument is given - * - * @param conversionChar the conversion parameter inside the formatting String - * @param partIndex index of the part inside the String Context - * @param flags the list of flags, and their index, used inside the formatting String - * @param formattingStart the index in the part where the formatting substring starts, i.e. where the '%' is - * @return reports an error if the formatting parameter refer to the type of the parameter but no parameter is given - * nothing otherwise - */ - def checkTypeWithoutArgs(conversionChar : Char, partIndex : Int, flags : List[(Char, Int)], formattingStart : Int) = { - conversionChar match { - case 'o' | 'x' | 'X' => { - val notAllowedFlagOnCondition = List(('+', true, "only use '+' for BigInt conversions to o, x, X"), - (' ', true, "only use ' ' for BigInt conversions to o, x, X"), - ('(', true, "only use '(' for BigInt conversions to o, x, X"), - (',', true, "',' only allowed for d conversion of integral types")) - checkFlags(partIndex, flags, notAllowedFlagOnCondition) - } - case _ => //OK - } - } - - /** Checks that a given part of the String Context respects every formatting constraint per parameter - * - * @param part a particular part of the String Context - * @param start the index from which we start checking the part - * @param argument an Option containing the argument corresponding to the part and its index in the list of args, - * None if no args are specified. - * @param maxArgumentIndex an Option containing the maximum argument index possible, None if no args are specified - * @return a list with all the elements of the conversion per formatting string - */ - def checkPart(part : String, start : Int, argument : Option[(Int, Tree)], maxArgumentIndex : Option[Int]) : List[(Option[(Type, Int)], Char, List[(Char, Int)])] = { - reporter.resetReported() - val hasFormattingSubstring = getFormattingSubstring(part, part.size, start) - if (hasFormattingSubstring.nonEmpty) { - val formattingStart = hasFormattingSubstring.get - var nextStart = formattingStart - - argument match { - case Some(argIndex, arg) => { - val (hasArgumentIndex, argumentIndex, flags, hasWidth, width, hasPrecision, precision, hasRelative, relativeIndex, conversion) = getFormatSpecifiers(part, argIndex, argIndex + 1, false, formattingStart) - if (!reporter.hasReported){ - val conversionWithType = checkFormatSpecifiers(argIndex + 1, hasArgumentIndex, argumentIndex, Some(argIndex + 1), start == 0, maxArgumentIndex, hasRelative, hasWidth, width, hasPrecision, precision, flags, conversion, Some(arg.tpe), part) - nextStart = conversion + 1 - conversionWithType :: checkPart(part, nextStart, argument, maxArgumentIndex) - } else checkPart(part, conversion + 1, argument, maxArgumentIndex) - } - case None => { - val (hasArgumentIndex, argumentIndex, flags, hasWidth, width, hasPrecision, precision, hasRelative, relativeIndex, conversion) = getFormatSpecifiers(part, 0, 0, true, formattingStart) - if (hasArgumentIndex && !(part.charAt(argumentIndex).asDigit == 1 && (part.charAt(conversion) == 'n' || part.charAt(conversion) == '%'))) - reporter.partError("Argument index out of range", 0, argumentIndex) - if (hasRelative) - reporter.partError("No last arg", 0, relativeIndex) - if (!reporter.hasReported){ - val conversionWithType = checkFormatSpecifiers(0, hasArgumentIndex, argumentIndex, None, start == 0, maxArgumentIndex, hasRelative, hasWidth, width, hasPrecision, precision, flags, conversion, None, part) - nextStart = conversion + 1 - if (!reporter.hasReported && part.charAt(conversion) != '%' && part.charAt(conversion) != 'n' && !hasArgumentIndex && !hasRelative) - reporter.partError("conversions must follow a splice; use %% for literal %, %n for newline", 0, part.indexOf('%')) - conversionWithType :: checkPart(part, nextStart, argument, maxArgumentIndex) - } else checkPart(part, conversion + 1, argument, maxArgumentIndex) - } - } - } else { - reporter.restoreReported() - Nil - } - } - - val argument = args.size - - // check validity of formatting - checkSizes(parts0.size - 1, argument) - - // add default format - val parts = addDefaultFormat(parts0) - - if (!parts.isEmpty && !reporter.hasReported) { - if (parts.size == 1 && args.size == 0 && parts.head.size != 0){ - val argTypeWithConversion = checkPart(parts.head, 0, None, None) - if (!reporter.hasReported) - for ((argument, conversionChar, flags) <- argTypeWithConversion) - checkArgTypeWithConversion(0, conversionChar, argument, flags, parts.head.indexOf('%')) - } else { - val partWithArgs = parts.tail.zip(args) - for (i <- (0 until args.size)){ - val (part, arg) = partWithArgs(i) - val argTypeWithConversion = checkPart(part, 0, Some((i, arg)), Some(args.size)) - if (!reporter.hasReported) - for ((argument, conversionChar, flags) <- argTypeWithConversion) - checkArgTypeWithConversion(i + 1, conversionChar, argument, flags, parts(i).indexOf('%')) - } - } - } - - parts.mkString - } -} diff --git a/compiler/src/dotty/tools/dotc/transform/localopt/StringInterpolatorOpt.scala b/compiler/src/dotty/tools/dotc/transform/localopt/StringInterpolatorOpt.scala index 1fc357ab2bfd..06e6a56484fc 100644 --- a/compiler/src/dotty/tools/dotc/transform/localopt/StringInterpolatorOpt.scala +++ b/compiler/src/dotty/tools/dotc/transform/localopt/StringInterpolatorOpt.scala @@ -13,110 +13,86 @@ import dotty.tools.dotc.core.Types._ import dotty.tools.dotc.transform.MegaPhase.MiniPhase import dotty.tools.dotc.typer.ConstFold -/** - * MiniPhase to transform s and raw string interpolators from using StringContext to string - * concatenation. Since string concatenation uses the Java String builder, we get a performance - * improvement in terms of these two interpolators. - * - * More info here: - * https://medium.com/@dkomanov/scala-string-interpolation-performance-21dc85e83afd - */ -class StringInterpolatorOpt extends MiniPhase { - import tpd._ +/** MiniPhase to transform s and raw string interpolators from using StringContext to string + * concatenation. Since string concatenation uses the Java String builder, we get a performance + * improvement in terms of these two interpolators. + * + * More info here: + * https://medium.com/@dkomanov/scala-string-interpolation-performance-21dc85e83afd + */ +class StringInterpolatorOpt extends MiniPhase: + import tpd.* override def phaseName: String = StringInterpolatorOpt.name override def description: String = StringInterpolatorOpt.description - override def checkPostCondition(tree: tpd.Tree)(using Context): Unit = { - tree match { + override def checkPostCondition(tree: tpd.Tree)(using Context): Unit = + tree match case tree: RefTree => val sym = tree.symbol assert(sym != defn.StringContext_raw && sym != defn.StringContext_s && sym != defn.StringContext_f, i"$tree in ${ctx.owner.showLocated} should have been rewritten by phase $phaseName") case _ => - } - } /** Matches a list of constant literals */ - private object Literals { - def unapply(tree: SeqLiteral)(using Context): Option[List[Literal]] = { - tree.elems match { - case literals if literals.forall(_.isInstanceOf[Literal]) => - Some(literals.map(_.asInstanceOf[Literal])) + private object Literals: + def unapply(tree: SeqLiteral)(using Context): Option[List[Literal]] = + tree.elems match + case literals if literals.forall(_.isInstanceOf[Literal]) => Some(literals.map(_.asInstanceOf[Literal])) case _ => None - } - } - } - private object StringContextApply { - def unapply(tree: Select)(using Context): Boolean = { - tree.symbol.eq(defn.StringContextModule_apply) && - tree.qualifier.symbol.eq(defn.StringContextModule) - } - } + private object StringContextApply: + def unapply(tree: Select)(using Context): Boolean = + (tree.symbol eq defn.StringContextModule_apply) && (tree.qualifier.symbol eq defn.StringContextModule) /** Matches an s or raw string interpolator */ - private object SOrRawInterpolator { - def unapply(tree: Tree)(using Context): Option[(List[Literal], List[Tree])] = { - tree match { - case Apply(Select(Apply(StringContextApply(), List(Literals(strs))), _), - List(SeqLiteral(elems, _))) if elems.length == strs.length - 1 => - Some(strs, elems) + private object SOrRawInterpolator: + def unapply(tree: Tree)(using Context): Option[(List[Literal], List[Tree])] = + tree match + case Apply(Select(Apply(StringContextApply(), List(Literals(strs))), _), List(SeqLiteral(elems, _))) + if elems.length == strs.length - 1 => Some(strs, elems) case _ => None - } - } - } //Extract the position from InvalidUnicodeEscapeException //which due to bincompat reasons is unaccessible. //TODO: remove once there is less restrictive bincompat - private object InvalidEscapePosition { - def unapply(t: Throwable): Option[Int] = t match { + private object InvalidEscapePosition: + def unapply(t: Throwable): Option[Int] = t match case iee: StringContext.InvalidEscapeException => Some(iee.index) - case il: IllegalArgumentException => il.getMessage() match { - case s"""invalid unicode escape at index $index of $_""" => index.toIntOption - case _ => None - } + case iae: IllegalArgumentException => iae.getMessage() match + case s"""invalid unicode escape at index $index of $_""" => index.toIntOption + case _ => None case _ => None - } - } - /** - * Match trees that resemble s and raw string interpolations. In the case of the s - * interpolator, escapes the string constants. Exposes the string constants as well as - * the variable references. - */ - private object StringContextIntrinsic { - def unapply(tree: Apply)(using Context): Option[(List[Literal], List[Tree])] = { - tree match { + /** Match trees that resemble s and raw string interpolations. In the case of the s + * interpolator, escapes the string constants. Exposes the string constants as well as + * the variable references. + */ + private object StringContextIntrinsic: + def unapply(tree: Apply)(using Context): Option[(List[Literal], List[Tree])] = + tree match case SOrRawInterpolator(strs, elems) => - if (tree.symbol == defn.StringContext_raw) Some(strs, elems) - else { // tree.symbol == defn.StringContextS + if tree.symbol == defn.StringContext_raw then Some(strs, elems) + else // tree.symbol == defn.StringContextS import dotty.tools.dotc.util.SourcePosition var stringPosition: SourcePosition = null - try { - val escapedStrs = strs.map(str => { + try + val escapedStrs = strs.map { str => stringPosition = str.sourcePos val escaped = StringContext.processEscapes(str.const.stringValue) cpy.Literal(str)(Constant(escaped)) - }) + } Some(escapedStrs, elems) - } catch { - case t @ InvalidEscapePosition(p) => { + catch + case t @ InvalidEscapePosition(p) => val errorSpan = stringPosition.span.startPos.shift(p) val errorPosition = stringPosition.withSpan(errorSpan) report.error(t.getMessage() + "\n", errorPosition) None - } - } - } case _ => None - } - } - } - override def transformApply(tree: Apply)(using Context): Tree = { + override def transformApply(tree: Apply)(using Context): Tree = def mkConcat(strs: List[Literal], elems: List[Tree]): Tree = val stri = strs.iterator val elemi = elems.iterator @@ -175,16 +151,12 @@ class StringInterpolatorOpt extends MiniPhase { ConstFold.Apply(tree).tpe match case ConstantType(x) => Literal(x).withSpan(tree.span).ensureConforms(tree.tpe) case _ => tree - } - override def transformSelect(tree: Select)(using Context): Tree = { + override def transformSelect(tree: Select)(using Context): Tree = ConstFold.Select(tree).tpe match case ConstantType(x) => Literal(x).withSpan(tree.span).ensureConforms(tree.tpe) case _ => tree - } - -} object StringInterpolatorOpt: val name: String = "stringInterpolatorOpt" - val description: String = "optimize raw and s string interpolators" + val description: String = "optimize s, f and raw string interpolators" diff --git a/compiler/test/dotty/tools/dotc/transform/FormatCheckerTest.scala b/compiler/test/dotty/tools/dotc/transform/FormatCheckerTest.scala index 56bc47d172b2..984008c1f8ce 100644 --- a/compiler/test/dotty/tools/dotc/transform/FormatCheckerTest.scala +++ b/compiler/test/dotty/tools/dotc/transform/FormatCheckerTest.scala @@ -6,15 +6,14 @@ import org.junit.{Test, Assert}, Assert.{assertEquals, assertFalse, assertTrue} import scala.collection.mutable.ListBuffer import scala.language.implicitConversions import scala.reflect.{ClassTag, classTag} -import scala.util.chaining._ import java.util.{Calendar, Date, Formattable} -import localopt.{FormatChecker, StringContextChecker} +import localopt.{FormatChecker, InterpolationReporter} // TDD for just the Checker class FormatCheckerTest: - class TestReporter extends StringContextChecker.InterpolationReporter: + class TestReporter extends InterpolationReporter: private var reported = false private var oldReported = false val reports = ListBuffer.empty[(String, Int, Int)] @@ -45,17 +44,6 @@ class FormatCheckerTest: end TestReporter given TestReporter = TestReporter() - /* - enum ArgTypeTag: - case BooleanArg, ByteArg, CharArg, ShortArg, IntArg, LongArg, FloatArg, DoubleArg, AnyArg, - StringArg, FormattableArg, BigIntArg, BigDecimalArg, CalendarArg, DateArg - given Conversion[ArgTypeTag, Int] = _.ordinal - def argTypeString(tag: Int) = - if tag < 0 then "Null" - else if tag >= ArgTypeTag.values.length then throw RuntimeException(s"Bad tag $tag") - else ArgTypeTag.values(tag) - */ - class TestChecker(args: ClassTag[?]*)(using val reporter: TestReporter) extends FormatChecker: def argType(argi: Int, types: ClassTag[?]*): ClassTag[?] = types.find(_ == args(argi)).getOrElse(types.head) val argc = args.length diff --git a/tests/run-macros/f-interpolator-tests.scala b/tests/run-macros/f-interpolator-tests.scala index 73ad9b8852fa..8c59ae19a187 100755 --- a/tests/run-macros/f-interpolator-tests.scala +++ b/tests/run-macros/f-interpolator-tests.scala @@ -1,10 +1,9 @@ -/** - * These tests test all the possible formats the f interpolator has to deal with. - * The tests are sorted by argument category as the arguments are on https://docs.oracle.com/javase/6/docs/api/java/util/Formatter.html#detail - * - * - * Some tests come from https://github.com/lampepfl/dotty/pull/3894/files - */ +/** These tests test all the possible formats the f interpolator has to deal with. + * + * The tests are sorted by argument category as the arguments are on https://docs.oracle.com/javase/6/docs/api/java/util/Formatter.html#detail + * + * Some tests come from https://github.com/lampepfl/dotty/pull/3894/files + */ object Test { def main(args: Array[String]) = { println(f"integer: ${5}%d") From 9826c446500d6959a09cb34ee82bd4d0273a6692 Mon Sep 17 00:00:00 2001 From: Som Snytt Date: Thu, 20 Jan 2022 22:44:26 -0800 Subject: [PATCH 04/10] Simplify Conversion --- .../transform/localopt/FormatChecker.scala | 184 +++++++++--------- .../localopt/StringInterpolatorOpt.scala | 6 +- 2 files changed, 97 insertions(+), 93 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/transform/localopt/FormatChecker.scala b/compiler/src/dotty/tools/dotc/transform/localopt/FormatChecker.scala index 54493c552473..069eac80e458 100644 --- a/compiler/src/dotty/tools/dotc/transform/localopt/FormatChecker.scala +++ b/compiler/src/dotty/tools/dotc/transform/localopt/FormatChecker.scala @@ -9,6 +9,8 @@ import scala.util.matching.Regex.Match import java.util.{Calendar, Date, Formattable} +import PartialFunction.cond + /** Formatter string checker. */ abstract class FormatChecker(using reporter: InterpolationReporter): @@ -19,7 +21,7 @@ abstract class FormatChecker(using reporter: InterpolationReporter): // count of args, for checking indexes def argc: Int - val allFlags = "-#+ 0,(<" + // match a conversion specifier val formatPattern = """%(?:(\d+)\$)?([-#+ 0,(<]+)?(\d+)?(\.\d+)?([tT]?[%a-zA-Z])?""".r // ordinal is the regex group index in the format pattern @@ -95,18 +97,16 @@ abstract class FormatChecker(using reporter: InterpolationReporter): extension (inline value: Boolean) inline def or(inline body: => Unit): Boolean = value || { body ; false } inline def orElse(inline body: => Unit): Boolean = value || { body ; true } - inline def but(inline body: => Unit): Boolean = value && { body ; false } inline def and(inline body: => Unit): Boolean = value && { body ; true } + inline def but(inline body: => Unit): Boolean = value && { body ; false } - /** A conversion specifier matched in the argi'th string part, - * with `argc` arguments to interpolate. - */ - sealed abstract class Conversion: - // the match for this descriptor - def descriptor: Match - // the part number for reporting errors - def argi: Int + enum Kind: + case StringXn, HashXn, BooleanXn, CharacterXn, IntegralXn, FloatingPointXn, DateTimeXn, LiteralXn, ErrorXn + import Kind.* + /** A conversion specifier matched in the argi'th string part, with `argc` arguments to interpolate. + */ + final class Conversion(val descriptor: Match, val argi: Int, val kind: Kind): // the descriptor fields val index: Option[Int] = descriptor.intOf(Index) val flags: String = descriptor.stringOf(Flags) @@ -115,26 +115,86 @@ abstract class FormatChecker(using reporter: InterpolationReporter): val op: String = descriptor.stringOf(CC) // the conversion char is the head of the op string (but see DateTimeXn) - val cc: Char = if isError then '?' else op(0) + val cc: Char = + kind match + case ErrorXn => '?' + case DateTimeXn => if op.length > 1 then op(1) else '?' + case _ => op(0) - def isError: Boolean = false def isIndexed: Boolean = index.nonEmpty || hasFlag('<') - def isLiteral: Boolean = false + def isError: Boolean = kind == ErrorXn + def isLiteral: Boolean = kind == LiteralXn // descriptor is at index 0 of the part string def isLeading: Boolean = descriptor.at(Spec) == 0 + // flags and index in specifier are ok + private def goodies = goodFlags && goodIndex + // true if passes. Default checks flags and index - def verify: Boolean = goodFlags && goodIndex + def verify: Boolean = + kind match { + case StringXn => goodies + case BooleanXn => goodies + case HashXn => goodies + case CharacterXn => goodies && noPrecision && only_-("c conversion") + case IntegralXn => + def d_# = cc == 'd' && hasFlag('#') and badFlag('#', "# not allowed for d conversion") + def x_comma = cc != 'd' && hasFlag(',') and badFlag(',', "',' only allowed for d conversion of integral types") + goodies && noPrecision && !d_# && !x_comma + case FloatingPointXn => + goodies && (cc match { + case 'a' | 'A' => + val badFlags = ",(".filter(hasFlag) + noPrecision && badFlags.isEmpty or badFlags.foreach(badf => badFlag(badf, s"'$badf' not allowed for a, A")) + case _ => true + }) + case DateTimeXn => + def hasCC = op.length == 2 or errorAt(CC)("Date/time conversion must have two characters") + def goodCC = "HIklMSLNpzZsQBbhAaCYyjmdeRTrDFc".contains(cc) or errorAt(CC, 1)(s"'$cc' doesn't seem to be a date or time conversion") + goodies && hasCC && goodCC && noPrecision && only_-("date/time conversions") + case LiteralXn => + op match { + case "%" => goodies && noPrecision and width.foreach(_ => warningAt(Width)("width ignored on literal")) + case "n" => noFlags && noWidth && noPrecision + } + case ErrorXn => + errorAt(CC)(s"illegal conversion character '$cc'") + false + } // is the specifier OK with the given arg - def accepts(arg: ClassTag[?]): Boolean = true + def accepts(arg: ClassTag[?]): Boolean = + kind match + case BooleanXn => arg == classTag[Boolean] orElse warningAt(CC)("Boolean format is null test for non-Boolean") + case IntegralXn => + arg == classTag[BigInt] || !cond(cc) { + case 'o' | 'x' | 'X' if hasAnyFlag("+ (") => "+ (".filter(hasFlag).foreach(bad => badFlag(bad, s"only use '$bad' for BigInt conversions to o, x, X")) ; true + } + case _ => true // what arg type if any does the conversion accept - def acceptableVariants: List[ClassTag[?]] + def acceptableVariants: List[ClassTag[?]] = + kind match { + case StringXn => if hasFlag('#') then classTag[Formattable] :: Nil else classTag[Any] :: Nil + case BooleanXn => classTag[Boolean] :: Conversion.FakeNullTag :: Nil + case HashXn => classTag[Any] :: Nil + case CharacterXn => classTag[Char] :: classTag[Byte] :: classTag[Short] :: classTag[Int] :: Nil + case IntegralXn => classTag[Int] :: classTag[Long] :: classTag[Byte] :: classTag[Short] :: classTag[BigInt] :: Nil + case FloatingPointXn => classTag[Double] :: classTag[Float] :: classTag[BigDecimal] :: Nil + case DateTimeXn => classTag[Long] :: classTag[Calendar] :: classTag[Date] :: Nil + case LiteralXn => Nil + case ErrorXn => Nil + } - // what flags does the conversion accept? defaults to all - protected def okFlags: String = allFlags + // what flags does the conversion accept? + private def okFlags: String = + kind match { + case StringXn => "-#<" + case BooleanXn | HashXn => "-<" + case LiteralXn => "-" + case _ => "-#+ 0,(<" + } def hasFlag(f: Char) = flags.contains(f) def hasAnyFlag(fs: String) = fs.exists(hasFlag) @@ -146,6 +206,7 @@ abstract class FormatChecker(using reporter: InterpolationReporter): def errorAt(g: SpecGroup, i: Int = 0)(msg: String) = reporter.partError(msg, argi, descriptor.offset(g, i)) def warningAt(g: SpecGroup, i: Int = 0)(msg: String) = reporter.partWarning(msg, argi, descriptor.offset(g, i)) + // various assertions def noFlags = flags.isEmpty or errorAt(Flags)("flags not allowed") def noWidth = width.isEmpty or errorAt(Width)("width not allowed") def noPrecision = precision.isEmpty or errorAt(Precision)("precision not allowed") @@ -162,84 +223,25 @@ abstract class FormatChecker(using reporter: InterpolationReporter): okRange || hasFlag('<') or errorAt(Index)("Argument index out of range") object Conversion: def apply(m: Match, i: Int): Conversion = - def badCC(msg: String) = ErrorXn(m, i).tap(error => error.errorAt(if (error.op.isEmpty) Spec else CC)(msg)) - def cv(cc: Char) = cc match - case 's' | 'S' => StringXn(m, i) - case 'h' | 'H' => HashXn(m, i) - case 'b' | 'B' => BooleanXn(m, i) - case 'c' | 'C' => CharacterXn(m, i) + def kindOf(cc: Char) = cc match + case 's' | 'S' => StringXn + case 'h' | 'H' => HashXn + case 'b' | 'B' => BooleanXn + case 'c' | 'C' => CharacterXn case 'd' | 'o' | - 'x' | 'X' => IntegralXn(m, i) + 'x' | 'X' => IntegralXn case 'e' | 'E' | 'f' | 'g' | 'G' | - 'a' | 'A' => FloatingPointXn(m, i) - case 't' | 'T' => DateTimeXn(m, i) - case '%' | 'n' => LiteralXn(m, i) - case _ => badCC(s"illegal conversion character '$cc'") - end cv + 'a' | 'A' => FloatingPointXn + case 't' | 'T' => DateTimeXn + case '%' | 'n' => LiteralXn + case _ => ErrorXn + end kindOf m.group(CC) match - case Some(cc) => cv(cc(0)).tap(_.verify) - case None => badCC(s"Missing conversion operator in '${m.matched}'; $literalHelp") + case Some(cc) => new Conversion(m, i, kindOf(cc(0))).tap(_.verify) + case None => new Conversion(m, i, ErrorXn).tap(_.errorAt(Spec)(s"Missing conversion operator in '${m.matched}'; $literalHelp")) end apply val literalHelp = "use %% for literal %, %n for newline" + private val FakeNullTag: ClassTag[?] = null end Conversion - abstract class GeneralXn extends Conversion - // s | S - class StringXn(val descriptor: Match, val argi: Int) extends GeneralXn: - val acceptableVariants = - if hasFlag('#') then classTag[Formattable] :: Nil - else classTag[Any] :: Nil - override protected def okFlags = "-#<" - // b | B - class BooleanXn(val descriptor: Match, val argi: Int) extends GeneralXn: - val FakeNullTag: ClassTag[?] = null - val acceptableVariants = classTag[Boolean] :: FakeNullTag :: Nil - override def accepts(arg: ClassTag[?]): Boolean = - arg == classTag[Boolean] orElse warningAt(CC)("Boolean format is null test for non-Boolean") - override protected def okFlags = "-<" - // h | H - class HashXn(val descriptor: Match, val argi: Int) extends GeneralXn: - val acceptableVariants = classTag[Any] :: Nil - override protected def okFlags = "-<" - // %% | %n - class LiteralXn(val descriptor: Match, val argi: Int) extends Conversion: - override def isLiteral = true - override def verify = op match - case "%" => super.verify && noPrecision and width.foreach(_ => warningAt(Width)("width ignored on literal")) - case "n" => noFlags && noWidth && noPrecision - override protected val okFlags = "-" - override def acceptableVariants = Nil - class CharacterXn(val descriptor: Match, val argi: Int) extends Conversion: - override def verify = super.verify && noPrecision && only_-("c conversion") - val acceptableVariants = classTag[Char] :: classTag[Byte] :: classTag[Short] :: classTag[Int] :: Nil - class IntegralXn(val descriptor: Match, val argi: Int) extends Conversion: - override def verify = - def d_# = cc == 'd' && hasFlag('#') and badFlag('#', "# not allowed for d conversion") - def x_comma = cc != 'd' && hasFlag(',') and badFlag(',', "',' only allowed for d conversion of integral types") - super.verify && noPrecision && !d_# && !x_comma - val acceptableVariants = classTag[Int] :: classTag[Long] :: classTag[Byte] :: classTag[Short] :: classTag[BigInt] :: Nil - override def accepts(arg: ClassTag[?]): Boolean = - arg == classTag[BigInt] || { - cc match - case 'o' | 'x' | 'X' if hasAnyFlag("+ (") => "+ (".filter(hasFlag).foreach(bad => badFlag(bad, s"only use '$bad' for BigInt conversions to o, x, X")) ; false - case _ => true - } - class FloatingPointXn(val descriptor: Match, val argi: Int) extends Conversion: - override def verify = super.verify && (cc match { - case 'a' | 'A' => - val badFlags = ",(".filter(hasFlag) - noPrecision && badFlags.isEmpty or badFlags.foreach(badf => badFlag(badf, s"'$badf' not allowed for a, A")) - case _ => true - }) - val acceptableVariants = classTag[Double] :: classTag[Float] :: classTag[BigDecimal] :: Nil - class DateTimeXn(val descriptor: Match, val argi: Int) extends Conversion: - override val cc: Char = if op.length > 1 then op(1) else '?' - def hasCC = op.length == 2 or errorAt(CC)("Date/time conversion must have two characters") - def goodCC = "HIklMSLNpzZsQBbhAaCYyjmdeRTrDFc".contains(cc) or errorAt(CC, 1)(s"'$cc' doesn't seem to be a date or time conversion") - override def verify = super.verify && hasCC && goodCC && noPrecision && only_-("date/time conversions") - val acceptableVariants = classTag[Long] :: classTag[Calendar] :: classTag[Date] :: Nil - class ErrorXn(val descriptor: Match, val argi: Int) extends Conversion: - override def isError = true - override def verify = false - override def acceptableVariants = Nil diff --git a/compiler/src/dotty/tools/dotc/transform/localopt/StringInterpolatorOpt.scala b/compiler/src/dotty/tools/dotc/transform/localopt/StringInterpolatorOpt.scala index 06e6a56484fc..b8e6300f4e04 100644 --- a/compiler/src/dotty/tools/dotc/transform/localopt/StringInterpolatorOpt.scala +++ b/compiler/src/dotty/tools/dotc/transform/localopt/StringInterpolatorOpt.scala @@ -114,9 +114,10 @@ class StringInterpolatorOpt extends MiniPhase: case nme.raw_ => sym eq defn.StringContext_raw case nme.f => sym eq defn.StringContext_f case _ => false + // Perform format checking and normalization, then make it StringOps(fmt).format(args1) with tweaked args def transformF(fun: Tree, args: Tree): Tree = - val (parts1, args1) = FormatInterpolatorTransform.checked(fun, args) - resolveConstructor(defn.StringOps.typeRef, List(parts1)) + val (fmt, args1) = FormatInterpolatorTransform.checked(fun, args) + resolveConstructor(defn.StringOps.typeRef, List(fmt)) .select(nme.format) .appliedTo(args1) // Starting with Scala 2.13, s and raw are macros in the standard @@ -137,6 +138,7 @@ class StringInterpolatorOpt extends MiniPhase: .appliedToTermArgs(List(process, args, parts)) } end transformS + // begin transformApply if isInterpolatedMethod then (tree: @unchecked) match case StringContextIntrinsic(strs: List[Literal], elems: List[Tree]) => From 838db3ee3d32ed76525bdbfa1974b37336537961 Mon Sep 17 00:00:00 2001 From: Som Snytt Date: Sat, 22 Jan 2022 00:24:01 -0800 Subject: [PATCH 05/10] Simplify TypedFormatChecker --- .../transform/localopt/FormatChecker.scala | 131 +++++++++++------- .../FormatInterpolatorTransform.scala | 83 +---------- .../dotc/transform/FormatCheckerTest.scala | 70 ---------- 3 files changed, 81 insertions(+), 203 deletions(-) delete mode 100644 compiler/test/dotty/tools/dotc/transform/FormatCheckerTest.scala diff --git a/compiler/src/dotty/tools/dotc/transform/localopt/FormatChecker.scala b/compiler/src/dotty/tools/dotc/transform/localopt/FormatChecker.scala index 069eac80e458..0f3a3ef4c154 100644 --- a/compiler/src/dotty/tools/dotc/transform/localopt/FormatChecker.scala +++ b/compiler/src/dotty/tools/dotc/transform/localopt/FormatChecker.scala @@ -2,8 +2,7 @@ package dotty.tools.dotc package transform.localopt import scala.annotation.tailrec -import scala.collection.mutable.{ListBuffer, Stack} -import scala.reflect.{ClassTag, classTag} +import scala.collection.mutable.ListBuffer import scala.util.chaining.* import scala.util.matching.Regex.Match @@ -11,15 +10,49 @@ import java.util.{Calendar, Date, Formattable} import PartialFunction.cond +import dotty.tools.dotc.ast.tpd.{Match => _, *} +import dotty.tools.dotc.core.Contexts._ +import dotty.tools.dotc.core.Symbols._ +import dotty.tools.dotc.core.Types._ +import dotty.tools.dotc.core.Phases.typerPhase + /** Formatter string checker. */ -abstract class FormatChecker(using reporter: InterpolationReporter): +class TypedFormatChecker(args: List[Tree])(using Context)(using reporter: InterpolationReporter): + + val argTypes = args.map(_.tpe) + val actuals = ListBuffer.empty[Tree] + + // count of args, for checking indexes + val argc = argTypes.length // Pick the first runtime type which the i'th arg can satisfy. // If conversion is required, implementation must emit it. - def argType(argi: Int, types: ClassTag[?]*): ClassTag[?] + def argType(argi: Int, types: Type*): Type = + require(argi < argc, s"$argi out of range picking from $types") + val tpe = argTypes(argi) + types.find(t => argConformsTo(argi, tpe, t)) + .orElse(types.find(t => argConvertsTo(argi, tpe, t))) + .getOrElse { + reporter.argError(s"Found: ${tpe.show}, Required: ${types.mkString(", ")}", argi) + actuals += args(argi) + types.head + } - // count of args, for checking indexes - def argc: Int + object formattableTypes: + val FormattableType = requiredClassRef("java.util.Formattable") + val BigIntType = requiredClassRef("scala.math.BigInt") + val BigDecimalType = requiredClassRef("scala.math.BigDecimal") + val CalendarType = requiredClassRef("java.util.Calendar") + val DateType = requiredClassRef("java.util.Date") + import formattableTypes.* + def argConformsTo(argi: Int, arg: Type, target: Type): Boolean = (arg <:< target).tap(if _ then actuals += args(argi)) + def argConvertsTo(argi: Int, arg: Type, target: Type): Boolean = + import typer.Implicits.SearchSuccess + atPhase(typerPhase) { + ctx.typer.inferView(args(argi), target) match + case SearchSuccess(view, ref, _, _) => actuals += view ; true + case _ => false + } // match a conversion specifier val formatPattern = """%(?:(\d+)\$)?([-#+ 0,(<]+)?(\d+)?(\.\d+)?([tT]?[%a-zA-Z])?""".r @@ -51,7 +84,7 @@ abstract class FormatChecker(using reporter: InterpolationReporter): def insertStringConversion(): Unit = amended += "%s" + part convert += Conversion(formatPattern.findAllMatchIn("%s").next(), n) // improve - argType(n-1, classTag[Any]) + argType(n-1, defn.AnyType) def errorLeading(op: Conversion) = op.errorAt(Spec)(s"conversions must follow a splice; ${Conversion.literalHelp}") def accept(op: Conversion): Unit = if !op.isLeading then errorLeading(op) @@ -66,11 +99,7 @@ abstract class FormatChecker(using reporter: InterpolationReporter): val cv = Conversion(matches.next(), n) if cv.isLiteral then insertStringConversion() else if cv.isIndexed then - if cv.index.getOrElse(-1) == n then accept(cv) - else - // either some other arg num, or '<' - //c.warning(op.groupPos(Index), "Index is not this arg") - insertStringConversion() + if cv.index.getOrElse(-1) == n then accept(cv) else insertStringConversion() else if !cv.isError then accept(cv) // any remaining conversions in this part must be either literals or indexed @@ -128,12 +157,26 @@ abstract class FormatChecker(using reporter: InterpolationReporter): // descriptor is at index 0 of the part string def isLeading: Boolean = descriptor.at(Spec) == 0 - // flags and index in specifier are ok - private def goodies = goodFlags && goodIndex - - // true if passes. Default checks flags and index + // true if passes. def verify: Boolean = - kind match { + // various assertions + def goodies = goodFlags && goodIndex + def noFlags = flags.isEmpty or errorAt(Flags)("flags not allowed") + def noWidth = width.isEmpty or errorAt(Width)("width not allowed") + def noPrecision = precision.isEmpty or errorAt(Precision)("precision not allowed") + def only_-(msg: String) = + val badFlags = flags.filterNot { case '-' | '<' => true case _ => false } + badFlags.isEmpty or badFlag(badFlags(0), s"Only '-' allowed for $msg") + def goodFlags = + val badFlags = flags.filterNot(okFlags.contains) + for f <- badFlags do badFlag(f, s"Illegal flag '$f'") + badFlags.isEmpty + def goodIndex = + if index.nonEmpty && hasFlag('<') then warningAt(Index)("Argument index ignored if '<' flag is present") + val okRange = index.map(i => i > 0 && i <= argc).getOrElse(true) + okRange || hasFlag('<') or errorAt(Index)("Argument index out of range") + // begin verify + kind match case StringXn => goodies case BooleanXn => goodies case HashXn => goodies @@ -143,58 +186,55 @@ abstract class FormatChecker(using reporter: InterpolationReporter): def x_comma = cc != 'd' && hasFlag(',') and badFlag(',', "',' only allowed for d conversion of integral types") goodies && noPrecision && !d_# && !x_comma case FloatingPointXn => - goodies && (cc match { + goodies && (cc match case 'a' | 'A' => val badFlags = ",(".filter(hasFlag) noPrecision && badFlags.isEmpty or badFlags.foreach(badf => badFlag(badf, s"'$badf' not allowed for a, A")) case _ => true - }) + ) case DateTimeXn => def hasCC = op.length == 2 or errorAt(CC)("Date/time conversion must have two characters") def goodCC = "HIklMSLNpzZsQBbhAaCYyjmdeRTrDFc".contains(cc) or errorAt(CC, 1)(s"'$cc' doesn't seem to be a date or time conversion") goodies && hasCC && goodCC && noPrecision && only_-("date/time conversions") case LiteralXn => - op match { + op match case "%" => goodies && noPrecision and width.foreach(_ => warningAt(Width)("width ignored on literal")) case "n" => noFlags && noWidth && noPrecision - } case ErrorXn => errorAt(CC)(s"illegal conversion character '$cc'") false - } + end verify // is the specifier OK with the given arg - def accepts(arg: ClassTag[?]): Boolean = + def accepts(arg: Type): Boolean = kind match - case BooleanXn => arg == classTag[Boolean] orElse warningAt(CC)("Boolean format is null test for non-Boolean") + case BooleanXn => arg == defn.BooleanType orElse warningAt(CC)("Boolean format is null test for non-Boolean") case IntegralXn => - arg == classTag[BigInt] || !cond(cc) { + arg == BigIntType || !cond(cc) { case 'o' | 'x' | 'X' if hasAnyFlag("+ (") => "+ (".filter(hasFlag).foreach(bad => badFlag(bad, s"only use '$bad' for BigInt conversions to o, x, X")) ; true } case _ => true // what arg type if any does the conversion accept - def acceptableVariants: List[ClassTag[?]] = - kind match { - case StringXn => if hasFlag('#') then classTag[Formattable] :: Nil else classTag[Any] :: Nil - case BooleanXn => classTag[Boolean] :: Conversion.FakeNullTag :: Nil - case HashXn => classTag[Any] :: Nil - case CharacterXn => classTag[Char] :: classTag[Byte] :: classTag[Short] :: classTag[Int] :: Nil - case IntegralXn => classTag[Int] :: classTag[Long] :: classTag[Byte] :: classTag[Short] :: classTag[BigInt] :: Nil - case FloatingPointXn => classTag[Double] :: classTag[Float] :: classTag[BigDecimal] :: Nil - case DateTimeXn => classTag[Long] :: classTag[Calendar] :: classTag[Date] :: Nil + def acceptableVariants: List[Type] = + kind match + case StringXn => if hasFlag('#') then FormattableType :: Nil else defn.AnyType :: Nil + case BooleanXn => defn.BooleanType :: defn.NullType :: Nil + case HashXn => defn.AnyType :: Nil + case CharacterXn => defn.CharType :: defn.ByteType :: defn.ShortType :: defn.IntType :: Nil + case IntegralXn => defn.IntType :: defn.LongType :: defn.ByteType :: defn.ShortType :: BigIntType :: Nil + case FloatingPointXn => defn.DoubleType :: defn.FloatType :: BigDecimalType :: Nil + case DateTimeXn => defn.LongType :: CalendarType :: DateType :: Nil case LiteralXn => Nil case ErrorXn => Nil - } // what flags does the conversion accept? private def okFlags: String = - kind match { + kind match case StringXn => "-#<" case BooleanXn | HashXn => "-<" case LiteralXn => "-" case _ => "-#+ 0,(<" - } def hasFlag(f: Char) = flags.contains(f) def hasAnyFlag(fs: String) = fs.exists(hasFlag) @@ -206,21 +246,6 @@ abstract class FormatChecker(using reporter: InterpolationReporter): def errorAt(g: SpecGroup, i: Int = 0)(msg: String) = reporter.partError(msg, argi, descriptor.offset(g, i)) def warningAt(g: SpecGroup, i: Int = 0)(msg: String) = reporter.partWarning(msg, argi, descriptor.offset(g, i)) - // various assertions - def noFlags = flags.isEmpty or errorAt(Flags)("flags not allowed") - def noWidth = width.isEmpty or errorAt(Width)("width not allowed") - def noPrecision = precision.isEmpty or errorAt(Precision)("precision not allowed") - def only_-(msg: String) = - val badFlags = flags.filterNot { case '-' | '<' => true case _ => false } - badFlags.isEmpty or badFlag(badFlags(0), s"Only '-' allowed for $msg") - def goodFlags = - val badFlags = flags.filterNot(okFlags.contains) - for f <- badFlags do badFlag(f, s"Illegal flag '$f'") - badFlags.isEmpty - def goodIndex = - if index.nonEmpty && hasFlag('<') then warningAt(Index)("Argument index ignored if '<' flag is present") - val okRange = index.map(i => i > 0 && i <= argc).getOrElse(true) - okRange || hasFlag('<') or errorAt(Index)("Argument index out of range") object Conversion: def apply(m: Match, i: Int): Conversion = def kindOf(cc: Char) = cc match @@ -243,5 +268,5 @@ abstract class FormatChecker(using reporter: InterpolationReporter): case None => new Conversion(m, i, ErrorXn).tap(_.errorAt(Spec)(s"Missing conversion operator in '${m.matched}'; $literalHelp")) end apply val literalHelp = "use %% for literal %, %n for newline" - private val FakeNullTag: ClassTag[?] = null end Conversion +end TypedFormatChecker diff --git a/compiler/src/dotty/tools/dotc/transform/localopt/FormatInterpolatorTransform.scala b/compiler/src/dotty/tools/dotc/transform/localopt/FormatInterpolatorTransform.scala index 436b56710370..52869a5877c2 100644 --- a/compiler/src/dotty/tools/dotc/transform/localopt/FormatInterpolatorTransform.scala +++ b/compiler/src/dotty/tools/dotc/transform/localopt/FormatInterpolatorTransform.scala @@ -1,27 +1,11 @@ package dotty.tools.dotc package transform.localopt -import dotty.tools.dotc.ast.Trees._ -import dotty.tools.dotc.ast.tpd -import dotty.tools.dotc.core.Decorators._ +import dotty.tools.dotc.ast.tpd.* import dotty.tools.dotc.core.Constants.Constant -import dotty.tools.dotc.core.Contexts._ -import dotty.tools.dotc.core.StdNames._ -import dotty.tools.dotc.core.NameKinds._ -import dotty.tools.dotc.core.Symbols._ -import dotty.tools.dotc.core.Types._ -import dotty.tools.dotc.core.Phases.typerPhase -import dotty.tools.dotc.typer.ProtoTypes._ - -import scala.StringContext.processEscapes -import scala.annotation.tailrec -import scala.collection.mutable.{ListBuffer, Stack} -import scala.reflect.{ClassTag, classTag} -import scala.util.chaining._ -import scala.util.matching.Regex.Match +import dotty.tools.dotc.core.Contexts.* object FormatInterpolatorTransform: - import tpd._ class PartsReporter(fun: Tree, args0: Tree, parts: List[Tree], args: List[Tree])(using Context) extends InterpolationReporter: private var reported = false @@ -50,67 +34,6 @@ object FormatInterpolatorTransform: reported = false def restoreReported(): Unit = reported = oldReported end PartsReporter - object tags: - import java.util.{Calendar, Date, Formattable} - val StringTag = classTag[String] - val FormattableTag = classTag[Formattable] - val BigIntTag = classTag[BigInt] - val BigDecimalTag = classTag[BigDecimal] - val CalendarTag = classTag[Calendar] - val DateTag = classTag[Date] - class FormattableTypes(using Context): - val FormattableType = requiredClassRef("java.util.Formattable") - val BigIntType = requiredClassRef("scala.math.BigInt") - val BigDecimalType = requiredClassRef("scala.math.BigDecimal") - val CalendarType = requiredClassRef("java.util.Calendar") - val DateType = requiredClassRef("java.util.Date") - class TypedFormatChecker(val args: List[Tree])(using Context, InterpolationReporter) extends FormatChecker: - val reporter = summon[InterpolationReporter] - val argTypes = args.map(_.tpe) - val actuals = ListBuffer.empty[Tree] - val argc = argTypes.length - def argType(argi: Int, types: Seq[ClassTag[?]]) = - require(argi < argc, s"$argi out of range picking from $types") - val tpe = argTypes(argi) - types.find(t => argConformsTo(argi, tpe, argTypeOf(t))) - .orElse(types.find(t => argConvertsTo(argi, tpe, argTypeOf(t)))) - .getOrElse { - reporter.argError(s"Found: ${tpe.show}, Required: ${types.mkString(", ")}", argi) - actuals += args(argi) - types.head - } - final lazy val fmtTypes = FormattableTypes() - import tags.*, fmtTypes.* - def argConformsTo(argi: Int, arg: Type, target: Type): Boolean = - (arg <:< target).tap(if _ then actuals += args(argi)) - def argConvertsTo(argi: Int, arg: Type, target: Type): Boolean = - import typer.Implicits.SearchSuccess - atPhase(typerPhase) { - ctx.typer.inferView(args(argi), target) match - case SearchSuccess(view, ref, _, _) => actuals += view ; true - case _ => false - } - def argTypeOf(tag: ClassTag[?]): Type = tag match - case StringTag => defn.StringType - case ClassTag.Boolean => defn.BooleanType - case ClassTag.Byte => defn.ByteType - case ClassTag.Char => defn.CharType - case ClassTag.Short => defn.ShortType - case ClassTag.Int => defn.IntType - case ClassTag.Long => defn.LongType - case ClassTag.Float => defn.FloatType - case ClassTag.Double => defn.DoubleType - case ClassTag.Any => defn.AnyType - case ClassTag.AnyRef => defn.AnyRefType - case FormattableTag => FormattableType - case BigIntTag => BigIntType - case BigDecimalTag => BigDecimalType - case CalendarTag => CalendarType - case DateTag => DateType - case null => defn.NullType - case _ => reporter.strCtxError(s"Unknown type for format $tag") - defn.AnyType - end TypedFormatChecker /** For f"${arg}%xpart", check format conversions and return (format, args) * suitable for String.format(format, args). @@ -146,7 +69,7 @@ object FormatInterpolatorTransform: if reporter.hasReported then (literally(parts.mkString), args0) else assert(checker.argc == checker.actuals.size, s"Expected ${checker.argc}, actuals size is ${checker.actuals.size} for [${parts.mkString(", ")}]") - (literally(checked.mkString), tpd.SeqLiteral(checker.actuals.toList, elemtpt)) + (literally(checked.mkString), SeqLiteral(checker.actuals.toList, elemtpt)) end checked end FormatInterpolatorTransform diff --git a/compiler/test/dotty/tools/dotc/transform/FormatCheckerTest.scala b/compiler/test/dotty/tools/dotc/transform/FormatCheckerTest.scala deleted file mode 100644 index 984008c1f8ce..000000000000 --- a/compiler/test/dotty/tools/dotc/transform/FormatCheckerTest.scala +++ /dev/null @@ -1,70 +0,0 @@ -package dotty.tools -package dotc.transform - -import org.junit.{Test, Assert}, Assert.{assertEquals, assertFalse, assertTrue} - -import scala.collection.mutable.ListBuffer -import scala.language.implicitConversions -import scala.reflect.{ClassTag, classTag} - -import java.util.{Calendar, Date, Formattable} - -import localopt.{FormatChecker, InterpolationReporter} - -// TDD for just the Checker -class FormatCheckerTest: - class TestReporter extends InterpolationReporter: - private var reported = false - private var oldReported = false - val reports = ListBuffer.empty[(String, Int, Int)] - - def partError(message: String, index: Int, offset: Int): Unit = - reported = true - reports += ((message, index, offset)) - def partWarning(message: String, index: Int, offset: Int): Unit = - reported = true - reports += ((message, index, offset)) - def argError(message: String, index: Int): Unit = - reported = true - reports += ((message, index, 0)) - def strCtxError(message: String): Unit = - reported = true - reports += ((message, 0, 0)) - def argsError(message: String): Unit = - reports += ((message, 0, 0)) - - def hasReported: Boolean = reported - - def resetReported(): Unit = - oldReported = reported - reported = false - - def restoreReported(): Unit = - reported = oldReported - end TestReporter - given TestReporter = TestReporter() - - class TestChecker(args: ClassTag[?]*)(using val reporter: TestReporter) extends FormatChecker: - def argType(argi: Int, types: ClassTag[?]*): ClassTag[?] = types.find(_ == args(argi)).getOrElse(types.head) - val argc = args.length - - def checked(parts: String*)(args: ClassTag[?]*): String = - val checker = TestChecker(args*) - val (amended, _) = checker.checked(parts.toList) - assertFalse(checker.reporter.hasReported) - amended.mkString - def assertChecked(parts: String*)(args: ClassTag[?]*)(p: TestReporter => Boolean = _ => true): Unit = - val checker = TestChecker(args*) - checker.checked(parts.toList) - assertTrue(p(checker.reporter)) - def errorIs(msg: String): (TestReporter => Boolean) = _.reports.head._1.contains(msg) - - @Test def `simple string` = assertEquals("xyz", checked("xyz")()) - @Test def `one string` = assertEquals("xyz%s123", checked("xyz", "123")(classTag[String])) - @Test def `in first part` = assertEquals("x%ny%%z%s123", checked("x%ny%%z", "123")(classTag[String])) - @Test def `one int` = assertEquals("xyz%d123", checked("xyz", "%d123")(classTag[Int])) - //@Test def `one bad int`: Unit = assertChecked("xyz", "%d123")(classTag[String])(errorIs("Type error")) - @Test def `extra descriptor` = assertChecked("xyz", "%s12%d3")(classTag[String])(errorIs("conversions must follow")) - @Test def `bad leader`: Unit = assertChecked("%dxyz")()(_.reports.head._1.contains("conversions must follow")) - @Test def `in second part`: Unit = assertEquals("xyz%s1%n2%%3", checked("xyz", "1%n2%%3")(classTag[String])) - @Test def `something weird`: Unit = assertEquals("xyz%tH123", checked("xyz", "%tH123")(classTag[Calendar])) From d05c62bee7afd6d2cbb2ff6df132ef114e32e622 Mon Sep 17 00:00:00 2001 From: Som Snytt Date: Sat, 22 Jan 2022 15:50:40 -0800 Subject: [PATCH 06/10] Drop special reporter --- .../transform/localopt/FormatChecker.scala | 41 +++++---- .../FormatInterpolatorTransform.scala | 89 ++----------------- 2 files changed, 32 insertions(+), 98 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/transform/localopt/FormatChecker.scala b/compiler/src/dotty/tools/dotc/transform/localopt/FormatChecker.scala index 0f3a3ef4c154..94b10a5b538e 100644 --- a/compiler/src/dotty/tools/dotc/transform/localopt/FormatChecker.scala +++ b/compiler/src/dotty/tools/dotc/transform/localopt/FormatChecker.scala @@ -17,7 +17,7 @@ import dotty.tools.dotc.core.Types._ import dotty.tools.dotc.core.Phases.typerPhase /** Formatter string checker. */ -class TypedFormatChecker(args: List[Tree])(using Context)(using reporter: InterpolationReporter): +class TypedFormatChecker(partsElems: List[Tree], parts: List[String], args: List[Tree])(using Context): val argTypes = args.map(_.tpe) val actuals = ListBuffer.empty[Tree] @@ -33,7 +33,7 @@ class TypedFormatChecker(args: List[Tree])(using Context)(using reporter: Interp types.find(t => argConformsTo(argi, tpe, t)) .orElse(types.find(t => argConvertsTo(argi, tpe, t))) .getOrElse { - reporter.argError(s"Found: ${tpe.show}, Required: ${types.mkString(", ")}", argi) + report.argError(s"Found: ${tpe.show}, Required: ${types.mkString(", ")}", argi) actuals += args(argi) types.head } @@ -64,20 +64,17 @@ class TypedFormatChecker(args: List[Tree])(using Context)(using reporter: Interp /** For N part strings and N-1 args to interpolate, normalize parts and check arg types. * - * Returns parts, possibly updated with explicit leading "%s", - * and conversions for each arg. - * - * Implementation must emit conversions required by invocations of `argType`. + * Returns normalized part strings and args, where args correcpond to conversions in tail of parts. */ - def checked(parts0: List[String]): (List[String], List[Conversion]) = + def checked: (List[String], List[Tree]) = val amended = ListBuffer.empty[String] val convert = ListBuffer.empty[Conversion] @tailrec - def loop(parts: List[String], n: Int): Unit = - parts match + def loop(remaining: List[String], n: Int): Unit = + remaining match case part0 :: more => - def badPart(t: Throwable): String = "".tap(_ => reporter.partError(t.getMessage, index = n, offset = 0)) + def badPart(t: Throwable): String = "".tap(_ => report.partError(t.getMessage, index = n, offset = 0)) val part = try StringContext.processEscapes(part0) catch badPart val matches = formatPattern.findAllMatchIn(part) @@ -112,8 +109,11 @@ class TypedFormatChecker(args: List[Tree])(using Context)(using reporter: Interp case Nil => () end loop - loop(parts0, n = 0) - (amended.toList, convert.toList) + loop(parts, n = 0) + if reported then (Nil, Nil) + else + assert(argc == actuals.size, s"Expected ${argc} args but got ${actuals.size} for [${parts.mkString(", ")}]") + (amended.toList, actuals.toList) end checked extension (descriptor: Match) @@ -146,7 +146,7 @@ class TypedFormatChecker(args: List[Tree])(using Context)(using reporter: Interp // the conversion char is the head of the op string (but see DateTimeXn) val cc: Char = kind match - case ErrorXn => '?' + case ErrorXn => if op.isEmpty then '?' else op(0) case DateTimeXn => if op.length > 1 then op(1) else '?' case _ => op(0) @@ -243,8 +243,8 @@ class TypedFormatChecker(args: List[Tree])(using Context)(using reporter: Interp val i = flags.indexOf(f) match { case -1 => 0 case j => j } errorAt(Flags, i)(msg) - def errorAt(g: SpecGroup, i: Int = 0)(msg: String) = reporter.partError(msg, argi, descriptor.offset(g, i)) - def warningAt(g: SpecGroup, i: Int = 0)(msg: String) = reporter.partWarning(msg, argi, descriptor.offset(g, i)) + def errorAt(g: SpecGroup, i: Int = 0)(msg: String) = report.partError(msg, argi, descriptor.offset(g, i)) + def warningAt(g: SpecGroup, i: Int = 0)(msg: String) = report.partWarning(msg, argi, descriptor.offset(g, i)) object Conversion: def apply(m: Match, i: Int): Conversion = @@ -269,4 +269,15 @@ class TypedFormatChecker(args: List[Tree])(using Context)(using reporter: Interp end apply val literalHelp = "use %% for literal %, %n for newline" end Conversion + + var reported = false + + private def partPosAt(index: Int, offset: Int) = + val pos = partsElems(index).sourcePos + pos.withSpan(pos.span.shift(offset)) + + extension (r: report.type) + def argError(message: String, index: Int): Unit = r.error(message, args(index).srcPos).tap(_ => reported = true) + def partError(message: String, index: Int, offset: Int): Unit = r.error(message, partPosAt(index, offset)).tap(_ => reported = true) + def partWarning(message: String, index: Int, offset: Int): Unit = r.warning(message, partPosAt(index, offset)).tap(_ => reported = true) end TypedFormatChecker diff --git a/compiler/src/dotty/tools/dotc/transform/localopt/FormatInterpolatorTransform.scala b/compiler/src/dotty/tools/dotc/transform/localopt/FormatInterpolatorTransform.scala index 52869a5877c2..79d94c26c692 100644 --- a/compiler/src/dotty/tools/dotc/transform/localopt/FormatInterpolatorTransform.scala +++ b/compiler/src/dotty/tools/dotc/transform/localopt/FormatInterpolatorTransform.scala @@ -7,34 +7,6 @@ import dotty.tools.dotc.core.Contexts.* object FormatInterpolatorTransform: - class PartsReporter(fun: Tree, args0: Tree, parts: List[Tree], args: List[Tree])(using Context) extends InterpolationReporter: - private var reported = false - private var oldReported = false - private def partPosAt(index: Int, offset: Int) = - val pos = parts(index).sourcePos - pos.withSpan(pos.span.shift(offset)) - def partError(message: String, index: Int, offset: Int): Unit = - reported = true - report.error(message, partPosAt(index, offset)) - def partWarning(message: String, index: Int, offset: Int): Unit = - reported = true - report.warning(message, partPosAt(index, offset)) - def argError(message: String, index: Int): Unit = - reported = true - report.error(message, args(index).srcPos) - def strCtxError(message: String): Unit = - reported = true - report.error(message, fun.srcPos) - def argsError(message: String): Unit = - reported = true - report.error(message, args0.srcPos) - def hasReported: Boolean = reported - def resetReported(): Unit = - oldReported = reported - reported = false - def restoreReported(): Unit = reported = oldReported - end PartsReporter - /** For f"${arg}%xpart", check format conversions and return (format, args) * suitable for String.format(format, args). */ @@ -50,67 +22,18 @@ object FormatInterpolatorTransform: case _ => report.error("Expected statically known argument list", args0.srcPos) (Nil, EmptyTree) - given reporter: InterpolationReporter = PartsReporter(fun, args0, partsExpr, args) def literally(s: String) = Literal(Constant(s)) - inline val skip = false if parts.lengthIs != args.length + 1 then - reporter.strCtxError { + val badParts = if parts.isEmpty then "there are no parts" else s"too ${if parts.lengthIs > args.length + 1 then "few" else "many"} arguments for interpolated string" - } + report.error(badParts, fun.srcPos) (literally(""), args0) - else if skip then - val checked = parts.head :: parts.tail.map(p => if p.startsWith("%") then p else "%s" + p) - (literally(checked.mkString), args0) else - val checker = TypedFormatChecker(args) - val (checked, cvs) = checker.checked(parts) - if reporter.hasReported then (literally(parts.mkString), args0) - else - assert(checker.argc == checker.actuals.size, s"Expected ${checker.argc}, actuals size is ${checker.actuals.size} for [${parts.mkString(", ")}]") - (literally(checked.mkString), SeqLiteral(checker.actuals.toList, elemtpt)) + val checker = TypedFormatChecker(partsExpr, parts, args) + val (format, formatArgs) = checker.checked + if format.isEmpty then (literally(parts.mkString), args0) + else (literally(format.mkString), SeqLiteral(formatArgs.toList, elemtpt)) end checked end FormatInterpolatorTransform - -/** This trait defines a tool to report errors/warnings that do not depend on Position. */ -trait InterpolationReporter: - - /** Reports error/warning of size 1 linked with a part of the StringContext. - * - * @param message the message to report as error/warning - * @param index the index of the part inside the list of parts of the StringContext - * @param offset the index in the part String where the error is - * @return an error/warning depending on the function - */ - def partError(message: String, index: Int, offset: Int): Unit - def partWarning(message: String, index: Int, offset: Int): Unit - - /** Reports error linked with an argument to format. - * - * @param message the message to report as error/warning - * @param index the index of the argument inside the list of arguments of the format function - * @return an error depending on the function - */ - def argError(message: String, index: Int): Unit - - /** Reports error linked with the list of arguments or the StringContext. - * - * @param message the message to report in the error - * @return an error - */ - def strCtxError(message: String): Unit - def argsError(message: String): Unit - - /** Claims whether an error or a warning has been reported - * - * @return true if an error/warning has been reported, false - */ - def hasReported: Boolean - - /** Stores the old value of the reported and reset it to false */ - def resetReported(): Unit - - /** Restores the value of the reported boolean that has been reset */ - def restoreReported(): Unit -end InterpolationReporter From f33198f8db4ed40576677e1cde6050ed9228ee27 Mon Sep 17 00:00:00 2001 From: Som Snytt Date: Sun, 23 Jan 2022 12:01:40 -0800 Subject: [PATCH 07/10] Improve error position and recovery for bad dollar --- compiler/src/dotty/tools/dotc/parsing/Scanners.scala | 4 +++- compiler/test/dotty/tools/repl/ReplTest.scala | 5 +++-- tests/neg/interpolator-dollar.check | 4 ++++ tests/neg/interpolator-dollar.scala | 6 ++++++ 4 files changed, 16 insertions(+), 3 deletions(-) create mode 100644 tests/neg/interpolator-dollar.check create mode 100644 tests/neg/interpolator-dollar.scala diff --git a/compiler/src/dotty/tools/dotc/parsing/Scanners.scala b/compiler/src/dotty/tools/dotc/parsing/Scanners.scala index c2c13e899ef4..1fee0f52f770 100644 --- a/compiler/src/dotty/tools/dotc/parsing/Scanners.scala +++ b/compiler/src/dotty/tools/dotc/parsing/Scanners.scala @@ -1168,7 +1168,9 @@ object Scanners { finishNamedToken(IDENTIFIER, target = next) } else - error("invalid string interpolation: `$$`, `$\"`, `$`ident or `$`BlockExpr expected") + error("invalid string interpolation: `$$`, `$\"`, `$`ident or `$`BlockExpr expected", off = charOffset - 2) + putChar('$') + getStringPart(multiLine) } else { val isUnclosedLiteral = !isUnicodeEscape && (ch == SU || (!multiLine && (ch == CR || ch == LF))) diff --git a/compiler/test/dotty/tools/repl/ReplTest.scala b/compiler/test/dotty/tools/repl/ReplTest.scala index 55f8afa26260..d5256986f874 100644 --- a/compiler/test/dotty/tools/repl/ReplTest.scala +++ b/compiler/test/dotty/tools/repl/ReplTest.scala @@ -63,8 +63,9 @@ extends ReplDriver(options, new PrintStream(out, true, StandardCharsets.UTF_8.na case "" => Nil case nonEmptyLine => nonEmptyLine :: Nil } + def nonBlank(line: String): Boolean = line.exists(!Character.isWhitespace(_)) - val expectedOutput = lines.flatMap(filterEmpties) + val expectedOutput = lines.filter(nonBlank) val actualOutput = { val opts = toolArgsParse(lines.take(1)) val (optsLine, inputLines) = if opts.isEmpty then ("", lines) else (lines.head, lines.drop(1)) @@ -80,7 +81,7 @@ extends ReplDriver(options, new PrintStream(out, true, StandardCharsets.UTF_8.na out.linesIterator.foreach(buf.append) nstate } - (optsLine :: buf.toList).flatMap(filterEmpties) + (optsLine :: buf.toList).filter(nonBlank) } if !FileDiff.matches(actualOutput, expectedOutput) then diff --git a/tests/neg/interpolator-dollar.check b/tests/neg/interpolator-dollar.check new file mode 100644 index 000000000000..2de0c843725e --- /dev/null +++ b/tests/neg/interpolator-dollar.check @@ -0,0 +1,4 @@ +-- Error: tests/neg/interpolator-dollar.scala:5:20 --------------------------------------------------------------------- +5 | def oops = f"$s%s $ Date: Tue, 25 Jan 2022 11:27:24 -0800 Subject: [PATCH 08/10] Pretty print type on f-interpolator, improve caret Tweak vulpix to show difference in expectations. --- .../transform/localopt/FormatChecker.scala | 18 +- .../dotty/tools/vulpix/ParallelTesting.scala | 88 ++++---- tests/neg/f-interpolator-neg.check | 200 ++++++++++++++++++ tests/neg/f-interpolator-neg.scala | 4 +- tests/neg/f-interpolator-tests.check | 8 + 5 files changed, 269 insertions(+), 49 deletions(-) create mode 100644 tests/neg/f-interpolator-neg.check create mode 100644 tests/neg/f-interpolator-tests.check diff --git a/compiler/src/dotty/tools/dotc/transform/localopt/FormatChecker.scala b/compiler/src/dotty/tools/dotc/transform/localopt/FormatChecker.scala index 94b10a5b538e..0ba7bd14a9b6 100644 --- a/compiler/src/dotty/tools/dotc/transform/localopt/FormatChecker.scala +++ b/compiler/src/dotty/tools/dotc/transform/localopt/FormatChecker.scala @@ -15,6 +15,7 @@ import dotty.tools.dotc.core.Contexts._ import dotty.tools.dotc.core.Symbols._ import dotty.tools.dotc.core.Types._ import dotty.tools.dotc.core.Phases.typerPhase +import dotty.tools.dotc.util.Spans.Span /** Formatter string checker. */ class TypedFormatChecker(partsElems: List[Tree], parts: List[String], args: List[Tree])(using Context): @@ -33,7 +34,7 @@ class TypedFormatChecker(partsElems: List[Tree], parts: List[String], args: List types.find(t => argConformsTo(argi, tpe, t)) .orElse(types.find(t => argConvertsTo(argi, tpe, t))) .getOrElse { - report.argError(s"Found: ${tpe.show}, Required: ${types.mkString(", ")}", argi) + report.argError(s"Found: ${tpe.show}, Required: ${types.map(_.show).mkString(", ")}", argi) actuals += args(argi) types.head } @@ -118,6 +119,7 @@ class TypedFormatChecker(partsElems: List[Tree], parts: List[String], args: List extension (descriptor: Match) def at(g: SpecGroup): Int = descriptor.start(g.ordinal) + def end(g: SpecGroup): Int = descriptor.end(g.ordinal) def offset(g: SpecGroup, i: Int = 0): Int = at(g) + i def group(g: SpecGroup): Option[String] = Option(descriptor.group(g.ordinal)) def stringOf(g: SpecGroup): String = group(g).getOrElse("") @@ -243,8 +245,8 @@ class TypedFormatChecker(partsElems: List[Tree], parts: List[String], args: List val i = flags.indexOf(f) match { case -1 => 0 case j => j } errorAt(Flags, i)(msg) - def errorAt(g: SpecGroup, i: Int = 0)(msg: String) = report.partError(msg, argi, descriptor.offset(g, i)) - def warningAt(g: SpecGroup, i: Int = 0)(msg: String) = report.partWarning(msg, argi, descriptor.offset(g, i)) + def errorAt(g: SpecGroup, i: Int = 0)(msg: String) = report.partError(msg, argi, descriptor.offset(g, i), descriptor.end(g)) + def warningAt(g: SpecGroup, i: Int = 0)(msg: String) = report.partWarning(msg, argi, descriptor.offset(g, i), descriptor.end(g)) object Conversion: def apply(m: Match, i: Int): Conversion = @@ -272,12 +274,14 @@ class TypedFormatChecker(partsElems: List[Tree], parts: List[String], args: List var reported = false - private def partPosAt(index: Int, offset: Int) = + private def partPosAt(index: Int, offset: Int, end: Int) = val pos = partsElems(index).sourcePos - pos.withSpan(pos.span.shift(offset)) + val bgn = pos.span.start + offset + val fin = if end < 0 then pos.span.end else pos.span.start + end + pos.withSpan(Span(bgn, fin, bgn)) extension (r: report.type) def argError(message: String, index: Int): Unit = r.error(message, args(index).srcPos).tap(_ => reported = true) - def partError(message: String, index: Int, offset: Int): Unit = r.error(message, partPosAt(index, offset)).tap(_ => reported = true) - def partWarning(message: String, index: Int, offset: Int): Unit = r.warning(message, partPosAt(index, offset)).tap(_ => reported = true) + def partError(message: String, index: Int, offset: Int, end: Int = -1): Unit = r.error(message, partPosAt(index, offset, end)).tap(_ => reported = true) + def partWarning(message: String, index: Int, offset: Int, end: Int = -1): Unit = r.warning(message, partPosAt(index, offset, end)).tap(_ => reported = true) end TypedFormatChecker diff --git a/compiler/test/dotty/tools/vulpix/ParallelTesting.scala b/compiler/test/dotty/tools/vulpix/ParallelTesting.scala index a497eed7072e..b2b4942ccb1b 100644 --- a/compiler/test/dotty/tools/vulpix/ParallelTesting.scala +++ b/compiler/test/dotty/tools/vulpix/ParallelTesting.scala @@ -14,13 +14,14 @@ import java.util.concurrent.{TimeUnit, TimeoutException, Executors => JExecutors import scala.collection.mutable import scala.io.{Codec, Source} +import scala.jdk.CollectionConverters.* import scala.util.{Random, Try, Failure => TryFailure, Success => TrySuccess, Using} import scala.util.control.NonFatal import scala.util.matching.Regex import scala.collection.mutable.ListBuffer import dotc.{Compiler, Driver} -import dotc.core.Contexts._ +import dotc.core.Contexts.* import dotc.decompiler import dotc.report import dotc.interfaces.Diagnostic.ERROR @@ -750,17 +751,26 @@ trait ParallelTesting extends RunnerOrchestration { self => def compilerCrashed = reporters.exists(_.compilerCrashed) lazy val (errorMap, expectedErrors) = getErrorMapAndExpectedCount(testSource.sourceFiles.toIndexedSeq) lazy val actualErrors = reporters.foldLeft(0)(_ + _.errorCount) - def hasMissingAnnotations = getMissingExpectedErrors(errorMap, reporters.iterator.flatMap(_.errors)) + lazy val (expected, unexpected) = getMissingExpectedErrors(errorMap, reporters.iterator.flatMap(_.errors)) + def hasMissingAnnotations = expected.nonEmpty || unexpected.nonEmpty def showErrors = "-> following the errors:\n" + - reporters.flatMap(_.allErrors.map(e => (e.pos.line + 1).toString + ": " + e.message)).mkString(start = "at ", sep = "\n at ", end = "") - - if (compilerCrashed) Some(s"Compiler crashed when compiling: ${testSource.title}") - else if (actualErrors == 0) Some(s"\nNo errors found when compiling neg test $testSource") - else if (expectedErrors == 0) Some(s"\nNo errors expected/defined in $testSource -- use // error or // nopos-error") - else if (expectedErrors != actualErrors) Some(s"\nWrong number of errors encountered when compiling $testSource\nexpected: $expectedErrors, actual: $actualErrors " + showErrors) - else if (hasMissingAnnotations) Some(s"\nErrors found on incorrect row numbers when compiling $testSource\n$showErrors") - else if (!errorMap.isEmpty) Some(s"\nExpected error(s) have {=}: $errorMap") - else None + reporters.flatMap(_.allErrors.sortBy(_.pos.line).map(e => s"${e.pos.line + 1}: ${e.message}")).mkString(" at ", "\n at ", "") + + Option { + if compilerCrashed then s"Compiler crashed when compiling: ${testSource.title}" + else if actualErrors == 0 then s"\nNo errors found when compiling neg test $testSource" + else if expectedErrors == 0 then s"\nNo errors expected/defined in $testSource -- use // error or // nopos-error" + else if expectedErrors != actualErrors then + s"""|Wrong number of errors encountered when compiling $testSource + |expected: $expectedErrors, actual: $actualErrors + |${expected.mkString("Unfulfilled expectations:\n", "\n", "")} + |${unexpected.mkString("Unexpected errors:\n", "\n", "")} + |$showErrors + |""".stripMargin.trim.linesIterator.mkString("\n", "\n", "") + else if hasMissingAnnotations then s"\nErrors found on incorrect row numbers when compiling $testSource\n$showErrors" + else if !errorMap.isEmpty then s"\nExpected error(s) have {=}: $errorMap" + else null + } } override def onSuccess(testSource: TestSource, reporters: Seq[TestReporter], logger: LoggedRunnable): Unit = @@ -783,7 +793,7 @@ trait ParallelTesting extends RunnerOrchestration { self => source.getLines.zipWithIndex.foreach { case (line, lineNbr) => val errors = line.toSeq.sliding("// error".length).count(_.unwrap == "// error") if (errors > 0) - errorMap.put(s"${file.getPath}:$lineNbr", errors) + errorMap.put(s"${file.getPath}:${lineNbr+1}", errors) val noposErrors = line.toSeq.sliding("// nopos-error".length).count(_.unwrap == "// nopos-error") if (noposErrors > 0) { @@ -813,34 +823,32 @@ trait ParallelTesting extends RunnerOrchestration { self => (errorMap, expectedErrors) } - def getMissingExpectedErrors(errorMap: HashMap[String, Integer], reporterErrors: Iterator[Diagnostic]) = !reporterErrors.forall { error => - val pos1 = error.pos.nonInlined - val key = if (pos1.exists) { - def toRelative(path: String): String = // For some reason, absolute paths leak from the compiler itself... - path.split(JFile.separatorChar).dropWhile(_ != "tests").mkString(JFile.separator) - val fileName = toRelative(pos1.source.file.toString) - s"$fileName:${pos1.line}" - - } else "nopos" - - val errors = errorMap.get(key) - - def missing = { echo(s"Error reported in ${pos1.source}, but no annotation found") ; false } - - if (errors ne null) { - if (errors == 1) errorMap.remove(key) - else errorMap.put(key, errors - 1) - true - } - else if key == "nopos" then - missing - else - errorMap.get("anypos") match - case null => missing - case 1 => errorMap.remove("anypos") ; true - case slack => if slack < 1 then missing - else errorMap.put("anypos", slack - 1) ; true - } + // return unfulfilled expected errors and unexpected diagnostics + def getMissingExpectedErrors(errorMap: HashMap[String, Integer], reporterErrors: Iterator[Diagnostic]): (List[String], List[String]) = + val unexpected, unpositioned = ListBuffer.empty[String] + // For some reason, absolute paths leak from the compiler itself... + def relativize(path: String): String = path.split(JFile.separatorChar).dropWhile(_ != "tests").mkString(JFile.separator) + def seenAt(key: String): Boolean = + errorMap.get(key) match + case null => false + case 1 => errorMap.remove(key) ; true + case n => errorMap.put(key, n - 1) ; true + def sawDiagnostic(d: Diagnostic): Unit = + d.pos.nonInlined match + case srcpos if srcpos.exists => + val key = s"${relativize(srcpos.source.file.toString)}:${srcpos.line + 1}" + if !seenAt(key) then unexpected += key + case srcpos => + if !seenAt("nopos") then unpositioned += relativize(srcpos.source.file.toString) + + reporterErrors.foreach(sawDiagnostic) + + errorMap.get("anypos") match + case n if n == unexpected.size => errorMap.remove("anypos") ; unexpected.clear() + case _ => + + (errorMap.asScala.keys.toList, (unexpected ++ unpositioned).toList) + end getMissingExpectedErrors } private final class NoCrashTest(testSources: List[TestSource], times: Int, threadLimit: Option[Int], suppressAllOutput: Boolean)(implicit summaryReport: SummaryReporting) diff --git a/tests/neg/f-interpolator-neg.check b/tests/neg/f-interpolator-neg.check new file mode 100644 index 000000000000..ea8df052589e --- /dev/null +++ b/tests/neg/f-interpolator-neg.check @@ -0,0 +1,200 @@ +-- Error: tests/neg/f-interpolator-neg.scala:4:4 ----------------------------------------------------------------------- +4 | new StringContext().f() // error + | ^^^^^^^^^^^^^^^^^^^^^ + | there are no parts +-- Error: tests/neg/f-interpolator-neg.scala:5:4 ----------------------------------------------------------------------- +5 | new StringContext("", " is ", "%2d years old").f(s) // error + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | too few arguments for interpolated string +-- Error: tests/neg/f-interpolator-neg.scala:6:4 ----------------------------------------------------------------------- +6 | new StringContext("", " is ", "%2d years old").f(s, d, d) // error + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | too many arguments for interpolated string +-- Error: tests/neg/f-interpolator-neg.scala:7:4 ----------------------------------------------------------------------- +7 | new StringContext("", "").f() // error + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | too few arguments for interpolated string +-- Error: tests/neg/f-interpolator-neg.scala:11:7 ---------------------------------------------------------------------- +11 | f"$s%b" // error + | ^ + | Found: (s : String), Required: Boolean, Null +-- Error: tests/neg/f-interpolator-neg.scala:12:7 ---------------------------------------------------------------------- +12 | f"$s%c" // error + | ^ + | Found: (s : String), Required: Char, Byte, Short, Int +-- Error: tests/neg/f-interpolator-neg.scala:13:7 ---------------------------------------------------------------------- +13 | f"$f%c" // error + | ^ + | Found: (f : Double), Required: Char, Byte, Short, Int +-- Error: tests/neg/f-interpolator-neg.scala:14:7 ---------------------------------------------------------------------- +14 | f"$s%x" // error + | ^ + | Found: (s : String), Required: Int, Long, Byte, Short, BigInt +-- Error: tests/neg/f-interpolator-neg.scala:15:7 ---------------------------------------------------------------------- +15 | f"$b%d" // error + | ^ + | Found: (b : Boolean), Required: Int, Long, Byte, Short, BigInt +-- Error: tests/neg/f-interpolator-neg.scala:16:7 ---------------------------------------------------------------------- +16 | f"$s%d" // error + | ^ + | Found: (s : String), Required: Int, Long, Byte, Short, BigInt +-- Error: tests/neg/f-interpolator-neg.scala:17:7 ---------------------------------------------------------------------- +17 | f"$f%o" // error + | ^ + | Found: (f : Double), Required: Int, Long, Byte, Short, BigInt +-- Error: tests/neg/f-interpolator-neg.scala:18:7 ---------------------------------------------------------------------- +18 | f"$s%e" // error + | ^ + | Found: (s : String), Required: Double, Float, BigDecimal +-- Error: tests/neg/f-interpolator-neg.scala:19:7 ---------------------------------------------------------------------- +19 | f"$b%f" // error + | ^ + | Found: (b : Boolean), Required: Double, Float, BigDecimal +-- Error: tests/neg/f-interpolator-neg.scala:20:9 ---------------------------------------------------------------------- +20 | f"$s%i" // error + | ^ + | illegal conversion character 'i' +-- Error: tests/neg/f-interpolator-neg.scala:24:9 ---------------------------------------------------------------------- +24 | f"$s%+ 0,(s" // error + | ^^^^^ + | Illegal flag '+' +-- Error: tests/neg/f-interpolator-neg.scala:25:9 ---------------------------------------------------------------------- +25 | f"$c%#+ 0,(c" // error + | ^^^^^^ + | Only '-' allowed for c conversion +-- Error: tests/neg/f-interpolator-neg.scala:26:9 ---------------------------------------------------------------------- +26 | f"$d%#d" // error + | ^ + | # not allowed for d conversion +-- Error: tests/neg/f-interpolator-neg.scala:27:9 ---------------------------------------------------------------------- +27 | f"$d%,x" // error + | ^ + | ',' only allowed for d conversion of integral types +-- Error: tests/neg/f-interpolator-neg.scala:28:9 ---------------------------------------------------------------------- +28 | f"$d%+ (x" // error + | ^^^ + | only use '+' for BigInt conversions to o, x, X +-- Error: tests/neg/f-interpolator-neg.scala:29:9 ---------------------------------------------------------------------- +29 | f"$f%,(a" // error + | ^^ + | ',' not allowed for a, A +-- Error: tests/neg/f-interpolator-neg.scala:30:9 ---------------------------------------------------------------------- +30 | f"$t%#+ 0,(tT" // error + | ^^^^^^ + | Only '-' allowed for date/time conversions +-- Error: tests/neg/f-interpolator-neg.scala:31:7 ---------------------------------------------------------------------- +31 | f"%-#+ 0,(n" // error + | ^^^^^^^ + | flags not allowed +-- Error: tests/neg/f-interpolator-neg.scala:32:7 ---------------------------------------------------------------------- +32 | f"%#+ 0,(%" // error + | ^^^^^^ + | Illegal flag '#' +-- Error: tests/neg/f-interpolator-neg.scala:36:9 ---------------------------------------------------------------------- +36 | f"$c%.2c" // error + | ^^ + | precision not allowed +-- Error: tests/neg/f-interpolator-neg.scala:37:9 ---------------------------------------------------------------------- +37 | f"$d%.2d" // error + | ^^ + | precision not allowed +-- Error: tests/neg/f-interpolator-neg.scala:38:7 ---------------------------------------------------------------------- +38 | f"%.2%" // error + | ^^ + | precision not allowed +-- Error: tests/neg/f-interpolator-neg.scala:39:7 ---------------------------------------------------------------------- +39 | f"%.2n" // error + | ^^ + | precision not allowed +-- Error: tests/neg/f-interpolator-neg.scala:40:9 ---------------------------------------------------------------------- +40 | f"$f%.2a" // error + | ^^ + | precision not allowed +-- Error: tests/neg/f-interpolator-neg.scala:41:9 ---------------------------------------------------------------------- +41 | f"$t%.2tT" // error + | ^^ + | precision not allowed +-- Error: tests/neg/f-interpolator-neg.scala:45:7 ---------------------------------------------------------------------- +45 | f"% Date: Thu, 3 Feb 2022 13:02:20 -0800 Subject: [PATCH 09/10] Use normal string escaping for f --- compiler/src/dotty/tools/dotc/parsing/Scanners.scala | 2 +- compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala | 2 +- .../src/scala/quoted/runtime/impl/printers/SourceCode.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/parsing/Scanners.scala b/compiler/src/dotty/tools/dotc/parsing/Scanners.scala index 1fee0f52f770..18f4f542b86c 100644 --- a/compiler/src/dotty/tools/dotc/parsing/Scanners.scala +++ b/compiler/src/dotty/tools/dotc/parsing/Scanners.scala @@ -1253,7 +1253,7 @@ object Scanners { nextChar() } } - val alt = if oct == LF then raw"\n" else f"${"\\"}u$oct%04x" + val alt = if oct == LF then raw"\n" else f"\\u$oct%04x" error(s"octal escape literals are unsupported: use $alt instead", start) putChar(oct.toChar) } diff --git a/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala b/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala index 197a2e6ded9c..dd5d55b21f50 100644 --- a/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala +++ b/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala @@ -542,7 +542,7 @@ class PlainPrinter(_ctx: Context) extends Printer { case '"' => "\\\"" case '\'' => "\\\'" case '\\' => "\\\\" - case _ => if (ch.isControl) f"${"\\"}u${ch.toInt}%04x" else String.valueOf(ch) + case _ => if ch.isControl then f"\\u${ch.toInt}%04x" else String.valueOf(ch) } def toText(const: Constant): Text = const.tag match { diff --git a/compiler/src/scala/quoted/runtime/impl/printers/SourceCode.scala b/compiler/src/scala/quoted/runtime/impl/printers/SourceCode.scala index b259b3f21b86..88ee3e985277 100644 --- a/compiler/src/scala/quoted/runtime/impl/printers/SourceCode.scala +++ b/compiler/src/scala/quoted/runtime/impl/printers/SourceCode.scala @@ -1423,7 +1423,7 @@ object SourceCode { case '"' => "\\\"" case '\'' => "\\\'" case '\\' => "\\\\" - case _ => if (ch.isControl) f"${"\\"}u${ch.toInt}%04x" else String.valueOf(ch) + case _ => if ch.isControl then f"\\u${ch.toInt}%04x" else String.valueOf(ch) } private def escapedString(str: String): String = str flatMap escapedChar From 0c87244bccf993085d75904d425572753bcde704 Mon Sep 17 00:00:00 2001 From: Som Snytt Date: Fri, 28 Jan 2022 04:16:46 -0800 Subject: [PATCH 10/10] Collect error comments Prefer a regex for collecting magic error comments. Allow arbitrary space after line comment but warn that no space `//error` disables that error, which is useful for testing and development. --- .../dotty/tools/vulpix/ParallelTesting.scala | 43 +++++++------------ 1 file changed, 16 insertions(+), 27 deletions(-) diff --git a/compiler/test/dotty/tools/vulpix/ParallelTesting.scala b/compiler/test/dotty/tools/vulpix/ParallelTesting.scala index b2b4942ccb1b..bd30e7fff98e 100644 --- a/compiler/test/dotty/tools/vulpix/ParallelTesting.scala +++ b/compiler/test/dotty/tools/vulpix/ParallelTesting.scala @@ -785,43 +785,32 @@ trait ParallelTesting extends RunnerOrchestration { self => // // We collect these in a map `"file:row" -> numberOfErrors`, for // nopos errors we save them in `"file" -> numberOfNoPosErrors` - def getErrorMapAndExpectedCount(files: Seq[JFile]): (HashMap[String, Integer], Int) = { + def getErrorMapAndExpectedCount(files: Seq[JFile]): (HashMap[String, Integer], Int) = + val comment = raw"//( *)(nopos-|anypos-)?error".r val errorMap = new HashMap[String, Integer]() var expectedErrors = 0 + def bump(key: String): Unit = + errorMap.get(key) match + case null => errorMap.put(key, 1) + case n => errorMap.put(key, n+1) + expectedErrors += 1 files.filter(isSourceFile).foreach { file => Using(Source.fromFile(file, StandardCharsets.UTF_8.name)) { source => source.getLines.zipWithIndex.foreach { case (line, lineNbr) => - val errors = line.toSeq.sliding("// error".length).count(_.unwrap == "// error") - if (errors > 0) - errorMap.put(s"${file.getPath}:${lineNbr+1}", errors) - - val noposErrors = line.toSeq.sliding("// nopos-error".length).count(_.unwrap == "// nopos-error") - if (noposErrors > 0) { - val nopos = errorMap.get("nopos") - val existing: Integer = if (nopos eq null) 0 else nopos - errorMap.put("nopos", noposErrors + existing) - } - - val anyposErrors = line.toSeq.sliding("// anypos-error".length).count(_.unwrap == "// anypos-error") - if (anyposErrors > 0) { - val anypos = errorMap.get("anypos") - val existing: Integer = if (anypos eq null) 0 else anypos - errorMap.put("anypos", anyposErrors + existing) - } - - val possibleTypos = List("//error" -> "// error", "//nopos-error" -> "// nopos-error", "//anypos-error" -> "// anypos-error") - for ((possibleTypo, expected) <- possibleTypos) { - if (line.contains(possibleTypo)) - echo(s"Warning: Possible typo in error tag in file ${file.getCanonicalPath}:$lineNbr: found `$possibleTypo` but expected `$expected`") + comment.findAllMatchIn(line).foreach { m => + m.group(2) match + case prefix if m.group(1).isEmpty => + val what = Option(prefix).getOrElse("") + echo(s"Warning: ${file.getCanonicalPath}:${lineNbr}: found `//${what}error` but expected `// ${what}error`, skipping comment") + case "nopos-" => bump("nopos") + case "anypos-" => bump("anypos") + case _ => bump(s"${file.getPath}:${lineNbr+1}") } - - expectedErrors += anyposErrors + noposErrors + errors } }.get } - (errorMap, expectedErrors) - } + end getErrorMapAndExpectedCount // return unfulfilled expected errors and unexpected diagnostics def getMissingExpectedErrors(errorMap: HashMap[String, Integer], reporterErrors: Iterator[Diagnostic]): (List[String], List[String]) =