diff --git a/compiler/src/dotty/tools/dotc/parsing/Scanners.scala b/compiler/src/dotty/tools/dotc/parsing/Scanners.scala index 61e31ff1dce8..a850e0068bb7 100644 --- a/compiler/src/dotty/tools/dotc/parsing/Scanners.scala +++ b/compiler/src/dotty/tools/dotc/parsing/Scanners.scala @@ -1253,7 +1253,7 @@ object Scanners { nextChar() } } - val alt = if oct == LF then raw"\n" else f"\u$oct%04x" + val alt = if oct == LF then raw"\n" else f"${"\\"}u$oct%04x" error(s"octal escape literals are unsupported: use $alt instead", start) putChar(oct.toChar) } diff --git a/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala b/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala index c412afaf0487..b66474f0433b 100644 --- a/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala +++ b/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala @@ -526,7 +526,7 @@ class PlainPrinter(_ctx: Context) extends Printer { case '"' => "\\\"" case '\'' => "\\\'" case '\\' => "\\\\" - case _ => if (ch.isControl) f"\u${ch.toInt}%04x" else String.valueOf(ch) + case _ => if (ch.isControl) f"${"\\"}u${ch.toInt}%04x" else String.valueOf(ch) } def toText(const: Constant): Text = const.tag match { diff --git a/compiler/src/dotty/tools/dotc/transform/localopt/FormatChecker.scala b/compiler/src/dotty/tools/dotc/transform/localopt/FormatChecker.scala new file mode 100644 index 000000000000..1cbd73db2ba0 --- /dev/null +++ b/compiler/src/dotty/tools/dotc/transform/localopt/FormatChecker.scala @@ -0,0 +1,247 @@ +package dotty.tools.dotc +package transform.localopt + +import scala.annotation.tailrec +import scala.collection.mutable.{ListBuffer, Stack} +import scala.reflect.{ClassTag, classTag} +import scala.util.chaining.* +import scala.util.matching.Regex.Match + +import java.util.{Calendar, Date, Formattable} + +import StringContextChecker.InterpolationReporter + +/** Formatter string checker. */ +abstract class FormatChecker(using reporter: InterpolationReporter): + + // Pick the first runtime type which the i'th arg can satisfy. + // If conversion is required, implementation must emit it. + def argType(argi: Int, types: ClassTag[?]*): ClassTag[?] + + // count of args, for checking indexes + def argc: Int + + val allFlags = "-#+ 0,(<" + val formatPattern = """%(?:(\d+)\$)?([-#+ 0,(<]+)?(\d+)?(\.\d+)?([tT]?[%a-zA-Z])?""".r + + // ordinal is the regex group index in the format pattern + enum SpecGroup: + case Spec, Index, Flags, Width, Precision, CC + import SpecGroup.* + + /** For N part strings and N-1 args to interpolate, normalize parts and check arg types. + * + * Returns parts, possibly updated with explicit leading "%s", + * and conversions for each arg. + * + * Implementation must emit conversions required by invocations of `argType`. + */ + def checked(parts0: List[String]): (List[String], List[Conversion]) = + val amended = ListBuffer.empty[String] + val convert = ListBuffer.empty[Conversion] + + @tailrec + def loop(parts: List[String], n: Int): Unit = + parts match + case part0 :: more => + def badPart(t: Throwable): String = "".tap(_ => reporter.partError(t.getMessage, index = n, offset = 0)) + val part = try StringContext.processEscapes(part0) catch badPart + val matches = formatPattern.findAllMatchIn(part) + + def insertStringConversion(): Unit = + amended += "%s" + part + convert += Conversion(formatPattern.findAllMatchIn("%s").next(), n) // improve + argType(n-1, classTag[Any]) + def errorLeading(op: Conversion) = op.errorAt(Spec)(s"conversions must follow a splice; ${Conversion.literalHelp}") + def accept(op: Conversion): Unit = + if !op.isLeading then errorLeading(op) + op.accepts(argType(n-1, op.acceptableVariants*)) + amended += part + convert += op + + // after the first part, a leading specifier is required for the interpolated arg; %s is supplied if needed + if n == 0 then amended += part + else if !matches.hasNext then insertStringConversion() + else + val cv = Conversion(matches.next(), n) + if cv.isLiteral then insertStringConversion() + else if cv.isIndexed then + if cv.index.getOrElse(-1) == n then accept(cv) + else + // either some other arg num, or '<' + //c.warning(op.groupPos(Index), "Index is not this arg") + insertStringConversion() + else if !cv.isError then accept(cv) + + // any remaining conversions in this part must be either literals or indexed + while matches.hasNext do + val cv = Conversion(matches.next(), n) + if n == 0 && cv.hasFlag('<') then cv.badFlag('<', "No last arg") + else if !cv.isLiteral && !cv.isIndexed then errorLeading(cv) + + loop(more, n + 1) + case Nil => () + end loop + + loop(parts0, n = 0) + (amended.toList, convert.toList) + end checked + + extension (descriptor: Match) + def at(g: SpecGroup): Int = descriptor.start(g.ordinal) + def offset(g: SpecGroup, i: Int = 0): Int = at(g) + i + def group(g: SpecGroup): Option[String] = Option(descriptor.group(g.ordinal)) + def stringOf(g: SpecGroup): String = group(g).getOrElse("") + def intOf(g: SpecGroup): Option[Int] = group(g).map(_.toInt) + + extension (inline value: Boolean) + inline def or(inline body: => Unit): Boolean = value || { body ; false } + inline def orElse(inline body: => Unit): Boolean = value || { body ; true } + inline def but(inline body: => Unit): Boolean = value && { body ; false } + inline def and(inline body: => Unit): Boolean = value && { body ; true } + + /** A conversion specifier matched in the argi'th string part, + * with `argc` arguments to interpolate. + */ + sealed abstract class Conversion: + // the match for this descriptor + def descriptor: Match + // the part number for reporting errors + def argi: Int + + // the descriptor fields + val index: Option[Int] = descriptor.intOf(Index) + val flags: String = descriptor.stringOf(Flags) + val width: Option[Int] = descriptor.intOf(Width) + val precision: Option[Int] = descriptor.group(Precision).map(_.drop(1).toInt) + val op: String = descriptor.stringOf(CC) + + // the conversion char is the head of the op string (but see DateTimeXn) + val cc: Char = if isError then '?' else op(0) + + def isError: Boolean = false + def isIndexed: Boolean = index.nonEmpty || hasFlag('<') + def isLiteral: Boolean = false + + // descriptor is at index 0 of the part string + def isLeading: Boolean = descriptor.at(Spec) == 0 + + // true if passes. Default checks flags and index + def verify: Boolean = goodFlags && goodIndex + + // is the specifier OK with the given arg + def accepts(arg: ClassTag[?]): Boolean = true + + // what arg type if any does the conversion accept + def acceptableVariants: List[ClassTag[?]] + + // what flags does the conversion accept? defaults to all + protected def okFlags: String = allFlags + + def hasFlag(f: Char) = flags.contains(f) + def hasAnyFlag(fs: String) = fs.exists(hasFlag) + + def badFlag(f: Char, msg: String) = + val i = flags.indexOf(f) match { case -1 => 0 case j => j } + errorAt(Flags, i)(msg) + + def errorAt(g: SpecGroup, i: Int = 0)(msg: String) = reporter.partError(msg, argi, descriptor.offset(g, i)) + def warningAt(g: SpecGroup, i: Int = 0)(msg: String) = reporter.partWarning(msg, argi, descriptor.offset(g, i)) + + def noFlags = flags.isEmpty or errorAt(Flags)("flags not allowed") + def noWidth = width.isEmpty or errorAt(Width)("width not allowed") + def noPrecision = precision.isEmpty or errorAt(Precision)("precision not allowed") + def only_-(msg: String) = + val badFlags = flags.filterNot { case '-' | '<' => true case _ => false } + badFlags.isEmpty or badFlag(badFlags(0), s"Only '-' allowed for $msg") + def goodFlags = + val badFlags = flags.filterNot(okFlags.contains) + for f <- badFlags do badFlag(f, s"Illegal flag '$f'") + badFlags.isEmpty + def goodIndex = + if index.nonEmpty && hasFlag('<') then warningAt(Index)("Argument index ignored if '<' flag is present") + val okRange = index.map(i => i > 0 && i <= argc).getOrElse(true) + okRange || hasFlag('<') or errorAt(Index)("Argument index out of range") + object Conversion: + def apply(m: Match, i: Int): Conversion = + def badCC(msg: String) = ErrorXn(m, i).tap(error => error.errorAt(if (error.op.isEmpty) Spec else CC)(msg)) + def cv(cc: Char) = cc match + case 's' | 'S' => StringXn(m, i) + case 'h' | 'H' => HashXn(m, i) + case 'b' | 'B' => BooleanXn(m, i) + case 'c' | 'C' => CharacterXn(m, i) + case 'd' | 'o' | + 'x' | 'X' => IntegralXn(m, i) + case 'e' | 'E' | + 'f' | + 'g' | 'G' | + 'a' | 'A' => FloatingPointXn(m, i) + case 't' | 'T' => DateTimeXn(m, i) + case '%' | 'n' => LiteralXn(m, i) + case _ => badCC(s"illegal conversion character '$cc'") + end cv + m.group(CC) match + case Some(cc) => cv(cc(0)).tap(_.verify) + case None => badCC(s"Missing conversion operator in '${m.matched}'; $literalHelp") + end apply + val literalHelp = "use %% for literal %, %n for newline" + end Conversion + abstract class GeneralXn extends Conversion + // s | S + class StringXn(val descriptor: Match, val argi: Int) extends GeneralXn: + val acceptableVariants = + if hasFlag('#') then classTag[Formattable] :: Nil + else classTag[Any] :: Nil + override protected def okFlags = "-#<" + // b | B + class BooleanXn(val descriptor: Match, val argi: Int) extends GeneralXn: + val FakeNullTag: ClassTag[?] = null + val acceptableVariants = classTag[Boolean] :: FakeNullTag :: Nil + override def accepts(arg: ClassTag[?]): Boolean = + arg == classTag[Boolean] orElse warningAt(CC)("Boolean format is null test for non-Boolean") + override protected def okFlags = "-<" + // h | H + class HashXn(val descriptor: Match, val argi: Int) extends GeneralXn: + val acceptableVariants = classTag[Any] :: Nil + override protected def okFlags = "-<" + // %% | %n + class LiteralXn(val descriptor: Match, val argi: Int) extends Conversion: + override def isLiteral = true + override def verify = op match + case "%" => super.verify && noPrecision and width.foreach(_ => warningAt(Width)("width ignored on literal")) + case "n" => noFlags && noWidth && noPrecision + override protected val okFlags = "-" + override def acceptableVariants = Nil + class CharacterXn(val descriptor: Match, val argi: Int) extends Conversion: + override def verify = super.verify && noPrecision && only_-("c conversion") + val acceptableVariants = classTag[Char] :: classTag[Byte] :: classTag[Short] :: classTag[Int] :: Nil + class IntegralXn(val descriptor: Match, val argi: Int) extends Conversion: + override def verify = + def d_# = cc == 'd' && hasFlag('#') and badFlag('#', "# not allowed for d conversion") + def x_comma = cc != 'd' && hasFlag(',') and badFlag(',', "',' only allowed for d conversion of integral types") + super.verify && noPrecision && !d_# && !x_comma + val acceptableVariants = classTag[Int] :: classTag[Long] :: classTag[Byte] :: classTag[Short] :: classTag[BigInt] :: Nil + override def accepts(arg: ClassTag[?]): Boolean = + arg == classTag[BigInt] || { + cc match + case 'o' | 'x' | 'X' if hasAnyFlag("+ (") => "+ (".filter(hasFlag).foreach(bad => badFlag(bad, s"only use '$bad' for BigInt conversions to o, x, X")) ; false + case _ => true + } + class FloatingPointXn(val descriptor: Match, val argi: Int) extends Conversion: + override def verify = super.verify && (cc match { + case 'a' | 'A' => + val badFlags = ",(".filter(hasFlag) + noPrecision && badFlags.isEmpty or badFlags.foreach(badf => badFlag(badf, s"'$badf' not allowed for a, A")) + case _ => true + }) + val acceptableVariants = classTag[Double] :: classTag[Float] :: classTag[BigDecimal] :: Nil + class DateTimeXn(val descriptor: Match, val argi: Int) extends Conversion: + override val cc: Char = if op.length > 1 then op(1) else '?' + def hasCC = op.length == 2 or errorAt(CC)("Date/time conversion must have two characters") + def goodCC = "HIklMSLNpzZsQBbhAaCYyjmdeRTrDFc".contains(cc) or errorAt(CC, 1)(s"'$cc' doesn't seem to be a date or time conversion") + override def verify = super.verify && hasCC && goodCC && noPrecision && only_-("date/time conversions") + val acceptableVariants = classTag[Long] :: classTag[Calendar] :: classTag[Date] :: Nil + class ErrorXn(val descriptor: Match, val argi: Int) extends Conversion: + override def isError = true + override def verify = false + override def acceptableVariants = Nil diff --git a/compiler/src/dotty/tools/dotc/transform/localopt/FormatInterpolatorTransform.scala b/compiler/src/dotty/tools/dotc/transform/localopt/FormatInterpolatorTransform.scala new file mode 100644 index 000000000000..820b42d2c225 --- /dev/null +++ b/compiler/src/dotty/tools/dotc/transform/localopt/FormatInterpolatorTransform.scala @@ -0,0 +1,195 @@ +package dotty.tools.dotc +package transform.localopt + +import dotty.tools.dotc.ast.Trees._ +import dotty.tools.dotc.ast.tpd +import dotty.tools.dotc.core.Decorators._ +import dotty.tools.dotc.core.Constants.Constant +import dotty.tools.dotc.core.Contexts._ +import dotty.tools.dotc.core.StdNames._ +import dotty.tools.dotc.core.NameKinds._ +import dotty.tools.dotc.core.Symbols._ +import dotty.tools.dotc.core.Types._ +import dotty.tools.dotc.core.Phases.typerPhase +import dotty.tools.dotc.typer.ProtoTypes._ + +import scala.StringContext.processEscapes +import scala.annotation.tailrec +import scala.collection.mutable.{ListBuffer, Stack} +import scala.reflect.{ClassTag, classTag} +import scala.util.chaining._ +import scala.util.matching.Regex.Match + +object FormatInterpolatorTransform: + import tpd._ + import StringContextChecker.InterpolationReporter + + /* + /** This trait defines a tool to report errors/warnings that do not depend on Position. */ + trait InterpolationReporter: + + /** Reports error/warning of size 1 linked with a part of the StringContext. + * + * @param message the message to report as error/warning + * @param index the index of the part inside the list of parts of the StringContext + * @param offset the index in the part String where the error is + * @return an error/warning depending on the function + */ + def partError(message: String, index: Int, offset: Int): Unit + def partWarning(message: String, index: Int, offset: Int): Unit + + /** Reports error linked with an argument to format. + * + * @param message the message to report as error/warning + * @param index the index of the argument inside the list of arguments of the format function + * @return an error depending on the function + */ + def argError(message: String, index: Int): Unit + + /** Reports error linked with the list of arguments or the StringContext. + * + * @param message the message to report in the error + * @return an error + */ + def strCtxError(message: String): Unit + def argsError(message: String): Unit + + /** Claims whether an error or a warning has been reported + * + * @return true if an error/warning has been reported, false + */ + def hasReported: Boolean + + /** Stores the old value of the reported and reset it to false */ + def resetReported(): Unit + + /** Restores the value of the reported boolean that has been reset */ + def restoreReported(): Unit + end InterpolationReporter + */ + class PartsReporter(fun: Tree, args0: Tree, parts: List[Tree], args: List[Tree])(using Context) extends InterpolationReporter: + private var reported = false + private var oldReported = false + private def partPosAt(index: Int, offset: Int) = + val pos = parts(index).sourcePos + pos.withSpan(pos.span.shift(offset)) + def partError(message: String, index: Int, offset: Int): Unit = + reported = true + report.error(message, partPosAt(index, offset)) + def partWarning(message: String, index: Int, offset: Int): Unit = + reported = true + report.warning(message, partPosAt(index, offset)) + def argError(message: String, index: Int): Unit = + reported = true + report.error(message, args(index).srcPos) + def strCtxError(message: String): Unit = + reported = true + report.error(message, fun.srcPos) + def argsError(message: String): Unit = + reported = true + report.error(message, args0.srcPos) + def hasReported: Boolean = reported + def resetReported(): Unit = + oldReported = reported + reported = false + def restoreReported(): Unit = reported = oldReported + end PartsReporter + object tags: + import java.util.{Calendar, Date, Formattable} + val StringTag = classTag[String] + val FormattableTag = classTag[Formattable] + val BigIntTag = classTag[BigInt] + val BigDecimalTag = classTag[BigDecimal] + val CalendarTag = classTag[Calendar] + val DateTag = classTag[Date] + class FormattableTypes(using Context): + val FormattableType = requiredClassRef("java.util.Formattable") + val BigIntType = requiredClassRef("scala.math.BigInt") + val BigDecimalType = requiredClassRef("scala.math.BigDecimal") + val CalendarType = requiredClassRef("java.util.Calendar") + val DateType = requiredClassRef("java.util.Date") + class TypedFormatChecker(val args: List[Tree])(using Context, InterpolationReporter) extends FormatChecker: + val reporter = summon[InterpolationReporter] + val argTypes = args.map(_.tpe) + val actuals = ListBuffer.empty[Tree] + val argc = argTypes.length + def argType(argi: Int, types: Seq[ClassTag[?]]) = + require(argi < argc, s"$argi out of range picking from $types") + val tpe = argTypes(argi) + types.find(t => argConformsTo(argi, tpe, argTypeOf(t))) + .orElse(types.find(t => argConvertsTo(argi, tpe, argTypeOf(t)))) + .getOrElse { + reporter.argError(s"Found: ${tpe.show}, Required: ${types.mkString(", ")}", argi) + actuals += args(argi) + types.head + } + final lazy val fmtTypes = FormattableTypes() + import tags.*, fmtTypes.* + def argConformsTo(argi: Int, arg: Type, target: Type): Boolean = + (arg <:< target).tap(if _ then actuals += args(argi)) + def argConvertsTo(argi: Int, arg: Type, target: Type): Boolean = + import typer.Implicits.SearchSuccess + atPhase(typerPhase) { + ctx.typer.inferView(args(argi), target) match + case SearchSuccess(view, ref, _, _) => actuals += view ; true + case _ => false + } + def argTypeOf(tag: ClassTag[?]): Type = tag match + case StringTag => defn.StringType + case ClassTag.Boolean => defn.BooleanType + case ClassTag.Byte => defn.ByteType + case ClassTag.Char => defn.CharType + case ClassTag.Short => defn.ShortType + case ClassTag.Int => defn.IntType + case ClassTag.Long => defn.LongType + case ClassTag.Float => defn.FloatType + case ClassTag.Double => defn.DoubleType + case ClassTag.Any => defn.AnyType + case ClassTag.AnyRef => defn.AnyRefType + case FormattableTag => FormattableType + case BigIntTag => BigIntType + case BigDecimalTag => BigDecimalType + case CalendarTag => CalendarType + case DateTag => DateType + case null => defn.NullType + case _ => reporter.strCtxError(s"Unknown type for format $tag") + defn.AnyType + end TypedFormatChecker + + /** For f"${arg}%xpart", check format conversions and return (format, args) + * suitable for String.format(format, args). + */ + def checked(fun: Tree, args0: Tree)(using Context): (Tree, Tree) = + val (partsExpr, parts) = fun match + case TypeApply(Select(Apply(_, (parts: SeqLiteral) :: Nil), _), _) => + (parts.elems, parts.elems.map { case Literal(Constant(s: String)) => s }) + case _ => + report.error("Expected statically known StringContext", fun.srcPos) + (Nil, Nil) + val (args, elemtpt) = args0 match + case seqlit: SeqLiteral => (seqlit.elems, seqlit.elemtpt) + case _ => + report.error("Expected statically known argument list", args0.srcPos) + (Nil, EmptyTree) + given reporter: InterpolationReporter = PartsReporter(fun, args0, partsExpr, args) + + def literally(s: String) = Literal(Constant(s)) + inline val skip = false + if parts.lengthIs != args.length + 1 then + reporter.strCtxError { + if parts.isEmpty then "there are no parts" + else s"too ${if parts.lengthIs > args.length + 1 then "few" else "many"} arguments for interpolated string" + } + (literally(""), args0) + else if skip then + val checked = parts.head :: parts.tail.map(p => if p.startsWith("%") then p else "%s" + p) + (literally(checked.mkString), args0) + else + val checker = TypedFormatChecker(args) + val (checked, cvs) = checker.checked(parts) + if reporter.hasReported then (literally(parts.mkString), args0) + else + assert(checker.argc == checker.actuals.size, s"Expected ${checker.argc}, actuals size is ${checker.actuals.size} for [${parts.mkString(", ")}]") + (literally(checked.mkString), tpd.SeqLiteral(checker.actuals.toList, elemtpt)) + end checked +end FormatInterpolatorTransform diff --git a/compiler/src/dotty/tools/dotc/transform/localopt/StringContextChecker.scala b/compiler/src/dotty/tools/dotc/transform/localopt/StringContextChecker.scala index fbd09f43b853..18f296af8c4e 100644 --- a/compiler/src/dotty/tools/dotc/transform/localopt/StringContextChecker.scala +++ b/compiler/src/dotty/tools/dotc/transform/localopt/StringContextChecker.scala @@ -49,7 +49,7 @@ object StringContextChecker { * * @return true if an error/warning has been reported, false */ - def hasReported() : Boolean + def hasReported : Boolean /** Stores the old value of the reported and reset it to false */ def resetReported() : Unit @@ -106,7 +106,7 @@ object StringContextChecker { report.error(message, args0.srcPos) } - def hasReported() : Boolean = { + def hasReported : Boolean = { reported } @@ -160,7 +160,7 @@ object StringContextChecker { case p :: parts1 => p :: parts1.map((part : String) => { if (!part.startsWith("%")) { val index = part.indexOf('%') - if (!reporter.hasReported() && index != -1) { + if (!reporter.hasReported && index != -1) { reporter.partError("conversions must follow a splice; use %% for literal %, %n for newline", parts.indexOf(part), index) "%s" + part } else "%s" + part @@ -296,7 +296,7 @@ object StringContextChecker { } //conversion - if((conversion >= l || (!part.charAt(conversion).isLetter && part.charAt(conversion) != '%')) && !reporter.hasReported()) + if((conversion >= l || (!part.charAt(conversion).isLetter && part.charAt(conversion) != '%')) && !reporter.hasReported) reporter.partError("Missing conversion operator in '" + part.substring(pos, conversion) + "'; use %% for literal %, %n for newline", partIndex, pos) val hasWidth = (hasWidth1 && !hasArgumentIndex) || hasWidth2 @@ -337,7 +337,7 @@ object StringContextChecker { if (argumentIndex > maxArgumentIndex || argumentIndex <= 0) reporter.partError("Argument index out of range", partIndex, offset) - if (expected && expectedArgumentIndex != argumentIndex && !reporter.hasReported()) + if (expected && expectedArgumentIndex != argumentIndex && !reporter.hasReported) reporter.partWarning("Index is not this arg", partIndex, offset) } @@ -376,10 +376,10 @@ object StringContextChecker { def checkUniqueFlags(partIndex : Int, flags : List[(Char, Int)], notAllowedFlagOnCondition : List[(Char, Boolean, String)]) = { reporter.resetReported() for {flag <- flags ; (nonAllowedFlag, condition, message) <- notAllowedFlagOnCondition ; if (flag._1 == nonAllowedFlag && condition)} { - if (!reporter.hasReported()) + if (!reporter.hasReported) reporter.partError(message, partIndex, flag._2) } - if (!reporter.hasReported()) + if (!reporter.hasReported) reporter.restoreReported() } @@ -656,7 +656,7 @@ object StringContextChecker { argument match { case Some(argIndex, arg) => { val (hasArgumentIndex, argumentIndex, flags, hasWidth, width, hasPrecision, precision, hasRelative, relativeIndex, conversion) = getFormatSpecifiers(part, argIndex, argIndex + 1, false, formattingStart) - if (!reporter.hasReported()){ + if (!reporter.hasReported){ val conversionWithType = checkFormatSpecifiers(argIndex + 1, hasArgumentIndex, argumentIndex, Some(argIndex + 1), start == 0, maxArgumentIndex, hasRelative, hasWidth, width, hasPrecision, precision, flags, conversion, Some(arg.tpe), part) nextStart = conversion + 1 conversionWithType :: checkPart(part, nextStart, argument, maxArgumentIndex) @@ -668,10 +668,10 @@ object StringContextChecker { reporter.partError("Argument index out of range", 0, argumentIndex) if (hasRelative) reporter.partError("No last arg", 0, relativeIndex) - if (!reporter.hasReported()){ + if (!reporter.hasReported){ val conversionWithType = checkFormatSpecifiers(0, hasArgumentIndex, argumentIndex, None, start == 0, maxArgumentIndex, hasRelative, hasWidth, width, hasPrecision, precision, flags, conversion, None, part) nextStart = conversion + 1 - if (!reporter.hasReported() && part.charAt(conversion) != '%' && part.charAt(conversion) != 'n' && !hasArgumentIndex && !hasRelative) + if (!reporter.hasReported && part.charAt(conversion) != '%' && part.charAt(conversion) != 'n' && !hasArgumentIndex && !hasRelative) reporter.partError("conversions must follow a splice; use %% for literal %, %n for newline", 0, part.indexOf('%')) conversionWithType :: checkPart(part, nextStart, argument, maxArgumentIndex) } else checkPart(part, conversion + 1, argument, maxArgumentIndex) @@ -691,10 +691,10 @@ object StringContextChecker { // add default format val parts = addDefaultFormat(parts0) - if (!parts.isEmpty && !reporter.hasReported()) { + if (!parts.isEmpty && !reporter.hasReported) { if (parts.size == 1 && args.size == 0 && parts.head.size != 0){ val argTypeWithConversion = checkPart(parts.head, 0, None, None) - if (!reporter.hasReported()) + if (!reporter.hasReported) for ((argument, conversionChar, flags) <- argTypeWithConversion) checkArgTypeWithConversion(0, conversionChar, argument, flags, parts.head.indexOf('%')) } else { @@ -702,7 +702,7 @@ object StringContextChecker { for (i <- (0 until args.size)){ val (part, arg) = partWithArgs(i) val argTypeWithConversion = checkPart(part, 0, Some((i, arg)), Some(args.size)) - if (!reporter.hasReported()) + if (!reporter.hasReported) for ((argument, conversionChar, flags) <- argTypeWithConversion) checkArgTypeWithConversion(i + 1, conversionChar, argument, flags, parts(i).indexOf('%')) } diff --git a/compiler/src/dotty/tools/dotc/transform/localopt/StringInterpolatorOpt.scala b/compiler/src/dotty/tools/dotc/transform/localopt/StringInterpolatorOpt.scala index 90b2b4f4cabf..a1458e5acb1b 100644 --- a/compiler/src/dotty/tools/dotc/transform/localopt/StringInterpolatorOpt.scala +++ b/compiler/src/dotty/tools/dotc/transform/localopt/StringInterpolatorOpt.scala @@ -120,6 +120,11 @@ class StringInterpolatorOpt extends MiniPhase { (sym.name == nme.raw_ && sym.eq(defn.StringContext_raw)) || (sym.name == nme.f && sym.eq(defn.StringContext_f)) || (sym.name == nme.s && sym.eq(defn.StringContext_s)) + def transformF(fun: Tree, args: Tree): Tree = + val (parts1, args1) = FormatInterpolatorTransform.checked(fun, args) + resolveConstructor(defn.StringOps.typeRef, List(parts1)) + .select(nme.format) + .appliedTo(args1) if (isInterpolatedMethod) (tree: @unchecked) match { case StringContextIntrinsic(strs: List[Literal], elems: List[Tree]) => @@ -136,10 +141,7 @@ class StringInterpolatorOpt extends MiniPhase { } result case Apply(intp, args :: Nil) if sym.eq(defn.StringContext_f) => - val partsStr = StringContextChecker.checkedParts(intp, args).mkString - resolveConstructor(defn.StringOps.typeRef, List(Literal(Constant(partsStr)))) - .select(nme.format) - .appliedTo(args) + transformF(intp, args) // Starting with Scala 2.13, s and raw are macros in the standard // library, so we need to expand them manually. // sc.s(args) --> standardInterpolator(processEscapes, args, sc.parts) diff --git a/compiler/src/scala/quoted/runtime/impl/printers/SourceCode.scala b/compiler/src/scala/quoted/runtime/impl/printers/SourceCode.scala index 77a54e23ad61..5a9bff781434 100644 --- a/compiler/src/scala/quoted/runtime/impl/printers/SourceCode.scala +++ b/compiler/src/scala/quoted/runtime/impl/printers/SourceCode.scala @@ -1423,7 +1423,7 @@ object SourceCode { case '"' => "\\\"" case '\'' => "\\\'" case '\\' => "\\\\" - case _ => if (ch.isControl) f"\u${ch.toInt}%04x" else String.valueOf(ch) + case _ => if (ch.isControl) f"${"\\"}u${ch.toInt}%04x" else String.valueOf(ch) } private def escapedString(str: String): String = str flatMap escapedChar diff --git a/compiler/test/dotty/tools/dotc/transform/FormatCheckerTest.scala b/compiler/test/dotty/tools/dotc/transform/FormatCheckerTest.scala new file mode 100644 index 000000000000..56bc47d172b2 --- /dev/null +++ b/compiler/test/dotty/tools/dotc/transform/FormatCheckerTest.scala @@ -0,0 +1,82 @@ +package dotty.tools +package dotc.transform + +import org.junit.{Test, Assert}, Assert.{assertEquals, assertFalse, assertTrue} + +import scala.collection.mutable.ListBuffer +import scala.language.implicitConversions +import scala.reflect.{ClassTag, classTag} +import scala.util.chaining._ + +import java.util.{Calendar, Date, Formattable} + +import localopt.{FormatChecker, StringContextChecker} + +// TDD for just the Checker +class FormatCheckerTest: + class TestReporter extends StringContextChecker.InterpolationReporter: + private var reported = false + private var oldReported = false + val reports = ListBuffer.empty[(String, Int, Int)] + + def partError(message: String, index: Int, offset: Int): Unit = + reported = true + reports += ((message, index, offset)) + def partWarning(message: String, index: Int, offset: Int): Unit = + reported = true + reports += ((message, index, offset)) + def argError(message: String, index: Int): Unit = + reported = true + reports += ((message, index, 0)) + def strCtxError(message: String): Unit = + reported = true + reports += ((message, 0, 0)) + def argsError(message: String): Unit = + reports += ((message, 0, 0)) + + def hasReported: Boolean = reported + + def resetReported(): Unit = + oldReported = reported + reported = false + + def restoreReported(): Unit = + reported = oldReported + end TestReporter + given TestReporter = TestReporter() + + /* + enum ArgTypeTag: + case BooleanArg, ByteArg, CharArg, ShortArg, IntArg, LongArg, FloatArg, DoubleArg, AnyArg, + StringArg, FormattableArg, BigIntArg, BigDecimalArg, CalendarArg, DateArg + given Conversion[ArgTypeTag, Int] = _.ordinal + def argTypeString(tag: Int) = + if tag < 0 then "Null" + else if tag >= ArgTypeTag.values.length then throw RuntimeException(s"Bad tag $tag") + else ArgTypeTag.values(tag) + */ + + class TestChecker(args: ClassTag[?]*)(using val reporter: TestReporter) extends FormatChecker: + def argType(argi: Int, types: ClassTag[?]*): ClassTag[?] = types.find(_ == args(argi)).getOrElse(types.head) + val argc = args.length + + def checked(parts: String*)(args: ClassTag[?]*): String = + val checker = TestChecker(args*) + val (amended, _) = checker.checked(parts.toList) + assertFalse(checker.reporter.hasReported) + amended.mkString + def assertChecked(parts: String*)(args: ClassTag[?]*)(p: TestReporter => Boolean = _ => true): Unit = + val checker = TestChecker(args*) + checker.checked(parts.toList) + assertTrue(p(checker.reporter)) + def errorIs(msg: String): (TestReporter => Boolean) = _.reports.head._1.contains(msg) + + @Test def `simple string` = assertEquals("xyz", checked("xyz")()) + @Test def `one string` = assertEquals("xyz%s123", checked("xyz", "123")(classTag[String])) + @Test def `in first part` = assertEquals("x%ny%%z%s123", checked("x%ny%%z", "123")(classTag[String])) + @Test def `one int` = assertEquals("xyz%d123", checked("xyz", "%d123")(classTag[Int])) + //@Test def `one bad int`: Unit = assertChecked("xyz", "%d123")(classTag[String])(errorIs("Type error")) + @Test def `extra descriptor` = assertChecked("xyz", "%s12%d3")(classTag[String])(errorIs("conversions must follow")) + @Test def `bad leader`: Unit = assertChecked("%dxyz")()(_.reports.head._1.contains("conversions must follow")) + @Test def `in second part`: Unit = assertEquals("xyz%s1%n2%%3", checked("xyz", "1%n2%%3")(classTag[String])) + @Test def `something weird`: Unit = assertEquals("xyz%tH123", checked("xyz", "%tH123")(classTag[Calendar])) diff --git a/tests/neg/f-interpolator-tests.scala b/tests/neg/f-interpolator-tests.scala new file mode 100644 index 000000000000..ca4b4ad11e39 --- /dev/null +++ b/tests/neg/f-interpolator-tests.scala @@ -0,0 +1,11 @@ + +trait T { + val s = "um" + def `um uh` = f"$s%d" // error + + // "% y" looks like a format conversion because ' ' is a legal flag + def i11256 = + val x = 42 + val y = 27 + f"x % y = ${x % y}%d" // error: illegal conversion character 'y' +} diff --git a/tests/neg/fEscapes.check b/tests/neg/fEscapes.check new file mode 100644 index 000000000000..29fe1412e075 --- /dev/null +++ b/tests/neg/fEscapes.check @@ -0,0 +1,4 @@ +-- Error: tests/neg/fEscapes.scala:5:18 -------------------------------------------------------------------------------- +5 | val fEscape = f"\u$octal%04x" // error + | ^^ + | invalid unicode escape at index 1 of \u diff --git a/tests/neg/fEscapes.scala b/tests/neg/fEscapes.scala new file mode 100644 index 000000000000..c4a9a6ffb200 --- /dev/null +++ b/tests/neg/fEscapes.scala @@ -0,0 +1,5 @@ + +// f-interpolator wasn't doing any escape processing +class C: + val octal = 8 + val fEscape = f"\u$octal%04x" // error diff --git a/tests/run-macros/f-interpolator-tests.scala b/tests/run-macros/f-interpolator-tests.scala index a50d22a4c022..73ad9b8852fa 100755 --- a/tests/run-macros/f-interpolator-tests.scala +++ b/tests/run-macros/f-interpolator-tests.scala @@ -23,6 +23,7 @@ object Test { dateArgsTests specificLiteralsTests argumentsTests + unitTests } def multilineTests = { @@ -199,5 +200,209 @@ object Test { def argumentsTests = { println(f"${"a"}%s ${"b"}%s % "false", + f"${b_true}%b" -> "true", + + f"${null}%b" -> "false", + f"${false}%b" -> "false", + f"${true}%b" -> "true", + f"${true && false}%b" -> "false", + f"${java.lang.Boolean.valueOf(false)}%b" -> "false", + f"${java.lang.Boolean.valueOf(true)}%b" -> "true", + + f"${null}%B" -> "FALSE", + f"${false}%B" -> "FALSE", + f"${true}%B" -> "TRUE", + f"${java.lang.Boolean.valueOf(false)}%B" -> "FALSE", + f"${java.lang.Boolean.valueOf(true)}%B" -> "TRUE", + + f"${"true"}%b" -> "true", + f"${"false"}%b"-> "false", + + // 'h' | 'H' (category: general) + // ----------------------------- + f"${null}%h" -> "null", + f"${f_zero}%h" -> "0", + f"${f_zero_-}%h" -> "80000000", + f"${s}%h" -> "4c01926", + + f"${null}%H" -> "NULL", + f"${s}%H" -> "4C01926", + + // 's' | 'S' (category: general) + // ----------------------------- + f"${null}%s" -> "null", + f"${null}%S" -> "NULL", + f"${s}%s" -> "Scala", + f"${s}%S" -> "SCALA", + f"${5}" -> "5", + f"${i}" -> "42", + f"${Symbol("foo")}" -> "Symbol(foo)", + + f"${Thread.State.NEW}" -> "NEW", + + // 'c' | 'C' (category: character) + // ------------------------------- + f"${120:Char}%c" -> "x", + f"${120:Byte}%c" -> "x", + f"${120:Short}%c" -> "x", + f"${120:Int}%c" -> "x", + f"${java.lang.Character.valueOf('x')}%c" -> "x", + f"${java.lang.Byte.valueOf(120:Byte)}%c" -> "x", + f"${java.lang.Short.valueOf(120:Short)}%c" -> "x", + f"${java.lang.Integer.valueOf(120)}%c" -> "x", + + f"${'x' : java.lang.Character}%c" -> "x", + f"${(120:Byte) : java.lang.Byte}%c" -> "x", + f"${(120:Short) : java.lang.Short}%c" -> "x", + f"${120 : java.lang.Integer}%c" -> "x", + + f"${"Scala"}%c" -> "S", + + // 'd' | 'o' | 'x' | 'X' (category: integral) + // ------------------------------------------ + f"${120:Byte}%d" -> "120", + f"${120:Short}%d" -> "120", + f"${120:Int}%d" -> "120", + f"${120:Long}%d" -> "120", + f"${60 * 2}%d" -> "120", + f"${java.lang.Byte.valueOf(120:Byte)}%d" -> "120", + f"${java.lang.Short.valueOf(120:Short)}%d" -> "120", + f"${java.lang.Integer.valueOf(120)}%d" -> "120", + f"${java.lang.Long.valueOf(120)}%d" -> "120", + f"${120 : java.lang.Integer}%d" -> "120", + f"${120 : java.lang.Long}%d" -> "120", + f"${BigInt(120)}%d" -> "120", + + f"${new java.math.BigInteger("120")}%d" -> "120", + + f"${4}%#10X" -> " 0X4", + + f"She is ${fff}%#s feet tall." -> "She is 4 feet tall.", + + f"Just want to say ${"hello, world"}%#s..." -> "Just want to say hello, world...", + + //{ implicit val strToShort: Conversion[String, Short] = java.lang.Short.parseShort ; f"${"120"}%d" } -> "120", + //{ implicit val strToInt = (s: String) => 42 ; f"${"120"}%d" } -> "42", + + // 'e' | 'E' | 'g' | 'G' | 'f' | 'a' | 'A' (category: floating point) + // ------------------------------------------------------------------ + f"${3.4f}%e" -> locally"3.400000e+00", + f"${3.4}%e" -> locally"3.400000e+00", + f"${3.4f : java.lang.Float}%e" -> locally"3.400000e+00", + f"${3.4 : java.lang.Double}%e" -> locally"3.400000e+00", + + f"${BigDecimal(3.4)}%e" -> locally"3.400000e+00", + + f"${new java.math.BigDecimal(3.4)}%e" -> locally"3.400000e+00", + + f"${3}%e" -> locally"3.000000e+00", + f"${3L}%e" -> locally"3.000000e+00", + + // 't' | 'T' (category: date/time) + // ------------------------------- + f"${cal}%TD" -> "05/26/12", + f"${cal.getTime}%TD" -> "05/26/12", + f"${cal.getTime.getTime}%TD" -> "05/26/12", + f"""${"1234"}%TD""" -> "05/26/12", + + // literals and arg indexes + f"%%" -> "%", + f" mind%n------%nmatter" -> + """| mind + |------ + |matter""".stripMargin.linesIterator.mkString(System.lineSeparator), + f"${i}%d % "42 42 9", + f"${7}%d % "7 7 9", + f"${7}%d %2$$d ${9}%d" -> "7 9 9", + + f"${null}%d % "null FALSE", + + f"${5: Any}" -> "5", + f"${5}%s% "55", + f"${3.14}%s,% locally"3.14,${"3.140000"}", + + f"z" -> "z" + ) + + for ((f, s) <- ss) assertEquals(s, f) + end `f interpolator baseline` + + def fIf = + val res = f"${if true then 2.5 else 2.5}%.2f" + val expected = locally"2.50" + assertEquals(expected, res) + + def fIfNot = + val res = f"${if false then 2.5 else 3.5}%.2f" + val expected = locally"3.50" + assertEquals(expected, res) + + // in Scala 2, [A >: Any] forced not to convert 3 to 3.0; Scala 3 harmonics should also respect lower bound. + def fHeteroArgs() = + val res = f"${3.14}%.2f rounds to ${3}%d" + val expected = locally"${"3.14"} rounds to 3" + assertEquals(expected, res) } +object StringContextTestUtils: + private val decimalSeparator: Char = new DecimalFormat().getDecimalFormatSymbols().getDecimalSeparator() + private val numberPattern = """(\d+)\.(\d+.*)""".r + private def applyProperLocale(number: String): String = + val numberPattern(intPart, fractionalPartAndSuffix) = number + s"$intPart$decimalSeparator$fractionalPartAndSuffix" + + extension (sc: StringContext) + // Use this String interpolator to avoid problems with a locale-dependent decimal mark. + def locally(numbers: String*): String = + val numbersWithCorrectLocale = numbers.map(applyProperLocale) + sc.s(numbersWithCorrectLocale: _*) + + // Handles cases like locally"3.14" - it's prettier than locally"${"3.14"}". + def locally(): String = sc.parts.map(applyProperLocale).mkString diff --git a/tests/run/t6476.check b/tests/run/t6476.check index 69bf68978177..e2a080bcf6dc 100644 --- a/tests/run/t6476.check +++ b/tests/run/t6476.check @@ -1,13 +1,18 @@ "Hello", Alice "Hello", Alice + +"Hello", Alice +"Hello", Alice + \"Hello\", Alice \"Hello\", Alice -\"Hello\", Alice -\"Hello\", Alice + \TILT\ -\\TILT\\ -\\TILT\\ \TILT\ \\TILT\\ + +\TILT\ +\TILT\ \\TILT\\ + \TILT\ diff --git a/tests/run/t6476.scala b/tests/run/t6476.scala index a04645065a2a..25a1d5f03ec1 100644 --- a/tests/run/t6476.scala +++ b/tests/run/t6476.scala @@ -3,21 +3,21 @@ object Test { val person = "Alice" println(s"\"Hello\", $person") println(s"""\"Hello\", $person""") - + println() println(f"\"Hello\", $person") println(f"""\"Hello\", $person""") - + println() println(raw"\"Hello\", $person") println(raw"""\"Hello\", $person""") - + println() println(s"\\TILT\\") println(f"\\TILT\\") println(raw"\\TILT\\") - + println() println(s"""\\TILT\\""") println(f"""\\TILT\\""") println(raw"""\\TILT\\""") - + println() println(raw"""\TILT\""") } }