From 1392f663f063b69a5a53365b4a1fbcbb77c74543 Mon Sep 17 00:00:00 2001 From: Som Snytt Date: Mon, 22 Nov 2021 20:53:53 -0800 Subject: [PATCH] Brace reduction and remove dead code --- compiler/src/dotty/tools/dotc/Compiler.scala | 2 +- .../transform/localopt/FormatChecker.scala | 2 - .../FormatInterpolatorTransform.scala | 86 ++- .../localopt/StringContextChecker.scala | 714 ------------------ .../localopt/StringInterpolatorOpt.scala | 116 ++- .../dotc/transform/FormatCheckerTest.scala | 16 +- tests/run-macros/f-interpolator-tests.scala | 13 +- 7 files changed, 95 insertions(+), 854 deletions(-) delete mode 100644 compiler/src/dotty/tools/dotc/transform/localopt/StringContextChecker.scala diff --git a/compiler/src/dotty/tools/dotc/Compiler.scala b/compiler/src/dotty/tools/dotc/Compiler.scala index 736e1b08d1e7..4911cdebd138 100644 --- a/compiler/src/dotty/tools/dotc/Compiler.scala +++ b/compiler/src/dotty/tools/dotc/Compiler.scala @@ -84,7 +84,7 @@ class Compiler { new ExplicitOuter, // Add accessors to outer classes from nested ones. new ExplicitSelf, // Make references to non-trivial self types explicit as casts new ElimByName, // Expand by-name parameter references - new StringInterpolatorOpt) :: // Optimizes raw and s string interpolators by rewriting them to string concatenations + new StringInterpolatorOpt) :: // Optimizes raw and s and f string interpolators by rewriting them to string concatenations or formats List(new PruneErasedDefs, // Drop erased definitions from scopes and simplify erased expressions new UninitializedDefs, // Replaces `compiletime.uninitialized` by `_` new InlinePatterns, // Remove placeholders of inlined patterns diff --git a/compiler/src/dotty/tools/dotc/transform/localopt/FormatChecker.scala b/compiler/src/dotty/tools/dotc/transform/localopt/FormatChecker.scala index 1cbd73db2ba0..54493c552473 100644 --- a/compiler/src/dotty/tools/dotc/transform/localopt/FormatChecker.scala +++ b/compiler/src/dotty/tools/dotc/transform/localopt/FormatChecker.scala @@ -9,8 +9,6 @@ import scala.util.matching.Regex.Match import java.util.{Calendar, Date, Formattable} -import StringContextChecker.InterpolationReporter - /** Formatter string checker. */ abstract class FormatChecker(using reporter: InterpolationReporter): diff --git a/compiler/src/dotty/tools/dotc/transform/localopt/FormatInterpolatorTransform.scala b/compiler/src/dotty/tools/dotc/transform/localopt/FormatInterpolatorTransform.scala index 820b42d2c225..436b56710370 100644 --- a/compiler/src/dotty/tools/dotc/transform/localopt/FormatInterpolatorTransform.scala +++ b/compiler/src/dotty/tools/dotc/transform/localopt/FormatInterpolatorTransform.scala @@ -22,51 +22,7 @@ import scala.util.matching.Regex.Match object FormatInterpolatorTransform: import tpd._ - import StringContextChecker.InterpolationReporter - /* - /** This trait defines a tool to report errors/warnings that do not depend on Position. */ - trait InterpolationReporter: - - /** Reports error/warning of size 1 linked with a part of the StringContext. - * - * @param message the message to report as error/warning - * @param index the index of the part inside the list of parts of the StringContext - * @param offset the index in the part String where the error is - * @return an error/warning depending on the function - */ - def partError(message: String, index: Int, offset: Int): Unit - def partWarning(message: String, index: Int, offset: Int): Unit - - /** Reports error linked with an argument to format. - * - * @param message the message to report as error/warning - * @param index the index of the argument inside the list of arguments of the format function - * @return an error depending on the function - */ - def argError(message: String, index: Int): Unit - - /** Reports error linked with the list of arguments or the StringContext. - * - * @param message the message to report in the error - * @return an error - */ - def strCtxError(message: String): Unit - def argsError(message: String): Unit - - /** Claims whether an error or a warning has been reported - * - * @return true if an error/warning has been reported, false - */ - def hasReported: Boolean - - /** Stores the old value of the reported and reset it to false */ - def resetReported(): Unit - - /** Restores the value of the reported boolean that has been reset */ - def restoreReported(): Unit - end InterpolationReporter - */ class PartsReporter(fun: Tree, args0: Tree, parts: List[Tree], args: List[Tree])(using Context) extends InterpolationReporter: private var reported = false private var oldReported = false @@ -193,3 +149,45 @@ object FormatInterpolatorTransform: (literally(checked.mkString), tpd.SeqLiteral(checker.actuals.toList, elemtpt)) end checked end FormatInterpolatorTransform + +/** This trait defines a tool to report errors/warnings that do not depend on Position. */ +trait InterpolationReporter: + + /** Reports error/warning of size 1 linked with a part of the StringContext. + * + * @param message the message to report as error/warning + * @param index the index of the part inside the list of parts of the StringContext + * @param offset the index in the part String where the error is + * @return an error/warning depending on the function + */ + def partError(message: String, index: Int, offset: Int): Unit + def partWarning(message: String, index: Int, offset: Int): Unit + + /** Reports error linked with an argument to format. + * + * @param message the message to report as error/warning + * @param index the index of the argument inside the list of arguments of the format function + * @return an error depending on the function + */ + def argError(message: String, index: Int): Unit + + /** Reports error linked with the list of arguments or the StringContext. + * + * @param message the message to report in the error + * @return an error + */ + def strCtxError(message: String): Unit + def argsError(message: String): Unit + + /** Claims whether an error or a warning has been reported + * + * @return true if an error/warning has been reported, false + */ + def hasReported: Boolean + + /** Stores the old value of the reported and reset it to false */ + def resetReported(): Unit + + /** Restores the value of the reported boolean that has been reset */ + def restoreReported(): Unit +end InterpolationReporter diff --git a/compiler/src/dotty/tools/dotc/transform/localopt/StringContextChecker.scala b/compiler/src/dotty/tools/dotc/transform/localopt/StringContextChecker.scala deleted file mode 100644 index 18f296af8c4e..000000000000 --- a/compiler/src/dotty/tools/dotc/transform/localopt/StringContextChecker.scala +++ /dev/null @@ -1,714 +0,0 @@ -package dotty.tools.dotc -package transform.localopt - -import dotty.tools.dotc.ast.Trees._ -import dotty.tools.dotc.ast.tpd -import dotty.tools.dotc.core.Decorators._ -import dotty.tools.dotc.core.Constants.Constant -import dotty.tools.dotc.core.Contexts._ -import dotty.tools.dotc.core.StdNames._ -import dotty.tools.dotc.core.NameKinds._ -import dotty.tools.dotc.core.Symbols._ -import dotty.tools.dotc.core.Types._ - -// Ported from old dotty.internal.StringContextMacro -// TODO: port Scala 2 logic? (see https://github.com/scala/scala/blob/2.13.x/src/compiler/scala/tools/reflect/FormatInterpolator.scala#L74) -object StringContextChecker { - import tpd._ - - /** This trait defines a tool to report errors/warnings that do not depend on Position. */ - trait InterpolationReporter { - - /** Reports error/warning of size 1 linked with a part of the StringContext. - * - * @param message the message to report as error/warning - * @param index the index of the part inside the list of parts of the StringContext - * @param offset the index in the part String where the error is - * @return an error/warning depending on the function - */ - def partError(message : String, index : Int, offset : Int) : Unit - def partWarning(message : String, index : Int, offset : Int) : Unit - - /** Reports error linked with an argument to format. - * - * @param message the message to report as error/warning - * @param index the index of the argument inside the list of arguments of the format function - * @return an error depending on the function - */ - def argError(message : String, index : Int) : Unit - - /** Reports error linked with the list of arguments or the StringContext. - * - * @param message the message to report in the error - * @return an error - */ - def strCtxError(message : String) : Unit - def argsError(message : String) : Unit - - /** Claims whether an error or a warning has been reported - * - * @return true if an error/warning has been reported, false - */ - def hasReported : Boolean - - /** Stores the old value of the reported and reset it to false */ - def resetReported() : Unit - - /** Restores the value of the reported boolean that has been reset */ - def restoreReported() : Unit - } - - /** Check the format of the parts of the f".." arguments and returns the string parts of the StringContext */ - def checkedParts(strContext_f: Tree, args0: Tree)(using Context): String = { - - val (partsExpr, parts) = strContext_f match { - case TypeApply(Select(Apply(_, (parts: SeqLiteral) :: Nil), _), _) => - (parts.elems, parts.elems.map { case Literal(Constant(str: String)) => str } ) - case _ => - report.error("Expected statically known String Context", strContext_f.srcPos) - return "" - } - - val args = args0 match { - case args: SeqLiteral => args.elems - case _ => - report.error("Expected statically known argument list", args0.srcPos) - return "" - } - - val reporter = new InterpolationReporter{ - private[this] var reported = false - private[this] var oldReported = false - def partError(message : String, index : Int, offset : Int) : Unit = { - reported = true - val pos = partsExpr(index).sourcePos - val posOffset = pos.withSpan(pos.span.shift(offset)) - report.error(message, posOffset) - } - def partWarning(message : String, index : Int, offset : Int) : Unit = { - reported = true - val pos = partsExpr(index).sourcePos - val posOffset = pos.withSpan(pos.span.shift(offset)) - report.warning(message, posOffset) - } - - def argError(message : String, index : Int) : Unit = { - reported = true - report.error(message, args(index).srcPos) - } - - def strCtxError(message : String) : Unit = { - reported = true - report.error(message, strContext_f.srcPos) - } - def argsError(message : String) : Unit = { - reported = true - report.error(message, args0.srcPos) - } - - def hasReported : Boolean = { - reported - } - - def resetReported() : Unit = { - oldReported = reported - reported = false - } - - def restoreReported() : Unit = { - reported = oldReported - } - } - - checked(parts, args, reporter) - } - - def checked(parts0: List[String], args: List[Tree], reporter: InterpolationReporter)(using Context): String = { - - - /** Checks if the number of arguments are the same as the number of formatting strings - * - * @param format the number of formatting parts in the StringContext - * @param argument the number of arguments to interpolate in the string - * @return reports an error if the number of arguments does not match with the number of formatting strings, - * nothing otherwise - */ - def checkSizes(format : Int, argument : Int) : Unit = { - if (format > argument && !(format == -1 && argument == 0)) - if (argument == 0) - reporter.argsError("too few arguments for interpolated string") - else - reporter.argError("too few arguments for interpolated string", argument - 1) - if (format < argument && !(format == -1 && argument == 0)) - if (argument == 0) - reporter.argsError("too many arguments for interpolated string") - else - reporter.argError("too many arguments for interpolated string", format) - if (format == -1) - reporter.strCtxError("there are no parts") - } - - /** Adds the default "%s" to the Strings that do not have any given format - * - * @param parts the list of parts contained in the StringContext - * @return a new list of string with all a defined formatting or reports an error if the '%' and - * formatting parameter are too far away from the argument that they refer to - * For example : f2"${d}random-leading-junk%d" will lead to an error - */ - def addDefaultFormat(parts : List[String]) : List[String] = parts match { - case Nil => Nil - case p :: parts1 => p :: parts1.map((part : String) => { - if (!part.startsWith("%")) { - val index = part.indexOf('%') - if (!reporter.hasReported && index != -1) { - reporter.partError("conversions must follow a splice; use %% for literal %, %n for newline", parts.indexOf(part), index) - "%s" + part - } else "%s" + part - } else part - }) - } - - /** Checks whether a part contains a formatting substring - * - * @param part the part to check - * @param l the length of the given part - * @param index the index where to start to look for a potential new formatting string - * @return an Option containing the index in the part where a new formatting String starts, None otherwise - */ - def getFormattingSubstring(part : String, l : Int, index : Int) : Option[Int] = { - var i = index - var result : Option[Int] = None - while (i < l){ - if (part.charAt(i) == '%' && result.isEmpty) - result = Some(i) - i += 1 - } - result - } - - /** Finds all the flags that are inside a formatting String from a given index - * - * @param i the index in the String s where to start to check - * @param l the length of s - * @param s the String to check - * @return a list containing all the flags that are inside the formatting String, - * and their index in the String - */ - def getFlags(i : Int, l : Int, s : String) : List[(Char, Int)] = { - def isFlag(c : Char) : Boolean = c match { - case '-' | '#' | '+' | ' ' | '0' | ',' | '(' => true - case _ => false - } - if (i < l && isFlag(s.charAt(i))) (s.charAt(i), i) :: getFlags(i + 1, l, s) - else Nil - } - - /** Skips the Characters that are width or argumentIndex parameters - * - * @param i the index where to start checking in the given String - * @param s the String to check - * @param l the length of s - * @return a tuple containing the index in the String after skipping - * the parameters, true if it has a width parameter and its value, false otherwise - */ - def skipWidth(i : Int, s : String, l : Int) = { - var j = i - var width = (false, 0) - while (j < l && Character.isDigit(s.charAt(j))){ - width = (true, j) - j += 1 - } - (j, width._1, width._2) - } - - /** Retrieves all the formatting parameters from a part and their index in it - * - * @param part the String containing the formatting parameters - * @param argIndex the index of the current argument inside the list of arguments to interpolate - * @param partIndex the index of the current part inside the list of parts in the StringContext - * @param noArg true if there is no arg, i.e. "%%" or "%n" - * @param pos the initial index where to start checking the part - * @return reports an error if any of the size of the arguments and the parts do not match or if a conversion - * parameter is missing. Otherwise, - * the index where the format specifier substring is, - * hasArgumentIndex (true and the index of its corresponding argumentIndex if there is an argument index, false and 0 otherwise) and - * flags that contains the list of flags (empty if there is none), - * hasWidth (true and the index of the width parameter if there is a width, false and 0 otherwise), - * hasPrecision (true and the index of the precision if there is a precision, false and 0 otherwise), - * hasRelative (true if the specifiers use relative indexing, false otherwise) and - * conversion character index - */ - def getFormatSpecifiers(part : String, argIndex : Int, partIndex : Int, noArg : Boolean, pos : Int) : (Boolean, Int, List[(Char, Int)], Boolean, Int, Boolean, Int, Boolean, Int, Int) = { - var conversion = pos - var hasArgumentIndex = false - var argumentIndex = pos - var hasPrecision = false - var precision = pos - val l = part.length - - if (l >= 1 && part.charAt(conversion) == '%') - conversion += 1 - else if (!noArg) - reporter.argError("too many arguments for interpolated string", argIndex) - - //argument index or width - val (i, hasWidth1, width1) = skipWidth(conversion, part, l) - conversion = i - - //argument index - if (conversion < l && part.charAt(conversion) == '$'){ - if (hasWidth1){ - hasArgumentIndex = true - argumentIndex = width1 - conversion += 1 - } else { - reporter.partError("Missing conversion operator in '" + part.substring(0, conversion) + "'; use %% for literal %, %n for newline", partIndex, 0) - } - } - - //relative indexing - val hasRelative = conversion < l && part.charAt(conversion) == '<' - val relativeIndex = conversion - if (hasRelative) - conversion += 1 - - //flags - val flags = getFlags(conversion, l, part) - conversion += flags.size - - //width - val (j, hasWidth2, width2) = skipWidth(conversion, part, l) - conversion = j - - //precision - if (conversion < l && part.charAt(conversion) == '.') { - precision = conversion - conversion += 1 - hasPrecision = true - val oldConversion = conversion - while (conversion < l && Character.isDigit(part.charAt(conversion))) { - conversion += 1 - } - if (oldConversion == conversion) { - reporter.partError("Missing conversion operator in '" + part.substring(pos, oldConversion - 1) + "'; use %% for literal %, %n for newline", partIndex, pos) - hasPrecision = false - } - } - - //conversion - if((conversion >= l || (!part.charAt(conversion).isLetter && part.charAt(conversion) != '%')) && !reporter.hasReported) - reporter.partError("Missing conversion operator in '" + part.substring(pos, conversion) + "'; use %% for literal %, %n for newline", partIndex, pos) - - val hasWidth = (hasWidth1 && !hasArgumentIndex) || hasWidth2 - val width = if (hasWidth1 && !hasArgumentIndex) width1 else width2 - (hasArgumentIndex, argumentIndex, flags, hasWidth, width, hasPrecision, precision, hasRelative, relativeIndex, conversion) - } - - /** Checks if a given type is a subtype of any of the possibilities - * - * @param actualType the given type - * @param expectedType the type we are expecting - * @param argIndex the index of the argument that should type check - * @param possibilities all the types within which we want to find a super type of the actualType - * @return reports a type mismatch error if the actual type is not a subtype of any of the possibilities, - * nothing otherwise - */ - def checkSubtype(actualType: Type, expectedType: String, argIndex: Int, possibilities: List[Type]) = { - if !possibilities.exists(actualType <:< _) then - reporter.argError("type mismatch;\n found : " + actualType.widen.show.stripPrefix("scala.Predef.").stripPrefix("java.lang.").stripPrefix("scala.") + "\n required: " + expectedType, argIndex) - } - - /** Checks whether a given argument index, relative or not, is in the correct bounds - * - * @param partIndex the index of the part we are checking - * @param offset the index in the part where there might be an error - * @param relative true if relative indexing is used, false otherwise - * @param argumentIndex the argument index parameter in the formatting String - * @param expected true if we have an expectedArgumentIndex, false otherwise - * @param expectedArgumentIndex the expected argument index parameter - * @param maxArgumentIndex the maximum argument index parameter that can be used - * @return reports a warning if relative indexing is used but an argument is still given, - * an error is the argument index is not in the bounds [1, number of arguments] - */ - def checkArgumentIndex(partIndex : Int, offset : Int, relative : Boolean, argumentIndex : Int, expected : Boolean, expectedArgumentIndex : Int, maxArgumentIndex : Int) = { - if (relative) - reporter.partWarning("Argument index ignored if '<' flag is present", partIndex, offset) - - if (argumentIndex > maxArgumentIndex || argumentIndex <= 0) - reporter.partError("Argument index out of range", partIndex, offset) - - if (expected && expectedArgumentIndex != argumentIndex && !reporter.hasReported) - reporter.partWarning("Index is not this arg", partIndex, offset) - } - - /** Checks if a parameter is specified whereas it is not allowed - * - * @param hasParameter true if parameter is specified, false otherwise - * @param partIndex the index of the part inside the parts - * @param offset the index in the part where to report an error - * @param parameter the parameter that is not allowed - * @return reports an error if hasParameter is true, nothing otherwise - */ - def checkNotAllowedParameter(hasParameter : Boolean, partIndex : Int, offset : Int, parameter : String) = { - if (hasParameter) - reporter.partError(parameter + " not allowed", partIndex, offset) - } - - /** Checks if the flags are allowed for the conversion - * - * @param partIndex the index of the part in the String Context - * @param flags the specified flags to check - * @param notAllowedFlagsOnCondition a list that maps which flags are allowed depending on the conversion Char - * @return reports an error if the flag is not allowed, nothing otherwise - */ - def checkFlags(partIndex : Int, flags : List[(Char, Int)], notAllowedFlagOnCondition: List[(Char, Boolean, String)]) = { - for {flag <- flags ; (nonAllowedFlag, condition, message) <- notAllowedFlagOnCondition ; if (flag._1 == nonAllowedFlag && condition)} - reporter.partError(message, partIndex, flag._2) - } - - /** Checks if the flags are allowed for the conversion - * - * @param partIndex the index of the part in the String Context - * @param flags the specified flags to check - * @param notAllowedFlagsOnCondition a list that maps which flags are allowed depending on the conversion Char - * @return reports an error only once if at least one of the flags is not allowed, nothing otherwise - */ - def checkUniqueFlags(partIndex : Int, flags : List[(Char, Int)], notAllowedFlagOnCondition : List[(Char, Boolean, String)]) = { - reporter.resetReported() - for {flag <- flags ; (nonAllowedFlag, condition, message) <- notAllowedFlagOnCondition ; if (flag._1 == nonAllowedFlag && condition)} { - if (!reporter.hasReported) - reporter.partError(message, partIndex, flag._2) - } - if (!reporter.hasReported) - reporter.restoreReported() - } - - /** Checks all the formatting parameters for a Character conversion - * - * @param partIndex the index of the part, that we are checking, inside the parts - * @param flags the flags parameters inside the formatting part - * @param hasPrecision true if precision parameter is specified, false otherwise - * @param precision the index of the precision parameter inside the part - * @return reports an error - * if precision is specified or if the used flags are different from '-' - */ - def checkCharacterConversion(partIndex : Int, flags : List[(Char, Int)], hasPrecision : Boolean, precisionIndex : Int) = { - val notAllowedFlagOnCondition = for (flag <- List('#', '+', ' ', '0', ',', '(')) yield (flag, true, "Only '-' allowed for c conversion") - checkUniqueFlags(partIndex, flags, notAllowedFlagOnCondition) - checkNotAllowedParameter(hasPrecision, partIndex, precisionIndex, "precision") - } - - /** Checks all the formatting parameters for an Integral conversion - * - * @param partIndex the index of the part, that we are checking, inside the parts - * @param argType the type of the argument matching with the given part - * @param conversionChar the Char used for the formatting conversion - * @param flags the flags parameters inside the formatting part - * @param hasPrecision true if precision parameter is specified, false otherwise - * @param precision the index of the precision parameter inside the part - * @return reports an error - * if precision is specified or if the used flags are not allowed : - * ’d’: only ’#’ is allowed, - * ’o’, ’x’, ’X’: ’-’, ’#’, ’0’ are always allowed, depending on the type, this will be checked in the type check step - */ - def checkIntegralConversion(partIndex : Int, argType : Option[Type], conversionChar : Char, flags : List[(Char, Int)], hasPrecision : Boolean, precision : Int) = { - if (conversionChar == 'd') - checkFlags(partIndex, flags, List(('#', true, "# not allowed for d conversion"))) - - checkNotAllowedParameter(hasPrecision, partIndex, precision, "precision") - } - - /** Checks all the formatting parameters for a Floating Point conversion - * - * @param partIndex the index of the part, that we are checking, inside the parts - * @param conversionChar the Char used for the formatting conversion - * @param flags the flags parameters inside the formatting part - * @param hasPrecision true if precision parameter is specified, false otherwise - * @param precision the index of the precision parameter inside the part - * @return reports an error - * if precision is specified for 'a', 'A' conversion or if the used flags are '(' and ',' for 'a', 'A' - */ - def checkFloatingPointConversion(partIndex: Int, conversionChar : Char, flags : List[(Char, Int)], hasPrecision : Boolean, precision : Int) = { - if(conversionChar == 'a' || conversionChar == 'A'){ - for {flag <- flags ; if (flag._1 == ',' || flag._1 == '(')} - reporter.partError("'" + flag._1 + "' not allowed for a, A", partIndex, flag._2) - checkNotAllowedParameter(hasPrecision, partIndex, precision, "precision") - } - } - - /** Checks all the formatting parameters for a Time conversion - * - * @param partIndex the index of the part, that we are checking, inside the parts - * @param part the part that we are checking - * @param conversionIndex the index of the conversion Char used in the part - * @param flags the flags parameters inside the formatting part - * @param hasPrecision true if precision parameter is specified, false otherwise - * @param precision the index of the precision parameter inside the part - * @return reports an error - * if precision is specified, if the time suffix is not given/incorrect or if the used flags are - * different from '-' - */ - def checkTimeConversion(partIndex : Int, part : String, conversionIndex : Int, flags : List[(Char, Int)], hasPrecision : Boolean, precision : Int) = { - /** Checks whether a time suffix is given and whether it is allowed - * - * @param part the part that we are checking - * @param partIndex the index of the part inside of the parts of the StringContext - * @param conversionIndex the index of the conversion Char inside the part - * @param return reports an error if no suffix is specified or if the given suffix is not - * part of the allowed ones - */ - def checkTime(part : String, partIndex : Int, conversionIndex : Int) : Unit = { - if (conversionIndex + 1 >= part.size) - reporter.partError("Date/time conversion must have two characters", partIndex, conversionIndex) - else { - part.charAt(conversionIndex + 1) match { - case 'H' | 'I' | 'k' | 'l' | 'M' | 'S' | 'L' | 'N' | 'p' | 'z' | 'Z' | 's' | 'Q' => //times - case 'B' | 'b' | 'h' | 'A' | 'a' | 'C' | 'Y' | 'y' | 'j' | 'm' | 'd' | 'e' => //dates - case 'R' | 'T' | 'r' | 'D' | 'F' | 'c' => //dates and times - case c => reporter.partError("'" + c + "' doesn't seem to be a date or time conversion", partIndex, conversionIndex + 1) - } - } - } - - val notAllowedFlagOnCondition = for (flag <- List('#', '+', ' ', '0', ',', '(')) yield (flag, true, "Only '-' allowed for date/time conversions") - checkUniqueFlags(partIndex, flags, notAllowedFlagOnCondition) - checkNotAllowedParameter(hasPrecision, partIndex, precision, "precision") - checkTime(part, partIndex, conversionIndex) - } - - /** Checks all the formatting parameters for a General conversion - * - * @param partIndex the index of the part, that we are checking, inside the parts - * @param argType the type of the argument matching with the given part - * @param conversionChar the Char used for the formatting conversion - * @param flags the flags parameters inside the formatting part - * @return reports an error - * if '#' flag is used or if any other flag is used - */ - def checkGeneralConversion(partIndex : Int, argType : Option[Type], conversionChar : Char, flags : List[(Char, Int)]) = { - for {flag <- flags ; if (flag._1 != '-' && flag._1 != '#')} - reporter.partError("Illegal flag '" + flag._1 + "'", partIndex, flag._2) - } - - /** Checks all the formatting parameters for a special Char such as '%' and end of line - * - * @param partIndex the index of the part, that we are checking, inside the parts - * @param conversionChar the Char used for the formatting conversion - * @param hasPrecision true if precision parameter is specified, false otherwise - * @param precision the index of the precision parameter inside the part - * @param hasWidth true if width parameter is specified, false otherwise - * @param width the index of the width parameter inside the part - * @return reports an error if precision or width is specified for '%' or - * if precision is specified for end of line - */ - def checkSpecials(partIndex : Int, conversionChar : Char, hasPrecision : Boolean, precision : Int, hasWidth : Boolean, width : Int, flags : List[(Char, Int)]) = conversionChar match { - case 'n' => { - checkNotAllowedParameter(hasPrecision, partIndex, precision, "precision") - checkNotAllowedParameter(hasWidth, partIndex, width, "width") - val notAllowedFlagOnCondition = for (flag <- List('-', '#', '+', ' ', '0', ',', '(')) yield (flag, true, "flags not allowed") - checkUniqueFlags(partIndex, flags, notAllowedFlagOnCondition) - } - case '%' => { - checkNotAllowedParameter(hasPrecision, partIndex, precision, "precision") - val notAllowedFlagOnCondition = for (flag <- List('#', '+', ' ', '0', ',', '(')) yield (flag, true, "Illegal flag '" + flag + "'") - checkFlags(partIndex, flags, notAllowedFlagOnCondition) - } - case _ => // OK - } - - /** Checks whether the format specifiers are correct depending on the conversion parameter - * - * @param partIndex the index of the part, that we are checking, inside the parts - * @param part the part to check - * The rest of the inputs correspond to the output of the function getFormatSpecifiers - * @param hasArgumentIndex - * @param actualArgumentIndex - * @param expectedArgumentIndex - * @param firstFormattingSubstring true if it is the first in the list, i.e. not an indexed argument - * @param maxArgumentIndex - * @param hasRelative - * @param hasWidth - * @param hasPrecision - * @param precision - * @param flags - * @param conversion - * @param argType - * @return the argument index and its type if there is an argument, the flags and the conversion parameter - * reports an error/warning if the formatting parameters are not allowed/wrong, nothing otherwise - */ - def checkFormatSpecifiers(partIndex : Int, hasArgumentIndex : Boolean, actualArgumentIndex : Int, expectedArgumentIndex : Option[Int], firstFormattingSubstring : Boolean, maxArgumentIndex : Option[Int], - hasRelative : Boolean, hasWidth : Boolean, width : Int, hasPrecision : Boolean, precision : Int, flags : List[(Char, Int)], conversion : Int, argType : Option[Type], part : String) : (Option[(Type, Int)], Char, List[(Char, Int)])= { - val conversionChar = part.charAt(conversion) - - if (hasArgumentIndex && expectedArgumentIndex.nonEmpty && maxArgumentIndex.nonEmpty && firstFormattingSubstring) - checkArgumentIndex(partIndex, actualArgumentIndex, hasRelative, part.charAt(actualArgumentIndex).asDigit, true, expectedArgumentIndex.get, maxArgumentIndex.get) - else if(hasArgumentIndex && maxArgumentIndex.nonEmpty && !firstFormattingSubstring) - checkArgumentIndex(partIndex, actualArgumentIndex, hasRelative, part.charAt(actualArgumentIndex).asDigit, false, 0, maxArgumentIndex.get) - - conversionChar match { - case 'c' | 'C' => checkCharacterConversion(partIndex, flags, hasPrecision, precision) - case 'd' | 'o' | 'x' | 'X' => checkIntegralConversion(partIndex, argType, conversionChar, flags, hasPrecision, precision) - case 'e' | 'E' |'f' | 'g' | 'G' | 'a' | 'A' => checkFloatingPointConversion(partIndex, conversionChar, flags, hasPrecision, precision) - case 't' | 'T' => checkTimeConversion(partIndex, part, conversion, flags, hasPrecision, precision) - case 'b' | 'B' | 'h' | 'H' | 'S' | 's' => checkGeneralConversion(partIndex, argType, conversionChar, flags) - case 'n' | '%' => checkSpecials(partIndex, conversionChar, hasPrecision, precision, hasWidth, width, flags) - case illegal => reporter.partError("illegal conversion character '" + illegal + "'", partIndex, conversion) - } - - (if (argType.isEmpty) None else Some(argType.get, (partIndex - 1)), conversionChar, flags) - } - - /** Checks whether the argument type, if there is one, type checks with the formatting parameters - * - * @param partIndex the index of the part, that we are checking, inside the parts - * @param conversionChar the character used for the conversion - * @param argument an option containing the type and index of the argument, None if there is no argument - * @param flags the flags used for the formatting - * @param formattingStart the index in the part where the formatting substring starts, i.e. where the '%' is - * @return reports an error/warning if the formatting parameters are not allowed/wrong depending on the type, nothing otherwise - */ - def checkArgTypeWithConversion(partIndex : Int, conversionChar : Char, argument : Option[(Type, Int)], flags : List[(Char, Int)], formattingStart : Int) = { - if (argument.nonEmpty) - checkTypeWithArgs(argument.get, conversionChar, partIndex, flags) - else - checkTypeWithoutArgs(conversionChar, partIndex, flags, formattingStart) - } - - /** Checks whether the argument type checks with the formatting parameters - * - * @param argument the given argument to check - * @param conversionChar the conversion parameter inside the formatting String - * @param partIndex index of the part inside the String Context - * @param flags the list of flags, and their index, used inside the formatting String - * @return reports an error if the argument type does not correspond with the conversion character, - * nothing otherwise - */ - def checkTypeWithArgs(argument : (Type, Int), conversionChar : Char, partIndex : Int, flags : List[(Char, Int)]) = { - def booleans = List(defn.BooleanType, defn.NullType) - def dates = List(defn.LongType, defn.JavaCalendarClass.typeRef, defn.JavaDateClass.typeRef) - def floatingPoints = List(defn.DoubleType, defn.FloatType, defn.JavaBigDecimalClass.typeRef) - def integral = List(defn.IntType, defn.LongType, defn.ShortType, defn.ByteType, defn.JavaBigIntegerClass.typeRef) - def character = List(defn.CharType, defn.ByteType, defn.ShortType, defn.IntType) - - val (argType, argIndex) = argument - conversionChar match { - case 'c' | 'C' => checkSubtype(argType, "Char", argIndex, character) - case 'd' | 'o' | 'x' | 'X' => { - checkSubtype(argType, "Int", argIndex, integral) - if (conversionChar != 'd') { - val notAllowedFlagOnCondition = List(('+', !(argType <:< defn.JavaBigIntegerClass.typeRef), "only use '+' for BigInt conversions to o, x, X"), - (' ', !(argType <:< defn.JavaBigIntegerClass.typeRef), "only use ' ' for BigInt conversions to o, x, X"), - ('(', !(argType <:< defn.JavaBigIntegerClass.typeRef), "only use '(' for BigInt conversions to o, x, X"), - (',', true, "',' only allowed for d conversion of integral types")) - checkFlags(partIndex, flags, notAllowedFlagOnCondition) - } - } - case 'e' | 'E' |'f' | 'g' | 'G' | 'a' | 'A' => checkSubtype(argType, "Double", argIndex, floatingPoints) - case 't' | 'T' => checkSubtype(argType, "Date", argIndex, dates) - case 'b' | 'B' => checkSubtype(argType, "Boolean", argIndex, booleans) - case 'h' | 'H' | 'S' | 's' => - if !(argType <:< defn.JavaFormattableClass.typeRef) then - for flag <- flags; if flag._1 == '#' do - reporter.argError("type mismatch;\n found : " + argType.widen.show.stripPrefix("scala.Predef.").stripPrefix("java.lang.").stripPrefix("scala.") + "\n required: java.util.Formattable", argIndex) - case 'n' | '%' => - case illegal => - } - } - - /** Reports error when the formatting parameter require a specific type but no argument is given - * - * @param conversionChar the conversion parameter inside the formatting String - * @param partIndex index of the part inside the String Context - * @param flags the list of flags, and their index, used inside the formatting String - * @param formattingStart the index in the part where the formatting substring starts, i.e. where the '%' is - * @return reports an error if the formatting parameter refer to the type of the parameter but no parameter is given - * nothing otherwise - */ - def checkTypeWithoutArgs(conversionChar : Char, partIndex : Int, flags : List[(Char, Int)], formattingStart : Int) = { - conversionChar match { - case 'o' | 'x' | 'X' => { - val notAllowedFlagOnCondition = List(('+', true, "only use '+' for BigInt conversions to o, x, X"), - (' ', true, "only use ' ' for BigInt conversions to o, x, X"), - ('(', true, "only use '(' for BigInt conversions to o, x, X"), - (',', true, "',' only allowed for d conversion of integral types")) - checkFlags(partIndex, flags, notAllowedFlagOnCondition) - } - case _ => //OK - } - } - - /** Checks that a given part of the String Context respects every formatting constraint per parameter - * - * @param part a particular part of the String Context - * @param start the index from which we start checking the part - * @param argument an Option containing the argument corresponding to the part and its index in the list of args, - * None if no args are specified. - * @param maxArgumentIndex an Option containing the maximum argument index possible, None if no args are specified - * @return a list with all the elements of the conversion per formatting string - */ - def checkPart(part : String, start : Int, argument : Option[(Int, Tree)], maxArgumentIndex : Option[Int]) : List[(Option[(Type, Int)], Char, List[(Char, Int)])] = { - reporter.resetReported() - val hasFormattingSubstring = getFormattingSubstring(part, part.size, start) - if (hasFormattingSubstring.nonEmpty) { - val formattingStart = hasFormattingSubstring.get - var nextStart = formattingStart - - argument match { - case Some(argIndex, arg) => { - val (hasArgumentIndex, argumentIndex, flags, hasWidth, width, hasPrecision, precision, hasRelative, relativeIndex, conversion) = getFormatSpecifiers(part, argIndex, argIndex + 1, false, formattingStart) - if (!reporter.hasReported){ - val conversionWithType = checkFormatSpecifiers(argIndex + 1, hasArgumentIndex, argumentIndex, Some(argIndex + 1), start == 0, maxArgumentIndex, hasRelative, hasWidth, width, hasPrecision, precision, flags, conversion, Some(arg.tpe), part) - nextStart = conversion + 1 - conversionWithType :: checkPart(part, nextStart, argument, maxArgumentIndex) - } else checkPart(part, conversion + 1, argument, maxArgumentIndex) - } - case None => { - val (hasArgumentIndex, argumentIndex, flags, hasWidth, width, hasPrecision, precision, hasRelative, relativeIndex, conversion) = getFormatSpecifiers(part, 0, 0, true, formattingStart) - if (hasArgumentIndex && !(part.charAt(argumentIndex).asDigit == 1 && (part.charAt(conversion) == 'n' || part.charAt(conversion) == '%'))) - reporter.partError("Argument index out of range", 0, argumentIndex) - if (hasRelative) - reporter.partError("No last arg", 0, relativeIndex) - if (!reporter.hasReported){ - val conversionWithType = checkFormatSpecifiers(0, hasArgumentIndex, argumentIndex, None, start == 0, maxArgumentIndex, hasRelative, hasWidth, width, hasPrecision, precision, flags, conversion, None, part) - nextStart = conversion + 1 - if (!reporter.hasReported && part.charAt(conversion) != '%' && part.charAt(conversion) != 'n' && !hasArgumentIndex && !hasRelative) - reporter.partError("conversions must follow a splice; use %% for literal %, %n for newline", 0, part.indexOf('%')) - conversionWithType :: checkPart(part, nextStart, argument, maxArgumentIndex) - } else checkPart(part, conversion + 1, argument, maxArgumentIndex) - } - } - } else { - reporter.restoreReported() - Nil - } - } - - val argument = args.size - - // check validity of formatting - checkSizes(parts0.size - 1, argument) - - // add default format - val parts = addDefaultFormat(parts0) - - if (!parts.isEmpty && !reporter.hasReported) { - if (parts.size == 1 && args.size == 0 && parts.head.size != 0){ - val argTypeWithConversion = checkPart(parts.head, 0, None, None) - if (!reporter.hasReported) - for ((argument, conversionChar, flags) <- argTypeWithConversion) - checkArgTypeWithConversion(0, conversionChar, argument, flags, parts.head.indexOf('%')) - } else { - val partWithArgs = parts.tail.zip(args) - for (i <- (0 until args.size)){ - val (part, arg) = partWithArgs(i) - val argTypeWithConversion = checkPart(part, 0, Some((i, arg)), Some(args.size)) - if (!reporter.hasReported) - for ((argument, conversionChar, flags) <- argTypeWithConversion) - checkArgTypeWithConversion(i + 1, conversionChar, argument, flags, parts(i).indexOf('%')) - } - } - } - - parts.mkString - } -} diff --git a/compiler/src/dotty/tools/dotc/transform/localopt/StringInterpolatorOpt.scala b/compiler/src/dotty/tools/dotc/transform/localopt/StringInterpolatorOpt.scala index 04dd2a1af25e..529f1face73d 100644 --- a/compiler/src/dotty/tools/dotc/transform/localopt/StringInterpolatorOpt.scala +++ b/compiler/src/dotty/tools/dotc/transform/localopt/StringInterpolatorOpt.scala @@ -13,108 +13,84 @@ import dotty.tools.dotc.core.Types._ import dotty.tools.dotc.transform.MegaPhase.MiniPhase import dotty.tools.dotc.typer.ConstFold -/** - * MiniPhase to transform s and raw string interpolators from using StringContext to string - * concatenation. Since string concatenation uses the Java String builder, we get a performance - * improvement in terms of these two interpolators. - * - * More info here: - * https://medium.com/@dkomanov/scala-string-interpolation-performance-21dc85e83afd - */ -class StringInterpolatorOpt extends MiniPhase { - import tpd._ +/** MiniPhase to transform s and raw string interpolators from using StringContext to string + * concatenation. Since string concatenation uses the Java String builder, we get a performance + * improvement in terms of these two interpolators. + * + * More info here: + * https://medium.com/@dkomanov/scala-string-interpolation-performance-21dc85e83afd + */ +class StringInterpolatorOpt extends MiniPhase: + import tpd.* override def phaseName: String = "stringInterpolatorOpt" - override def checkPostCondition(tree: tpd.Tree)(using Context): Unit = { - tree match { + override def checkPostCondition(tree: tpd.Tree)(using Context): Unit = + tree match case tree: RefTree => val sym = tree.symbol assert(sym != defn.StringContext_raw && sym != defn.StringContext_s && sym != defn.StringContext_f, i"$tree in ${ctx.owner.showLocated} should have been rewritten by phase $phaseName") case _ => - } - } /** Matches a list of constant literals */ - private object Literals { - def unapply(tree: SeqLiteral)(using Context): Option[List[Literal]] = { - tree.elems match { - case literals if literals.forall(_.isInstanceOf[Literal]) => - Some(literals.map(_.asInstanceOf[Literal])) + private object Literals: + def unapply(tree: SeqLiteral)(using Context): Option[List[Literal]] = + tree.elems match + case literals if literals.forall(_.isInstanceOf[Literal]) => Some(literals.map(_.asInstanceOf[Literal])) case _ => None - } - } - } - private object StringContextApply { - def unapply(tree: Select)(using Context): Boolean = { - tree.symbol.eq(defn.StringContextModule_apply) && - tree.qualifier.symbol.eq(defn.StringContextModule) - } - } + private object StringContextApply: + def unapply(tree: Select)(using Context): Boolean = + (tree.symbol eq defn.StringContextModule_apply) && (tree.qualifier.symbol eq defn.StringContextModule) /** Matches an s or raw string interpolator */ - private object SOrRawInterpolator { - def unapply(tree: Tree)(using Context): Option[(List[Literal], List[Tree])] = { - tree match { - case Apply(Select(Apply(StringContextApply(), List(Literals(strs))), _), - List(SeqLiteral(elems, _))) if elems.length == strs.length - 1 => - Some(strs, elems) + private object SOrRawInterpolator: + def unapply(tree: Tree)(using Context): Option[(List[Literal], List[Tree])] = + tree match + case Apply(Select(Apply(StringContextApply(), List(Literals(strs))), _), List(SeqLiteral(elems, _))) + if elems.length == strs.length - 1 => Some(strs, elems) case _ => None - } - } - } //Extract the position from InvalidUnicodeEscapeException //which due to bincompat reasons is unaccessible. //TODO: remove once there is less restrictive bincompat - private object InvalidEscapePosition { - def unapply(t: Throwable): Option[Int] = t match { + private object InvalidEscapePosition: + def unapply(t: Throwable): Option[Int] = t match case iee: StringContext.InvalidEscapeException => Some(iee.index) - case il: IllegalArgumentException => il.getMessage() match { - case s"""invalid unicode escape at index $index of $_""" => index.toIntOption - case _ => None - } + case iae: IllegalArgumentException => iae.getMessage() match + case s"""invalid unicode escape at index $index of $_""" => index.toIntOption + case _ => None case _ => None - } - } - /** - * Match trees that resemble s and raw string interpolations. In the case of the s - * interpolator, escapes the string constants. Exposes the string constants as well as - * the variable references. - */ - private object StringContextIntrinsic { - def unapply(tree: Apply)(using Context): Option[(List[Literal], List[Tree])] = { - tree match { + /** Match trees that resemble s and raw string interpolations. In the case of the s + * interpolator, escapes the string constants. Exposes the string constants as well as + * the variable references. + */ + private object StringContextIntrinsic: + def unapply(tree: Apply)(using Context): Option[(List[Literal], List[Tree])] = + tree match case SOrRawInterpolator(strs, elems) => - if (tree.symbol == defn.StringContext_raw) Some(strs, elems) - else { // tree.symbol == defn.StringContextS + if tree.symbol == defn.StringContext_raw then Some(strs, elems) + else // tree.symbol == defn.StringContextS import dotty.tools.dotc.util.SourcePosition var stringPosition: SourcePosition = null - try { - val escapedStrs = strs.map(str => { + try + val escapedStrs = strs.map { str => stringPosition = str.sourcePos val escaped = StringContext.processEscapes(str.const.stringValue) cpy.Literal(str)(Constant(escaped)) - }) + } Some(escapedStrs, elems) - } catch { - case t @ InvalidEscapePosition(p) => { + catch + case t @ InvalidEscapePosition(p) => val errorSpan = stringPosition.span.startPos.shift(p) val errorPosition = stringPosition.withSpan(errorSpan) report.error(t.getMessage() + "\n", errorPosition) None - } - } - } case _ => None - } - } - } - override def transformApply(tree: Apply)(using Context): Tree = { + override def transformApply(tree: Apply)(using Context): Tree = def mkConcat(strs: List[Literal], elems: List[Tree]): Tree = val stri = strs.iterator val elemi = elems.iterator @@ -173,12 +149,8 @@ class StringInterpolatorOpt extends MiniPhase { ConstFold.Apply(tree).tpe match case ConstantType(x) => Literal(x).withSpan(tree.span).ensureConforms(tree.tpe) case _ => tree - } - override def transformSelect(tree: Select)(using Context): Tree = { + override def transformSelect(tree: Select)(using Context): Tree = ConstFold.Select(tree).tpe match case ConstantType(x) => Literal(x).withSpan(tree.span).ensureConforms(tree.tpe) case _ => tree - } - -} diff --git a/compiler/test/dotty/tools/dotc/transform/FormatCheckerTest.scala b/compiler/test/dotty/tools/dotc/transform/FormatCheckerTest.scala index 56bc47d172b2..984008c1f8ce 100644 --- a/compiler/test/dotty/tools/dotc/transform/FormatCheckerTest.scala +++ b/compiler/test/dotty/tools/dotc/transform/FormatCheckerTest.scala @@ -6,15 +6,14 @@ import org.junit.{Test, Assert}, Assert.{assertEquals, assertFalse, assertTrue} import scala.collection.mutable.ListBuffer import scala.language.implicitConversions import scala.reflect.{ClassTag, classTag} -import scala.util.chaining._ import java.util.{Calendar, Date, Formattable} -import localopt.{FormatChecker, StringContextChecker} +import localopt.{FormatChecker, InterpolationReporter} // TDD for just the Checker class FormatCheckerTest: - class TestReporter extends StringContextChecker.InterpolationReporter: + class TestReporter extends InterpolationReporter: private var reported = false private var oldReported = false val reports = ListBuffer.empty[(String, Int, Int)] @@ -45,17 +44,6 @@ class FormatCheckerTest: end TestReporter given TestReporter = TestReporter() - /* - enum ArgTypeTag: - case BooleanArg, ByteArg, CharArg, ShortArg, IntArg, LongArg, FloatArg, DoubleArg, AnyArg, - StringArg, FormattableArg, BigIntArg, BigDecimalArg, CalendarArg, DateArg - given Conversion[ArgTypeTag, Int] = _.ordinal - def argTypeString(tag: Int) = - if tag < 0 then "Null" - else if tag >= ArgTypeTag.values.length then throw RuntimeException(s"Bad tag $tag") - else ArgTypeTag.values(tag) - */ - class TestChecker(args: ClassTag[?]*)(using val reporter: TestReporter) extends FormatChecker: def argType(argi: Int, types: ClassTag[?]*): ClassTag[?] = types.find(_ == args(argi)).getOrElse(types.head) val argc = args.length diff --git a/tests/run-macros/f-interpolator-tests.scala b/tests/run-macros/f-interpolator-tests.scala index 73ad9b8852fa..8c59ae19a187 100755 --- a/tests/run-macros/f-interpolator-tests.scala +++ b/tests/run-macros/f-interpolator-tests.scala @@ -1,10 +1,9 @@ -/** - * These tests test all the possible formats the f interpolator has to deal with. - * The tests are sorted by argument category as the arguments are on https://docs.oracle.com/javase/6/docs/api/java/util/Formatter.html#detail - * - * - * Some tests come from https://github.com/lampepfl/dotty/pull/3894/files - */ +/** These tests test all the possible formats the f interpolator has to deal with. + * + * The tests are sorted by argument category as the arguments are on https://docs.oracle.com/javase/6/docs/api/java/util/Formatter.html#detail + * + * Some tests come from https://github.com/lampepfl/dotty/pull/3894/files + */ object Test { def main(args: Array[String]) = { println(f"integer: ${5}%d")