diff --git a/.circleci/config.yml b/.circleci/config.yml index 0c70156df..989939808 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -26,7 +26,7 @@ jobs: test: <<: *defaults steps: - - run: cat /dev/null | sbt compilerJVM/test + - run: cat /dev/null | sbt test workflows: version: 2 build_test_deploy: diff --git a/README.md b/README.md index 188a05fa6..db2c22b2f 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ # Kaitai Struct: compiler -This project is an official reference compiler for [Kaitai Struct](https://github.com/kaitai-io/kaitai_struct) project. +This project is an official reference compiler for [Kaitai Struct](http://kaitai.io) project. Kaitai Struct is a declarative language used to describe various binary data structures, laid out in files or in memory: i.e. binary @@ -19,115 +19,22 @@ languages. These modules will include the generated code for a parser that can read described data structure from a file / stream and give access to it in a nice, easy-to-comprehend API. -Please refer to [documentation in Kaitai Struct project](https://github.com/kaitai-io/kaitai_struct) -for details on `.ksy` files and general usage patterns. +## Further information -## Trying without install +If you're looking for information on: -Kaitai Struct compiler can be tried instantly, without any downloads -and installation, at - -http://kaitai.io/repl - -Note that this implementation uses the same reference code as in this -repository and executes totally on a client side, without any queries -to server backend. - -## Downloading and installing - -### Linux .deb builds (Debian/Ubuntu) - -There is an official .deb repository available. The repository is hosted -at BinTray and signed with BinTray GPG key (`379CE192D401AB61`), so it's -necessary to import that key first if your box haven't used any BinTray -repositories beforehand: - -```shell -echo "deb https://dl.bintray.com/kaitai-io/debian jessie main" | sudo tee /etc/apt/sources.list.d/kaitai.list -sudo apt-key adv --keyserver hkp://pool.sks-keyservers.net --recv 379CE192D401AB61 -sudo apt-get update -sudo apt-get install kaitai-struct-compiler -``` - -### Windows builds - -An official `.msi` installer build is available to download at - -https://bintray.com/kaitai-io/universal/kaitai-struct-compiler/_latestVersion - -### Universal builds - -Basically, everything that can run Java can use so called "universal" -builds: a .zip file that includes all the required .jar files bundled -and launcher scripts for UNIX/Windows systems. No installation -required, one can just unpack and run it. Available at download also at - -https://bintray.com/kaitai-io/universal/kaitai-struct-compiler/_latestVersion - -### Source code - -If you're interested in developing compiler itself, you can check out -source code in repository: - - git clone https://github.com/kaitai-io/kaitai_struct_compiler - -See the [developer documentation](http://doc.kaitai.io/developers.html) for -general pointers on how to proceed with the source code then. - -## Usage - -`kaitai-struct-compiler [options] ...` - -Alternatively, a symlink `ksc` is provided and can be used everywhere -just as full name. - -Common options: - -* `...` — source files (.ksy) -* `-t | --target ` — target languages (`graphviz`, `csharp`, - `all`, `perl`, `java`, `go`, `cpp_stl`, `php`, `lua`, `python`, `ruby`, `javascript` - * `all` is a special case: it compiles all possible target - languages, creating language-specific directories (as per language - identifiers) inside output directory, and then creating output - module(s) for each language starting from there -* `-d | --outdir ` — output directory - (filenames will be auto-generated) - -Language-specific options: - -* `--dot-net-namespace ` — .NET namespace (C# only, default: Kaitai) -* `--java-package ` — Java package (Java only, default: root package) -* `--php-namespace ` — PHP namespace (PHP only, default: root package) - -Misc options: - -* `--verbose` — verbose output -* `--help` — display usage information and exit -* `--version` — output version information and exit - -A few examples, given that file `foo.ksy` exists in current directory -and describes format with ID `foo`: - -* `kaitai-struct-compiler -t python foo.ksy` — compile format in - `foo.ksy`, write output in current directory to file `foo.py` -* `kaitai-struct-compiler -t java foo.ksy` — compile format in - `foo.ksy`, create "src" subdir in current one and write output in - `src/Foo.java` -* `kaitai-struct-compiler -t java --java-package org.example foo.ksy` - — compile format in `foo.ksy`, create "src/org/example" subdir tree - in current one and write output in `src/org/example/Foo.java`; - resulting file will bear correct Java package clause. -* `kaitai-struct-compiler -t all -d /tmp/out --java-package org.example foo.ksy` - — compile format in `foo.ksy`, creating a hierarchy of files: - * `/tmp/out/java/src/org/example/Foo.java` - * `/tmp/out/python/foo.py` - * `/tmp/out/ruby/foo.rb` +* Kaitai Struct language itself (`.ksy` files, general usage patterns) + — refer to the [user guide](http://doc.kaitai.io/user_guide.html). +* How to download and install Kaitai Struct — see the + [downloads](http://kaitai.io/#download). +* How to build the compiler, run the test suite, and join the + development — see the [developer memo](http://doc.kaitai.io/developers.html). ## Licensing ### Main code -Kaitai Struct compiler itself is copyright (C) 2015-2018 Kaitai +Kaitai Struct compiler itself is copyright (C) 2015-2019 Kaitai Project. This program is free software: you can redistribute it and/or modify diff --git a/RELEASE_NOTES.md b/RELEASE_NOTES.md index 3943451bb..5b044839c 100644 --- a/RELEASE_NOTES.md +++ b/RELEASE_NOTES.md @@ -1,4 +1,12 @@ -# 0.8 (TBD) +# 0.9 (TBD) + +* New targets support: + * Python with [Construct library](https://construct.readthedocs.io) +* Expression language: + * New methods: + * byte arrays: `length` + +# 0.8 (2018-02-05) * New target languages: * Lua (96% tests pass score) diff --git a/build.sbt b/build.sbt index bf0d21f88..08c9a9ae1 100644 --- a/build.sbt +++ b/build.sbt @@ -5,7 +5,7 @@ import sbt.Keys._ resolvers += Resolver.sonatypeRepo("public") -val VERSION = "0.8" +val VERSION = "0.9-SNAPSHOT" val TARGET_LANGS = "C++/STL, C#, Java, JavaScript, Lua, Perl, PHP, Python, Ruby" lazy val root = project.in(file(".")). @@ -128,6 +128,8 @@ lazy val compiler = crossProject.in(file(".")). // https://github.com/sbt/sbt-native-packager/issues/1067 debianNativeBuildOptions in Debian := Seq("-Zgzip", "-z3"), + debianPackageDependencies := Seq("java8-runtime-headless"), + packageSummary in Linux := s"compiler to generate binary data parsers in $TARGET_LANGS", packageSummary in Windows := "Kaitai Struct compiler", packageDescription in Linux := diff --git a/js/src/main/scala/io/kaitai/struct/MainJs.scala b/js/src/main/scala/io/kaitai/struct/MainJs.scala index 2f16766e3..9e11d0ad9 100644 --- a/js/src/main/scala/io/kaitai/struct/MainJs.scala +++ b/js/src/main/scala/io/kaitai/struct/MainJs.scala @@ -1,13 +1,13 @@ package io.kaitai.struct -import io.kaitai.struct.format.{ClassSpec, JavaScriptClassSpecs, JavaScriptKSYParser, KSVersion} +import io.kaitai.struct.format.{JavaScriptKSYParser, KSVersion} import io.kaitai.struct.languages.components.LanguageCompilerStatic +import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.Future import scala.scalajs.js import scala.scalajs.js.JSConverters._ import scala.scalajs.js.annotation.JSExport -import scala.concurrent.ExecutionContext.Implicits.global -import scala.concurrent.Future @JSExport object MainJs { @@ -16,13 +16,12 @@ object MainJs { @JSExport def compile(langStr: String, yaml: js.Object, importer: JavaScriptImporter, debug: Boolean = false): js.Promise[js.Dictionary[String]] = { try { - val config = new RuntimeConfig(debug = debug) + // TODO: add proper enabled by a flag + //Log.initFromVerboseFlag(Seq("file", "value", "parent", "type_resolve", "type_valid", "seq_sizes", "import")) + val config = new RuntimeConfig(autoRead = !debug, readStoresPos = debug) val lang = LanguageCompilerStatic.byString(langStr) - val yamlScala = JavaScriptKSYParser.yamlJavascriptToScala(yaml) - val firstSpec = ClassSpec.fromYaml(yamlScala) - val specs = new JavaScriptClassSpecs(importer, firstSpec) - Main.importAndPrecompile(specs, config).map { (_) => + JavaScriptKSYParser.yamlToSpecs(yaml, importer, config).map { (specs) => specs.flatMap({ case (_, spec) => val files = Main.compile(specs, spec, lang, config).files files.map((x) => x.fileName -> x.contents).toMap diff --git a/js/src/main/scala/io/kaitai/struct/format/JavaScriptKSYParser.scala b/js/src/main/scala/io/kaitai/struct/format/JavaScriptKSYParser.scala index 4aa0e1f29..4aafd4e4d 100644 --- a/js/src/main/scala/io/kaitai/struct/format/JavaScriptKSYParser.scala +++ b/js/src/main/scala/io/kaitai/struct/format/JavaScriptKSYParser.scala @@ -1,8 +1,25 @@ package io.kaitai.struct.format +import io.kaitai.struct.{JavaScriptImporter, Main, RuntimeConfig} + +import scala.concurrent.Future import scala.scalajs.js +import scala.concurrent.ExecutionContext.Implicits.global object JavaScriptKSYParser { + /** + * Converts first YAML (given as JavaScript object) to the ClassSpecs + * object, fully imported and precompiled. + * @param yaml first KSY file (YAML), given as JavaScript object + * @return future of ClassSpecs object + */ + def yamlToSpecs(yaml: Any, importer: JavaScriptImporter, config: RuntimeConfig): Future[ClassSpecs] = { + val yamlScala = yamlJavascriptToScala(yaml) + val firstSpec = ClassSpec.fromYaml(yamlScala) + val specs = new JavaScriptClassSpecs(importer, firstSpec) + Main.importAndPrecompile(specs, config).map((_) => specs) + } + def yamlJavascriptToScala(src: Any): Any = { src match { case array: js.Array[AnyRef] => diff --git a/jvm/src/main/scala/io/kaitai/struct/JavaMain.scala b/jvm/src/main/scala/io/kaitai/struct/JavaMain.scala index 712b3a549..cb0a1eb29 100644 --- a/jvm/src/main/scala/io/kaitai/struct/JavaMain.scala +++ b/jvm/src/main/scala/io/kaitai/struct/JavaMain.scala @@ -7,6 +7,7 @@ import io.kaitai.struct.CompileLog._ import io.kaitai.struct.JavaMain.CLIConfig import io.kaitai.struct.format.{ClassSpec, ClassSpecs, KSVersion, YAMLParseException} import io.kaitai.struct.formats.JavaKSYParser +import io.kaitai.struct.languages.CppCompiler import io.kaitai.struct.languages.components.LanguageCompilerStatic import io.kaitai.struct.precompile.ErrorInInput @@ -24,8 +25,9 @@ object JavaMain { runtime: RuntimeConfig = RuntimeConfig() ) - val ALL_LANGS = LanguageCompilerStatic.NAME_TO_CLASS.keySet - val VALID_LANGS = ALL_LANGS + "all" + val ALL_LANGS = LanguageCompilerStatic.NAME_TO_CLASS.keySet - "cpp_stl" + "cpp_stl_98" + "cpp_stl_11" + val VALID_LANGS = LanguageCompilerStatic.NAME_TO_CLASS.keySet + "all" + val CPP_STANDARDS = Set("98", "11") def parseCommandLine(args: Array[String]): Option[CLIConfig] = { val parser = new scopt.OptionParser[CLIConfig](BuildInfo.name) { @@ -60,10 +62,37 @@ object JavaMain { } text("output directory (filenames will be auto-generated)") val importPathExample = List("", "", "...").mkString(File.pathSeparator) - opt[String]('I', "import-path") valueName(importPathExample) action { (x, c) => + opt[String]('I', "import-path") optional() unbounded() valueName(importPathExample) action { (x, c) => c.copy(importPaths = c.importPaths ++ x.split(File.pathSeparatorChar)) } text(".ksy library search path(s) for imports (see also KSPATH env variable)") + opt[String]("cpp-namespace") valueName("") action { (x, c) => + c.copy( + runtime = c.runtime.copy( + cppConfig = c.runtime.cppConfig.copy( + namespace = x.split("::").toList + ) + ) + ) + } text("C++ namespace (C++ only, default: none)") + + opt[String]("cpp-standard") valueName("") action { (x, c) => + c.copy( + runtime = c.runtime.copy( + cppConfig = x match { + case "98" => c.runtime.cppConfig.copyAsCpp98() + case "11" => c.runtime.cppConfig.copyAsCpp11() + } + ) + ) + } text("C++ standard to target (C++ only, supported: 98, 11, default: 98)") validate { x => + if (CPP_STANDARDS.contains(x)) { + success + } else { + failure(s"'$x' is not a valid C++ standard to target; valid ones are: ${CPP_STANDARDS.mkString(", ")}") + } + } + opt[String]("go-package") valueName("") action { (x, c) => c.copy(runtime = c.runtime.copy(goPackage = x)) } text("Go package (Go only, default: none)") @@ -115,9 +144,18 @@ object JavaMain { } } + opt[Unit]("no-auto-read") action { (x, c) => + c.copy(runtime = c.runtime.copy(autoRead = false)) + } text("disable auto-running `_read` in constructor") + + opt[Unit]("read-pos") action { (x, c) => + c.copy(runtime = c.runtime.copy(readStoresPos = true)) + } text("`_read` remembers attribute positions in stream") + opt[Unit]("debug") action { (x, c) => - c.copy(runtime = c.runtime.copy(debug = true)) - } text("enable debugging helpers (mostly used by visualization tools)") + c.copy(runtime = c.runtime.copy(autoRead = false, readStoresPos = true)) + } text("same as --no-auto-read --read-pos (useful for visualization tools)") + help("help") text("display this help and exit") version("version") text("output version information and exit") } @@ -259,10 +297,19 @@ class JavaMain(config: CLIConfig) { def compileOneLang(specs: ClassSpecs, langStr: String, outDir: String): Map[String, SpecEntry] = { Log.fileOps.info(() => s"... compiling it for $langStr... ") - val lang = LanguageCompilerStatic.byString(langStr) + + val (lang, fixedRuntime) = langStr match { + case "cpp_stl_98" => + (CppCompiler, config.runtime.copy(cppConfig = config.runtime.cppConfig.copyAsCpp98())) + case "cpp_stl_11" => + (CppCompiler, config.runtime.copy(cppConfig = config.runtime.cppConfig.copyAsCpp11())) + case _ => + (LanguageCompilerStatic.byString(langStr), config.runtime) + } + specs.map { case (_, classSpec) => val res = try { - compileSpecAndWriteToFile(specs, classSpec, lang, outDir) + compileSpecAndWriteToFile(specs, classSpec, lang, fixedRuntime, outDir) } catch { case ex: Throwable => if (config.throwExceptions) @@ -277,9 +324,10 @@ class JavaMain(config: CLIConfig) { specs: ClassSpecs, spec: ClassSpec, lang: LanguageCompilerStatic, + runtime: RuntimeConfig, outDir: String ): SpecSuccess = { - val res = Main.compile(specs, spec, lang, config.runtime) + val res = Main.compile(specs, spec, lang, runtime) res.files.foreach { (file) => Log.fileOps.info(() => s".... writing ${file.fileName}") @@ -298,7 +346,7 @@ class JavaMain(config: CLIConfig) { private def exceptionToCompileError(ex: Throwable, srcFile: String): CompileError = { if (!config.jsonOutput) - Console.println(ex.getMessage) + Console.err.println(ex.getMessage) ex match { case ype: YAMLParseException => CompileError("(main)", ype.path, ype.msg) diff --git a/jvm/src/main/scala/io/kaitai/struct/formats/JavaKSYParser.scala b/jvm/src/main/scala/io/kaitai/struct/formats/JavaKSYParser.scala index 57e371564..b6326b20e 100644 --- a/jvm/src/main/scala/io/kaitai/struct/formats/JavaKSYParser.scala +++ b/jvm/src/main/scala/io/kaitai/struct/formats/JavaKSYParser.scala @@ -1,6 +1,7 @@ package io.kaitai.struct.formats -import java.io.{File, FileReader} +import java.io._ +import java.nio.charset.Charset import java.util.{List => JList, Map => JMap} import io.kaitai.struct.JavaMain.CLIConfig @@ -25,11 +26,19 @@ object JavaKSYParser { def fileNameToSpec(yamlFilename: String): ClassSpec = { Log.fileOps.info(() => s"reading $yamlFilename...") - val scalaSrc = readerToYaml(new FileReader(yamlFilename)) + + // This complex string of classes is due to the fact that Java's + // default "FileReader" implementation always uses system locale, + // which screws up encoding on some systems and screws up reading + // UTF-8 files with BOM + val fis = new FileInputStream(yamlFilename) + val isr = new InputStreamReader(fis, Charset.forName("UTF-8")) + val br = new BufferedReader(isr) + val scalaSrc = readerToYaml(br) ClassSpec.fromYaml(scalaSrc) } - def readerToYaml(reader: FileReader): Any = { + def readerToYaml(reader: Reader): Any = { val yamlLoader = new Yaml(new SafeConstructor) val javaSrc = yamlLoader.load(reader) yamlJavaToScala(javaSrc) diff --git a/jvm/src/test/scala/io/kaitai/struct/datatype/SwitchType$Test.scala b/jvm/src/test/scala/io/kaitai/struct/datatype/SwitchType$Test.scala new file mode 100644 index 000000000..eb2fef40f --- /dev/null +++ b/jvm/src/test/scala/io/kaitai/struct/datatype/SwitchType$Test.scala @@ -0,0 +1,40 @@ +package io.kaitai.struct.datatype + +import io.kaitai.struct.datatype.DataType.SwitchType +import io.kaitai.struct.exprlang.Expressions +import io.kaitai.struct.format.ClassSpec +import org.scalatest.FunSpec +import org.scalatest.Matchers._ + +class SwitchType$Test extends FunSpec { + describe("SwitchType.parseSwitch") { + it ("combines ints properly") { + val t = SwitchType( + Expressions.parse("foo"), + Map( + Expressions.parse("1") -> DataType.IntMultiType(true, DataType.Width2, Some(LittleEndian)), + Expressions.parse("2") -> DataType.IntMultiType(false, DataType.Width4, Some(LittleEndian)) + ) + ) + + t.combinedType should be(DataType.CalcIntType) + } + + it ("combines owning user types properly") { + val ut1 = DataType.UserTypeInstream(List("foo"), None) + ut1.classSpec = Some(ClassSpec.opaquePlaceholder(List("foo"))) + val ut2 = DataType.UserTypeInstream(List("bar"), None) +// ut2.classSpec = Some(ClassSpec.opaquePlaceholder(List("bar"))) + + val t = SwitchType( + Expressions.parse("foo"), + Map( + Expressions.parse("1") -> ut1, + Expressions.parse("2") -> ut2 + ) + ) + + t.combinedType should be(DataType.KaitaiStructType) + } + } +} diff --git a/jvm/src/test/scala/io/kaitai/struct/exprlang/ExpressionsSpec.scala b/jvm/src/test/scala/io/kaitai/struct/exprlang/ExpressionsSpec.scala index 5e0fe54e3..0c4769671 100644 --- a/jvm/src/test/scala/io/kaitai/struct/exprlang/ExpressionsSpec.scala +++ b/jvm/src/test/scala/io/kaitai/struct/exprlang/ExpressionsSpec.scala @@ -131,10 +131,41 @@ class ExpressionsSpec extends FunSpec { Expressions.parse("~(7+3)") should be (UnaryOp(Invert, BinOp(IntNum(7), Add, IntNum(3)))) } + // Enums it("parses port::http") { Expressions.parse("port::http") should be (EnumByLabel(identifier("port"), identifier("http"))) } + it("parses some_type::port::http") { + Expressions.parse("some_type::port::http") should be ( + EnumByLabel( + identifier("port"), + identifier("http"), + typeId(absolute = false, Seq("some_type")) + ) + ) + } + + it("parses parent_type::child_type::port::http") { + Expressions.parse("parent_type::child_type::port::http") should be ( + EnumByLabel( + identifier("port"), + identifier("http"), + typeId(absolute = false, Seq("parent_type", "child_type")) + ) + ) + } + + it("parses ::parent_type::child_type::port::http") { + Expressions.parse("::parent_type::child_type::port::http") should be ( + EnumByLabel( + identifier("port"), + identifier("http"), + typeId(absolute = true, Seq("parent_type", "child_type")) + ) + ) + } + it("parses port::http.to_i + 8000 == 8080") { Expressions.parse("port::http.to_i + 8000 == 8080") should be ( Compare( @@ -171,6 +202,36 @@ class ExpressionsSpec extends FunSpec { Expressions.parse("truer") should be (Name(identifier("truer"))) } + // Boolean operations + it("parses not foo") { + Expressions.parse("not foo") should be ( + UnaryOp( + Ast.unaryop.Not, + Name(identifier("foo")) + ) + ) + } + + it("parses note_len") { + Expressions.parse("note_len") should be (Name(identifier("note_len"))) + } + + it("parses notnot") { + Expressions.parse("notnot") should be (Name(identifier("notnot"))) + } + + it("parses not not true") { + Expressions.parse("not not true") should be ( + UnaryOp( + Ast.unaryop.Not, + UnaryOp( + Ast.unaryop.Not, + Bool(true) + ) + ) + ) + } + // String literals it("parses simple string") { Expressions.parse("\"abc\"") should be (Str("abc")) @@ -202,23 +263,57 @@ class ExpressionsSpec extends FunSpec { // Casts it("parses 123.as") { - Expressions.parse("123.as") should be (CastToType(IntNum(123),identifier("u4"))) + Expressions.parse("123.as") should be ( + CastToType(IntNum(123), typeId(false, Seq("u4"))) + ) } it("parses (123).as") { - Expressions.parse("(123).as") should be (CastToType(IntNum(123),identifier("u4"))) + Expressions.parse("(123).as") should be ( + CastToType(IntNum(123), typeId(false, Seq("u4"))) + ) } it("parses \"str\".as") { - Expressions.parse("\"str\".as") should be (CastToType(Str("str"),identifier("x"))) + Expressions.parse("\"str\".as") should be ( + CastToType(Str("str"), typeId(false, Seq("x"))) + ) } it("parses foo.as") { - Expressions.parse("foo.as") should be (CastToType(Name(identifier("foo")),identifier("x"))) + Expressions.parse("foo.as") should be ( + CastToType(Name(identifier("foo")), typeId(false, Seq("x"))) + ) } it("parses foo.as < x > ") { - Expressions.parse("foo.as < x > ") should be (CastToType(Name(identifier("foo")),identifier("x"))) + Expressions.parse("foo.as < x > ") should be ( + CastToType(Name(identifier("foo")), typeId(false, Seq("x"))) + ) + } + + it("parses foo.as") { + Expressions.parse("foo.as") should be ( + CastToType(Name(identifier("foo")), typeId(false, Seq("bar", "baz"))) + ) + } + + it("parses foo.as<::bar::baz>") { + Expressions.parse("foo.as<::bar::baz>") should be ( + CastToType(Name(identifier("foo")), typeId(true, Seq("bar", "baz"))) + ) + } + + it("parses foo.as") { + Expressions.parse("foo.as") should be ( + CastToType(Name(identifier("foo")), typeId(false, Seq("bar"), true)) + ) + } + + it("parses foo.as<::bar::baz[]>") { + Expressions.parse("foo.as<::bar::baz[]>") should be ( + CastToType(Name(identifier("foo")), typeId(true, Seq("bar", "baz"), true)) + ) } it("parses foo.as") { diff --git a/jvm/src/test/scala/io/kaitai/struct/translators/TranslatorSpec.scala b/jvm/src/test/scala/io/kaitai/struct/translators/TranslatorSpec.scala index 70a6efa1b..e5247c27d 100644 --- a/jvm/src/test/scala/io/kaitai/struct/translators/TranslatorSpec.scala +++ b/jvm/src/test/scala/io/kaitai/struct/translators/TranslatorSpec.scala @@ -6,8 +6,8 @@ import io.kaitai.struct.exprlang.{Ast, Expressions} import io.kaitai.struct.format.ClassSpec import io.kaitai.struct.languages._ import io.kaitai.struct.languages.components.LanguageCompilerStatic -import io.kaitai.struct.{ImportList, RuntimeConfig} -import org.scalatest.FunSuite +import io.kaitai.struct.{ImportList, RuntimeConfig, StringLanguageOutputWriter} +import org.scalatest.{FunSuite, Tag} import org.scalatest.Matchers._ class TranslatorSpec extends FunSuite { @@ -34,7 +34,6 @@ class TranslatorSpec extends FunSuite { everybodyExcept("3 / 2", "(3 / 2)", Map( JavaScriptCompiler -> "Math.floor(3 / 2)", - LuaCompiler -> "3 / 2", PerlCompiler -> "int(3 / 2)", PHPCompiler -> "intval(3 / 2)", PythonCompiler -> "3 // 2" @@ -44,7 +43,6 @@ class TranslatorSpec extends FunSuite { everybodyExcept("(1 + 2) / (7 * 8)", "((1 + 2) / (7 * 8))", Map( JavaScriptCompiler -> "Math.floor((1 + 2) / (7 * 8))", - LuaCompiler -> "(1 + 2) / (7 * 8)", PerlCompiler -> "int((1 + 2) / (7 * 8))", PHPCompiler -> "intval((1 + 2) / (7 * 8))", PythonCompiler -> "(1 + 2) // (7 * 8)" @@ -57,6 +55,13 @@ class TranslatorSpec extends FunSuite { full("2 < 3 ? \"foo\" : \"bar\"", CalcIntType, CalcStrType, Map[LanguageCompilerStatic, String]( CppCompiler -> "(2 < 3) ? (std::string(\"foo\")) : (std::string(\"bar\"))", CSharpCompiler -> "2 < 3 ? \"foo\" : \"bar\"", + GoCompiler -> """var tmp1 string; + |if (2 < 3) { + | tmp1 = "foo" + |} else { + | tmp1 = "bar" + |} + |tmp1""".stripMargin, JavaCompiler -> "2 < 3 ? \"foo\" : \"bar\"", JavaScriptCompiler -> "2 < 3 ? \"foo\" : \"bar\"", LuaCompiler -> "2 < 3 and \"foo\" or \"bar\"", @@ -66,8 +71,12 @@ class TranslatorSpec extends FunSuite { RubyCompiler -> "2 < 3 ? \"foo\" : \"bar\"" )) - everybody("~777", "~777") - everybody("~(7+3)", "~((7 + 3))") + everybodyExcept("~777", "~777", Map[LanguageCompilerStatic, String]( + GoCompiler -> "^777" + )) + everybodyExcept("~(7+3)", "~((7 + 3))", Map[LanguageCompilerStatic, String]( + GoCompiler -> "^((7 + 3))" + )) // Simple float operations everybody("1.2 + 3.4", "(1.2 + 3.4)", CalcFloatType) @@ -84,6 +93,7 @@ class TranslatorSpec extends FunSuite { full("true", CalcBooleanType, CalcBooleanType, Map[LanguageCompilerStatic, String]( CppCompiler -> "true", CSharpCompiler -> "true", + GoCompiler -> "true", JavaCompiler -> "true", JavaScriptCompiler -> "true", LuaCompiler -> "true", @@ -96,6 +106,7 @@ class TranslatorSpec extends FunSuite { full("false", CalcBooleanType, CalcBooleanType, Map[LanguageCompilerStatic, String]( CppCompiler -> "false", CSharpCompiler -> "false", + GoCompiler -> "false", JavaCompiler -> "false", JavaScriptCompiler -> "false", LuaCompiler -> "false", @@ -108,6 +119,11 @@ class TranslatorSpec extends FunSuite { full("some_bool.to_i", CalcBooleanType, CalcIntType, Map[LanguageCompilerStatic, String]( CppCompiler -> "some_bool()", CSharpCompiler -> "(SomeBool ? 1 : 0)", + GoCompiler -> """tmp1 := 0 + |if this.SomeBool { + | tmp1 = 1 + |} + |tmp1""".stripMargin, JavaCompiler -> "(someBool() ? 1 : 0)", JavaScriptCompiler -> "(this.someBool | 0)", LuaCompiler -> "self.some_bool and 1 or 0", @@ -121,6 +137,7 @@ class TranslatorSpec extends FunSuite { full("foo_str", CalcStrType, CalcStrType, Map[LanguageCompilerStatic, String]( CppCompiler -> "foo_str()", CSharpCompiler -> "FooStr", + GoCompiler -> "this.FooStr", JavaCompiler -> "fooStr()", JavaScriptCompiler -> "this.fooStr", LuaCompiler -> "self.foo_str", @@ -130,9 +147,10 @@ class TranslatorSpec extends FunSuite { RubyCompiler -> "foo_str" )) - full("foo_block", userType("block"), userType("block"), Map[LanguageCompilerStatic, String]( + full("foo_block", userType(List("block")), userType(List("block")), Map[LanguageCompilerStatic, String]( CppCompiler -> "foo_block()", CSharpCompiler -> "FooBlock", + GoCompiler -> "this.FooBlock", JavaCompiler -> "fooBlock()", JavaScriptCompiler -> "this.fooBlock", LuaCompiler -> "self.foo_block", @@ -145,6 +163,7 @@ class TranslatorSpec extends FunSuite { full("foo.bar", FooBarProvider, CalcStrType, Map[LanguageCompilerStatic, String]( CppCompiler -> "foo()->bar()", CSharpCompiler -> "Foo.Bar", + GoCompiler -> "this.Foo.Bar", JavaCompiler -> "foo().bar()", JavaScriptCompiler -> "this.foo.bar", LuaCompiler -> "self.foo.bar", @@ -157,6 +176,7 @@ class TranslatorSpec extends FunSuite { full("foo.inner.baz", FooBarProvider, CalcIntType, Map[LanguageCompilerStatic, String]( CppCompiler -> "foo()->inner()->baz()", CSharpCompiler -> "Foo.Inner.Baz", + GoCompiler -> "this.Foo.Inner.Baz", JavaCompiler -> "foo().inner().baz()", JavaScriptCompiler -> "this.foo.inner.baz", LuaCompiler -> "self.foo.inner.baz", @@ -166,9 +186,10 @@ class TranslatorSpec extends FunSuite { RubyCompiler -> "foo.inner.baz" )) - full("_root.foo", userType("block"), userType("block"), Map[LanguageCompilerStatic, String]( + full("_root.foo", userType(List("top_class", "block")), userType(List("top_class", "block")), Map[LanguageCompilerStatic, String]( CppCompiler -> "_root()->foo()", CSharpCompiler -> "M_Root.Foo", + GoCompiler -> "this._root.Foo", JavaCompiler -> "_root.foo()", JavaScriptCompiler -> "this._root.foo", LuaCompiler -> "self._root.foo", @@ -181,6 +202,7 @@ class TranslatorSpec extends FunSuite { full("a != 2 and a != 5", CalcIntType, CalcBooleanType, Map[LanguageCompilerStatic, String]( CppCompiler -> "a() != 2 && a() != 5", CSharpCompiler -> "A != 2 && A != 5", + GoCompiler -> "a != 2 && a != 5", JavaCompiler -> "a() != 2 && a() != 5", JavaScriptCompiler -> "this.a != 2 && this.a != 5", LuaCompiler -> "self.a ~= 2 and self.a ~= 5", @@ -193,7 +215,8 @@ class TranslatorSpec extends FunSuite { // Arrays full("[0, 1, 100500]", CalcIntType, ArrayType(CalcIntType), Map[LanguageCompilerStatic, String]( CSharpCompiler -> "new List { 0, 1, 100500 }", - JavaCompiler -> "new ArrayList(Arrays.asList(0L, 1L, 100500L))", + GoCompiler -> "[]int{0, 1, 100500}", + JavaCompiler -> "new ArrayList(Arrays.asList(0, 1L, 100500))", JavaScriptCompiler -> "[0, 1, 100500]", LuaCompiler -> "{0, 1, 100500}", PerlCompiler -> "(0, 1, 100500)", @@ -205,31 +228,45 @@ class TranslatorSpec extends FunSuite { full("[34, 0, 10, 64, 65, 66, 92]", CalcIntType, CalcBytesType, Map[LanguageCompilerStatic, String]( CppCompiler -> "std::string(\"\\x22\\x00\\x0A\\x40\\x41\\x42\\x5C\", 7)", CSharpCompiler -> "new byte[] { 34, 0, 10, 64, 65, 66, 92 }", + GoCompiler -> "\"\\x22\\x00\\x0A\\x40\\x41\\x42\\x5C\"", JavaCompiler -> "new byte[] { 34, 0, 10, 64, 65, 66, 92 }", JavaScriptCompiler -> "[34, 0, 10, 64, 65, 66, 92]", LuaCompiler -> "\"\\034\\000\\010\\064\\065\\066\\092\"", PerlCompiler -> "pack('C*', (34, 0, 10, 64, 65, 66, 92))", PHPCompiler -> "\"\\x22\\x00\\x0A\\x40\\x41\\x42\\x5C\"", - PythonCompiler -> "struct.pack('7b', 34, 0, 10, 64, 65, 66, 92)", + PythonCompiler -> "b\"\\x22\\x00\\x0A\\x40\\x41\\x42\\x5C\"", RubyCompiler -> "[34, 0, 10, 64, 65, 66, 92].pack('C*')" )) full("[255, 0, 255]", CalcIntType, CalcBytesType, Map[LanguageCompilerStatic, String]( CppCompiler -> "std::string(\"\\xFF\\x00\\xFF\", 3)", CSharpCompiler -> "new byte[] { 255, 0, 255 }", + GoCompiler -> "\"\\xFF\\x00\\xFF\"", JavaCompiler -> "new byte[] { -1, 0, -1 }", JavaScriptCompiler -> "[255, 0, 255]", LuaCompiler -> "\"\\255\\000\\255\"", PerlCompiler -> "pack('C*', (255, 0, 255))", PHPCompiler -> "\"\\xFF\\x00\\xFF\"", - PythonCompiler -> "struct.pack('3b', -1, 0, -1)", + PythonCompiler -> "b\"\\255\\000\\255\"", RubyCompiler -> "[255, 0, 255].pack('C*')" )) + full("[0, 1, 2].length", CalcIntType, CalcIntType, Map[LanguageCompilerStatic, String]( + CppCompiler -> "std::string(\"\\x00\\x01\\x02\", 3).length()", + GoCompiler -> "len(\"\\x00\\x01\\x02\")", + JavaCompiler -> "new byte[] { 0, 1, 2 }.length", + LuaCompiler -> "string.len(\"str\")", + PerlCompiler -> "length(pack('C*', (0, 1, 2)))", + PHPCompiler -> "strlen(\"\\x00\\x01\\x02\")", + PythonCompiler -> "len(b\"\\x00\\x01\\x02\")", + RubyCompiler -> "[0, 1, 2].pack('C*').size" + )) + full("a[42]", ArrayType(CalcStrType), CalcStrType, Map[LanguageCompilerStatic, String]( CppCompiler -> "a()->at(42)", CSharpCompiler -> "A[42]", - JavaCompiler -> "a().get(42)", + GoCompiler -> "this.A[42]", + JavaCompiler -> "a().get((int) 42)", JavaScriptCompiler -> "this.a[42]", LuaCompiler -> "self.a[43]", PHPCompiler -> "$this->a()[42]", @@ -240,6 +277,7 @@ class TranslatorSpec extends FunSuite { full("a[42 - 2]", ArrayType(CalcStrType), CalcStrType, Map[LanguageCompilerStatic, String]( CppCompiler -> "a()->at((42 - 2))", CSharpCompiler -> "A[(42 - 2)]", + GoCompiler -> "this.A[(42 - 2)]", JavaCompiler -> "a().get((42 - 2))", JavaScriptCompiler -> "this.a[(42 - 2)]", LuaCompiler -> "self.a[(43 - 2)]", @@ -251,6 +289,7 @@ class TranslatorSpec extends FunSuite { full("a.first", ArrayType(CalcIntType), CalcIntType, Map[LanguageCompilerStatic, String]( CppCompiler -> "a()->front()", CSharpCompiler -> "A[0]", + GoCompiler -> "this.A[0]", JavaCompiler -> "a().get(0)", JavaScriptCompiler -> "this.a[0]", LuaCompiler -> "self.a[1]", @@ -261,7 +300,8 @@ class TranslatorSpec extends FunSuite { full("a.last", ArrayType(CalcIntType), CalcIntType, Map[LanguageCompilerStatic, String]( CppCompiler -> "a()->back()", - CSharpCompiler -> "A[A.Length - 1]", + CSharpCompiler -> "A[A.Count - 1]", + GoCompiler -> "this.A[len(this.A)-1]", JavaCompiler -> "a().get(a().size() - 1)", JavaScriptCompiler -> "this.a[this.a.length - 1]", LuaCompiler -> "self.a[#self.a]", @@ -273,6 +313,7 @@ class TranslatorSpec extends FunSuite { full("a.size", ArrayType(CalcIntType), CalcIntType, Map[LanguageCompilerStatic, String]( CppCompiler -> "a()->size()", CSharpCompiler -> "A.Count", + GoCompiler -> "len(this.A)", JavaCompiler -> "a().size()", JavaScriptCompiler -> "this.a.length", LuaCompiler -> "#self.a", @@ -286,6 +327,7 @@ class TranslatorSpec extends FunSuite { full("\"str\"", CalcIntType, CalcStrType, Map[LanguageCompilerStatic, String]( CppCompiler -> "std::string(\"str\")", CSharpCompiler -> "\"str\"", + GoCompiler -> "\"str\"", JavaCompiler -> "\"str\"", JavaScriptCompiler -> "\"str\"", LuaCompiler -> "\"str\"", @@ -298,6 +340,7 @@ class TranslatorSpec extends FunSuite { full("\"str\\nnext\"", CalcIntType, CalcStrType, Map[LanguageCompilerStatic, String]( CppCompiler -> "std::string(\"str\\nnext\")", CSharpCompiler -> "\"str\\nnext\"", + GoCompiler -> "\"str\\nnext\"", JavaCompiler -> "\"str\\nnext\"", JavaScriptCompiler -> "\"str\\nnext\"", LuaCompiler -> "\"str\\nnext\"", @@ -310,6 +353,7 @@ class TranslatorSpec extends FunSuite { full("\"str\\u000anext\"", CalcIntType, CalcStrType, Map[LanguageCompilerStatic, String]( CppCompiler -> "std::string(\"str\\nnext\")", CSharpCompiler -> "\"str\\nnext\"", + GoCompiler -> "\"str\\u000anext\"", JavaCompiler -> "\"str\\nnext\"", JavaScriptCompiler -> "\"str\\nnext\"", LuaCompiler -> "\"str\\nnext\"", @@ -322,6 +366,7 @@ class TranslatorSpec extends FunSuite { full("\"str\\0next\"", CalcIntType, CalcStrType, Map[LanguageCompilerStatic, String]( CppCompiler -> "std::string(\"str\\000next\", 8)", CSharpCompiler -> "\"str\\0next\"", + GoCompiler -> "\"str\\000next\"", JavaCompiler -> "\"str\\000next\"", JavaScriptCompiler -> "\"str\\000next\"", LuaCompiler -> "\"str\\000next\"", @@ -367,8 +412,9 @@ class TranslatorSpec extends FunSuite { full("\"str\".length", CalcIntType, CalcIntType, Map[LanguageCompilerStatic, String]( CppCompiler -> "std::string(\"str\").length()", CSharpCompiler -> "\"str\".Length", + GoCompiler -> "utf8.RuneCountInString(\"str\")", JavaCompiler -> "\"str\".length()", - JavaScriptCompiler -> "#\"str\"", + JavaScriptCompiler -> "\"str\".length", LuaCompiler -> "string.len(\"str\")", PerlCompiler -> "length(\"str\")", PHPCompiler -> "strlen(\"str\")", @@ -379,6 +425,7 @@ class TranslatorSpec extends FunSuite { full("\"str\".reverse", CalcIntType, CalcStrType, Map[LanguageCompilerStatic, String]( CppCompiler -> "kaitai::kstream::reverse(std::string(\"str\"))", CSharpCompiler -> "new string(Array.Reverse(\"str\".ToCharArray()))", + GoCompiler -> "kaitai.StringReverse(\"str\")", JavaCompiler -> "new StringBuilder(\"str\").reverse().toString()", JavaScriptCompiler -> "Array.from(\"str\").reverse().join('')", LuaCompiler -> "string.reverse(\"str\")", @@ -391,6 +438,7 @@ class TranslatorSpec extends FunSuite { full("\"12345\".to_i", CalcIntType, CalcIntType, Map[LanguageCompilerStatic, String]( CppCompiler -> "std::stoi(std::string(\"12345\"))", CSharpCompiler -> "Convert.ToInt64(\"12345\", 10)", + GoCompiler -> "func()(int){i, err := strconv.Atoi(\"12345\"); if (err != nil) { panic(err) }; return i}()", JavaCompiler -> "Long.parseLong(\"12345\", 10)", JavaScriptCompiler -> "Number.parseInt(\"12345\", 10)", LuaCompiler -> "tonumber(\"12345\")", @@ -403,6 +451,7 @@ class TranslatorSpec extends FunSuite { full("\"1234fe\".to_i(16)", CalcIntType, CalcIntType, Map[LanguageCompilerStatic, String]( CppCompiler -> "std::stoi(std::string(\"1234fe\"), 0, 16)", CSharpCompiler -> "Convert.ToInt64(\"1234fe\", 16)", + GoCompiler -> "func()(int64){i, err := strconv.ParseInt(\"1234fe\", 16, 64); if (err != nil) { panic(err) }; return i}()", JavaCompiler -> "Long.parseLong(\"1234fe\", 16)", JavaScriptCompiler -> "Number.parseInt(\"1234fe\", 16)", LuaCompiler -> "tonumber(\"1234fe\", 16)", @@ -416,6 +465,7 @@ class TranslatorSpec extends FunSuite { full("other.as.bar", FooBarProvider, CalcStrType, Map[LanguageCompilerStatic, String]( CppCompiler -> "static_cast(other())->bar()", CSharpCompiler -> "((Block) (Other)).Bar", + GoCompiler -> "this.Other.(Block).Bar", JavaCompiler -> "((Block) (other())).bar()", JavaScriptCompiler -> "this.other.bar", LuaCompiler -> "self.other.bar", @@ -425,15 +475,112 @@ class TranslatorSpec extends FunSuite { RubyCompiler -> "other.bar" )) + full("other.as.baz", FooBarProvider, CalcIntType, Map[LanguageCompilerStatic, String]( + CppCompiler -> "static_cast(other())->baz()", + CSharpCompiler -> "((Block.Innerblock) (Other)).Baz", + GoCompiler -> "this.Other.(Block.Innerblock).Baz", + JavaCompiler -> "((Block.Innerblock) (other())).baz()", + JavaScriptCompiler -> "this.other.baz", + LuaCompiler -> "self.other.baz", + PerlCompiler -> "$self->other()->baz()", + PHPCompiler -> "$this->other()->baz()", + PythonCompiler -> "self.other.baz", + RubyCompiler -> "other.baz" + )) + + // primitive pure types + full("(1 + 2).as", CalcIntType, IntMultiType(true, Width2, None), Map[LanguageCompilerStatic, String]( + CppCompiler -> "static_cast((1 + 2))", + CSharpCompiler -> "((short) ((1 + 2)))", + GoCompiler -> "int16((1 + 2))", + JavaCompiler -> "((short) ((1 + 2)))", + JavaScriptCompiler -> "(1 + 2)", + LuaCompiler -> "(1 + 2)", + PerlCompiler -> "(1 + 2)", + PHPCompiler -> "(1 + 2)", + PythonCompiler -> "(1 + 2)", + RubyCompiler -> "(1 + 2)" + )) + + // empty array casting + full("[].as", CalcIntType, CalcBytesType, Map[LanguageCompilerStatic, String]( + CppCompiler -> "std::string(\"\", 0)", + CSharpCompiler -> "new byte[] { }", + GoCompiler -> "\"\"", + JavaCompiler -> "new byte[] { }", + JavaScriptCompiler -> "[]", + LuaCompiler -> "\"\"", + PerlCompiler -> "pack('C*', ())", + PHPCompiler -> "\"\"", + PythonCompiler -> "b\"\"", + RubyCompiler -> "[].pack('C*')" + )) + + full("[].as", CalcIntType, ArrayType(Int1Type(false)), Map[LanguageCompilerStatic, String]( + CppCompiler -> "std::string(\"\")", + CSharpCompiler -> "new List { }", + GoCompiler -> "[]uint8{}", + JavaCompiler -> "new ArrayList(Arrays.asList())", + JavaScriptCompiler -> "[]", + LuaCompiler -> "{}", + PerlCompiler -> "()", + PHPCompiler -> "[]", + PythonCompiler -> "[]", + RubyCompiler -> "[]" + )) + + full("[].as", CalcIntType, ArrayType(FloatMultiType(Width8, None)), Map[LanguageCompilerStatic, String]( + CppCompiler -> "std::string(\"\", 0)", + CSharpCompiler -> "new List { }", + GoCompiler -> "[]float64{}", + JavaCompiler -> "new ArrayList(Arrays.asList())", + JavaScriptCompiler -> "[]", + LuaCompiler -> "{}", + PerlCompiler -> "()", + PHPCompiler -> "[]", + PythonCompiler -> "[]", + RubyCompiler -> "[]" + )) + + // type enforcement: casting to non-literal byte array + full("[0 + 1, 5].as", CalcIntType, CalcBytesType, Map[LanguageCompilerStatic, String]( + CppCompiler -> "???", + CSharpCompiler -> "new byte[] { (0 + 1), 5 }", + GoCompiler -> "string([]byte{(0 + 1), 5})", + JavaCompiler -> "new byte[] { (0 + 1), 5 }", + JavaScriptCompiler -> "new Uint8Array([(0 + 1), 5])", + LuaCompiler -> "???", + PerlCompiler -> "pack('C*', ((0 + 1), 5))", + PHPCompiler -> "pack('C*', (0 + 1), 5)", + PythonCompiler -> "struct.pack('2b', (0 + 1), 5)", + RubyCompiler -> "[(0 + 1), 5].pack('C*')" + )) + + // type enforcement: casting to array of integers + full("[0, 1, 2].as", CalcIntType, ArrayType(Int1Type(false)), Map[LanguageCompilerStatic, String]( + CSharpCompiler -> "new List { 0, 1, 2 }", + GoCompiler -> "[]uint8{0, 1, 2}", + JavaCompiler -> "new ArrayList(Arrays.asList(0, 1, 2))", + JavaScriptCompiler -> "[0, 1, 2]", + LuaCompiler -> "{0, 1, 2}", + PerlCompiler -> "(0, 1, 2)", + PHPCompiler -> "[0, 1, 2]", + PythonCompiler -> "[0, 1, 2]", + RubyCompiler -> "[0, 1, 2]" + )) + def runTest(src: String, tp: TypeProvider, expType: DataType, expOut: ResultMap) { var eo: Option[Ast.expr] = None test(s"_expr:$src") { eo = Some(Expressions.parse(src)) } - val langs = Map[LanguageCompilerStatic, BaseTranslator]( - CppCompiler -> new CppTranslator(tp, new ImportList()), + val goOutput = new StringLanguageOutputWriter(" ") + + val langs = Map[LanguageCompilerStatic, AbstractTranslator with TypeDetector]( + CppCompiler -> new CppTranslator(tp, new ImportList(), RuntimeConfig()), CSharpCompiler -> new CSharpTranslator(tp, new ImportList()), + GoCompiler -> new GoTranslator(goOutput, tp, new ImportList()), JavaCompiler -> new JavaTranslator(tp, new ImportList()), JavaScriptCompiler -> new JavaScriptTranslator(tp), LuaCompiler -> new LuaTranslator(tp, new ImportList()), @@ -445,14 +592,19 @@ class TranslatorSpec extends FunSuite { langs.foreach { case (langObj, tr) => val langName = LanguageCompilerStatic.CLASS_TO_NAME(langObj) - test(s"$langName:$src") { + test(s"$langName:$src", Tag(langName), Tag(src)) { eo match { case Some(e) => - val tr: BaseTranslator = langs(langObj) + val tr: AbstractTranslator with TypeDetector = langs(langObj) expOut.get(langObj) match { case Some(expResult) => tr.detectType(e) should be(expType) - tr.translate(e) should be(expResult) + val actResult1 = tr.translate(e) + val actResult2 = langObj match { + case GoCompiler => goOutput.result + actResult1 + case _ => actResult1 + } + actResult2 should be(expResult) case None => fail("no expected result") } @@ -469,10 +621,10 @@ class TranslatorSpec extends FunSuite { abstract class FakeTypeProvider extends TypeProvider { val nowClass = ClassSpec.opaquePlaceholder(List("top_class")) - override def resolveEnum(enumName: String) = + override def resolveEnum(inType: Ast.typeId, enumName: String) = throw new NotImplementedError - override def resolveType(typeName: String): DataType = + override def resolveType(typeName: Ast.typeId): DataType = throw new NotImplementedError override def isLazy(attrName: String): Boolean = false @@ -485,30 +637,57 @@ class TranslatorSpec extends FunSuite { override def determineType(inClass: ClassSpec, name: String): DataType = t } + /** + * Emulates the following system of types: + * + *
+    *   meta:
+    *     id: top_class
+    *   types:
+    *     block:
+    *       seq:
+    *         - id: bar
+    *           type: str
+    *         - id: inner
+    *           type: innerblock
+    *       types:
+    *         innerblock:
+    *           instances:
+    *             baz:
+    *               value: 123
+    * 
+ */ case object FooBarProvider extends FakeTypeProvider { override def determineType(name: String): DataType = { name match { - case "foo" => userType("block") + case "foo" => userType(List("top_class", "block")) } } override def determineType(inClass: ClassSpec, name: String): DataType = { - (inClass.name, name) match { - case (List("block"), "bar") => CalcStrType - case (List("block"), "inner") => userType("innerblock") - case (List("innerblock"), "baz") => CalcIntType + (inClass.name.last, name) match { + case ("block", "bar") => CalcStrType + case ("block", "inner") => userType(List("top_class", "block", "innerblock")) + case ("innerblock", "baz") => CalcIntType } } - override def resolveType(typeName: String): DataType = { - typeName match { - case "top_class" | "block" | "innerblock" => userType(typeName) + override def resolveType(typeName: Ast.typeId): DataType = { + typeName.names match { + case Seq("top_class") => + userType(List("top_class")) + case Seq("block") | + Seq("top_class", "block") => + userType(List("top_class", "block")) + case Seq("innerblock") | + Seq("block", "innerblock") | + Seq("top_class", "block", "innerblock") => + userType(List("top_class", "block", "innerblock")) } } } - def userType(name: String) = { - val lname = List(name) + def userType(lname: List[String]) = { val cs = ClassSpec.opaquePlaceholder(lname) val ut = UserTypeInstream(lname, None) ut.classSpec = Some(cs) diff --git a/jvm/src/test/scala/io/kaitai/struct/translators/TypeDetector$Test.scala b/jvm/src/test/scala/io/kaitai/struct/translators/TypeDetector$Test.scala new file mode 100644 index 000000000..1e81a236e --- /dev/null +++ b/jvm/src/test/scala/io/kaitai/struct/translators/TypeDetector$Test.scala @@ -0,0 +1,16 @@ +package io.kaitai.struct.translators + +import io.kaitai.struct.datatype.DataType._ +import org.scalatest.FunSpec +import org.scalatest.Matchers._ + +class TypeDetector$Test extends FunSpec { + describe("TypeDetector") { + it("combines ints properly") { + val ut1 = CalcUserType(List("foo"), None) + val ut2 = CalcUserType(List("bar"), None) + + TypeDetector.combineTypes(ut1, ut2) should be(CalcKaitaiStructType) + } + } +} diff --git a/lib_bintray.sh b/lib_bintray.sh new file mode 100644 index 000000000..6ccc6faee --- /dev/null +++ b/lib_bintray.sh @@ -0,0 +1,84 @@ +# Shell library to handle uploads & publishing of artifacts to Bintray + +# All functions get their settings from global variables: +# +# * Authentication: +# * BINTRAY_USER +# * BINTRAY_API_KEY - to be passed as secret env variable# +# * Package ID / upload coordinates: +# * BINTRAY_ACCOUNT +# * BINTRAY_REPO +# * BINTRAY_PACKAGE +# * BINTRAY_VERSION +# * Debian-specific settings: +# * BINTRAY_DEB_DISTRIBUTION +# * BINTRAY_DEB_ARCH +# * BINTRAY_DEB_COMPONENT +# * Debug options: +# * BINTRAY_CURL_ARGS - set to `-vv` for verbose output + +## +# Creates version for a package at Bintray +bintray_create_version() +{ + echo "bintray_create_version(repo=${BINTRAY_REPO}, package=${BINTRAY_PACKAGE}, version=${BINTRAY_VERSION})" + + curl $BINTRAY_CURL_ARGS -f \ + "-u$BINTRAY_USER:$BINTRAY_API_KEY" \ + -H "Content-Type: application/json" \ + -X POST "https://api.bintray.com/packages/${BINTRAY_ACCOUNT}/${BINTRAY_REPO}/${BINTRAY_PACKAGE}/versions" \ + --data "{ \"name\": \"$BINTRAY_VERSION\", \"release_notes\": \"auto\", \"released\": \"\" }" +# --data "{ \"name\": \"$version\", \"release_notes\": \"auto\", \"release_url\": \"$BASE_DESC/$RPM_NAME\", \"released\": \"\" }" +} + +## +# Uploads generic file to Bintray. +# +# Input: +# $1 = filename to upload +bintray_upload_generic() +{ + local filename="$1" + + echo "bintray_upload_generic(repo=${BINTRAY_REPO}, package=${BINTRAY_PACKAGE}, version=${BINTRAY_VERSION}, filename=${filename})" + + curl $BINTRAY_CURL_ARGS -f \ + -T "$filename" \ + "-u$BINTRAY_USER:$BINTRAY_API_KEY" \ + -H "X-Bintray-Package: $BINTRAY_PACKAGE" \ + -H "X-Bintray-Version: $BINTRAY_VERSION" \ + https://api.bintray.com/content/$BINTRAY_ACCOUNT/$BINTRAY_REPO/ +} + +## +# Uploads deb package to Bintray. +# +# Input: +# $1 = filename to upload +bintray_upload_deb() +{ + local filename="$1" + + echo "bintray_upload_deb(repo=${BINTRAY_REPO}, package=${BINTRAY_PACKAGE}, version=${BINTRAY_VERSION}, filename=${filename})" + + curl $BINTRAY_CURL_ARGS -f \ + -T "$filename" \ + "-u$BINTRAY_USER:$BINTRAY_API_KEY" \ + -H "X-Bintray-Package: $BINTRAY_PACKAGE" \ + -H "X-Bintray-Version: $BINTRAY_VERSION" \ + -H "X-Bintray-Debian-Distribution: $BINTRAY_DEB_DISTRIBUTION" \ + -H "X-Bintray-Debian-Architecture: $BINTRAY_DEB_ARCH" \ + -H "X-Bintray-Debian-Component: $BINTRAY_DEB_COMPONENT" \ + https://api.bintray.com/content/$BINTRAY_ACCOUNT/$BINTRAY_REPO/ +} + +bintray_publish_version() +{ + echo "bintray_publish_version(repo=${BINTRAY_REPO}, package=${BINTRAY_PACKAGE}, version=${BINTRAY_VERSION})" + + curl $BINTRAY_CURL_ARGS -f \ + "-u$BINTRAY_USER:$BINTRAY_API_KEY" \ + -H "Content-Type: application/json" \ + -X POST "https://api.bintray.com/content/$BINTRAY_ACCOUNT/$BINTRAY_REPO/$BINTRAY_PACKAGE/$BINTRAY_VERSION/publish" \ + --data "{ \"discard\": \"false\" }" +} diff --git a/project/build.properties b/project/build.properties index e98ac44be..210243d0d 100644 --- a/project/build.properties +++ b/project/build.properties @@ -1 +1 @@ -sbt.version = 1.1.0-RC4 +sbt.version = 1.1.1 diff --git a/project/plugins.sbt b/project/plugins.sbt index 322ef8094..5973f265a 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -1,5 +1,5 @@ logLevel := Level.Warn -addSbtPlugin("com.typesafe.sbt" % "sbt-native-packager" % "1.3.2") +addSbtPlugin("com.typesafe.sbt" % "sbt-native-packager" % "1.3.12") addSbtPlugin("com.eed3si9n" % "sbt-buildinfo" % "0.7.0") addSbtPlugin("org.scala-js" % "sbt-scalajs" % "0.6.21") diff --git a/publish_deb_to_bintray.sh b/publish_deb_to_bintray.sh new file mode 100755 index 000000000..28a557ab6 --- /dev/null +++ b/publish_deb_to_bintray.sh @@ -0,0 +1,21 @@ +#!/bin/sh -ef + +. ./lib_bintray.sh + +# Config +BINTRAY_USER=greycat +BINTRAY_ACCOUNT=kaitai-io +BINTRAY_REPO=debian_unstable +BINTRAY_PACKAGE=kaitai-struct-compiler +BINTRAY_VERSION="$KAITAI_STRUCT_VERSION" +# BINTRAY_API_KEY comes from encrypted variables from web UI + +BINTRAY_DEB_DISTRIBUTION=jessie +BINTRAY_DEB_ARCH=all +BINTRAY_DEB_COMPONENT=main + +#BINTRAY_CURL_ARGS=-v + +bintray_create_version +bintray_upload_deb "jvm/target/kaitai-struct-compiler_${KAITAI_STRUCT_VERSION}_all.deb" +bintray_publish_version diff --git a/publish_zip_to_bintray.sh b/publish_zip_to_bintray.sh new file mode 100755 index 000000000..a21705393 --- /dev/null +++ b/publish_zip_to_bintray.sh @@ -0,0 +1,17 @@ +#!/bin/sh -ef + +. ./lib_bintray.sh + +# Config +BINTRAY_USER=greycat +BINTRAY_ACCOUNT=kaitai-io +BINTRAY_REPO=universal_unstable +BINTRAY_PACKAGE=kaitai-struct-compiler +BINTRAY_VERSION="$KAITAI_STRUCT_VERSION" +# BINTRAY_API_KEY comes from encrypted variables from web UI + +#BINTRAY_CURL_ARGS=-v + +bintray_create_version +bintray_upload_generic "jvm/target/universal/kaitai-struct-compiler-${KAITAI_STRUCT_VERSION}.zip" +bintray_publish_version diff --git a/shared/src/main/scala/io/kaitai/struct/ClassCompiler.scala b/shared/src/main/scala/io/kaitai/struct/ClassCompiler.scala index 660ba8522..b05eff43c 100644 --- a/shared/src/main/scala/io/kaitai/struct/ClassCompiler.scala +++ b/shared/src/main/scala/io/kaitai/struct/ClassCompiler.scala @@ -4,10 +4,8 @@ import io.kaitai.struct.CompileLog.FileSuccess import io.kaitai.struct.datatype.DataType._ import io.kaitai.struct.datatype._ import io.kaitai.struct.exprlang.Ast -import io.kaitai.struct.format._ -import io.kaitai.struct.languages.components.{LanguageCompiler, LanguageCompilerStatic} - -import scala.collection.mutable.ListBuffer +import io.kaitai.struct.format.{AttrSpec, _} +import io.kaitai.struct.languages.components.{ExtraAttrs, LanguageCompiler, LanguageCompilerStatic} class ClassCompiler( classSpecs: ClassSpecs, @@ -51,24 +49,20 @@ class ClassCompiler( if (lang.innerDocstrings) compileClassDoc(curClass) - val extraAttrs = ListBuffer[AttrSpec]() - extraAttrs += AttrSpec(List(), RootIdentifier, UserTypeInstream(topClassName, None)) - extraAttrs += AttrSpec(List(), ParentIdentifier, curClass.parentType) - // Forward declarations for recursive types curClass.types.foreach { case (typeName, _) => lang.classForwardDeclaration(List(typeName)) } if (lang.innerEnums) compileEnums(curClass) - if (lang.debug) + if (lang.config.readStoresPos) lang.debugClassSequence(curClass.seq) // Constructor compileConstructor(curClass) // Read method(s) - compileEagerRead(curClass.seq, extraAttrs, curClass.meta.endian) + compileEagerRead(curClass.seq, curClass.meta.endian) // Destructor compileDestructor(curClass) @@ -80,10 +74,17 @@ class ClassCompiler( provider.nowClass = curClass } - compileInstances(curClass, extraAttrs) + compileInstances(curClass) // Attributes declarations and readers - val allAttrs: List[MemberSpec] = curClass.seq ++ curClass.params ++ extraAttrs + val allAttrs: List[MemberSpec] = + curClass.seq ++ + curClass.params ++ + List( + AttrSpec(List(), RootIdentifier, CalcUserType(topClassName, None)), + AttrSpec(List(), ParentIdentifier, curClass.parentType) + ) ++ + ExtraAttrs.forClassSpec(curClass, lang) compileAttrDeclarations(allAttrs) compileAttrReaders(allAttrs) @@ -113,17 +114,48 @@ class ClassCompiler( curClass.meta.endian.contains(InheritedEndian), curClass.params ) + compileInit(curClass) curClass.instances.foreach { case (instName, _) => lang.instanceClear(instName) } - if (!lang.debug) + if (lang.config.autoRead) lang.runRead() lang.classConstructorFooter } /** - * Compiles destructor for a given class. It should clean up everything + * Compile initialization of class members for a given type. Typically + * this is only required for languages which both: + * + * * don't perform auto-initialization of object with some default + * values (like 0s) on object creation, + * * require these members to be initialized because any other + * procedures with object (e.g. destruction) will require that + * + * Currently, this is only applicable to C++ without smart pointers, + * as destructors we'll generate will rely on pointers being set to + * null. + * @param curClass current type to generate code for + */ + def compileInit(curClass: ClassSpec) = { + curClass.seq.foreach((attr) => compileAttrInit(attr)) + curClass.instances.foreach { case (_, instSpec) => + instSpec match { + case pis: ParseInstanceSpec => compileAttrInit(pis) + case _: ValueInstanceSpec => // ignore for now + } + } + } + + def compileAttrInit(originalAttr: AttrLikeSpec): Unit = { + val extraAttrs = ExtraAttrs.forAttr(originalAttr, lang) + val allAttrs = List(originalAttr) ++ extraAttrs + allAttrs.foreach((attr) => lang.attrInit(attr)) + } + + /** + * Compiles destructor for a given type. It should clean up everything * (i.e. every applicable allocated seq / instance attribute variables, and * any extra attribute variables, if they were used). - * @param curClass current class to generate code for + * @param curClass current type to generate code for */ def compileDestructor(curClass: ClassSpec) = { lang.classDestructorHeader(curClass.name, curClass.parentType, topClassName) @@ -137,6 +169,11 @@ class ClassCompiler( lang.classDestructorFooter } + /** + * Iterates over a given list of attributes and generates attribute + * declarations for each of them. + * @param attrs attribute list to traverse + */ def compileAttrDeclarations(attrs: List[MemberSpec]): Unit = { attrs.foreach { (attr) => val isNullable = if (lang.switchBytesOnlyAsRaw) { @@ -153,7 +190,7 @@ class ClassCompiler( * readers (AKA getters) for each of them. * @param attrs attribute list to traverse */ - def compileAttrReaders(attrs: List[MemberSpec]): Unit = + def compileAttrReaders(attrs: List[MemberSpec]): Unit = { attrs.foreach { (attr) => // FIXME: Python should have some form of attribute docs too if (!attr.doc.isEmpty && !lang.innerDocstrings) @@ -165,31 +202,50 @@ class ClassCompiler( } lang.attributeReader(attr.id, attr.dataTypeComposite, isNullable) } + } - def compileEagerRead(seq: List[AttrSpec], extraAttrs: ListBuffer[AttrSpec], endian: Option[Endianness]): Unit = { + /** + * Compiles everything related to "eager reading" for a given list of + * sequence attributes and endianness. Depending on endianness: + * + * * For types known to have fixed endianness, we do just "_read" method. + * * For types with ambiguous endianness, we'll do `_read` + "_read_le" + + * "_read_be" methods. If endianness needs to be calculated, we'll perform + * that calculation in "_read". If it's inherited, then we'll just make + * decision based on that inherited setting. + * + * @param seq list of sequence attributes + * @param endian endianness setting + */ + def compileEagerRead(seq: List[AttrSpec], endian: Option[Endianness]): Unit = { endian match { case None | Some(_: FixedEndian) => - compileSeqProc(seq, extraAttrs, None) + compileSeqProc(seq, None) case Some(ce: CalcEndian) => lang.readHeader(None, false) compileCalcEndian(ce) lang.runReadCalc() lang.readFooter() - compileSeqProc(seq, extraAttrs, Some(LittleEndian)) - compileSeqProc(seq, extraAttrs, Some(BigEndian)) + compileSeqProc(seq, Some(LittleEndian)) + compileSeqProc(seq, Some(BigEndian)) case Some(InheritedEndian) => lang.readHeader(None, false) lang.runReadCalc() lang.readFooter() - compileSeqProc(seq, extraAttrs, Some(LittleEndian)) - compileSeqProc(seq, extraAttrs, Some(BigEndian)) + compileSeqProc(seq, Some(LittleEndian)) + compileSeqProc(seq, Some(BigEndian)) } } val IS_LE_ID = SpecialIdentifier("_is_le") + /** + * Compiles endianness calculation procedure and stores result in a special + * attribute [[IS_LE_ID]]. Typically occurs as part of "_read" method. + * @param ce calculated endianness specification + */ def compileCalcEndian(ce: CalcEndian): Unit = { def renderProc(result: FixedEndian): Unit = { val v = Ast.expr.Bool(result == LittleEndian) @@ -202,49 +258,51 @@ class ClassCompiler( /** * Compiles seq reading method (complete with header and footer). * @param seq sequence of attributes - * @param extraAttrs extra attributes to be allocated * @param defEndian default endianness */ - def compileSeqProc(seq: List[AttrSpec], extraAttrs: ListBuffer[AttrSpec], defEndian: Option[FixedEndian]) = { + def compileSeqProc(seq: List[AttrSpec], defEndian: Option[FixedEndian]) = { lang.readHeader(defEndian, seq.isEmpty) - compileSeq(seq, extraAttrs, defEndian) + compileSeq(seq, defEndian) lang.readFooter() } /** * Compiles seq reading method body (only reading statements). * @param seq sequence of attributes - * @param extraAttrs extra attributes to be allocated * @param defEndian default endianness */ - def compileSeq(seq: List[AttrSpec], extraAttrs: ListBuffer[AttrSpec], defEndian: Option[FixedEndian]) = { + def compileSeq(seq: List[AttrSpec], defEndian: Option[FixedEndian]) = { var wasUnaligned = false seq.foreach { (attr) => val nowUnaligned = isUnalignedBits(attr.dataType) if (wasUnaligned && !nowUnaligned) lang.alignToByte(lang.normalIO) - lang.attrParse(attr, attr.id, extraAttrs, defEndian) + lang.attrParse(attr, attr.id, defEndian) wasUnaligned = nowUnaligned } } + /** + * Compiles all enums specifications for a given type. + * @param curClass current type to generate code for + */ def compileEnums(curClass: ClassSpec): Unit = curClass.enums.foreach { case(_, enumColl) => compileEnum(curClass, enumColl) } /** * Compile subclasses for a given class. - * @param curClass current class to generate code for + * @param curClass current type to generate code for */ def compileSubclasses(curClass: ClassSpec): Unit = curClass.types.foreach { case (_, intClass) => compileClass(intClass) } - def compileInstances(curClass: ClassSpec, extraAttrs: ListBuffer[AttrSpec]) = { + def compileInstances(curClass: ClassSpec) = { curClass.instances.foreach { case (instName, instSpec) => - compileInstance(curClass.name, instName, instSpec, extraAttrs, curClass.meta.endian) + compileInstance(curClass.name, instName, instSpec, curClass.meta.endian) } } - def compileInstance(className: List[String], instName: InstanceIdentifier, instSpec: InstanceSpec, extraAttrs: ListBuffer[AttrSpec], endian: Option[Endianness]): Unit = { + def compileInstance(className: List[String], instName: InstanceIdentifier, instSpec: InstanceSpec, endian: Option[Endianness]): Unit = { // Determine datatype val dataType = instSpec.dataTypeComposite @@ -255,7 +313,7 @@ class ClassCompiler( lang.instanceHeader(className, instName, dataType, instSpec.isNullable) if (lang.innerDocstrings) compileInstanceDoc(instName, instSpec) - lang.instanceCheckCacheAndReturn(instName) + lang.instanceCheckCacheAndReturn(instName, dataType) instSpec match { case vi: ValueInstanceSpec => @@ -263,11 +321,11 @@ class ClassCompiler( lang.instanceCalculate(instName, dataType, vi.value) lang.attrParseIfFooter(vi.ifExpr) case i: ParseInstanceSpec => - lang.attrParse(i, instName, extraAttrs, endian) + lang.attrParse(i, instName, endian) } lang.instanceSetCalculated(instName) - lang.instanceReturn(instName) + lang.instanceReturn(instName, dataType) lang.instanceFooter } diff --git a/shared/src/main/scala/io/kaitai/struct/ClassTypeProvider.scala b/shared/src/main/scala/io/kaitai/struct/ClassTypeProvider.scala index 11eda9e07..04e8258da 100644 --- a/shared/src/main/scala/io/kaitai/struct/ClassTypeProvider.scala +++ b/shared/src/main/scala/io/kaitai/struct/ClassTypeProvider.scala @@ -2,6 +2,7 @@ package io.kaitai.struct import io.kaitai.struct.datatype.DataType import io.kaitai.struct.datatype.DataType._ +import io.kaitai.struct.exprlang.Ast import io.kaitai.struct.format._ import io.kaitai.struct.precompile.{EnumNotFoundError, FieldNotFoundError, TypeNotFoundError, TypeUndecidedError} import io.kaitai.struct.translators.TypeProvider @@ -62,13 +63,14 @@ class ClassTypeProvider(classSpecs: ClassSpecs, var topClass: ClassSpec) extends case GenericStructClassSpec => KaitaiStructType case cs: ClassSpec => - val ut = UserTypeInstream(cs.name, None) + val ut = CalcUserType(cs.name, None) ut.classSpec = Some(cs) ut } } - override def resolveEnum(enumName: String): EnumSpec = resolveEnum(nowClass, enumName) + override def resolveEnum(inType: Ast.typeId, enumName: String): EnumSpec = + resolveEnum(resolveClassSpec(inType), enumName) def resolveEnum(inClass: ClassSpec, enumName: String): EnumSpec = { inClass.enums.get(enumName) match { @@ -84,22 +86,42 @@ class ClassTypeProvider(classSpecs: ClassSpecs, var topClass: ClassSpec) extends } } - override def resolveType(typeName: String): DataType = resolveType(nowClass, typeName) + override def resolveType(typeName: Ast.typeId): DataType = + makeUserType(resolveClassSpec(typeName)) - def resolveType(inClass: ClassSpec, typeName: String): DataType = { + def resolveClassSpec(typeName: Ast.typeId): ClassSpec = + resolveClassSpec( + if (typeName.absolute) topClass else nowClass, + typeName.names + ) + + def resolveClassSpec(inClass: ClassSpec, typeName: Seq[String]): ClassSpec = { + if (typeName.isEmpty) + return inClass + + val headTypeName :: restTypesNames = typeName.toList + val nextClass = resolveClassSpec(inClass, headTypeName) + if (restTypesNames.isEmpty) { + nextClass + } else { + resolveClassSpec(nextClass, restTypesNames) + } + } + + def resolveClassSpec(inClass: ClassSpec, typeName: String): ClassSpec = { if (inClass.name.last == typeName) - return makeUserType(inClass) + return inClass inClass.types.get(typeName) match { case Some(spec) => - makeUserType(spec) + spec case None => // let's try upper levels of hierarchy inClass.upClass match { - case Some(upClass) => resolveType(upClass, typeName) + case Some(upClass) => resolveClassSpec(upClass, typeName) case None => classSpecs.get(typeName) match { - case Some(spec) => makeUserType(spec) + case Some(spec) => spec case None => throw new TypeNotFoundError(typeName, nowClass) } diff --git a/shared/src/main/scala/io/kaitai/struct/ConstructClassCompiler.scala b/shared/src/main/scala/io/kaitai/struct/ConstructClassCompiler.scala new file mode 100644 index 000000000..a4b6a3c53 --- /dev/null +++ b/shared/src/main/scala/io/kaitai/struct/ConstructClassCompiler.scala @@ -0,0 +1,223 @@ +package io.kaitai.struct + +import io.kaitai.struct.datatype.DataType._ +import io.kaitai.struct.datatype._ +import io.kaitai.struct.exprlang.Ast +import io.kaitai.struct.format._ +import io.kaitai.struct.languages.components.{LanguageCompiler, LanguageCompilerStatic} +import io.kaitai.struct.translators.ConstructTranslator + +class ConstructClassCompiler(classSpecs: ClassSpecs, topClass: ClassSpec) extends AbstractCompiler { + val out = new StringLanguageOutputWriter(indent) + val importList = new ImportList + + val provider = new ClassTypeProvider(classSpecs, topClass) + val translator = new ConstructTranslator(provider, importList) + + override def compile: CompileLog.SpecSuccess = { + out.puts("from construct import *") + out.puts("from construct.lib import *") + out.puts + + compileClass(topClass) + + out.puts(s"_schema = ${type2class(topClass)}") + + CompileLog.SpecSuccess( + "", + List(CompileLog.FileSuccess( + outFileName(topClass.nameAsStr), + out.result + )) + ) + } + + def compileClass(cs: ClassSpec): Unit = { + cs.types.foreach { case (_, typeSpec) => compileClass(typeSpec) } + + cs.enums.foreach { case (_, enumSpec) => compileEnum(enumSpec) } + + out.puts(s"${type2class(cs)} = Struct(") + out.inc + + provider.nowClass = cs + + cs.seq.foreach((seqAttr) => compileAttr(seqAttr)) + cs.instances.foreach { case (id, instSpec) => + instSpec match { + case vis: ValueInstanceSpec => + compileValueInstance(id, vis) + case pis: ParseInstanceSpec => + compileParseInstance(pis) + } + } + + out.dec + out.puts(")") + out.puts + } + + def compileAttr(attr: AttrLikeSpec): Unit = { + out.puts(s"'${idToStr(attr.id)}' / ${compileAttrBody(attr)},") + } + + def compileValueInstance(id: Identifier, vis: ValueInstanceSpec): Unit = { + val typeStr = s"Computed(lambda this: ${translator.translate(vis.value)})" + val typeStr2 = wrapWithIf(typeStr, vis.ifExpr) + out.puts(s"'${idToStr(id)}' / $typeStr2,") + } + + def compileParseInstance(attr: ParseInstanceSpec): Unit = { + attr.pos match { + case None => + compileAttr(attr) + case Some(pos) => + out.puts(s"'${idToStr(attr.id)}' / " + + s"Pointer(${translator.translate(pos)}, ${compileAttrBody(attr)}),") + } + } + + def compileAttrBody(attr: AttrLikeSpec): String = { + val typeStr1 = typeToStr(attr.dataType) + val typeStr2 = wrapWithRepeat(typeStr1, attr.cond.repeat, attr.dataType) + wrapWithIf(typeStr2, attr.cond.ifExpr) + } + + def wrapWithRepeat(typeStr: String, repeat: RepeatSpec, dataType: DataType) = repeat match { + case RepeatExpr(expr) => + s"Array(${translator.translate(expr)}, $typeStr)" + case RepeatUntil(expr) => + provider._currentIteratorType = Some(dataType) + s"RepeatUntil(lambda obj_, list_, this: ${translator.translate(expr)}, $typeStr)" + case RepeatEos => + s"GreedyRange($typeStr)" + case NoRepeat => + typeStr + } + + def wrapWithIf(typeStr: String, ifExpr: Option[Ast.expr]) = ifExpr match { + case Some(expr) => s"If(${translator.translate(expr)}, $typeStr)" + case None => typeStr + } + + def compileEnum(enumSpec: EnumSpec): Unit = { + out.puts(s"def ${enumToStr(enumSpec)}(subcon):") + out.inc + out.puts("return Enum(subcon,") + out.inc + enumSpec.sortedSeq.foreach { case (number, valueSpec) => + out.puts(s"${valueSpec.name}=$number,") + } + out.dec + out.puts(")") + out.dec + out.puts + } + + def idToStr(id: Identifier): String = { + id match { + case SpecialIdentifier(name) => name + case NamedIdentifier(name) => name + case NumberedIdentifier(idx) => s"_${NumberedIdentifier.TEMPLATE}$idx" + case InstanceIdentifier(name) => name + } + } + + def type2class(cs: ClassSpec) = cs.name.mkString("__") + + def enumToStr(enumSpec: EnumSpec) = enumSpec.name.mkString("__") + + def typeToStr(dataType: DataType): String = dataType match { + case fbt: FixedBytesType => + s"Const(${translator.doByteArrayLiteral(fbt.contents)})" + case Int1Type(signed) => + s"Int8${signToStr(signed)}b" + case IntMultiType(signed, width, endianOpt) => + s"Int${width.width * 8}${signToStr(signed)}${fixedEndianToStr(endianOpt.get)}" + case FloatMultiType(width, endianOpt) => + s"Float${width.width * 8}${fixedEndianToStr(endianOpt.get)}" + case BytesEosType(terminator, include, padRight, process) => + "GreedyBytes" + case blt: BytesLimitType => + attrBytesLimitType(blt) + case btt: BytesTerminatedType => + attrBytesTerminatedType(btt, "GreedyBytes") + case StrFromBytesType(bytes, encoding) => + bytes match { + case BytesEosType(terminator, include, padRight, process) => + s"GreedyString(encoding='$encoding')" + case blt: BytesLimitType => + attrBytesLimitType(blt, s"GreedyString(encoding='$encoding')") + case btt: BytesTerminatedType => + attrBytesTerminatedType(btt, s"GreedyString(encoding='$encoding')") + } + case ut: UserTypeInstream => + s"LazyBound(lambda: ${type2class(ut.classSpec.get)})" + case utb: UserTypeFromBytes => + utb.bytes match { + //case BytesEosType(terminator, include, padRight, process) => + case BytesLimitType(size, terminator, include, padRight, process) => + s"FixedSized(${translator.translate(size)}, LazyBound(lambda: ${type2class(utb.classSpec.get)}))" + //case BytesTerminatedType(terminator, include, consume, eosError, process) => + case _ => "???" + } + case et: EnumType => + s"${enumToStr(et.enumSpec.get)}(${typeToStr(et.basedOn)})" + case st: SwitchType => + attrSwitchType(st) + case _ => "???" + } + + def attrBytesLimitType(blt: BytesLimitType, subcon: String = "GreedyBytes"): String = { + val subcon2 = blt.terminator match { + case None => subcon + case Some(term) => + val termStr = "\\x%02X".format(term & 0xff) + s"NullTerminated($subcon, term=b'$termStr', include=${translator.doBoolLiteral(blt.include)})" + } + val subcon3 = blt.padRight match { + case None => subcon2 + case Some(padRight) => + val padStr = "\\x%02X".format(padRight & 0xff) + s"NullStripped($subcon2, pad=b'$padStr')" + } + s"FixedSized(${translator.translate(blt.size)}, $subcon3)" + } + + def attrBytesTerminatedType(btt: BytesTerminatedType, subcon: String): String = { + val termStr = "\\x%02X".format(btt.terminator & 0xff) + s"NullTerminated($subcon, " + + s"term=b'$termStr', " + + s"include=${translator.doBoolLiteral(btt.include)}, " + + s"consume=${translator.doBoolLiteral(btt.consume)})" + } + + def attrSwitchType(st: SwitchType): String = { + val cases = st.cases.filter { case (caseExpr, _) => + caseExpr != SwitchType.ELSE_CONST + }.map { case (caseExpr, caseType) => + s"${translator.translate(caseExpr)}: ${typeToStr(caseType)}, " + } + + val defaultSuffix = st.cases.get(SwitchType.ELSE_CONST).map((t) => + s", default=${typeToStr(t)}" + ).getOrElse("") + + s"Switch(${translator.translate(st.on)}, {${cases.mkString}}$defaultSuffix)" + } + + def signToStr(signed: Boolean) = if (signed) "s" else "u" + + def fixedEndianToStr(e: FixedEndian) = e match { + case LittleEndian => "l" + case BigEndian => "b" + } + + def indent: String = "\t" + def outFileName(topClassName: String): String = s"$topClassName.py" +} + +object ConstructClassCompiler extends LanguageCompilerStatic { + // FIXME: Unused, should be probably separated from LanguageCompilerStatic + override def getCompiler(tp: ClassTypeProvider, config: RuntimeConfig): LanguageCompiler = ??? +} diff --git a/shared/src/main/scala/io/kaitai/struct/DocClassCompiler.scala b/shared/src/main/scala/io/kaitai/struct/DocClassCompiler.scala new file mode 100644 index 000000000..685cd1c38 --- /dev/null +++ b/shared/src/main/scala/io/kaitai/struct/DocClassCompiler.scala @@ -0,0 +1,84 @@ +package io.kaitai.struct + +import io.kaitai.struct.format._ +import io.kaitai.struct.precompile.CalculateSeqSizes +import io.kaitai.struct.translators.RubyTranslator + +abstract class DocClassCompiler(classSpecs: ClassSpecs, topClass: ClassSpec) extends AbstractCompiler { + val provider = new ClassTypeProvider(classSpecs, topClass) + val translator = new RubyTranslator(provider) + + // TODO: move it into SingleOutputFile equivalent + val out = new StringLanguageOutputWriter(indent) + def outFileName(topClass: ClassSpec): String + def indent: String + // END move to SingleOutputFile + + def nowClass: ClassSpec = provider.nowClass + def nowClassName = provider.nowClass.name + + override def compile: CompileLog.SpecSuccess = { + fileHeader(topClass) + compileClass(topClass) + fileFooter(topClass) + + CompileLog.SpecSuccess( + "", + List(CompileLog.FileSuccess( + outFileName(topClass), + out.result + )) + ) + } + + def compileClass(curClass: ClassSpec): Unit = { + provider.nowClass = curClass + + classHeader(curClass) + + // Sequence + compileSeq(curClass) + + // Instances + curClass.instances.foreach { case (_, instSpec) => + instSpec match { + case pis: ParseInstanceSpec => + compileParseInstance(curClass, pis) + case vis: ValueInstanceSpec => + compileValueInstance(vis) + } + } + + // Enums + curClass.enums.foreach { case(enumName, enumColl) => compileEnum(enumName, enumColl) } + + // Recursive types + curClass.types.foreach { case (_, intClass) => compileClass(intClass) } + + classFooter(curClass) + } + + def compileSeq(curClass: ClassSpec): Unit = { + seqHeader(curClass) + + CalculateSeqSizes.forEachSeqAttr(curClass, (attr, seqPos, sizeElement, sizeContainer) => { + compileSeqAttr(curClass, attr, seqPos, sizeElement, sizeContainer) + }) + + seqFooter(curClass) + } + + def fileHeader(topClass: ClassSpec): Unit + def fileFooter(topClass: ClassSpec): Unit + + def classHeader(classSpec: ClassSpec): Unit + def classFooter(classSpec: ClassSpec): Unit + + def seqHeader(classSpec: ClassSpec): Unit + def seqFooter(classSpec: ClassSpec): Unit + + def compileSeqAttr(classSpec: ClassSpec, attr: AttrSpec, seqPos: Option[Int], sizeElement: Sized, sizeContainer: Sized): Unit + def compileParseInstance(classSpec: ClassSpec, inst: ParseInstanceSpec): Unit + def compileValueInstance(vis: ValueInstanceSpec): Unit + def compileEnum(enumName: String, enumColl: EnumSpec): Unit +} diff --git a/shared/src/main/scala/io/kaitai/struct/GoClassCompiler.scala b/shared/src/main/scala/io/kaitai/struct/GoClassCompiler.scala index 0e57e55bb..098881829 100644 --- a/shared/src/main/scala/io/kaitai/struct/GoClassCompiler.scala +++ b/shared/src/main/scala/io/kaitai/struct/GoClassCompiler.scala @@ -17,16 +17,18 @@ class GoClassCompiler( override def compileClass(curClass: ClassSpec): Unit = { provider.nowClass = curClass - val extraAttrs = ListBuffer[AttrSpec]() - extraAttrs += AttrSpec(List(), IoIdentifier, KaitaiStreamType) - extraAttrs += AttrSpec(List(), RootIdentifier, UserTypeInstream(topClassName, None)) - extraAttrs += AttrSpec(List(), ParentIdentifier, curClass.parentType) - - extraAttrs ++= getExtraAttrs(curClass) + val extraAttrs = List( + AttrSpec(List(), IoIdentifier, KaitaiStreamType), + AttrSpec(List(), RootIdentifier, UserTypeInstream(topClassName, None)), + AttrSpec(List(), ParentIdentifier, curClass.parentType) + ) ++ ExtraAttrs.forClassSpec(curClass, lang) if (!curClass.doc.isEmpty) lang.classDoc(curClass.name, curClass.doc) + // Enums declaration defines types, so they need to go first + compileEnums(curClass) + // Basic struct declaration lang.classHeader(curClass.name) compileAttrDeclarations(curClass.seq ++ extraAttrs) @@ -36,19 +38,17 @@ class GoClassCompiler( lang.classFooter(curClass.name) // Constructor = Read() function - compileReadFunction(curClass, extraAttrs) + compileReadFunction(curClass) - compileInstances(curClass, extraAttrs) + compileInstances(curClass) compileAttrReaders(curClass.seq ++ extraAttrs) - compileEnums(curClass) - // Recursive types compileSubclasses(curClass) } - def compileReadFunction(curClass: ClassSpec, extraAttrs: ListBuffer[AttrSpec]) = { + def compileReadFunction(curClass: ClassSpec) = { lang.classConstructorHeader( curClass.name, curClass.parentType, @@ -61,11 +61,11 @@ class GoClassCompiler( case Some(fe: FixedEndian) => Some(fe) case _ => None } - compileSeq(curClass.seq, extraAttrs, defEndian) + compileSeq(curClass.seq, defEndian) lang.classConstructorFooter } - override def compileInstance(className: List[String], instName: InstanceIdentifier, instSpec: InstanceSpec, extraAttrs: ListBuffer[AttrSpec], endian: Option[Endianness]): Unit = { + override def compileInstance(className: List[String], instName: InstanceIdentifier, instSpec: InstanceSpec, endian: Option[Endianness]): Unit = { // FIXME: support calculated endianness // Determine datatype @@ -74,7 +74,7 @@ class GoClassCompiler( if (!instSpec.doc.isEmpty) lang.attributeDoc(instName, instSpec.doc) lang.instanceHeader(className, instName, dataType, instSpec.isNullable) - lang.instanceCheckCacheAndReturn(instName) + lang.instanceCheckCacheAndReturn(instName, dataType) instSpec match { case vi: ValueInstanceSpec => @@ -82,17 +82,11 @@ class GoClassCompiler( lang.instanceCalculate(instName, dataType, vi.value) lang.attrParseIfFooter(vi.ifExpr) case i: ParseInstanceSpec => - lang.attrParse(i, instName, extraAttrs, None) // FIXME + lang.attrParse(i, instName, None) // FIXME } lang.instanceSetCalculated(instName) - lang.instanceReturn(instName) + lang.instanceReturn(instName, dataType) lang.instanceFooter } - - def getExtraAttrs(curClass: ClassSpec): List[AttrSpec] = { - curClass.seq.foldLeft(List[AttrSpec]())( - (attrs, attr) => attrs ++ ExtraAttrs.forAttr(attr) - ) - } } diff --git a/shared/src/main/scala/io/kaitai/struct/GraphvizClassCompiler.scala b/shared/src/main/scala/io/kaitai/struct/GraphvizClassCompiler.scala index 03d8ef1ce..2769d08de 100644 --- a/shared/src/main/scala/io/kaitai/struct/GraphvizClassCompiler.scala +++ b/shared/src/main/scala/io/kaitai/struct/GraphvizClassCompiler.scala @@ -3,7 +3,6 @@ package io.kaitai.struct import io.kaitai.struct.datatype.DataType import io.kaitai.struct.datatype.DataType._ import io.kaitai.struct.exprlang.Ast -import io.kaitai.struct.exprlang.Ast.expr import io.kaitai.struct.format._ import io.kaitai.struct.languages.components.{LanguageCompiler, LanguageCompilerStatic} import io.kaitai.struct.precompile.CalculateSeqSizes @@ -100,18 +99,6 @@ class GraphvizClassCompiler(classSpecs: ClassSpecs, topClass: ClassSpec) extends tableEnd } - def seqPosToStr(seqPos: Option[Int]): Option[String] = { - seqPos.map { (pos) => - val posByte = pos / 8 - val posBit = pos % 8 - if (posBit == 0) { - s"$posByte" - } else { - s"$posByte:$posBit" - } - } - } - def compileParseInstance(className: List[String], id: InstanceIdentifier, inst: ParseInstanceSpec): Unit = { val name = id.name val lastInstPos = inst.pos @@ -272,46 +259,46 @@ class GraphvizClassCompiler(classSpecs: ClassSpecs, topClass: ClassSpec) extends } } - def expressionSize(ex: expr, attrName: String): String = { + def expressionSize(ex: Ast.expr, attrName: String): String = { expression(ex, getGraphvizNode(nowClassName, nowClass, attrName) + s":${attrName}_size", STYLE_EDGE_SIZE) } - def expressionPos(ex: expr, attrName: String): String = { + def expressionPos(ex: Ast.expr, attrName: String): String = { expression(ex, getGraphvizNode(nowClassName, nowClass, attrName) + s":${attrName}_pos", STYLE_EDGE_POS) } - def expressionType(ex: expr, attrName: String): String = { + def expressionType(ex: Ast.expr, attrName: String): String = { expression(ex, getGraphvizNode(nowClassName, nowClass, attrName) + s":${attrName}_type", STYLE_EDGE_VALUE) } - def expression(e: expr, portName: String, style: String): String = { + def expression(e: Ast.expr, portName: String, style: String): String = { affectedVars(e).foreach((v) => links += ((v, portName, style)) ) htmlEscape(translator.translate(e)) } - def affectedVars(e: expr): List[String] = { + def affectedVars(e: Ast.expr): List[String] = { e match { - case expr.BoolOp(op, values) => + case Ast.expr.BoolOp(op, values) => values.flatMap(affectedVars).toList - case expr.BinOp(left, op, right) => + case Ast.expr.BinOp(left, op, right) => affectedVars(left) ++ affectedVars(right) - case expr.UnaryOp(op, operand) => + case Ast.expr.UnaryOp(op, operand) => affectedVars(operand) - case expr.IfExp(condition, ifTrue, ifFalse) => + case Ast.expr.IfExp(condition, ifTrue, ifFalse) => affectedVars(condition) ++ affectedVars(ifTrue) ++ affectedVars(ifFalse) // case expr.Dict(keys, values) => - case expr.Compare(left, ops, right) => + case Ast.expr.Compare(left, ops, right) => affectedVars(left) ++ affectedVars(right) // case expr.Call(func, args) => - case expr.IntNum(_) | expr.FloatNum(_) | expr.Str(_) | expr.Bool(_) => + case Ast.expr.IntNum(_) | Ast.expr.FloatNum(_) | Ast.expr.Str(_) | Ast.expr.Bool(_) => List() - case expr.EnumByLabel(enumName, label) => + case _: Ast.expr.EnumByLabel => List() - case expr.EnumById(enumName, id) => + case Ast.expr.EnumById(_, id, _) => affectedVars(id) - case expr.Attribute(value, attr) => + case Ast.expr.Attribute(value, attr) => val targetClass = translator.detectType(value) targetClass match { case t: UserType => @@ -328,17 +315,20 @@ class GraphvizClassCompiler(classSpecs: ClassSpecs, topClass: ClassSpec) extends case _ => affectedVars(value) } - case expr.Subscript(value, idx) => + case Ast.expr.Subscript(value, idx) => affectedVars(value) ++ affectedVars(idx) case SwitchType.ELSE_CONST => // "_" is a special const for List() - case expr.Name(Ast.identifier("_io")) => - // "_io" is a special const too - List() - case expr.Name(id) => - List(resolveLocalNode(id.name)) - case expr.List(elts) => + case Ast.expr.Name(id) => + if (id.name.charAt(0) == '_') { + // other special consts like "_io", "_index", etc + List() + } else { + // this must be local name, resolve it + List(resolveLocalNode(id.name)) + } + case Ast.expr.List(elts) => elts.flatMap(affectedVars).toList } } @@ -444,4 +434,22 @@ object GraphvizClassCompiler extends LanguageCompilerStatic { def htmlEscape(s: String): String = { s.replaceAll("&", "&").replaceAll("<", "<").replaceAll(">", ">").replaceAll("\"", """) } + + /** + * Converts bit-level position into byte/bit human-readable combination. + * @param seqPos optional number of bits + * @return fractional human-readable string which displays "bytes:bits", + * akin to "minutes:seconds" time display + */ + def seqPosToStr(seqPos: Option[Int]): Option[String] = { + seqPos.map { (pos) => + val posByte = pos / 8 + val posBit = pos % 8 + if (posBit == 0) { + s"$posByte" + } else { + s"$posByte:$posBit" + } + } + } } diff --git a/shared/src/main/scala/io/kaitai/struct/HtmlClassCompiler.scala b/shared/src/main/scala/io/kaitai/struct/HtmlClassCompiler.scala new file mode 100644 index 000000000..9d9d74c1a --- /dev/null +++ b/shared/src/main/scala/io/kaitai/struct/HtmlClassCompiler.scala @@ -0,0 +1,152 @@ +package io.kaitai.struct + +import io.kaitai.struct.datatype.DataType +import io.kaitai.struct.datatype.DataType.UserType +import io.kaitai.struct.exprlang.Ast +import io.kaitai.struct.format._ +import io.kaitai.struct.languages.components.{LanguageCompiler, LanguageCompilerStatic} + +class HtmlClassCompiler(classSpecs: ClassSpecs, topClass: ClassSpec) extends DocClassCompiler(classSpecs, topClass) { + import HtmlClassCompiler._ + + override def outFileName(topClass: ClassSpec): String = s"${topClass.nameAsStr}.html" + + override def indent: String = "" + + override def fileHeader(topClass: ClassSpec): Unit = { + out.puts( + s""" + | + | + | + | + | + | + | + | + | + | + | ${type2str(topClass.name.last)} format specification + | + | +
+ |

${type2str(topClass.name.last)} format specification

+ | + """.stripMargin) + + // TODO: parse & output meta/title, meta/file-extensions, etc + } + + override def fileFooter(topClass: ClassSpec): Unit = { + out.puts( + """ + |
+ | + | + | + | + | + | + | + """.stripMargin) + } + + override def classHeader(classSpec: ClassSpec): Unit = { + out.puts(s"") + out.puts(s"<$headerByIndent>Type: ${type2str(classSpec.name.last)}") + out.puts + + classSpec.doc.summary.foreach(summary => + out.puts(s"

$summary

") + ) + out.inc + } + + override def classFooter(classSpec: ClassSpec): Unit = { + out.dec + } + + override def seqHeader(classSpec: ClassSpec): Unit = { + out.puts("") + out.puts("") + } + + override def seqFooter(classSpec: ClassSpec): Unit = { + out.puts("
OffsetSizeIDTypeNote
") + } + + override def compileSeqAttr(classSpec: ClassSpec, attr: AttrSpec, seqPos: Option[Int], sizeElement: Sized, sizeContainer: Sized): Unit = { + out.puts("") + out.puts(s"${GraphvizClassCompiler.seqPosToStr(seqPos).getOrElse("???")}") + out.puts(s"...") + out.puts(s"${attr.id.humanReadable}") + out.puts(s"${kaitaiType2NativeType(attr.dataType)}") + out.puts(s"${attr.doc.summary.getOrElse("")}") + out.puts("") + } + + override def compileParseInstance(classSpec: ClassSpec, inst: ParseInstanceSpec): Unit = { + out.puts(s"

Parse instance: ${inst.id.humanReadable}

") + out.puts("") + out.puts("") + out.puts(s"") + out.puts(s"") + out.puts(s"") + out.puts(s"") + out.puts(s"") + out.puts("") + out.puts("
${expression(inst.pos)}...${inst.id.humanReadable}${kaitaiType2NativeType(inst.dataType)}${inst.doc.summary.getOrElse("")}
") + } + + override def compileValueInstance(vis: ValueInstanceSpec): Unit = { + out.puts(s"value instance: ${vis}") + } + + override def compileEnum(enumName: String, enumColl: EnumSpec): Unit = { + out.puts(s"") + out.puts(s"<$headerByIndent>Enum: $enumName") + out.puts + + out.puts("") + out.puts("") + out.puts("") + out.puts("") + + enumColl.sortedSeq.foreach { case (id, value) => + out.puts("") + out.puts(s"") + out.puts("") + } + + out.puts("
IDNameNote
$id${value.name}${value.doc.summary.getOrElse("")}
") + } + + def headerByIndent: String = s"h${out.indentLevel + 1}" + + def expression(exOpt: Option[Ast.expr]): String = { + exOpt match { + case Some(ex) => translator.translate(ex) + case None => "" + } + } +} + +object HtmlClassCompiler extends LanguageCompilerStatic { + // FIXME: Unused, should be probably separated from LanguageCompilerStatic + override def getCompiler( + tp: ClassTypeProvider, + config: RuntimeConfig + ): LanguageCompiler = ??? + + def type2str(name: String): String = Utils.upperCamelCase(name) + + def classSpec2Anchor(spec: ClassSpec): String = "type-" + spec.name.mkString("-") + + def enumSpec2Anchor(spec: EnumSpec): String = "enum-" + spec.name.mkString("-") + + def kaitaiType2NativeType(attrType: DataType): String = attrType match { + case ut: UserType => + "" + type2str(ut.name.last) + "" + case _ => GraphvizClassCompiler.dataTypeName(attrType) + } +} diff --git a/shared/src/main/scala/io/kaitai/struct/Log.scala b/shared/src/main/scala/io/kaitai/struct/Log.scala index 1aeb161f6..aef4e1b93 100644 --- a/shared/src/main/scala/io/kaitai/struct/Log.scala +++ b/shared/src/main/scala/io/kaitai/struct/Log.scala @@ -27,7 +27,8 @@ object Log { "type_resolve", "type_valid", "seq_sizes", - "import" + "import", + "enum_resolve" ) var fileOps: Logger = NullLogger @@ -37,6 +38,7 @@ object Log { var typeValid: Logger = NullLogger var seqSizes: Logger = NullLogger var importOps: Logger = NullLogger + var enumResolve: Logger = NullLogger def initFromVerboseFlag(subsystems: Seq[String]): Unit = { fileOps = NullLogger @@ -50,6 +52,7 @@ object Log { case "type_valid" => typeValid = ConsoleLogger case "seq_sizes" => seqSizes = ConsoleLogger case "import" => importOps = ConsoleLogger + case "enum_resolve" => enumResolve = ConsoleLogger } } } diff --git a/shared/src/main/scala/io/kaitai/struct/Main.scala b/shared/src/main/scala/io/kaitai/struct/Main.scala index 420775f38..15f00b34f 100644 --- a/shared/src/main/scala/io/kaitai/struct/Main.scala +++ b/shared/src/main/scala/io/kaitai/struct/Main.scala @@ -1,7 +1,7 @@ package io.kaitai.struct import io.kaitai.struct.format.{ClassSpec, ClassSpecs, GenericStructClassSpec} -import io.kaitai.struct.languages.GoCompiler +import io.kaitai.struct.languages.{GoCompiler, RustCompiler} import io.kaitai.struct.languages.components.LanguageCompilerStatic import io.kaitai.struct.precompile._ @@ -59,6 +59,12 @@ object Main { new GraphvizClassCompiler(specs, spec) case GoCompiler => new GoClassCompiler(specs, spec, config) + case RustCompiler => + new RustClassCompiler(specs, spec, config) + case ConstructClassCompiler => + new ConstructClassCompiler(specs, spec) + case HtmlClassCompiler => + new HtmlClassCompiler(specs, spec) case _ => new ClassCompiler(specs, spec, config, lang) } @@ -74,7 +80,7 @@ object Main { */ private def updateConfig(config: RuntimeConfig, topClass: ClassSpec): RuntimeConfig = { if (topClass.meta.forceDebug) { - config.copy(debug = true) + config.copy(autoRead = false, readStoresPos = true) } else { config } diff --git a/shared/src/main/scala/io/kaitai/struct/RuntimeConfig.scala b/shared/src/main/scala/io/kaitai/struct/RuntimeConfig.scala index aa7a56abf..fe275701e 100644 --- a/shared/src/main/scala/io/kaitai/struct/RuntimeConfig.scala +++ b/shared/src/main/scala/io/kaitai/struct/RuntimeConfig.scala @@ -1,8 +1,72 @@ package io.kaitai.struct +/** + * C++-specific runtime configuration of the compiler. + * @param usePragmaOnce If true, use `#pragma once` in headers. If false (default), + * use `#ifndef`-`#define`-`#endif` guards. + * @param pointers Choose which style of pointers to use. + */ +case class CppRuntimeConfig( + namespace: List[String] = List(), + usePragmaOnce: Boolean = false, + pointers: CppRuntimeConfig.Pointers = CppRuntimeConfig.RawPointers +) { + /** + * Copies this C++ runtime config, applying all the default settings for + * C++98 target. + */ + def copyAsCpp98() = copy( + usePragmaOnce = false, + pointers = CppRuntimeConfig.RawPointers + ) + + /** + * Copies this C++ runtime config, applying all the default settings for + * C++11 target. + */ + def copyAsCpp11() = copy( + usePragmaOnce = true, + pointers = CppRuntimeConfig.UniqueAndRawPointers + ) +} + +object CppRuntimeConfig { + sealed trait Pointers + case object RawPointers extends Pointers + case object SharedPointers extends Pointers + case object UniqueAndRawPointers extends Pointers +} + +/** + * Runtime configuration of the compiler which controls certain aspects of + * code generation for target languages. + * @param autoRead If true, constructor (or equivalent) invocation would + * automatically run `_read` (or equivalent), thus allowing to + * run parsing just by constructing an object, passing a stream + * into it. If false, `_read` would be made public and it is + * expected to be invoked manually. + * @param readStoresPos If true, parser (`_read` or equivalent) will store + * positions of all the attributes relative to the stream; + * not required for production usage (as it is typically slow + * and memory-consuming), but it is crucial for visualizers, + * IDEs, etc, to be able to display data layout. + * @param opaqueTypes If true, invoking any unknown type will be treated as it was + * "opaque" type, i.e. an external KaitaiStruct-compatible type + * defined somewhere else. If false, it will be reported as + * precompile error. + * @param cppConfig C++-specific configuration + * @param goPackage Go package name + * @param javaPackage Java package name + * @param javaFromFileClass Java class to be invoked in `fromFile` helper methods + * @param dotNetNamespace .NET (C#) namespace + * @param phpNamespace PHP namespace + * @param pythonPackage Python package name + */ case class RuntimeConfig( - debug: Boolean = false, + autoRead: Boolean = true, + readStoresPos: Boolean = false, opaqueTypes: Boolean = false, + cppConfig: CppRuntimeConfig = CppRuntimeConfig(), goPackage: String = "", javaPackage: String = "", javaFromFileClass: String = "io.kaitai.struct.ByteBufferKaitaiStream", diff --git a/shared/src/main/scala/io/kaitai/struct/RustClassCompiler.scala b/shared/src/main/scala/io/kaitai/struct/RustClassCompiler.scala new file mode 100644 index 000000000..a148abc43 --- /dev/null +++ b/shared/src/main/scala/io/kaitai/struct/RustClassCompiler.scala @@ -0,0 +1,104 @@ +package io.kaitai.struct + +import io.kaitai.struct.datatype.DataType.{KaitaiStreamType, UserTypeInstream} +import io.kaitai.struct.datatype.{Endianness, FixedEndian, InheritedEndian} +import io.kaitai.struct.format._ +import io.kaitai.struct.languages.RustCompiler +import io.kaitai.struct.languages.components.ExtraAttrs + +import scala.collection.mutable.ListBuffer + +class RustClassCompiler( + classSpecs: ClassSpecs, + override val topClass: ClassSpec, + config: RuntimeConfig +) extends ClassCompiler(classSpecs, topClass, config, RustCompiler) { + + override def compileClass(curClass: ClassSpec): Unit = { + provider.nowClass = curClass + + val extraAttrs = ListBuffer[AttrSpec]() + extraAttrs += AttrSpec(List(), IoIdentifier, KaitaiStreamType) + extraAttrs += AttrSpec(List(), RootIdentifier, UserTypeInstream(topClassName, None)) + extraAttrs += AttrSpec(List(), ParentIdentifier, curClass.parentType) + + extraAttrs ++= ExtraAttrs.forClassSpec(curClass, lang) + + if (!curClass.doc.isEmpty) + lang.classDoc(curClass.name, curClass.doc) + + // Basic struct declaration + lang.classHeader(curClass.name) + + compileAttrDeclarations(curClass.seq ++ extraAttrs) + curClass.instances.foreach { case (instName, instSpec) => + compileInstanceDeclaration(instName, instSpec) + } + + // Constructor = Read() function + compileReadFunction(curClass) + + compileInstances(curClass) + + compileAttrReaders(curClass.seq ++ extraAttrs) + lang.classFooter(curClass.name) + + compileEnums(curClass) + + // Recursive types + compileSubclasses(curClass) + } + + def compileReadFunction(curClass: ClassSpec) = { + lang.classConstructorHeader( + curClass.name, + curClass.parentType, + topClassName, + curClass.meta.endian.contains(InheritedEndian), + curClass.params + ) + + // FIXME + val defEndian = curClass.meta.endian match { + case Some(fe: FixedEndian) => Some(fe) + case _ => None + } + + lang.readHeader(defEndian, false) + + compileSeq(curClass.seq, defEndian) + lang.classConstructorFooter + } + + override def compileInstances(curClass: ClassSpec) = { + lang.instanceDeclHeader(curClass.name) + curClass.instances.foreach { case (instName, instSpec) => + compileInstance(curClass.name, instName, instSpec, curClass.meta.endian) + } + } + + override def compileInstance(className: List[String], instName: InstanceIdentifier, instSpec: InstanceSpec, endian: Option[Endianness]): Unit = { + // FIXME: support calculated endianness + + // Determine datatype + val dataType = instSpec.dataTypeComposite + + if (!instSpec.doc.isEmpty) + lang.attributeDoc(instName, instSpec.doc) + lang.instanceHeader(className, instName, dataType, instSpec.isNullable) + lang.instanceCheckCacheAndReturn(instName, dataType) + + instSpec match { + case vi: ValueInstanceSpec => + lang.attrParseIfHeader(instName, vi.ifExpr) + lang.instanceCalculate(instName, dataType, vi.value) + lang.attrParseIfFooter(vi.ifExpr) + case i: ParseInstanceSpec => + lang.attrParse(i, instName, None) // FIXME + } + + lang.instanceSetCalculated(instName) + lang.instanceReturn(instName, dataType) + lang.instanceFooter + } +} diff --git a/shared/src/main/scala/io/kaitai/struct/TypeProcessor.scala b/shared/src/main/scala/io/kaitai/struct/TypeProcessor.scala index 7aaafaae4..92314ae9d 100644 --- a/shared/src/main/scala/io/kaitai/struct/TypeProcessor.scala +++ b/shared/src/main/scala/io/kaitai/struct/TypeProcessor.scala @@ -36,8 +36,8 @@ object TypeProcessor { } else { List() } - case SwitchType(_, cases) => - cases.flatMap { case (_, ut) => + case st: SwitchType => + st.cases.flatMap { case (_, ut) => getOpaqueDataTypes(ut) } case _ => diff --git a/shared/src/main/scala/io/kaitai/struct/Utils.scala b/shared/src/main/scala/io/kaitai/struct/Utils.scala index a14816f37..bae0c7050 100644 --- a/shared/src/main/scala/io/kaitai/struct/Utils.scala +++ b/shared/src/main/scala/io/kaitai/struct/Utils.scala @@ -5,6 +5,11 @@ import java.nio.charset.Charset import scala.collection.mutable.ListBuffer object Utils { + /** + * BigInt-typed max value of unsigned 64-bit integer. + */ + val MAX_UINT64 = BigInt("18446744073709551615") + private val RDecimal = "^(-?[0-9]+)$".r private val RHex = "^0x([0-9a-fA-F]+)$".r @@ -108,4 +113,23 @@ object Utils { } else { fullPath } + + /** + * Performs safe lookup for up to `len` character in a given + * string `src`, starting at `from`. + * @param src string to work on + * @param from starting character index + * @param len max length of substring + * @return substring of `src`, starting at `from`, up to `len` chars max + */ + def safeLookup(src: String, from: Int, len: Int): String = { + val maxLen = src.length + if (from >= maxLen) { + "" + } else { + val to = from + len + val safeTo = if (to > maxLen) maxLen else to + src.substring(from, safeTo) + } + } } diff --git a/shared/src/main/scala/io/kaitai/struct/datatype/DataType.scala b/shared/src/main/scala/io/kaitai/struct/datatype/DataType.scala index 78e111043..e23f39658 100644 --- a/shared/src/main/scala/io/kaitai/struct/datatype/DataType.scala +++ b/shared/src/main/scala/io/kaitai/struct/datatype/DataType.scala @@ -4,7 +4,13 @@ import io.kaitai.struct.exprlang.{Ast, Expressions} import io.kaitai.struct.format._ import io.kaitai.struct.translators.TypeDetector -sealed trait DataType +sealed trait DataType { + /** + * @return Data type as non-owning data type. Default implementation + * always returns itself, complex types + */ + def asNonOwning: DataType = this +} /** * A collection of case objects and classes that are used to represent internal @@ -21,14 +27,14 @@ object DataType { * A common trait for all types that can be read with a simple, * parameterless KaitaiStream API call. */ - trait ReadableType extends DataType { + sealed trait ReadableType extends DataType { def apiCall(defEndian: Option[FixedEndian]): String } - abstract class NumericType extends DataType - abstract class BooleanType extends DataType + abstract sealed class NumericType extends DataType + abstract sealed class BooleanType extends DataType - abstract class IntType extends NumericType + abstract sealed class IntType extends NumericType case object CalcIntType extends IntType case class Int1Type(signed: Boolean) extends IntType with ReadableType { override def apiCall(defEndian: Option[FixedEndian]): String = if (signed) "s1" else "u1" @@ -87,43 +93,103 @@ object DataType { case class StrFromBytesType(bytes: BytesType, encoding: String) extends StrType case object CalcBooleanType extends BooleanType - case class ArrayType(elType: DataType) extends DataType + /** + * Complex data type is a data type which creation and destruction is + * not an atomic, built-in operation, but rather a sequence of new/delete + * operations. The main common trait for all complex data types is a flag + * that determines whether they're "owning" or "borrowed". Owning objects + * manage their own creation/destruction, borrowed rely on other doing + * that. + */ + abstract sealed class ComplexDataType extends DataType { + /** + * @return If true, this is "owning" type: for languages where data ownership + * matters, this one represents primary owner of the data block, who + * will be responsible for whole life cycle: creation of the object + * and its destruction. + */ + def isOwning: Boolean + } + + /** + * Common abstract ancestor for all types which can treated as "user types". + * Namely, this typically means that this type has a name, may have some + * parameters, and forced parent expression. + * @param name name of the type, might include several components + * @param forcedParent optional parent enforcement expression + * @param args parameters passed into this type as extra arguments + */ abstract class UserType( val name: List[String], val forcedParent: Option[Ast.expr], var args: Seq[Ast.expr] - ) extends DataType { + ) extends ComplexDataType { var classSpec: Option[ClassSpec] = None def isOpaque = { val cs = classSpec.get cs.isTopLevel || cs.meta.isOpaque } + + override def asNonOwning: UserType = { + if (!isOwning) { + this + } else { + val r = CalcUserType(name, forcedParent, args) + r.classSpec = classSpec + r + } + } } case class UserTypeInstream( _name: List[String], _forcedParent: Option[Ast.expr], _args: Seq[Ast.expr] = Seq() - ) extends UserType(_name, _forcedParent, _args) + ) extends UserType(_name, _forcedParent, _args) { + def isOwning = true + } case class UserTypeFromBytes( _name: List[String], _forcedParent: Option[Ast.expr], _args: Seq[Ast.expr] = Seq(), bytes: BytesType, override val process: Option[ProcessExpr] - ) extends UserType(_name, _forcedParent, _args) with Processing + ) extends UserType(_name, _forcedParent, _args) with Processing { + override def isOwning = true + } + case class CalcUserType( + _name: List[String], + _forcedParent: Option[Ast.expr], + _args: Seq[Ast.expr] = Seq() + ) extends UserType(_name, _forcedParent, _args) { + override def isOwning = false + } + + case class ArrayType(elType: DataType) extends ComplexDataType { + override def isOwning: Boolean = true + override def asNonOwning: CalcArrayType = CalcArrayType(elType) + } + case class CalcArrayType(elType: DataType) extends ComplexDataType { + override def isOwning: Boolean = false + } val USER_TYPE_NO_PARENT = Ast.expr.Bool(false) case object AnyType extends DataType - case object KaitaiStructType extends DataType + case object KaitaiStructType extends ComplexDataType { + def isOwning = true + override def asNonOwning: DataType = CalcKaitaiStructType + } + case object CalcKaitaiStructType extends ComplexDataType { + def isOwning = false + } case object KaitaiStreamType extends DataType case class EnumType(name: List[String], basedOn: IntType) extends DataType { var enumSpec: Option[EnumSpec] = None } - case class SwitchType(on: Ast.expr, cases: Map[Ast.expr, DataType]) extends DataType { + case class SwitchType(on: Ast.expr, cases: Map[Ast.expr, DataType], isOwning: Boolean = true) extends ComplexDataType { def combinedType: DataType = TypeDetector.combineTypes(cases.values) /** @@ -160,6 +226,8 @@ object DataType { cases.values.exists((t) => t.isInstanceOf[UserTypeFromBytes] || t.isInstanceOf[BytesType] ) + + override def asNonOwning: DataType = SwitchType(on, cases, false) } object SwitchType { @@ -301,7 +369,11 @@ object DataType { } } - arg.enumRef match { + applyEnumType(r, arg.enumRef, path) + } + + private def applyEnumType(r: DataType, enumRef: Option[String], path: List[String]) = { + enumRef match { case Some(enumName) => r match { case numType: IntType => EnumType(classNameToList(enumName), numType) @@ -316,40 +388,45 @@ object DataType { private val RePureIntType = """([us])(2|4|8)""".r private val RePureFloatType = """f(4|8)""".r - def pureFromString(dto: Option[String]): DataType = { - dto match { - case None => CalcBytesType - case Some(dt) => dt match { - case "u1" => Int1Type(false) - case "s1" => Int1Type(true) - case RePureIntType(signStr, widthStr) => - IntMultiType( - signStr match { - case "s" => true - case "u" => false - }, - widthStr match { - case "2" => Width2 - case "4" => Width4 - case "8" => Width8 - }, - None - ) - case RePureFloatType(widthStr) => - FloatMultiType( - widthStr match { - case "4" => Width4 - case "8" => Width8 - }, - None - ) - case "str" => CalcStrType - case "bool" => CalcBooleanType - case "struct" => KaitaiStructType - case "io" => KaitaiStreamType - case "any" => AnyType - } - } + def pureFromString(dto: Option[String], enumRef: Option[String], path: List[String]): DataType = + applyEnumType(pureFromString(dto), enumRef, path) + + def pureFromString(dto: Option[String]): DataType = dto match { + case None => CalcBytesType + case Some(dt) => pureFromString(dt) + } + + def pureFromString(dt: String): DataType = dt match { + case "bytes" => CalcBytesType + case "u1" => Int1Type(false) + case "s1" => Int1Type(true) + case RePureIntType(signStr, widthStr) => + IntMultiType( + signStr match { + case "s" => true + case "u" => false + }, + widthStr match { + case "2" => Width2 + case "4" => Width4 + case "8" => Width8 + }, + None + ) + case RePureFloatType(widthStr) => + FloatMultiType( + widthStr match { + case "4" => Width4 + case "8" => Width8 + }, + None + ) + case "str" => CalcStrType + case "bool" => CalcBooleanType + case "struct" => CalcKaitaiStructType + case "io" => KaitaiStreamType + case "any" => AnyType + case _ => CalcUserType(classNameToList(dt), None) } def getEncoding(curEncoding: Option[String], metaDef: MetaSpec, path: List[String]): String = { diff --git a/shared/src/main/scala/io/kaitai/struct/datatype/Endianness.scala b/shared/src/main/scala/io/kaitai/struct/datatype/Endianness.scala index 9db938c3b..4606d88e9 100644 --- a/shared/src/main/scala/io/kaitai/struct/datatype/Endianness.scala +++ b/shared/src/main/scala/io/kaitai/struct/datatype/Endianness.scala @@ -6,7 +6,7 @@ import io.kaitai.struct.format.{ParseUtils, YAMLParseException} sealed trait Endianness -abstract class FixedEndian extends Endianness { +sealed abstract class FixedEndian extends Endianness { def toSuffix: String } case object LittleEndian extends FixedEndian { diff --git a/shared/src/main/scala/io/kaitai/struct/exprlang/Ast.scala b/shared/src/main/scala/io/kaitai/struct/exprlang/Ast.scala index ff23decd6..94131177e 100644 --- a/shared/src/main/scala/io/kaitai/struct/exprlang/Ast.scala +++ b/shared/src/main/scala/io/kaitai/struct/exprlang/Ast.scala @@ -21,6 +21,9 @@ package io.kaitai.struct.exprlang */ object Ast { case class identifier(name: String) + case class typeId(absolute: Boolean, names: Seq[String], isArray: Boolean = false) + + val EmptyTypeId = typeId(false, Seq()) // BoolOp() can use left & right? sealed trait expr @@ -36,11 +39,11 @@ object Ast { case class FloatNum(n: BigDecimal) extends expr case class Str(s: String) extends expr case class Bool(n: Boolean) extends expr - case class EnumByLabel(enumName: identifier, label: identifier) extends expr - case class EnumById(enumName: identifier, id: expr) extends expr + case class EnumByLabel(enumName: identifier, label: identifier, inType: typeId = EmptyTypeId) extends expr + case class EnumById(enumName: identifier, id: expr, inType: typeId = EmptyTypeId) extends expr case class Attribute(value: expr, attr: identifier) extends expr - case class CastToType(value: expr, typeName: identifier) extends expr + case class CastToType(value: expr, typeName: typeId) extends expr case class Subscript(value: expr, idx: expr) extends expr case class Name(id: identifier) extends expr case class List(elts: Seq[expr]) extends expr diff --git a/shared/src/main/scala/io/kaitai/struct/exprlang/Expressions.scala b/shared/src/main/scala/io/kaitai/struct/exprlang/Expressions.scala index 1e0d6f1ea..405de889d 100644 --- a/shared/src/main/scala/io/kaitai/struct/exprlang/Expressions.scala +++ b/shared/src/main/scala/io/kaitai/struct/exprlang/Expressions.scala @@ -27,6 +27,10 @@ import fastparse.StringReprOps object Expressions { val NAME: P[Ast.identifier] = Lexical.identifier + val TYPE_NAME: P[Ast.typeId] = P("::".!.? ~ NAME.rep(1, "::") ~ ("[" ~ "]").!.?).map { + case (first, names: Seq[Ast.identifier], arrayStr) => + Ast.typeId(first.nonEmpty, names.map((el) => el.name), arrayStr.nonEmpty) + } val INT_NUMBER = Lexical.integer val FLOAT_NUMBER = Lexical.floatnumber val STRING: P[String] = Lexical.stringliteral @@ -43,7 +47,7 @@ object Expressions { case Seq(x) => x case xs => Ast.expr.BoolOp(Ast.boolop.And, xs) } - val not_test: P[Ast.expr] = P( ("not" ~ not_test).map(Ast.expr.UnaryOp(Ast.unaryop.Not, _)) | comparison ) + val not_test: P[Ast.expr] = P( (kw("not") ~ not_test).map(Ast.expr.UnaryOp(Ast.unaryop.Not, _)) | comparison ) val comparison: P[Ast.expr] = P( expr ~ (comp_op ~ expr).? ).map{ case (lhs, None) => lhs case (lhs, Some(chunks)) => @@ -130,7 +134,7 @@ object Expressions { val trailer: P[Ast.expr => Ast.expr] = { val call = P("(" ~ arglist ~ ")").map{ case (args) => (lhs: Ast.expr) => Ast.expr.Call(lhs, args)} val slice = P("[" ~ test ~ "]").map{ case (args) => (lhs: Ast.expr) => Ast.expr.Subscript(lhs, args)} - val cast = P( "." ~ "as" ~ "<" ~ NAME ~ ">" ).map( + val cast = P( "." ~ "as" ~ "<" ~ TYPE_NAME ~ ">" ).map( typeName => (lhs: Ast.expr) => Ast.expr.CastToType(lhs, typeName) ) val attr = P("." ~ NAME).map(id => (lhs: Ast.expr) => Ast.expr.Attribute(lhs, id)) @@ -156,8 +160,18 @@ object Expressions { val testlist1: P[Seq[Ast.expr]] = P( test.rep(1, sep = ",") ) - val enumByName: P[Ast.expr.EnumByLabel] = P( (NAME) ~ "::" ~ (NAME) ).map { - case(enumName, enumLabel) => Ast.expr.EnumByLabel(enumName, enumLabel) + val enumByName: P[Ast.expr.EnumByLabel] = P("::".!.? ~ NAME.rep(2, "::")).map { + case (first, names: Seq[Ast.identifier]) => + val isAbsolute = first.nonEmpty + val (enumName, enumLabel) = names.takeRight(2) match { + case Seq(a, b) => (a, b) + } + val typePath = names.dropRight(2) + if (typePath.isEmpty) { + Ast.expr.EnumByLabel(enumName, enumLabel, Ast.EmptyTypeId) + } else { + Ast.expr.EnumByLabel(enumName, enumLabel, Ast.typeId(isAbsolute, typePath.map(_.name))) + } } val topExpr: P[Ast.expr] = P( test ~ End ) @@ -171,7 +185,7 @@ object Expressions { def parseList(src: String): Seq[Ast.expr] = realParse(src, topExprList) private def realParse[T](src: String, parser: P[T]): T = { - val r = parser.parse(src) + val r = parser.parse(src.trim) r match { case Parsed.Success(value, _) => value case f: Parsed.Failure => diff --git a/shared/src/main/scala/io/kaitai/struct/format/AttrSpec.scala b/shared/src/main/scala/io/kaitai/struct/format/AttrSpec.scala index 444d81578..60d547e89 100644 --- a/shared/src/main/scala/io/kaitai/struct/format/AttrSpec.scala +++ b/shared/src/main/scala/io/kaitai/struct/format/AttrSpec.scala @@ -10,12 +10,6 @@ import io.kaitai.struct.exprlang.{Ast, Expressions} import scala.collection.JavaConversions._ -sealed trait RepeatSpec -case class RepeatExpr(expr: Ast.expr) extends RepeatSpec -case class RepeatUntil(expr: Ast.expr) extends RepeatSpec -case object RepeatEos extends RepeatSpec -case object NoRepeat extends RepeatSpec - case class ConditionalSpec(ifExpr: Option[Ast.expr], repeat: RepeatSpec) trait AttrLikeSpec extends MemberSpec { @@ -173,20 +167,17 @@ object AttrSpec { val process = ProcessExpr.fromStr(ParseUtils.getOptValueStr(srcMap, "process", path), path) // TODO: add proper path propagation val contents = srcMap.get("contents").map(parseContentSpec(_, path ++ List("contents"))) - val size = ParseUtils.getOptValueStr(srcMap, "size", path).map(Expressions.parse) + val size = ParseUtils.getOptValueExpression(srcMap, "size", path) val sizeEos = ParseUtils.getOptValueBool(srcMap, "size-eos", path).getOrElse(false) - val ifExpr = ParseUtils.getOptValueStr(srcMap, "if", path).map(Expressions.parse) + val ifExpr = ParseUtils.getOptValueExpression(srcMap, "if", path) val encoding = ParseUtils.getOptValueStr(srcMap, "encoding", path) - val repeat = ParseUtils.getOptValueStr(srcMap, "repeat", path) - val repeatExpr = ParseUtils.getOptValueStr(srcMap, "repeat-expr", path).map(Expressions.parse) - val repeatUntil = ParseUtils.getOptValueStr(srcMap, "repeat-until", path).map(Expressions.parse) val terminator = ParseUtils.getOptValueInt(srcMap, "terminator", path) val consume = ParseUtils.getOptValueBool(srcMap, "consume", path).getOrElse(true) val include = ParseUtils.getOptValueBool(srcMap, "include", path).getOrElse(false) val eosError = ParseUtils.getOptValueBool(srcMap, "eos-error", path).getOrElse(true) val padRight = ParseUtils.getOptValueInt(srcMap, "pad-right", path) val enum = ParseUtils.getOptValueStr(srcMap, "enum", path) - val parent = ParseUtils.getOptValueStr(srcMap, "parent", path).map(Expressions.parse) + val parent = ParseUtils.getOptValueExpression(srcMap, "parent", path) val typObj = srcMap.get("type") @@ -216,14 +207,14 @@ object AttrSpec { } } - val (repeatSpec, legalRepeatKeys) = parseRepeat(repeat, repeatExpr, repeatUntil, path) + val (repeatSpec, legalRepeatKeys) = RepeatSpec.fromYaml(srcMap, path) val legalKeys = LEGAL_KEYS ++ legalRepeatKeys ++ (dataType match { case _: BytesType => LEGAL_KEYS_BYTES case _: StrFromBytesType => LEGAL_KEYS_STR case _: UserType => LEGAL_KEYS_BYTES case EnumType(_, _) => LEGAL_KEYS_ENUM - case SwitchType(on, cases) => LEGAL_KEYS_BYTES + case _: SwitchType => LEGAL_KEYS_BYTES case _ => Set() }) @@ -265,18 +256,11 @@ object AttrSpec { metaDef: MetaSpec, arg: YamlAttrArgs ): DataType = { - val _on = ParseUtils.getValueStr(switchSpec, "switch-on", path) + val on = ParseUtils.getValueExpression(switchSpec, "switch-on", path) val _cases = ParseUtils.getValueMapStrStr(switchSpec, "cases", path) ParseUtils.ensureLegalKeys(switchSpec, LEGAL_KEYS_SWITCH, path) - val on = try { - Expressions.parse(_on) - } catch { - case epe: Expressions.ParseException => - throw YAMLParseException.expression(epe, path ++ List("switch-on")) - } - val cases = _cases.map { case (condition, typeName) => val casePath = path ++ List("cases", condition) val condType = DataType.fromYaml( @@ -311,42 +295,4 @@ object AttrSpec { SwitchType(on, cases ++ addCases) } - - private def parseRepeat( - repeat: Option[String], - rExpr: Option[Ast.expr], - rUntil: Option[Ast.expr], - path: List[String] - ): (RepeatSpec, Set[String]) = { - repeat match { - case None => - (NoRepeat, Set()) - case Some("until") => - val spec = rUntil match { - case Some(expr) => RepeatUntil(expr) - case None => - throw new YAMLParseException( - "`repeat: until` requires a `repeat-until` expression", - path ++ List("repeat") - ) - } - (spec, Set("repeat-until")) - case Some("expr") => - val spec = rExpr match { - case Some(expr) => RepeatExpr(expr) - case None => - throw new YAMLParseException( - "`repeat: expr` requires a `repeat-expr` expression", - path ++ List("repeat") - ) - } - (spec, Set("repeat-expr")) - case Some("eos") => - (RepeatEos, Set()) - case Some(other) => - throw YAMLParseException.badDictValue( - Set("until", "expr", "eos"), other, path ++ List("repeat") - ) - } - } } diff --git a/shared/src/main/scala/io/kaitai/struct/format/ClassSpec.scala b/shared/src/main/scala/io/kaitai/struct/format/ClassSpec.scala index 2a6a34dab..60582238d 100644 --- a/shared/src/main/scala/io/kaitai/struct/format/ClassSpec.scala +++ b/shared/src/main/scala/io/kaitai/struct/format/ClassSpec.scala @@ -1,7 +1,8 @@ package io.kaitai.struct.format import io.kaitai.struct.datatype.DataType -import io.kaitai.struct.datatype.DataType.{KaitaiStructType, UserTypeInstream} +import io.kaitai.struct.datatype.DataType._ + import scala.collection.mutable /** @@ -53,8 +54,8 @@ case class ClassSpec( var seqSize: Sized = NotCalculatedSized def parentType: DataType = parentClass match { - case UnknownClassSpec | GenericStructClassSpec => KaitaiStructType - case t: ClassSpec => UserTypeInstream(t.name, None) + case UnknownClassSpec | GenericStructClassSpec => CalcKaitaiStructType + case t: ClassSpec => CalcUserType(t.name, None) } /** @@ -67,6 +68,21 @@ case class ClassSpec( typeSpec.forEachRec(proc) } } + + override def equals(obj: Any): Boolean = obj match { + case other: ClassSpec => + path == other.path && + isTopLevel == other.isTopLevel && + meta == other.meta && + doc == other.doc && + params == other.params && + seq == other.seq && + types == other.types && + instances == other.instances && + enums == other.enums && + name == other.name + case _ => false + } } object ClassSpec { diff --git a/shared/src/main/scala/io/kaitai/struct/format/EnumSpec.scala b/shared/src/main/scala/io/kaitai/struct/format/EnumSpec.scala index 8f10e19f2..2794e0b85 100644 --- a/shared/src/main/scala/io/kaitai/struct/format/EnumSpec.scala +++ b/shared/src/main/scala/io/kaitai/struct/format/EnumSpec.scala @@ -3,6 +3,12 @@ package io.kaitai.struct.format case class EnumSpec(map: Map[Long, EnumValueSpec]) { var name = List[String]() + /** + * @return Absolute name of enum as string, components separated by + * double colon operator `::` + */ + def nameAsStr = name.mkString("::") + /** * Stabilize order of generated enums by sorting it by integer ID - it * both looks nicer and doesn't screw diffs in generated code. diff --git a/shared/src/main/scala/io/kaitai/struct/format/InstanceSpec.scala b/shared/src/main/scala/io/kaitai/struct/format/InstanceSpec.scala index 6d6ca2821..4da5050c8 100644 --- a/shared/src/main/scala/io/kaitai/struct/format/InstanceSpec.scala +++ b/shared/src/main/scala/io/kaitai/struct/format/InstanceSpec.scala @@ -18,7 +18,7 @@ case class ValueInstanceSpec( override def isNullable: Boolean = ifExpr.isDefined } case class ParseInstanceSpec( - id: Identifier, + id: InstanceIdentifier, path: List[String], private val _doc: DocSpec, dataType: DataType, @@ -41,7 +41,7 @@ object InstanceSpec { def fromYaml(src: Any, path: List[String], metaDef: MetaSpec, id: InstanceIdentifier): InstanceSpec = { val srcMap = ParseUtils.asMapStr(src, path) - ParseUtils.getOptValueStr(srcMap, "value", path).map(Expressions.parse) match { + ParseUtils.getOptValueExpression(srcMap, "value", path) match { case Some(value) => // value instance ParseUtils.ensureLegalKeys(srcMap, LEGAL_KEYS_VALUE_INST, path, Some("value instance")) @@ -54,7 +54,7 @@ object InstanceSpec { Ast.expr.EnumById(Ast.identifier(enumName), value) } - val ifExpr = ParseUtils.getOptValueStr(srcMap, "if", path).map(Expressions.parse) + val ifExpr = ParseUtils.getOptValueExpression(srcMap, "if", path) ValueInstanceSpec( path, @@ -65,8 +65,8 @@ object InstanceSpec { ) case None => // normal positional instance - val pos = ParseUtils.getOptValueStr(srcMap, "pos", path).map(Expressions.parse) - val io = ParseUtils.getOptValueStr(srcMap, "io", path).map(Expressions.parse) + val pos = ParseUtils.getOptValueExpression(srcMap, "pos", path) + val io = ParseUtils.getOptValueExpression(srcMap, "io", path) val fakeAttrMap = srcMap.filterKeys((key) => key != "pos" && key != "io") val a = AttrSpec.fromYaml(fakeAttrMap, path, metaDef, id) diff --git a/shared/src/main/scala/io/kaitai/struct/format/ParamDefSpec.scala b/shared/src/main/scala/io/kaitai/struct/format/ParamDefSpec.scala index b5cef37be..5d25eaf62 100644 --- a/shared/src/main/scala/io/kaitai/struct/format/ParamDefSpec.scala +++ b/shared/src/main/scala/io/kaitai/struct/format/ParamDefSpec.scala @@ -22,6 +22,7 @@ object ParamDefSpec { val LEGAL_KEYS = Set( "id", "type", + "enum", "doc", "doc-ref" ) @@ -29,7 +30,11 @@ object ParamDefSpec { def fromYaml(srcMap: Map[String, Any], path: List[String], id: Identifier): ParamDefSpec = { val doc = DocSpec.fromYaml(srcMap, path) val typeStr = ParseUtils.getOptValueStr(srcMap, "type", path) - val dataType = DataType.pureFromString(typeStr) + val enumRef = ParseUtils.getOptValueStr(srcMap, "enum", path) + + val dataType = DataType.pureFromString(typeStr, enumRef, path) + + ParseUtils.ensureLegalKeys(srcMap, LEGAL_KEYS, path, Some("parameter definition")) ParamDefSpec(path, id, dataType, doc) } diff --git a/shared/src/main/scala/io/kaitai/struct/format/ParseUtils.scala b/shared/src/main/scala/io/kaitai/struct/format/ParseUtils.scala index 614f6954d..83430cdb1 100644 --- a/shared/src/main/scala/io/kaitai/struct/format/ParseUtils.scala +++ b/shared/src/main/scala/io/kaitai/struct/format/ParseUtils.scala @@ -1,6 +1,7 @@ package io.kaitai.struct.format import io.kaitai.struct.Utils +import io.kaitai.struct.exprlang.{Ast, Expressions} object ParseUtils { def ensureLegalKeys(src: Map[String, Any], legalKeys: Set[String], path: List[String], where: Option[String] = None) = { @@ -83,6 +84,24 @@ object ParseUtils { } } + def getValueExpression(src: Map[String, Any], field: String, path: List[String]): Ast.expr = { + try { + Expressions.parse(getValueStr(src, field, path)) + } catch { + case epe: Expressions.ParseException => + throw YAMLParseException.expression(epe, path) + } + } + + def getOptValueExpression(src: Map[String, Any], field: String, path: List[String]): Option[Ast.expr] = { + try { + getOptValueStr(src, field, path).map(Expressions.parse) + } catch { + case epe: Expressions.ParseException => + throw YAMLParseException.expression(epe, path) + } + } + /** * Gets a list of T-typed values from a given YAML map's key "field", * reporting errors accurately and ensuring type safety. diff --git a/shared/src/main/scala/io/kaitai/struct/format/ProcessExpr.scala b/shared/src/main/scala/io/kaitai/struct/format/ProcessExpr.scala index 7fe8a6cf9..9764f8731 100644 --- a/shared/src/main/scala/io/kaitai/struct/format/ProcessExpr.scala +++ b/shared/src/main/scala/io/kaitai/struct/format/ProcessExpr.scala @@ -20,20 +20,25 @@ object ProcessExpr { case None => None case Some(op) => - Some(op match { - case "zlib" => - ProcessZlib - case ReXor(arg) => - ProcessXor(Expressions.parse(arg)) - case ReRotate(dir, arg) => - ProcessRotate(dir == "l", Expressions.parse(arg)) - case ReCustom(name, args) => - ProcessCustom(name.split('.').toList, Expressions.parseList(args)) - case ReCustomNoArg(name) => - ProcessCustom(name.split('.').toList, Seq()) - case _ => - throw YAMLParseException.badProcess(op, path) - }) + try { + Some(op match { + case "zlib" => + ProcessZlib + case ReXor(arg) => + ProcessXor(Expressions.parse(arg)) + case ReRotate(dir, arg) => + ProcessRotate(dir == "l", Expressions.parse(arg)) + case ReCustom(name, args) => + ProcessCustom(name.split('.').toList, Expressions.parseList(args)) + case ReCustomNoArg(name) => + ProcessCustom(name.split('.').toList, Seq()) + case _ => + throw YAMLParseException.badProcess(op, path) + }) + } catch { + case epe: Expressions.ParseException => + throw YAMLParseException.expression(epe, path) + } } } } diff --git a/shared/src/main/scala/io/kaitai/struct/format/RepeatSpec.scala b/shared/src/main/scala/io/kaitai/struct/format/RepeatSpec.scala new file mode 100644 index 000000000..01187c090 --- /dev/null +++ b/shared/src/main/scala/io/kaitai/struct/format/RepeatSpec.scala @@ -0,0 +1,51 @@ +package io.kaitai.struct.format + +import io.kaitai.struct.exprlang.Ast + +sealed trait RepeatSpec +case class RepeatExpr(expr: Ast.expr) extends RepeatSpec +case class RepeatUntil(expr: Ast.expr) extends RepeatSpec +case object RepeatEos extends RepeatSpec +case object NoRepeat extends RepeatSpec + +object RepeatSpec { + def fromYaml( + srcMap: Map[String, Any], + path: List[String] + ): (RepeatSpec, Set[String]) = { + val repeat = ParseUtils.getOptValueStr(srcMap, "repeat", path) + val repeatExpr = ParseUtils.getOptValueExpression(srcMap, "repeat-expr", path) + val repeatUntil = ParseUtils.getOptValueExpression(srcMap, "repeat-until", path) + + repeat match { + case None => + (NoRepeat, Set()) + case Some("until") => + val spec = repeatUntil match { + case Some(expr) => RepeatUntil(expr) + case None => + throw new YAMLParseException( + "`repeat: until` requires a `repeat-until` expression", + path ++ List("repeat") + ) + } + (spec, Set("repeat-until")) + case Some("expr") => + val spec = repeatExpr match { + case Some(expr) => RepeatExpr(expr) + case None => + throw new YAMLParseException( + "`repeat: expr` requires a `repeat-expr` expression", + path ++ List("repeat") + ) + } + (spec, Set("repeat-expr")) + case Some("eos") => + (RepeatEos, Set()) + case Some(other) => + throw YAMLParseException.badDictValue( + Set("until", "expr", "eos"), other, path ++ List("repeat") + ) + } + } +} \ No newline at end of file diff --git a/shared/src/main/scala/io/kaitai/struct/format/YAMLParseException.scala b/shared/src/main/scala/io/kaitai/struct/format/YAMLParseException.scala index bedf3bb70..bf5829e00 100644 --- a/shared/src/main/scala/io/kaitai/struct/format/YAMLParseException.scala +++ b/shared/src/main/scala/io/kaitai/struct/format/YAMLParseException.scala @@ -1,6 +1,7 @@ package io.kaitai.struct.format import fastparse.StringReprOps +import io.kaitai.struct.Utils import io.kaitai.struct.datatype.DataType import io.kaitai.struct.exprlang.Expressions @@ -37,8 +38,22 @@ object YAMLParseException { def expression(epe: Expressions.ParseException, path: List[String]): YAMLParseException = { val f = epe.failure val pos = StringReprOps.prettyIndex(f.extra.input, f.index) + + // Try to diagnose most common errors and provide a friendly suggestion + val lookup2 = Utils.safeLookup(epe.src, f.index, 2) + val suggestion: String = (if (lookup2 == "&&") { + Some("and") + } else if (lookup2 == "||") { + Some("or") + } else { + None + }).map((x) => s", did you mean '$x'?").getOrElse("") + + f.extra.traced.expected + new YAMLParseException( - s"parsing expression '${epe.src}' failed on $pos, expected ${f.extra.traced.expected.replaceAll("\n", "\\n")}", + s"parsing expression '${epe.src}' failed on $pos, " + + s"expected ${f.extra.traced.expected.replaceAll("\n", "\\n")}$suggestion", path ) } diff --git a/shared/src/main/scala/io/kaitai/struct/languages/CSharpCompiler.scala b/shared/src/main/scala/io/kaitai/struct/languages/CSharpCompiler.scala index 55fcd971a..c4985ed24 100644 --- a/shared/src/main/scala/io/kaitai/struct/languages/CSharpCompiler.scala +++ b/shared/src/main/scala/io/kaitai/struct/languages/CSharpCompiler.scala @@ -2,10 +2,10 @@ package io.kaitai.struct.languages import io.kaitai.struct._ import io.kaitai.struct.datatype.DataType._ -import io.kaitai.struct.datatype.{CalcEndian, DataType, FixedEndian, InheritedEndian} +import io.kaitai.struct.datatype._ import io.kaitai.struct.exprlang.Ast import io.kaitai.struct.exprlang.Ast.expr -import io.kaitai.struct.format.{RepeatUntil, _} +import io.kaitai.struct.format._ import io.kaitai.struct.languages.components._ import io.kaitai.struct.translators.{CSharpTranslator, TypeDetector} @@ -135,7 +135,7 @@ class CSharpCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) } override def readHeader(endian: Option[FixedEndian], isEmpty: Boolean) = { - val readAccessAndType = if (debug) { + val readAccessAndType = if (!config.autoRead) { "public" } else { "private" @@ -341,9 +341,11 @@ class CSharpCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) out.puts("}") } - override def handleAssignmentSimple(id: Identifier, expr: String): Unit = { + override def handleAssignmentSimple(id: Identifier, expr: String): Unit = out.puts(s"${privateMemberName(id)} = $expr;") - } + + override def handleAssignmentTempVar(dataType: DataType, id: String, expr: String): Unit = + out.puts(s"${kaitaiType2NativeType(dataType)} $id = $expr;") override def parseExpr(dataType: DataType, assignType: DataType, io: String, defEndian: Option[FixedEndian]): String = { dataType match { @@ -391,6 +393,9 @@ class CSharpCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) expr2 } + override def userTypeDebugRead(id: String): Unit = + out.puts(s"$id._read();") + /** * Designates switch mode. If false, we're doing real switch-case for this * attribute. If true, we're doing if-based emulation. @@ -497,14 +502,14 @@ class CSharpCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) out.puts("}") } - override def instanceCheckCacheAndReturn(instName: InstanceIdentifier): Unit = { + override def instanceCheckCacheAndReturn(instName: InstanceIdentifier, dataType: DataType): Unit = { out.puts(s"if (${flagForInstName(instName)})") out.inc - instanceReturn(instName) + instanceReturn(instName, dataType) out.dec } - override def instanceReturn(instName: InstanceIdentifier): Unit = { + override def instanceReturn(instName: InstanceIdentifier, attrType: DataType): Unit = { out.puts(s"return ${privateMemberName(instName)};") } @@ -601,7 +606,7 @@ object CSharpCompiler extends LanguageCompilerStatic case _: BytesType => "byte[]" case AnyType => "object" - case KaitaiStructType => kstructName + case KaitaiStructType | CalcKaitaiStructType => kstructName case KaitaiStreamType => kstreamName case t: UserType => types2class(t.name) @@ -609,7 +614,7 @@ object CSharpCompiler extends LanguageCompilerStatic case ArrayType(inType) => s"List<${kaitaiType2NativeType(inType)}>" - case SwitchType(_, cases) => kaitaiType2NativeType(TypeDetector.combineTypes(cases.values)) + case st: SwitchType => kaitaiType2NativeType(st.combinedType) } } @@ -625,7 +630,10 @@ object CSharpCompiler extends LanguageCompilerStatic } } - def types2class(names: List[String]) = names.map(x => type2class(x)).mkString(".") + def types2class(typeName: Ast.typeId): String = + // FIXME: handle absolute + types2class(typeName.names) + def types2class(names: Iterable[String]) = names.map(type2class).mkString(".") override def kstructName = "KaitaiStruct" override def kstreamName = "KaitaiStream" diff --git a/shared/src/main/scala/io/kaitai/struct/languages/CppCompiler.scala b/shared/src/main/scala/io/kaitai/struct/languages/CppCompiler.scala index 3ebf25e96..93a3884cb 100644 --- a/shared/src/main/scala/io/kaitai/struct/languages/CppCompiler.scala +++ b/shared/src/main/scala/io/kaitai/struct/languages/CppCompiler.scala @@ -1,5 +1,6 @@ package io.kaitai.struct.languages +import io.kaitai.struct.CppRuntimeConfig._ import io.kaitai.struct._ import io.kaitai.struct.datatype.DataType._ import io.kaitai.struct.datatype.{CalcEndian, DataType, FixedEndian, InheritedEndian} @@ -9,8 +10,6 @@ import io.kaitai.struct.format._ import io.kaitai.struct.languages.components._ import io.kaitai.struct.translators.{CppTranslator, TypeDetector} -import scala.collection.mutable.ListBuffer - class CppCompiler( typeProvider: ClassTypeProvider, config: RuntimeConfig @@ -25,7 +24,7 @@ class CppCompiler( val importListSrc = new ImportList val importListHdr = new ImportList - override val translator = new CppTranslator(typeProvider, importListSrc) + override val translator = new CppTranslator(typeProvider, importListSrc, config) val outSrcHeader = new StringLanguageOutputWriter(indent) val outHdrHeader = new StringLanguageOutputWriter(indent) val outSrc = new StringLanguageOutputWriter(indent) @@ -54,11 +53,16 @@ class CppCompiler( override def fileHeader(topClassName: String): Unit = { outSrcHeader.puts(s"// $headerComment") outSrcHeader.puts + outSrcHeader.puts("#include ") outSrcHeader.puts("#include \"" + outFileName(topClassName) + ".h\"") outSrcHeader.puts - outHdrHeader.puts(s"#ifndef ${defineName(topClassName)}") - outHdrHeader.puts(s"#define ${defineName(topClassName)}") + if (config.cppConfig.usePragmaOnce) { + outHdrHeader.puts("#pragma once") + } else { + outHdrHeader.puts(s"#ifndef ${defineName(topClassName)}") + outHdrHeader.puts(s"#define ${defineName(topClassName)}") + } outHdrHeader.puts outHdrHeader.puts(s"// $headerComment") outHdrHeader.puts @@ -67,6 +71,13 @@ class CppCompiler( importListHdr.add("stdint.h") + config.cppConfig.pointers match { + case SharedPointers | UniqueAndRawPointers => + importListHdr.add("memory") + case RawPointers => + // no extra includes + } + // API compatibility check val minVer = KSVersion.minimalRuntime.toInt outHdr.puts @@ -76,11 +87,27 @@ class CppCompiler( KSVersion.minimalRuntime + " or later is required\"" ) outHdr.puts("#endif") + + config.cppConfig.namespace.foreach { (namespace) => + outSrc.puts(s"namespace $namespace {") + outSrc.inc + outHdr.puts(s"namespace $namespace {") + outHdr.inc + } } override def fileFooter(topClassName: String): Unit = { - outHdr.puts - outHdr.puts(s"#endif // ${defineName(topClassName)}") + config.cppConfig.namespace.foreach { (_) => + outSrc.dec + outSrc.puts("}") + outHdr.dec + outHdr.puts("}") + } + + if (!config.cppConfig.usePragmaOnce) { + outHdr.puts + outHdr.puts(s"#endif // ${defineName(topClassName)}") + } } override def opaqueClassDeclaration(classSpec: ClassSpec): Unit = { @@ -89,8 +116,15 @@ class CppCompiler( } override def classHeader(name: List[String]): Unit = { + val className = types2class(List(name.last)) + + val extraInherits = config.cppConfig.pointers match { + case RawPointers | UniqueAndRawPointers => "" + case SharedPointers => s", std::enable_shared_from_this<$className>" + } + outHdr.puts - outHdr.puts(s"class ${types2class(List(name.last))} : public $kstructName {") + outHdr.puts(s"class $className : public $kstructName$extraInherits {") outHdr.inc accessMode = PrivateAccess ensureMode(PublicAccess) @@ -129,33 +163,46 @@ class CppCompiler( s"${kaitaiType2NativeType(p.dataType)} ${paramName(p.id)}" ), "", ", ", ", ") + val classNameBrief = types2class(List(name.last)) + // Parameter names val pIo = paramName(IoIdentifier) val pParent = paramName(ParentIdentifier) val pRoot = paramName(RootIdentifier) // Types - val tIo = s"$kstreamName*" + val tIo = kaitaiType2NativeType(KaitaiStreamType) val tParent = kaitaiType2NativeType(parentType) - val tRoot = s"${types2class(rootClassName)}*" + val tRoot = kaitaiType2NativeType(CalcUserType(rootClassName, None)) outHdr.puts - outHdr.puts(s"${types2class(List(name.last))}($paramsArg" + + outHdr.puts(s"$classNameBrief($paramsArg" + s"$tIo $pIo, " + - s"$tParent $pParent = 0, " + - s"$tRoot $pRoot = 0$endianSuffixHdr);" + s"$tParent $pParent = $nullPtr, " + + s"$tRoot $pRoot = $nullPtr$endianSuffixHdr);" ) outSrc.puts - outSrc.puts(s"${types2class(name)}::${types2class(List(name.last))}($paramsArg" + + outSrc.puts(s"${types2class(name)}::$classNameBrief($paramsArg" + s"$tIo $pIo, " + s"$tParent $pParent, " + s"$tRoot $pRoot$endianSuffixSrc) : $kstructName($pIo) {" ) outSrc.inc + + // In shared pointers mode, this is required to be able to work with shared pointers to this + // in a constructor. This is obviously a hack and not a good practice. + // https://forum.libcinder.org/topic/solution-calling-shared-from-this-in-the-constructor + if (config.cppConfig.pointers == CppRuntimeConfig.SharedPointers) { + outSrc.puts(s"const auto weakPtrTrick = std::shared_ptr<$classNameBrief>(this, []($classNameBrief*){});") + } + handleAssignmentSimple(ParentIdentifier, pParent) handleAssignmentSimple(RootIdentifier, if (name == rootClassName) { - "this" + config.cppConfig.pointers match { + case RawPointers | UniqueAndRawPointers => "this" + case SharedPointers => "shared_from_this()" + } } else { pRoot }) @@ -216,7 +263,9 @@ class CppCompiler( case Some(e) => s"_${e.toSuffix}" case None => "" } - ensureMode(PrivateAccess) + + ensureMode(if (config.autoRead) PrivateAccess else PublicAccess) + outHdr.puts(s"void _read$suffix();") outSrc.puts outSrc.puts(s"void ${types2class(typeProvider.nowClass.name)}::_read$suffix() {") @@ -251,7 +300,7 @@ class CppCompiler( override def attributeReader(attrName: Identifier, attrType: DataType, isNullable: Boolean): Unit = { ensureMode(PublicAccess) - outHdr.puts(s"${kaitaiType2NativeType(attrType)} ${publicMemberName(attrName)}() const { return ${privateMemberName(attrName)}; }") + outHdr.puts(s"${kaitaiType2NativeType(attrType.asNonOwning)} ${publicMemberName(attrName)}() const { return ${nonOwningPointer(attrName, attrType)}; }") } override def universalDoc(doc: DocSpec): Unit = { @@ -275,6 +324,16 @@ class CppCompiler( outHdr.puts( " */") } + override def attrInit(attr: AttrLikeSpec): Unit = { + attr.dataTypeComposite match { + case _: UserType | _: ArrayType | KaitaiStreamType => + // data type will be pointer to user type, std::vector or stream, so we need to init it + outSrc.puts(s"${privateMemberName(attr.id)} = $nullPtr;") + case _ => + // no init required for value types + } + } + override def attrDestructor(attr: AttrLikeSpec, id: Identifier): Unit = { val checkLazy = if (attr.isLazy) { Some(calculatedFlagForName(id)) @@ -311,35 +370,37 @@ class CppCompiler( def destructMember(id: Identifier, innerType: DataType, isArray: Boolean, hasRaw: Boolean, hasIO: Boolean): Unit = { if (isArray) { - // raw is std::vector*, no need to delete its contents, but we - // need to clean up the vector pointer itself - if (hasRaw) - outSrc.puts(s"delete ${privateMemberName(RawIdentifier(id))};") - - // IO is std::vector*, needs destruction of both members - // and the vector pointer itself - if (hasIO) { - val ioVar = privateMemberName(IoStorageIdentifier(RawIdentifier(id))) - destructVector(s"$kstreamName*", ioVar) - outSrc.puts(s"delete $ioVar;") - } + if (config.cppConfig.pointers == CppRuntimeConfig.RawPointers) { + // raw is std::vector*, no need to delete its contents, but we + // need to clean up the vector pointer itself + if (hasRaw) + outSrc.puts(s"delete ${privateMemberName(RawIdentifier(id))};") + + // IO is std::vector*, needs destruction of both members + // and the vector pointer itself + if (hasIO) { + val ioVar = privateMemberName(IoStorageIdentifier(RawIdentifier(id))) + destructVector(s"$kstreamName*", ioVar) + outSrc.puts(s"delete $ioVar;") + } - // main member contents - if (needsDestruction(innerType)) { - val arrVar = privateMemberName(id) + // main member contents + if (needsDestruction(innerType)) { + val arrVar = privateMemberName(id) + + // C++ specific substitution: AnyType results from generic struct + raw bytes + // so we would assume that only generic struct needs to be cleaned up + val realType = innerType match { + case AnyType => KaitaiStructType + case _ => innerType + } - // C++ specific substitution: AnyType results from generic struct + raw bytes - // so we would assume that only generic struct needs to be cleaned up - val realType = innerType match { - case AnyType => KaitaiStructType - case _ => innerType + destructVector(kaitaiType2NativeType(realType), arrVar) } - destructVector(kaitaiType2NativeType(realType), arrVar) + // main member is a std::vector of something, always needs destruction + outSrc.puts(s"delete ${privateMemberName(id)};") } - - // main member is a std::vector of something, always needs destruction - outSrc.puts(s"delete ${privateMemberName(id)};") } else { // raw is just a string, no need to cleanup => we ignore `hasRaw` @@ -347,7 +408,7 @@ class CppCompiler( if (hasIO) outSrc.puts(s"delete ${privateMemberName(IoStorageIdentifier(RawIdentifier(id)))};") - if (needsDestruction(innerType)) + if (config.cppConfig.pointers == CppRuntimeConfig.RawPointers && needsDestruction(innerType)) outSrc.puts(s"delete ${privateMemberName(id)};") } } @@ -416,7 +477,7 @@ class CppCompiler( } } - override def allocateIO(id: Identifier, rep: RepeatSpec, extraAttrs: ListBuffer[AttrSpec]): String = { + override def allocateIO(id: Identifier, rep: RepeatSpec): String = { val memberName = privateMemberName(id) val ioId = IoStorageIdentifier(id) @@ -428,18 +489,17 @@ class CppCompiler( val newStream = s"new $kstreamName($args)" - val (ioType, ioName) = rep match { + val ioName = rep match { case NoRepeat => outSrc.puts(s"${privateMemberName(ioId)} = $newStream;") - (KaitaiStreamType, privateMemberName(ioId)) + privateMemberName(ioId) case _ => val localIO = s"io_${idToStr(id)}" outSrc.puts(s"$kstreamName* $localIO = $newStream;") outSrc.puts(s"${privateMemberName(ioId)}->push_back($localIO);") - (ArrayType(KaitaiStreamType), localIO) + localIO } - Utils.addUniqueAttr(extraAttrs, AttrSpec(List(), ioId, ioType)) ioName } @@ -486,10 +546,10 @@ class CppCompiler( importListHdr.add("vector") if (needRaw) { - outSrc.puts(s"${privateMemberName(RawIdentifier(id))} = new std::vector();") - outSrc.puts(s"${privateMemberName(IoStorageIdentifier(RawIdentifier(id)))} = new std::vector<$kstreamName*>();") + outSrc.puts(s"${privateMemberName(RawIdentifier(id))} = ${newVector(CalcBytesType)};") + outSrc.puts(s"${privateMemberName(IoStorageIdentifier(RawIdentifier(id)))} = ${newVector(KaitaiStreamType)};") } - outSrc.puts(s"${privateMemberName(id)} = new std::vector<${kaitaiType2NativeType(dataType)}>();") + outSrc.puts(s"${privateMemberName(id)} = ${newVector(dataType)};") outSrc.puts("{") outSrc.inc outSrc.puts("int i = 0;") @@ -498,7 +558,7 @@ class CppCompiler( } override def handleAssignmentRepeatEos(id: Identifier, expr: String): Unit = { - outSrc.puts(s"${privateMemberName(id)}->push_back($expr);") + outSrc.puts(s"${privateMemberName(id)}->push_back(${stdMoveWrap(expr)});") } override def condRepeatEosFooter: Unit = { @@ -516,20 +576,20 @@ class CppCompiler( outSrc.puts(s"int $lenVar = ${expression(repeatExpr)};") if (needRaw) { val rawId = privateMemberName(RawIdentifier(id)) - outSrc.puts(s"$rawId = new std::vector();") + outSrc.puts(s"$rawId = ${newVector(CalcBytesType)};") outSrc.puts(s"$rawId->reserve($lenVar);") val ioId = privateMemberName(IoStorageIdentifier(RawIdentifier(id))) - outSrc.puts(s"$ioId = new std::vector<$kstreamName*>();") + outSrc.puts(s"$ioId = ${newVector(KaitaiStreamType)};") outSrc.puts(s"$ioId->reserve($lenVar);") } - outSrc.puts(s"${privateMemberName(id)} = new std::vector<${kaitaiType2NativeType(dataType)}>();") + outSrc.puts(s"${privateMemberName(id)} = ${newVector(dataType)};") outSrc.puts(s"${privateMemberName(id)}->reserve($lenVar);") outSrc.puts(s"for (int i = 0; i < $lenVar; i++) {") outSrc.inc } override def handleAssignmentRepeatExpr(id: Identifier, expr: String): Unit = { - outSrc.puts(s"${privateMemberName(id)}->push_back($expr);") + outSrc.puts(s"${privateMemberName(id)}->push_back(${stdMoveWrap(expr)});") } override def condRepeatExprFooter: Unit = { @@ -541,26 +601,41 @@ class CppCompiler( importListHdr.add("vector") if (needRaw) { - outSrc.puts(s"${privateMemberName(RawIdentifier(id))} = new std::vector();") - outSrc.puts(s"${privateMemberName(IoStorageIdentifier(RawIdentifier(id)))} = new std::vector<$kstreamName*>();") + outSrc.puts(s"${privateMemberName(RawIdentifier(id))} = ${newVector(CalcBytesType)};") + outSrc.puts(s"${privateMemberName(IoStorageIdentifier(RawIdentifier(id)))} = ${newVector(KaitaiStreamType)};") } - outSrc.puts(s"${privateMemberName(id)} = new std::vector<${kaitaiType2NativeType(dataType)}>();") + outSrc.puts(s"${privateMemberName(id)} = ${newVector(dataType)};") outSrc.puts("{") outSrc.inc outSrc.puts("int i = 0;") - outSrc.puts(s"${kaitaiType2NativeType(dataType)} ${translator.doName("_")};") + outSrc.puts(s"${kaitaiType2NativeType(dataType.asNonOwning)} ${translator.doName("_")};") outSrc.puts("do {") outSrc.inc } + private val ReStdUniquePtr = "^std::unique_ptr<(.*?)>\\((.*?)\\)$".r + override def handleAssignmentRepeatUntil(id: Identifier, expr: String, isRaw: Boolean): Unit = { val (typeDecl, tempVar) = if (isRaw) { ("std::string ", translator.doName(Identifier.ITERATOR2)) } else { ("", translator.doName(Identifier.ITERATOR)) } - outSrc.puts(s"$typeDecl$tempVar = $expr;") - outSrc.puts(s"${privateMemberName(id)}->push_back($tempVar);") + + val (wrappedTempVar, rawPtrExpr) = if (config.cppConfig.pointers == UniqueAndRawPointers) { + expr match { + case ReStdUniquePtr(cppClass, innerExpr) => + (s"std::move(std::unique_ptr<$cppClass>($tempVar))", innerExpr) + case _ => + (tempVar, expr) + } + } else { + (tempVar, expr) + } + + outSrc.puts(s"$typeDecl$tempVar = $rawPtrExpr;") + + outSrc.puts(s"${privateMemberName(id)}->push_back($wrappedTempVar);") } override def condRepeatUntilFooter(id: Identifier, io: String, dataType: DataType, needRaw: Boolean, untilExpr: expr): Unit = { @@ -576,6 +651,9 @@ class CppCompiler( outSrc.puts(s"${privateMemberName(id)} = $expr;") } + override def handleAssignmentTempVar(dataType: DataType, id: String, expr: String): Unit = + outSrc.puts(s"${kaitaiType2NativeType(dataType)} $id = $expr;") + override def parseExpr(dataType: DataType, assignType: DataType, io: String, defEndian: Option[FixedEndian]): String = { dataType match { case t: ReadableType => @@ -596,9 +674,13 @@ class CppCompiler( "" } else { val parent = t.forcedParent match { - case Some(USER_TYPE_NO_PARENT) => "0" + case Some(USER_TYPE_NO_PARENT) => nullPtr case Some(fp) => translator.translate(fp) - case None => "this" + case None => + config.cppConfig.pointers match { + case RawPointers | UniqueAndRawPointers => "this" + case SharedPointers => s"shared_from_this()" + } } val addEndian = t.classSpec.get.meta.endian match { case Some(InheritedEndian) => ", m__is_le" @@ -606,7 +688,28 @@ class CppCompiler( } s", $parent, ${privateMemberName(RootIdentifier)}$addEndian" } - s"new ${types2class(t.name)}($addParams$io$addArgs)" + config.cppConfig.pointers match { + case RawPointers => + s"new ${types2class(t.name)}($addParams$io$addArgs)" + case SharedPointers => + s"std::make_shared<${types2class(t.name)}>($addParams$io$addArgs)" + case UniqueAndRawPointers => + importListSrc.add("memory") + // C++14 + //s"std::make_unique<${types2class(t.name)}>($addParams$io$addArgs)" + s"std::unique_ptr<${types2class(t.name)}>(new ${types2class(t.name)}($addParams$io$addArgs))" + } + } + } + + def newVector(elType: DataType): String = { + val cppElType = kaitaiType2NativeType(elType) + config.cppConfig.pointers match { + case RawPointers => + s"new std::vector<$cppElType>()" + case UniqueAndRawPointers => + s"std::unique_ptr>(new std::vector<$cppElType>())" + // TODO: C++14 with std::make_unique } } @@ -622,6 +725,9 @@ class CppCompiler( expr2 } + override def userTypeDebugRead(id: String): Unit = + outSrc.puts(s"$id->_read();") + /** * Designates switch mode. If false, we're doing real switch-case for this * attribute. If true, we're doing if-based emulation. @@ -706,10 +812,10 @@ class CppCompiler( override def instanceHeader(className: List[String], instName: InstanceIdentifier, dataType: DataType, isNullable: Boolean): Unit = { ensureMode(PublicAccess) - outHdr.puts(s"${kaitaiType2NativeType(dataType)} ${publicMemberName(instName)}();") + outHdr.puts(s"${kaitaiType2NativeType(dataType.asNonOwning)} ${publicMemberName(instName)}();") outSrc.puts - outSrc.puts(s"${kaitaiType2NativeType(dataType, true)} ${types2class(className)}::${publicMemberName(instName)}() {") + outSrc.puts(s"${kaitaiType2NativeType(dataType.asNonOwning, true)} ${types2class(className)}::${publicMemberName(instName)}() {") outSrc.inc } @@ -718,16 +824,15 @@ class CppCompiler( outSrc.puts("}") } - override def instanceCheckCacheAndReturn(instName: InstanceIdentifier): Unit = { + override def instanceCheckCacheAndReturn(instName: InstanceIdentifier, dataType: DataType): Unit = { outSrc.puts(s"if (${calculatedFlagForName(instName)})") outSrc.inc - instanceReturn(instName) + instanceReturn(instName, dataType) outSrc.dec } - override def instanceReturn(instName: InstanceIdentifier): Unit = { - outSrc.puts(s"return ${privateMemberName(instName)};") - } + override def instanceReturn(instName: InstanceIdentifier, attrType: DataType): Unit = + outSrc.puts(s"return ${nonOwningPointer(instName, attrType)};") override def enumDeclaration(curClass: List[String], enumName: String, enumColl: Seq[(Long, EnumValueSpec)]): Unit = { val enumClass = types2class(List(enumName)) @@ -752,61 +857,6 @@ class CppCompiler( def value2Const(enumName: String, label: String) = (enumName + "_" + label).toUpperCase - def kaitaiType2NativeType(attrType: DataType, absolute: Boolean = false): String = { - attrType match { - case Int1Type(false) => "uint8_t" - case IntMultiType(false, Width2, _) => "uint16_t" - case IntMultiType(false, Width4, _) => "uint32_t" - case IntMultiType(false, Width8, _) => "uint64_t" - - case Int1Type(true) => "int8_t" - case IntMultiType(true, Width2, _) => "int16_t" - case IntMultiType(true, Width4, _) => "int32_t" - case IntMultiType(true, Width8, _) => "int64_t" - - case FloatMultiType(Width4, _) => "float" - case FloatMultiType(Width8, _) => "double" - - case BitsType(_) => "uint64_t" - - case _: BooleanType => "bool" - case CalcIntType => "int32_t" - case CalcFloatType => "double" - - case _: StrType => "std::string" - case _: BytesType => "std::string" - - case t: UserType => - val typeStr = types2class(if (absolute) { - t.classSpec.get.name - } else { - t.name - }) - s"$typeStr*" - - case t: EnumType => - types2class(if (absolute) { - t.enumSpec.get.name - } else { - t.name - }) - - case ArrayType(inType) => s"std::vector<${kaitaiType2NativeType(inType, absolute)}>*" - - case KaitaiStreamType => s"$kstreamName*" - case KaitaiStructType => s"$kstructName*" - - case SwitchType(on, cases) => - kaitaiType2NativeType(TypeDetector.combineTypes( - // C++ does not have a concept of AnyType, and common use case "lots of incompatible UserTypes - // for cases + 1 BytesType for else" combined would result in exactly AnyType - so we try extra - // hard to avoid that here with this pre-filtering. In C++, "else" case with raw byte array would - // be available through _raw_* attribute anyway. - cases.filterNot { case (caseExpr, caseValue) => caseExpr == SwitchType.ELSE_CONST }.values - ), absolute) - } - } - def defineName(className: String) = className.toUpperCase + "_H_" /** @@ -857,7 +907,40 @@ class CppCompiler( } } - def type2class(name: String) = name + "_t" + override def type2class(className: String): String = CppCompiler.type2class(className) + + def kaitaiType2NativeType(attrType: DataType, absolute: Boolean = false): String = + CppCompiler.kaitaiType2NativeType(config.cppConfig, attrType, absolute) + + def nullPtr: String = config.cppConfig.pointers match { + case RawPointers => "0" + case SharedPointers | UniqueAndRawPointers => "nullptr" + } + + def nonOwningPointer(attrName: Identifier, attrType: DataType): String = { + config.cppConfig.pointers match { + case RawPointers => + privateMemberName(attrName) + case UniqueAndRawPointers => + attrType match { + case st: SwitchType => + nonOwningPointer(attrName, combineSwitchType(st)) + case t: ComplexDataType => + if (t.isOwning) { + s"${privateMemberName(attrName)}.get()" + } else { + privateMemberName(attrName) + } + case _ => + privateMemberName(attrName) + } + } + } + + def stdMoveWrap(expr: String): String = config.cppConfig.pointers match { + case UniqueAndRawPointers => s"std::move($expr)" + case _ => expr + } } object CppCompiler extends LanguageCompilerStatic with StreamStructNames { @@ -869,10 +952,106 @@ object CppCompiler extends LanguageCompilerStatic with StreamStructNames { override def kstructName = "kaitai::kstruct" override def kstreamName = "kaitai::kstream" - def types2class(components: List[String]) = { - components.map { - case "kaitai_struct" => "kaitai::kstruct" - case s => s + "_t" - }.mkString("::") + def kaitaiType2NativeType(config: CppRuntimeConfig, attrType: DataType, absolute: Boolean = false): String = { + attrType match { + case Int1Type(false) => "uint8_t" + case IntMultiType(false, Width2, _) => "uint16_t" + case IntMultiType(false, Width4, _) => "uint32_t" + case IntMultiType(false, Width8, _) => "uint64_t" + + case Int1Type(true) => "int8_t" + case IntMultiType(true, Width2, _) => "int16_t" + case IntMultiType(true, Width4, _) => "int32_t" + case IntMultiType(true, Width8, _) => "int64_t" + + case FloatMultiType(Width4, _) => "float" + case FloatMultiType(Width8, _) => "double" + + case BitsType(_) => "uint64_t" + + case _: BooleanType => "bool" + case CalcIntType => "int32_t" + case CalcFloatType => "double" + + case _: StrType => "std::string" + case _: BytesType => "std::string" + + case t: UserType => + val typeStr = types2class(if (absolute) { + t.classSpec.get.name + } else { + t.name + }) + config.pointers match { + case RawPointers => s"$typeStr*" + case SharedPointers => s"std::shared_ptr<$typeStr>" + case UniqueAndRawPointers => + if (t.isOwning) s"std::unique_ptr<$typeStr>" else s"$typeStr*" + } + + case t: EnumType => + types2class(if (absolute) { + t.enumSpec.get.name + } else { + t.name + }) + + case ArrayType(inType) => config.pointers match { + case RawPointers => s"std::vector<${kaitaiType2NativeType(config, inType, absolute)}>*" + case UniqueAndRawPointers => s"std::unique_ptr>" + } + case CalcArrayType(inType) => s"std::vector<${kaitaiType2NativeType(config, inType, absolute)}>*" + + case KaitaiStreamType => s"$kstreamName*" + case KaitaiStructType => config.pointers match { + case RawPointers => s"$kstructName*" + case SharedPointers => s"std::shared_ptr<$kstructName>" + case UniqueAndRawPointers => s"std::unique_ptr<$kstructName>" + } + case CalcKaitaiStructType => config.pointers match { + case RawPointers => s"$kstructName*" + case SharedPointers => s"std::shared_ptr<$kstructName>" + case UniqueAndRawPointers => s"$kstructName*" + } + + case st: SwitchType => + kaitaiType2NativeType(config, combineSwitchType(st), absolute) + } + } + + /** + * C++ does not have a concept of AnyType, and common use case "lots of + * incompatible UserTypes for cases + 1 BytesType for else" combined would + * result in exactly AnyType - so we try extra hard to avoid that here with + * this pre-filtering. In C++, "else" case with raw byte array would + * be available through _raw_* attribute anyway. + * + * @param st switch type to combine into one overall type + * @return + */ + def combineSwitchType(st: SwitchType): DataType = { + val ct1 = TypeDetector.combineTypes( + st.cases.filterNot { + case (caseExpr, _) => caseExpr == SwitchType.ELSE_CONST + }.values + ) + if (st.isOwning) { + ct1 + } else { + ct1.asNonOwning + } + } + + def types2class(typeName: Ast.typeId) = { + typeName.names.map(type2class).mkString( + if (typeName.absolute) "::" else "", + "::", + "" + ) } + + def types2class(components: List[String]) = + components.map(type2class).mkString("::") + + def type2class(name: String) = name + "_t" } diff --git a/shared/src/main/scala/io/kaitai/struct/languages/GoCompiler.scala b/shared/src/main/scala/io/kaitai/struct/languages/GoCompiler.scala index 430ff940b..7d5d0e5eb 100644 --- a/shared/src/main/scala/io/kaitai/struct/languages/GoCompiler.scala +++ b/shared/src/main/scala/io/kaitai/struct/languages/GoCompiler.scala @@ -1,7 +1,7 @@ package io.kaitai.struct.languages -import io.kaitai.struct.datatype.{DataType, FixedEndian} import io.kaitai.struct.datatype.DataType._ +import io.kaitai.struct.datatype.{DataType, FixedEndian} import io.kaitai.struct.exprlang.Ast import io.kaitai.struct.format._ import io.kaitai.struct.languages.components._ @@ -16,8 +16,7 @@ class GoCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) with UniversalFooter with UniversalDoc with AllocateIOLocalVar - with GoReads - with FixedContentsUsingArrayByteLiteral { + with GoReads { import GoCompiler._ override val translator = new GoTranslator(out, typeProvider, importList) @@ -117,8 +116,23 @@ class GoCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) override def attrParseHybrid(leProc: () => Unit, beProc: () => Unit): Unit = ??? - override def attrFixedContentsParse(attrName: Identifier, contents: String): Unit = { - out.puts(s"${privateMemberName(attrName)} = $normalIO.ensureFixedContents($contents);") + override def attrFixedContentsParse(attrName: Identifier, contents: Array[Byte]): Unit = { + out.puts(s"${privateMemberName(attrName)}, err = $normalIO.ReadBytes(${contents.length})") + + out.puts(s"if err != nil {") + out.inc + out.puts("return err") + out.dec + out.puts("}") + + importList.add("bytes") + importList.add("errors") + val expected = translator.resToStr(translator.doByteArrayLiteral(contents)) + out.puts(s"if !bytes.Equal(${privateMemberName(attrName)}, $expected) {") + out.inc + out.puts("return errors.New(\"Unexpected fixed contents\")") + out.dec + out.puts("}") } override def attrProcess(proc: ProcessExpr, varSrc: Identifier, varDest: Identifier): Unit = { @@ -146,7 +160,7 @@ class GoCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) val ioName = idToStr(IoStorageIdentifier(varName)) val args = rep match { - case RepeatEos | RepeatExpr(_) => s"$javaName.get($javaName.size() - 1)" + case RepeatEos | RepeatExpr(_) => s"$javaName[len($javaName) - 1]" case RepeatUntil(_) => translator.specialName(Identifier.ITERATOR2) case NoRepeat => javaName } @@ -158,8 +172,8 @@ class GoCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) } override def useIO(ioEx: Ast.expr): String = { - out.puts(s"$kstreamName io = ${expression(ioEx)};") - "io" + out.puts(s"thisIo := ${expression(ioEx)}") + "thisIo" } override def pushPos(io: String): Unit = { @@ -191,10 +205,19 @@ class GoCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) override def condRepeatEosHeader(id: Identifier, io: String, dataType: DataType, needRaw: Boolean): Unit = { if (needRaw) - out.puts(s"${privateMemberName(RawIdentifier(id))} = new ArrayList();") + out.puts(s"${privateMemberName(RawIdentifier(id))} = make([][]byte, 0);") //out.puts(s"${privateMemberName(id)} = make(${kaitaiType2NativeType(ArrayType(dataType))})") - out.puts(s"for !$io.EOF() {") + out.puts(s"for {") out.inc + + val eofVar = translator.allocateLocalVar() + out.puts(s"${translator.localVarName(eofVar)}, err := this._io.EOF()") + translator.outAddErrCheck() + out.puts(s"if ${translator.localVarName(eofVar)} {") + out.inc + out.puts("break") + out.dec + out.puts("}") } override def handleAssignmentRepeatEos(id: Identifier, r: TranslatorResult): Unit = { @@ -205,7 +228,7 @@ class GoCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) override def condRepeatExprHeader(id: Identifier, io: String, dataType: DataType, needRaw: Boolean, repeatExpr: Ast.expr): Unit = { if (needRaw) - out.puts(s"${privateMemberName(RawIdentifier(id))} = new ArrayList((int) (${expression(repeatExpr)}));") + out.puts(s"${privateMemberName(RawIdentifier(id))} = make([][]byte, ${expression(repeatExpr)})") out.puts(s"${privateMemberName(id)} = make(${kaitaiType2NativeType(ArrayType(dataType))}, ${expression(repeatExpr)})") out.puts(s"for i := range ${privateMemberName(id)} {") out.inc @@ -219,30 +242,25 @@ class GoCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) override def condRepeatUntilHeader(id: Identifier, io: String, dataType: DataType, needRaw: Boolean, untilExpr: Ast.expr): Unit = { if (needRaw) - out.puts(s"${privateMemberName(RawIdentifier(id))} = new ArrayList();") - out.puts(s"${privateMemberName(id)} = new ${kaitaiType2NativeType(ArrayType(dataType))}();") - out.puts("{") - out.inc - out.puts(s"${kaitaiType2NativeType(dataType)} ${translator.specialName(Identifier.ITERATOR)};") - out.puts("do {") + out.puts(s"${privateMemberName(RawIdentifier(id))} = make([][]byte, 0);") + out.puts("for {") out.inc } override def handleAssignmentRepeatUntil(id: Identifier, r: TranslatorResult, isRaw: Boolean): Unit = { val expr = translator.resToStr(r) - val (typeDecl, tempVar) = if (isRaw) { - ("byte[] ", translator.specialName(Identifier.ITERATOR2)) - } else { - ("", translator.specialName(Identifier.ITERATOR)) - } - out.puts(s"$typeDecl$tempVar = $expr;") - out.puts(s"${privateMemberName(id)}.add($tempVar);") + val tempVar = translator.specialName(Identifier.ITERATOR) + out.puts(s"$tempVar := $expr") + out.puts(s"${privateMemberName(id)} = append(${privateMemberName(id)}, $tempVar)") } override def condRepeatUntilFooter(id: Identifier, io: String, dataType: DataType, needRaw: Boolean, untilExpr: Ast.expr): Unit = { typeProvider._currentIteratorType = Some(dataType) + out.puts(s"if ${expression(untilExpr)} {") + out.inc + out.puts("break") out.dec - out.puts(s"} while (!(${expression(untilExpr)}));") + out.puts("}") out.dec out.puts("}") } @@ -263,7 +281,7 @@ class GoCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) case BytesTerminatedType(terminator, include, consume, eosError, _) => s"$io.ReadBytesTerm($terminator, $include, $consume, $eosError)" case BitsType1 => - s"$io.ReadBitsInt(1) != 0" + s"$io.ReadBitsInt(1)" case BitsType(width: Int) => s"$io.ReadBitsInt($width)" case t: UserType => @@ -297,18 +315,7 @@ class GoCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) out.puts(s"switch (${expression(on)}) {") override def switchCaseStart(condition: Ast.expr): Unit = { - // Java is very specific about what can be used as "condition" in "case - // condition:". - val condStr = condition match { - case Ast.expr.EnumByLabel(enumName, enumVal) => - // If switch is over a enum, only literal enum values are supported, - // and they must be written as "MEMBER", not "SomeEnum.MEMBER". - value2Const(enumVal.name) - case _ => - expression(condition) - } - - out.puts(s"case $condStr: {") + out.puts(s"case ${expression(condition)}: {") out.inc } @@ -351,14 +358,14 @@ class GoCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) out.puts(s"${privateMemberName(instName)} = $converted") } - override def instanceCheckCacheAndReturn(instName: InstanceIdentifier): Unit = { + override def instanceCheckCacheAndReturn(instName: InstanceIdentifier, dataType: DataType): Unit = { out.puts(s"if (this.${calculatedFlagForName(instName)}) {") out.inc - instanceReturn(instName) + instanceReturn(instName, dataType) universalFooter } - override def instanceReturn(instName: InstanceIdentifier): Unit = { + override def instanceReturn(instName: InstanceIdentifier, attrType: DataType): Unit = { out.puts(s"return ${privateMemberName(instName)}, nil") } @@ -366,42 +373,22 @@ class GoCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) out.puts(s"this.${calculatedFlagForName(instName)} = true") override def enumDeclaration(curClass: List[String], enumName: String, enumColl: Seq[(Long, EnumValueSpec)]): Unit = { - val enumClass = type2class(enumName) + val fullEnumName: List[String] = curClass ++ List(enumName) + val fullEnumNameStr = types2class(fullEnumName) out.puts - out.puts(s"public enum $enumClass {") + out.puts(s"type $fullEnumNameStr int") + out.puts("const (") out.inc - if (enumColl.size > 1) { - enumColl.dropRight(1).foreach { case (id, label) => - out.puts(s"${value2Const(label.name)}($id),") - } - } - enumColl.last match { - case (id, label) => - out.puts(s"${value2Const(label.name)}($id);") + enumColl.foreach { case (id, label) => + out.puts(s"${enumToStr(fullEnumName, label.name)} $fullEnumNameStr = $id") } - out.puts - out.puts("private final long id;") - out.puts(s"$enumClass(long id) { this.id = id; }") - out.puts("public long id() { return id; }") - out.puts(s"private static final Map byId = new HashMap(${enumColl.size});") - out.puts("static {") - out.inc - out.puts(s"for ($enumClass e : $enumClass.values())") - out.inc - out.puts(s"byId.put(e.id(), e);") - out.dec out.dec - out.puts("}") - out.puts(s"public static $enumClass byId(long id) { return byId.get(id); }") - out.dec - out.puts("}") + out.puts(")") } - def value2Const(s: String) = s.toUpperCase - def idToStr(id: Identifier): String = { id match { case SpecialIdentifier(name) => name @@ -473,22 +460,31 @@ object GoCompiler extends LanguageCompilerStatic case AnyType => "interface{}" case KaitaiStreamType => "*" + kstreamName - case KaitaiStructType => kstructName + case KaitaiStructType | CalcKaitaiStructType => kstructName case t: UserType => "*" + types2class(t.classSpec match { case Some(cs) => cs.name case None => t.name }) - case EnumType(name, _) => types2class(name) + case t: EnumType => types2class(t.enumSpec.get.name) case ArrayType(inType) => s"[]${kaitaiType2NativeType(inType)}" - case SwitchType(_, cases) => kaitaiType2NativeType(TypeDetector.combineTypes(cases.values)) + case st: SwitchType => kaitaiType2NativeType(st.combinedType) } } def types2class(names: List[String]) = names.map(x => type2class(x)).mkString("_") + def enumToStr(enumTypeAbs: List[String]): String = { + val enumName = enumTypeAbs.last + val enumClass: List[String] = enumTypeAbs.dropRight(1) + enumToStr(enumClass, enumName) + } + + def enumToStr(typeName: List[String], enumName: String): String = + types2class(typeName) + "__" + type2class(enumName) + override def kstreamName: String = "kaitai.Stream" override def kstructName: String = "interface{}" } diff --git a/shared/src/main/scala/io/kaitai/struct/languages/JavaCompiler.scala b/shared/src/main/scala/io/kaitai/struct/languages/JavaCompiler.scala index b973b15f8..c2ee08faf 100644 --- a/shared/src/main/scala/io/kaitai/struct/languages/JavaCompiler.scala +++ b/shared/src/main/scala/io/kaitai/struct/languages/JavaCompiler.scala @@ -72,7 +72,7 @@ class JavaCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) out.puts(s"public ${staticStr}class ${type2class(name)} extends $kstructName {") out.inc - if (debug) { + if (config.readStoresPos) { out.puts("public Map _attrStart = new HashMap();") out.puts("public Map _attrEnd = new HashMap();") out.puts("public Map> _arrStart = new HashMap>();") @@ -179,7 +179,7 @@ class JavaCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) } override def readHeader(endian: Option[FixedEndian], isEmpty: Boolean) = { - val readAccessAndType = if (debug) { + val readAccessAndType = if (!config.autoRead) { "public" } else { "private" @@ -362,7 +362,7 @@ class JavaCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) } override def condRepeatEosFooter: Unit = { - out.puts("i++;") + out.puts("i = i + 1;") out.dec out.puts("}") out.dec @@ -371,8 +371,8 @@ class JavaCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) override def condRepeatExprHeader(id: Identifier, io: String, dataType: DataType, needRaw: Boolean, repeatExpr: expr): Unit = { if (needRaw) - out.puts(s"${privateMemberName(RawIdentifier(id))} = new ArrayList(Long.valueOf(${expression(repeatExpr)}).intValue());") - out.puts(s"${idToStr(id)} = new ${kaitaiType2JavaType(ArrayType(dataType))}(Long.valueOf(${expression(repeatExpr)}).intValue());") + out.puts(s"${privateMemberName(RawIdentifier(id))} = new ArrayList(((Number) (${expression(repeatExpr)})).intValue());") + out.puts(s"${idToStr(id)} = new ${kaitaiType2JavaType(ArrayType(dataType))}(((Number) (${expression(repeatExpr)})).intValue());") out.puts(s"for (int i = 0; i < ${expression(repeatExpr)}; i++) {") out.inc @@ -389,7 +389,7 @@ class JavaCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) out.puts(s"${privateMemberName(id)} = new ${kaitaiType2JavaType(ArrayType(dataType))}();") out.puts("{") out.inc - out.puts(s"${kaitaiType2JavaType(dataType)} ${translator.doName("_")};") + out.puts(s"${kaitaiType2JavaType(dataType)} ${translator.doName(Identifier.ITERATOR)} = null;") out.puts("int i = 0;") out.puts("do {") out.inc @@ -409,7 +409,7 @@ class JavaCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) override def condRepeatUntilFooter(id: Identifier, io: String, dataType: DataType, needRaw: Boolean, untilExpr: expr): Unit = { typeProvider._currentIteratorType = Some(dataType) - out.puts("i++;") + out.puts("i = i + 1;") out.dec out.puts(s"} while (!(${expression(untilExpr)}));") out.dec @@ -530,10 +530,10 @@ class JavaCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) // Java is very specific about what can be used as "condition" in "case // condition:". val condStr = condition match { - case Ast.expr.EnumByLabel(_, enumVal) => + case enumByLabel: Ast.expr.EnumByLabel => // If switch is over a enum, only literal enum values are supported, // and they must be written as "MEMBER", not "SomeEnum.MEMBER". - value2Const(enumVal.name) + value2Const(enumByLabel.label.name) case _ => expression(condition) } @@ -579,14 +579,14 @@ class JavaCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) out.inc } - override def instanceCheckCacheAndReturn(instName: InstanceIdentifier): Unit = { + override def instanceCheckCacheAndReturn(instName: InstanceIdentifier, dataType: DataType): Unit = { out.puts(s"if (${privateMemberName(instName)} != null)") out.inc - instanceReturn(instName) + instanceReturn(instName, dataType) out.dec } - override def instanceReturn(instName: InstanceIdentifier): Unit = { + override def instanceReturn(instName: InstanceIdentifier, attrType: DataType): Unit = { out.puts(s"return ${privateMemberName(instName)};") } @@ -719,7 +719,7 @@ object JavaCompiler extends LanguageCompilerStatic case AnyType => "Object" case KaitaiStreamType => kstreamName - case KaitaiStructType => kstructName + case KaitaiStructType | CalcKaitaiStructType => kstructName case t: UserType => types2class(t.name) case EnumType(name, _) => types2class(name) @@ -763,14 +763,14 @@ object JavaCompiler extends LanguageCompilerStatic case AnyType => "Object" case KaitaiStreamType => kstreamName - case KaitaiStructType => kstructName + case KaitaiStructType | CalcKaitaiStructType => kstructName - case t: UserType => type2class(t.name.last) + case t: UserType => types2class(t.name) case EnumType(name, _) => types2class(name) case ArrayType(inType) => s"ArrayList<${kaitaiType2JavaTypeBoxed(inType)}>" - case SwitchType(_, cases) => kaitaiType2JavaTypeBoxed(TypeDetector.combineTypes(cases.values)) + case st: SwitchType => kaitaiType2JavaTypeBoxed(st.combinedType) } } diff --git a/shared/src/main/scala/io/kaitai/struct/languages/JavaScriptCompiler.scala b/shared/src/main/scala/io/kaitai/struct/languages/JavaScriptCompiler.scala index f3d5fb3f7..bcdef1942 100644 --- a/shared/src/main/scala/io/kaitai/struct/languages/JavaScriptCompiler.scala +++ b/shared/src/main/scala/io/kaitai/struct/languages/JavaScriptCompiler.scala @@ -103,7 +103,7 @@ class JavaScriptCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) // Store parameters passed to us params.foreach((p) => handleAssignmentSimple(p.id, paramName(p.id))) - if (debug) { + if (config.readStoresPos) { out.puts("this._debug = {};") } out.puts @@ -295,7 +295,7 @@ class JavaScriptCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) if (needRaw) out.puts(s"${privateMemberName(RawIdentifier(id))} = [];") out.puts(s"${privateMemberName(id)} = [];") - if (debug) + if (config.readStoresPos) out.puts(s"this._debug.${idToStr(id)}.arr = [];") out.puts("var i = 0;") out.puts(s"while (!$io.isEof()) {") @@ -316,7 +316,7 @@ class JavaScriptCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) if (needRaw) out.puts(s"${privateMemberName(RawIdentifier(id))} = new Array(${expression(repeatExpr)});") out.puts(s"${privateMemberName(id)} = new Array(${expression(repeatExpr)});") - if (debug) + if (config.readStoresPos) out.puts(s"this._debug.${idToStr(id)}.arr = new Array(${expression(repeatExpr)});") out.puts(s"for (var i = 0; i < ${expression(repeatExpr)}; i++) {") out.inc @@ -335,7 +335,7 @@ class JavaScriptCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) if (needRaw) out.puts(s"${privateMemberName(RawIdentifier(id))} = []") out.puts(s"${privateMemberName(id)} = []") - if (debug) + if (config.readStoresPos) out.puts(s"this._debug.${idToStr(id)}.arr = [];") out.puts("var i = 0;") out.puts("do {") @@ -388,7 +388,7 @@ class JavaScriptCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) case _ => "" } val addParams = Utils.join(t.args.map((a) => translator.translate(a)), ", ", ", ", "") - s"new ${type2class(t.name.last)}($io, $parent, $root$addEndian$addParams)" + s"new ${types2class(t.name)}($io, $parent, $root$addEndian$addParams)" } } @@ -503,14 +503,14 @@ class JavaScriptCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) out.puts("});") } - override def instanceCheckCacheAndReturn(instName: InstanceIdentifier): Unit = { + override def instanceCheckCacheAndReturn(instName: InstanceIdentifier, dataType: DataType): Unit = { out.puts(s"if (${privateMemberName(instName)} !== undefined)") out.inc - instanceReturn(instName) + instanceReturn(instName, dataType) out.dec } - override def instanceReturn(instName: InstanceIdentifier): Unit = { + override def instanceReturn(instName: InstanceIdentifier, attrType: DataType): Unit = { out.puts(s"return ${privateMemberName(instName)};") } @@ -597,6 +597,5 @@ object JavaScriptCompiler extends LanguageCompilerStatic // FIXME: probably KaitaiStruct will emerge some day in JavaScript runtime, but for now it is unused override def kstructName: String = ??? - def types2class(types: List[String]): String = - types.map(JavaScriptCompiler.type2class).mkString(".") + def types2class(types: List[String]): String = types.map(type2class).mkString(".") } diff --git a/shared/src/main/scala/io/kaitai/struct/languages/LuaCompiler.scala b/shared/src/main/scala/io/kaitai/struct/languages/LuaCompiler.scala index ebba1a88c..a671b90cf 100644 --- a/shared/src/main/scala/io/kaitai/struct/languages/LuaCompiler.scala +++ b/shared/src/main/scala/io/kaitai/struct/languages/LuaCompiler.scala @@ -263,15 +263,15 @@ class LuaCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) out.puts("end") out.puts } - override def instanceCheckCacheAndReturn(instName: InstanceIdentifier): Unit = { + override def instanceCheckCacheAndReturn(instName: InstanceIdentifier, dataType: DataType): Unit = { out.puts(s"if self.${idToStr(instName)} ~= nil then") out.inc - instanceReturn(instName) + instanceReturn(instName, dataType) out.dec out.puts("end") out.puts } - override def instanceReturn(instName: InstanceIdentifier): Unit = + override def instanceReturn(instName: InstanceIdentifier, attrType: DataType): Unit = out.puts(s"return ${privateMemberName(instName)}") override def enumDeclaration(curClass: List[String], enumName: String, enumColl: Seq[(Long, EnumValueSpec)]): Unit = { @@ -357,6 +357,9 @@ class LuaCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) expr2 } + override def userTypeDebugRead(id: String): Unit = + out.puts(s"$id:_read()") + override def switchStart(id: Identifier, on: Ast.expr): Unit = out.puts(s"local _on = ${expression(on)}") override def switchCaseFirstStart(condition: Ast.expr): Unit = { diff --git a/shared/src/main/scala/io/kaitai/struct/languages/PHPCompiler.scala b/shared/src/main/scala/io/kaitai/struct/languages/PHPCompiler.scala index 2bef260ad..111c718e0 100644 --- a/shared/src/main/scala/io/kaitai/struct/languages/PHPCompiler.scala +++ b/shared/src/main/scala/io/kaitai/struct/languages/PHPCompiler.scala @@ -133,8 +133,10 @@ class PHPCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) case Some(e) => s"${e.toSuffix.toUpperCase}" case None => "" } + val access = if (config.autoRead) "private" else "public" + out.puts - out.puts(s"private function _read$suffix() {") + out.puts(s"$access function _read$suffix() {") out.inc } @@ -304,6 +306,9 @@ class PHPCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) out.puts(s"${privateMemberName(id)} = $expr;") } + override def handleAssignmentTempVar(dataType: DataType, id: String, expr: String): Unit = + out.puts(s"$id = $expr;") + override def parseExpr(dataType: DataType, assignType: DataType, io: String, defEndian: Option[FixedEndian]): String = { dataType match { case t: ReadableType => @@ -350,6 +355,9 @@ class PHPCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) expr2 } + override def userTypeDebugRead(id: String): Unit = + out.puts(s"$id->_read();") + override def switchStart(id: Identifier, on: Ast.expr): Unit = { val onType = translator.detectType(on) @@ -379,14 +387,14 @@ class PHPCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) out.inc } - override def instanceCheckCacheAndReturn(instName: InstanceIdentifier): Unit = { + override def instanceCheckCacheAndReturn(instName: InstanceIdentifier, dataType: DataType): Unit = { out.puts(s"if (${privateMemberName(instName)} !== null)") out.inc - instanceReturn(instName) + instanceReturn(instName, dataType) out.dec } - override def instanceReturn(instName: InstanceIdentifier): Unit = { + override def instanceReturn(instName: InstanceIdentifier, attrType: DataType): Unit = { out.puts(s"return ${privateMemberName(instName)};") } @@ -450,7 +458,7 @@ class PHPCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) case ArrayType(_) => "array" - case KaitaiStructType => kstructName + case KaitaiStructType | CalcKaitaiStructType => kstructName case KaitaiStreamType => kstreamName } } diff --git a/shared/src/main/scala/io/kaitai/struct/languages/PerlCompiler.scala b/shared/src/main/scala/io/kaitai/struct/languages/PerlCompiler.scala index fdf020569..c6a24915b 100644 --- a/shared/src/main/scala/io/kaitai/struct/languages/PerlCompiler.scala +++ b/shared/src/main/scala/io/kaitai/struct/languages/PerlCompiler.scala @@ -363,11 +363,11 @@ class PerlCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) out.puts("my ($self) = @_;") } - override def instanceCheckCacheAndReturn(instName: InstanceIdentifier): Unit = { + override def instanceCheckCacheAndReturn(instName: InstanceIdentifier, dataType: DataType): Unit = { out.puts(s"return ${privateMemberName(instName)} if (${privateMemberName(instName)});") } - override def instanceReturn(instName: InstanceIdentifier): Unit = { + override def instanceReturn(instName: InstanceIdentifier, attrType: DataType): Unit = { out.puts(s"return ${privateMemberName(instName)};") } diff --git a/shared/src/main/scala/io/kaitai/struct/languages/PythonCompiler.scala b/shared/src/main/scala/io/kaitai/struct/languages/PythonCompiler.scala index 6026981fe..4ca11c961 100644 --- a/shared/src/main/scala/io/kaitai/struct/languages/PythonCompiler.scala +++ b/shared/src/main/scala/io/kaitai/struct/languages/PythonCompiler.scala @@ -66,12 +66,13 @@ class PythonCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) override def opaqueClassDeclaration(classSpec: ClassSpec): Unit = { val name = classSpec.name.head - val prefix = config.pythonPackage match { - case "" => "" - case "." => "." - case pkg => s"$pkg." - } - out.puts(s"from $prefix$name import ${type2class(name)}") + out.puts( + if (config.pythonPackage.nonEmpty) { + s"from ${config.pythonPackage} import $name" + } else { + s"import $name" + } + ) } override def classHeader(name: String): Unit = { @@ -94,6 +95,11 @@ class PythonCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) // Store parameters passed to us params.foreach((p) => handleAssignmentSimple(p.id, paramName(p.id))) + + if (config.readStoresPos) { + importList.add("import collections") + out.puts("self._debug = collections.defaultdict(dict)") + } } override def runRead(): Unit = { @@ -248,6 +254,40 @@ class PythonCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) override def alignToByte(io: String): Unit = out.puts(s"$io.align_to_byte()") + override def attrDebugStart(attrId: Identifier, attrType: DataType, ios: Option[String], rep: RepeatSpec): Unit = { + ios.foreach { (io) => + val name = attrId match { + case _: RawIdentifier | _: SpecialIdentifier => return + case _ => idToStr(attrId) + } + rep match { + case NoRepeat => + out.puts(s"self._debug['$name']['start'] = $io.pos()") + case _: RepeatExpr | RepeatEos | _: RepeatUntil => + out.puts(s"if not 'arr' in self._debug['$name']:") + out.inc + out.puts(s"self._debug['$name']['arr'] = []") + out.dec + out.puts(s"self._debug['$name']['arr'].append({'start': $io.pos()})") + } + } + } + + override def attrDebugEnd(attrId: Identifier, attrType: DataType, io: String, rep: RepeatSpec): Unit = { + val name = attrId match { + case _: RawIdentifier | _: SpecialIdentifier => return + case _ => idToStr(attrId) + } + rep match { + case NoRepeat => + out.puts(s"self._debug['$name']['end'] = $io.pos()") + case _: RepeatExpr => + out.puts(s"self._debug['$name']['arr'][i]['end'] = $io.pos()") + case RepeatEos | _: RepeatUntil => + out.puts(s"self._debug['$name']['arr'][len(${privateMemberName(attrId)}) - 1]['end'] = $io.pos()") + } + } + override def condIfHeader(expr: Ast.expr): Unit = { out.puts(s"if ${expression(expr)}:") out.inc @@ -306,6 +346,9 @@ class PythonCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) override def handleAssignmentSimple(id: Identifier, expr: String): Unit = out.puts(s"${privateMemberName(id)} = $expr") + override def handleAssignmentTempVar(dataType: DataType, id: String, expr: String): Unit = + out.puts(s"$id = $expr") + override def parseExpr(dataType: DataType, assignType: DataType, io: String, defEndian: Option[FixedEndian]): String = { dataType match { case t: ReadableType => @@ -335,7 +378,7 @@ class PythonCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) } s", $parent, self._root$addEndian" } - s"${types2class(t.classSpec.get.name)}($addParams$io$addArgs)" + s"${userType2class(t)}($addParams$io$addArgs)" } } @@ -351,6 +394,9 @@ class PythonCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) expr2 } + override def userTypeDebugRead(id: String): Unit = + out.puts(s"$id._read()") + override def switchStart(id: Identifier, on: Ast.expr): Unit = { out.puts(s"_on = ${expression(on)}") } @@ -381,15 +427,15 @@ class PythonCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) out.inc } - override def instanceCheckCacheAndReturn(instName: InstanceIdentifier): Unit = { + override def instanceCheckCacheAndReturn(instName: InstanceIdentifier, dataType: DataType): Unit = { out.puts(s"if hasattr(self, '${idToStr(instName)}'):") out.inc - instanceReturn(instName) + instanceReturn(instName, dataType) out.dec out.puts } - override def instanceReturn(instName: InstanceIdentifier): Unit = { + override def instanceReturn(instName: InstanceIdentifier, attrType: DataType): Unit = { // not very efficient, probably should be some other way to do that, but for now it will do: // workaround to avoid Python generating an "AttributeError: instance has no attribute" out.puts(s"return ${privateMemberName(instName)} if hasattr(self, '${idToStr(instName)}') else None") @@ -405,6 +451,11 @@ class PythonCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) out.dec } + override def debugClassSequence(seq: List[AttrSpec]) = { + val seqStr = seq.map((attr) => "\"" + idToStr(attr.id) + "\"").mkString(", ") + out.puts(s"SEQ_FIELDS = [$seqStr]") + } + def bool2Py(b: Boolean): String = if (b) { "True" } else { "False" } def idToStr(id: Identifier): String = { @@ -429,6 +480,17 @@ class PythonCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) } override def localTemporaryName(id: Identifier): String = s"_t_${idToStr(id)}" + + def userType2class(t: UserType): String = { + val name = t.classSpec.get.name + val firstName = name.head + val prefix = if (t.isOpaque && firstName != translator.provider.nowClass.name.head) { + s"$firstName." + } else { + "" + } + s"$prefix${types2class(name)}" + } } object PythonCompiler extends LanguageCompilerStatic diff --git a/shared/src/main/scala/io/kaitai/struct/languages/RubyCompiler.scala b/shared/src/main/scala/io/kaitai/struct/languages/RubyCompiler.scala index 1e588698a..516cc9268 100644 --- a/shared/src/main/scala/io/kaitai/struct/languages/RubyCompiler.scala +++ b/shared/src/main/scala/io/kaitai/struct/languages/RubyCompiler.scala @@ -64,7 +64,7 @@ class RubyCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) override def classHeader(name: String): Unit = { out.puts(s"class ${type2class(name)} < $kstructName") out.inc - if (debug) + if (config.readStoresPos) out.puts("attr_reader :_debug") } @@ -88,7 +88,7 @@ class RubyCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) // Store parameters passed to us params.foreach((p) => handleAssignmentSimple(p.id, paramName(p.id))) - if (debug) { + if (config.readStoresPos) { out.puts("@_debug = {}") } } @@ -308,6 +308,7 @@ class RubyCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) if (needRaw) out.puts(s"${privateMemberName(RawIdentifier(id))} = []") out.puts(s"${privateMemberName(id)} = []") + out.puts(s"${translator.doName(Identifier.ITERATOR)} = nil") out.puts("i = 0") out.puts("begin") out.inc @@ -361,7 +362,7 @@ class RubyCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) } s", $parent, @_root$addEndian" } - s"${type2class(t.name.last)}.new($io$addArgs$addParams)" + s"${types2class(t.name)}.new($io$addArgs$addParams)" } } @@ -404,11 +405,11 @@ class RubyCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) out.inc } - override def instanceCheckCacheAndReturn(instName: InstanceIdentifier): Unit = { + override def instanceCheckCacheAndReturn(instName: InstanceIdentifier, dataType: DataType): Unit = { out.puts(s"return ${privateMemberName(instName)} unless ${privateMemberName(instName)}.nil?") } - override def instanceReturn(instName: InstanceIdentifier): Unit = { + override def instanceReturn(instName: InstanceIdentifier, attrType: DataType): Unit = { out.puts(privateMemberName(instName)) } @@ -452,6 +453,8 @@ class RubyCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) override def publicMemberName(id: Identifier): String = idToStr(id) override def localTemporaryName(id: Identifier): String = s"_t_${idToStr(id)}" + + def types2class(names: List[String]) = names.map(type2class).mkString("::") } object RubyCompiler extends LanguageCompilerStatic diff --git a/shared/src/main/scala/io/kaitai/struct/languages/RustCompiler.scala b/shared/src/main/scala/io/kaitai/struct/languages/RustCompiler.scala new file mode 100644 index 000000000..38886e3b4 --- /dev/null +++ b/shared/src/main/scala/io/kaitai/struct/languages/RustCompiler.scala @@ -0,0 +1,605 @@ +package io.kaitai.struct.languages + +import io.kaitai.struct.{ClassTypeProvider, RuntimeConfig, Utils, _} +import io.kaitai.struct.datatype.DataType._ +import io.kaitai.struct.datatype.{CalcEndian, DataType, FixedEndian, InheritedEndian} +import io.kaitai.struct.exprlang.Ast +import io.kaitai.struct.format.{NoRepeat, RepeatEos, RepeatExpr, RepeatSpec, _} +import io.kaitai.struct.languages.components._ +import io.kaitai.struct.translators.{RustTranslator, TypeDetector} + +class RustCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) + extends LanguageCompiler(typeProvider, config) + with ObjectOrientedLanguage + with UpperCamelCaseClasses + with SingleOutputFile + with AllocateIOLocalVar + with UniversalFooter + with UniversalDoc + with FixedContentsUsingArrayByteLiteral + with EveryReadIsExpression { + + import RustCompiler._ + + override def innerClasses = false + + override def innerEnums = false + + override val translator: RustTranslator = new RustTranslator(typeProvider, config) + + override def universalFooter: Unit = { + out.dec + out.puts("}") + } + + override def outImports(topClass: ClassSpec) = + importList.toList.map((x) => s"use $x;").mkString("", "\n", "\n") + + override def indent: String = " " + override def outFileName(topClassName: String): String = s"$topClassName.rs" + + override def fileHeader(topClassName: String): Unit = { + outHeader.puts(s"// $headerComment") + outHeader.puts + + importList.add("std::option::Option") + importList.add("std::boxed::Box") + importList.add("std::io::Result") + importList.add("std::io::Cursor") + importList.add("std::vec::Vec") + importList.add("std::default::Default") + importList.add("kaitai_struct::KaitaiStream") + importList.add("kaitai_struct::KaitaiStruct") + + out.puts + } + + override def opaqueClassDeclaration(classSpec: ClassSpec): Unit = { + val name = type2class(classSpec.name.last) + val pkg = type2classAbs(classSpec.name) + + importList.add(s"$pkg::$name") + } + + override def classHeader(name: List[String]): Unit = + classHeader(name, Some(kstructName)) + + def classHeader(name: List[String], parentClass: Option[String]): Unit = { + out.puts("#[derive(Default)]") + out.puts(s"pub struct ${type2class(name)} {") + } + + override def classFooter(name: List[String]): Unit = universalFooter + + override def classConstructorHeader(name: List[String], parentType: DataType, rootClassName: List[String], isHybrid: Boolean, params: List[ParamDefSpec]): Unit = { + out.puts("}") + out.puts + + out.puts(s"impl KaitaiStruct for ${type2class(name)} {") + out.inc + + // Parameter names + val pIo = paramName(IoIdentifier) + val pParent = paramName(ParentIdentifier) + val pRoot = paramName(RootIdentifier) + + // Types + val tIo = kstreamName + val tParent = kaitaiType2NativeType(parentType) + + out.puts(s"fn new(stream: &mut S,") + out.puts(s" _parent: &Option>,") + out.puts(s" _root: &Option>)") + out.puts(s" -> Result") + out.inc + out.puts(s"where Self: Sized {") + + out.puts(s"let mut s: Self = Default::default();") + out.puts + + out.puts(s"s.stream = stream;") + + out.puts(s"s.read(stream, _parent, _root)?;") + out.puts + + out.puts("Ok(s)") + out.dec + out.puts("}") + out.puts + } + + override def runRead(): Unit = { + + } + + override def runReadCalc(): Unit = { + + } + + override def readHeader(endian: Option[FixedEndian], isEmpty: Boolean) = { + out.puts + out.puts(s"fn read(&mut self,") + out.puts(s" stream: &mut S,") + out.puts(s" _parent: &Option>,") + out.puts(s" _root: &Option>)") + out.puts(s" -> Result<()>") + out.inc + out.puts(s"where Self: Sized {") + } + + override def readFooter(): Unit = { + out.puts + out.puts("Ok(())") + out.dec + out.puts("}") + } + + override def attributeDeclaration(attrName: Identifier, attrType: DataType, isNullable: Boolean): Unit = { + attrName match { + case ParentIdentifier | RootIdentifier | IoIdentifier => + // just ignore it for now + case IoIdentifier => + out.puts(s" stream: ${kaitaiType2NativeType(attrType)},") + case _ => + out.puts(s" pub ${idToStr(attrName)}: ${kaitaiType2NativeType(attrType)},") + } + } + + override def attributeReader(attrName: Identifier, attrType: DataType, isNullable: Boolean): Unit = { + + } + + override def universalDoc(doc: DocSpec): Unit = { + if (doc.summary.isDefined) { + out.puts + out.puts("/*") + doc.summary.foreach((summary) => out.putsLines(" * ", summary)) + out.puts(" */") + } + } + + override def attrParseHybrid(leProc: () => Unit, beProc: () => Unit): Unit = { + out.puts("if ($this->_m__is_le) {") + out.inc + leProc() + out.dec + out.puts("} else {") + out.inc + beProc() + out.dec + out.puts("}") + } + + override def attrFixedContentsParse(attrName: Identifier, contents: String): Unit = + out.puts(s"${privateMemberName(attrName)} = $normalIO.ensureFixedContents($contents);") + + override def attrProcess(proc: ProcessExpr, varSrc: Identifier, varDest: Identifier): Unit = { + val srcName = privateMemberName(varSrc) + val destName = privateMemberName(varDest) + + proc match { + case ProcessXor(xorValue) => + val procName = translator.detectType(xorValue) match { + case _: IntType => "processXorOne" + case _: BytesType => "processXorMany" + } + out.puts(s"$destName = $kstreamName::$procName($srcName, ${expression(xorValue)});") + case ProcessZlib => + out.puts(s"$destName = $kstreamName::processZlib($srcName);") + case ProcessRotate(isLeft, rotValue) => + val expr = if (isLeft) { + expression(rotValue) + } else { + s"8 - (${expression(rotValue)})" + } + out.puts(s"$destName = $kstreamName::processRotateLeft($srcName, $expr, 1);") + case ProcessCustom(name, args) => + val procClass = if (name.length == 1) { + val onlyName = name.head + val className = type2class(onlyName) + importList.add(s"$onlyName::$className") + className + } else { + val pkgName = type2classAbs(name.init) + val className = type2class(name.last) + importList.add(s"$pkgName::$className") + s"$pkgName::$className" + } + + out.puts(s"let _process = $procClass::new(${args.map(expression).mkString(", ")});") + out.puts(s"$destName = _process.decode($srcName);") + } + } + + override def allocateIO(id: Identifier, rep: RepeatSpec): String = { + val memberName = privateMemberName(id) + + val args = rep match { + case RepeatEos | RepeatExpr(_) => s"$memberName.last()" + case RepeatUntil(_) => translator.doLocalName(Identifier.ITERATOR2) + case NoRepeat => memberName + } + + out.puts(s"let mut io = Cursor::new($args);") + "io" + } + + override def useIO(ioEx: Ast.expr): String = { + out.puts(s"let mut io = ${expression(ioEx)};") + "io" + } + + override def pushPos(io: String): Unit = + out.puts(s"let _pos = $io.pos();") + + override def seek(io: String, pos: Ast.expr): Unit = + out.puts(s"$io.seek(${expression(pos)});") + + override def popPos(io: String): Unit = + out.puts(s"$io.seek(_pos);") + + override def alignToByte(io: String): Unit = + out.puts(s"$io.alignToByte();") + + override def condIfHeader(expr: Ast.expr): Unit = { + out.puts(s"if ${expression(expr)} {") + out.inc + } + + override def condRepeatEosHeader(id: Identifier, io: String, dataType: DataType, needRaw: Boolean): Unit = { + if (needRaw) + out.puts(s"${privateMemberName(RawIdentifier(id))} = [];") + out.puts(s"${privateMemberName(id)} = [];") + out.puts(s"while !$io.isEof() {") + out.inc + } + + override def handleAssignmentRepeatEos(id: Identifier, expr: String): Unit = { + out.puts(s"${privateMemberName(id)}.push($expr);") + } + + override def condRepeatEosFooter: Unit = { + super.condRepeatEosFooter + } + + override def condRepeatExprHeader(id: Identifier, io: String, dataType: DataType, needRaw: Boolean, repeatExpr: Ast.expr): Unit = { + if (needRaw) + out.puts(s"${privateMemberName(RawIdentifier(id))} = vec!();") + out.puts(s"${privateMemberName(id)} = vec!();") + out.puts(s"for i in 0..${expression(repeatExpr)} {") + out.inc + } + + override def handleAssignmentRepeatExpr(id: Identifier, expr: String): Unit = { + out.puts(s"${privateMemberName(id)}.push($expr);") + } + + override def condRepeatUntilHeader(id: Identifier, io: String, dataType: DataType, needRaw: Boolean, untilExpr: Ast.expr): Unit = { + if (needRaw) + out.puts(s"${privateMemberName(RawIdentifier(id))} = vec!();") + out.puts(s"${privateMemberName(id)} = vec!();") + out.puts("while {") + out.inc + } + + override def handleAssignmentRepeatUntil(id: Identifier, expr: String, isRaw: Boolean): Unit = { + val tempVar = if (isRaw) { + translator.doLocalName(Identifier.ITERATOR2) + } else { + translator.doLocalName(Identifier.ITERATOR) + } + out.puts(s"let $tempVar = $expr;") + out.puts(s"${privateMemberName(id)}.append($expr);") + } + + override def condRepeatUntilFooter(id: Identifier, io: String, dataType: DataType, needRaw: Boolean, untilExpr: Ast.expr): Unit = { + typeProvider._currentIteratorType = Some(dataType) + out.puts(s"!(${expression(untilExpr)})") + out.dec + out.puts("} { }") + } + + override def handleAssignmentSimple(id: Identifier, expr: String): Unit = { + out.puts(s"${privateMemberName(id)} = $expr;") + } + + override def parseExpr(dataType: DataType, assignType: DataType, io: String, defEndian: Option[FixedEndian]): String = { + dataType match { + case t: ReadableType => + s"$io.read_${t.apiCall(defEndian)}()?" + case blt: BytesLimitType => + s"$io.read_bytes(${expression(blt.size)})?" + case _: BytesEosType => + s"$io.read_bytes_full()?" + case BytesTerminatedType(terminator, include, consume, eosError, _) => + s"$io.read_bytes_term($terminator, $include, $consume, $eosError)?" + case BitsType1 => + s"$io.read_bits_int(1)? != 0" + case BitsType(width: Int) => + s"$io.read_bits_int($width)?" + case t: UserType => + val addParams = Utils.join(t.args.map((a) => translator.translate(a)), "", ", ", ", ") + val addArgs = if (t.isOpaque) { + "" + } else { + val parent = t.forcedParent match { + case Some(USER_TYPE_NO_PARENT) => "null" + case Some(fp) => translator.translate(fp) + case None => "self" + } + val addEndian = t.classSpec.get.meta.endian match { + case Some(InheritedEndian) => s", ${privateMemberName(EndianIdentifier)}" + case _ => "" + } + s", $parent, ${privateMemberName(RootIdentifier)}$addEndian" + } + + s"Box::new(${translator.types2classAbs(t.classSpec.get.name)}::new(self.stream, self, _root)?)" + } + } + + override def bytesPadTermExpr(expr0: String, padRight: Option[Int], terminator: Option[Int], include: Boolean): String = { + val expr1 = padRight match { + case Some(padByte) => s"$kstreamName::bytesStripRight($expr0, $padByte)" + case None => expr0 + } + val expr2 = terminator match { + case Some(term) => s"$kstreamName::bytesTerminate($expr1, $term, $include)" + case None => expr1 + } + expr2 + } + + var switchIfs = false + val NAME_SWITCH_ON = Ast.expr.Name(Ast.identifier(Identifier.SWITCH_ON)) + + override def switchStart(id: Identifier, on: Ast.expr): Unit = { + val onType = translator.detectType(on) + + switchIfs = onType match { + case _: ArrayType | _: BytesType => true + case _ => false + } + + if (!switchIfs) { + out.puts(s"match ${expression(on)} {") + out.inc + } + } + + def switchCmpExpr(condition: Ast.expr): String = + expression( + Ast.expr.Compare( + NAME_SWITCH_ON, + Ast.cmpop.Eq, + condition + ) + ) + + override def switchCaseFirstStart(condition: Ast.expr): Unit = { + if (switchIfs) { + out.puts(s"if ${switchCmpExpr(condition)} {") + out.inc + } else { + switchCaseStart(condition) + } + } + + override def switchCaseStart(condition: Ast.expr): Unit = { + if (switchIfs) { + out.puts(s"elss if ${switchCmpExpr(condition)} {") + out.inc + } else { + out.puts(s"${expression(condition)} => {") + out.inc + } + } + + override def switchCaseEnd(): Unit = { + if (switchIfs) { + out.dec + out.puts("}") + } else { + out.dec + out.puts("},") + } + } + + override def switchElseStart(): Unit = { + if (switchIfs) { + out.puts("else {") + out.inc + } else { + out.puts("_ => {") + out.inc + } + } + + override def switchElseEnd(): Unit = { + out.dec + out.puts("}") + } + + override def switchEnd(): Unit = universalFooter + + override def instanceDeclaration(attrName: InstanceIdentifier, attrType: DataType, isNullable: Boolean): Unit = { + out.puts(s" pub ${idToStr(attrName)}: Option<${kaitaiType2NativeType(attrType)}>,") + } + + override def instanceDeclHeader(className: List[String]): Unit = { + out.dec + out.puts("}") + out.puts + + out.puts(s"impl ${type2class(className)} {") + out.inc + } + + override def instanceHeader(className: List[String], instName: InstanceIdentifier, dataType: DataType, isNullable: Boolean): Unit = { + out.puts(s"fn ${idToStr(instName)}(&mut self) -> ${kaitaiType2NativeType(dataType)} {") + out.inc + } + + override def instanceCheckCacheAndReturn(instName: InstanceIdentifier, dataType: DataType): Unit = { + out.puts(s"if let Some(x) = ${privateMemberName(instName)} {") + out.inc + out.puts("return x;") + out.dec + out.puts("}") + out.puts + } + + override def instanceReturn(instName: InstanceIdentifier, attrType: DataType): Unit = { + out.puts(s"return ${privateMemberName(instName)};") + } + + override def enumDeclaration(curClass: List[String], enumName: String, enumColl: Seq[(Long, EnumValueSpec)]): Unit = { + val enumClass = type2class(curClass ::: List(enumName)) + + out.puts(s"enum $enumClass {") + out.inc + + enumColl.foreach { case (id, label) => + universalDoc(label.doc) + out.puts(s"${value2Const(label.name)},") + } + + out.dec + out.puts("}") + } + + def value2Const(label: String) = label.toUpperCase + + def idToStr(id: Identifier): String = { + id match { + case SpecialIdentifier(name) => name + case NamedIdentifier(name) => Utils.lowerCamelCase(name) + case NumberedIdentifier(idx) => s"_${NumberedIdentifier.TEMPLATE}$idx" + case InstanceIdentifier(name) => Utils.lowerCamelCase(name) + case RawIdentifier(innerId) => "_raw_" + idToStr(innerId) + } + } + + override def privateMemberName(id: Identifier): String = { + id match { + case IoIdentifier => s"self.stream" + case RootIdentifier => s"_root" + case ParentIdentifier => s"_parent" + case _ => s"self.${idToStr(id)}" + } + } + + override def publicMemberName(id: Identifier) = idToStr(id) + + override def localTemporaryName(id: Identifier): String = s"$$_t_${idToStr(id)}" + + override def paramName(id: Identifier): String = s"${idToStr(id)}" + + def kaitaiType2NativeType(attrType: DataType): String = { + attrType match { + case Int1Type(false) => "u8" + case IntMultiType(false, Width2, _) => "u16" + case IntMultiType(false, Width4, _) => "u32" + case IntMultiType(false, Width8, _) => "u64" + + case Int1Type(true) => "i8" + case IntMultiType(true, Width2, _) => "i16" + case IntMultiType(true, Width4, _) => "i32" + case IntMultiType(true, Width8, _) => "i64" + + case FloatMultiType(Width4, _) => "f32" + case FloatMultiType(Width8, _) => "f64" + + case BitsType(_) => "u64" + + case _: BooleanType => "bool" + case CalcIntType => "i32" + case CalcFloatType => "f64" + + case _: StrType => "String" + case _: BytesType => "Vec" + + case t: UserType => t.classSpec match { + case Some(cs) => s"Box<${type2class(cs.name)}>" + case None => s"Box<${type2class(t.name)}>" + } + + case t: EnumType => t.enumSpec match { + case Some(cs) => s"Box<${type2class(cs.name)}>" + case None => s"Box<${type2class(t.name)}>" + } + + case ArrayType(inType) => s"Vec<${kaitaiType2NativeType(inType)}>" + + case KaitaiStreamType => s"Option>" + case KaitaiStructType | CalcKaitaiStructType => s"Option>" + + case st: SwitchType => kaitaiType2NativeType(st.combinedType) + } + } + + def kaitaiType2Default(attrType: DataType): String = { + attrType match { + case Int1Type(false) => "0" + case IntMultiType(false, Width2, _) => "0" + case IntMultiType(false, Width4, _) => "0" + case IntMultiType(false, Width8, _) => "0" + + case Int1Type(true) => "0" + case IntMultiType(true, Width2, _) => "0" + case IntMultiType(true, Width4, _) => "0" + case IntMultiType(true, Width8, _) => "0" + + case FloatMultiType(Width4, _) => "0" + case FloatMultiType(Width8, _) => "0" + + case BitsType(_) => "0" + + case _: BooleanType => "false" + case CalcIntType => "0" + case CalcFloatType => "0" + + case _: StrType => "\"\"" + case _: BytesType => "vec!()" + + case t: UserType => "Default::default()" + case t: EnumType => "Default::default()" + + case ArrayType(inType) => "vec!()" + + case KaitaiStreamType => "None" + case KaitaiStructType => "None" + + case _: SwitchType => "" + // TODO + } + } + + def type2class(names: List[String]) = types2classRel(names) + + def type2classAbs(names: List[String]) = + names.mkString("::") +} + +object RustCompiler extends LanguageCompilerStatic + with StreamStructNames + with UpperCamelCaseClasses { + override def getCompiler( + tp: ClassTypeProvider, + config: RuntimeConfig + ): LanguageCompiler = new RustCompiler(tp, config) + + override def kstructName = "&Option>" + override def kstreamName = "&mut S" + + def types2class(typeName: Ast.typeId) = { + typeName.names.map(type2class).mkString( + if (typeName.absolute) "__" else "", + "__", + "" + ) + } + + def types2classRel(names: List[String]) = + names.map(type2class).mkString("__") +} diff --git a/shared/src/main/scala/io/kaitai/struct/languages/components/AllocateAndStoreIO.scala b/shared/src/main/scala/io/kaitai/struct/languages/components/AllocateAndStoreIO.scala index cac6a7fdc..731cfc2ef 100644 --- a/shared/src/main/scala/io/kaitai/struct/languages/components/AllocateAndStoreIO.scala +++ b/shared/src/main/scala/io/kaitai/struct/languages/components/AllocateAndStoreIO.scala @@ -1,14 +1,22 @@ package io.kaitai.struct.languages.components -import io.kaitai.struct.format.{AttrSpec, Identifier, RepeatSpec} - -import scala.collection.mutable.ListBuffer +import io.kaitai.struct.datatype.DataType.{ArrayType, KaitaiStreamType} +import io.kaitai.struct.format._ /** * Allocates new IO and returns attribute identifier that it will be stored * at. This is used for languages without garbage collection that need to * keep track of allocated IOs. */ -trait AllocateAndStoreIO { - def allocateIO(id: Identifier, rep: RepeatSpec, extraAttrs: ListBuffer[AttrSpec]): String +trait AllocateAndStoreIO extends ExtraAttrs { + def allocateIO(id: Identifier, rep: RepeatSpec): String + + override def extraAttrForIO(id: Identifier, rep: RepeatSpec): List[AttrSpec] = { + val ioId = IoStorageIdentifier(id) + val ioType = rep match { + case NoRepeat => KaitaiStreamType + case _ => ArrayType(KaitaiStreamType) + } + List(AttrSpec(List(), ioId, ioType)) + } } diff --git a/shared/src/main/scala/io/kaitai/struct/languages/components/AllocateIOLocalVar.scala b/shared/src/main/scala/io/kaitai/struct/languages/components/AllocateIOLocalVar.scala index 6334d759a..fd84dce12 100644 --- a/shared/src/main/scala/io/kaitai/struct/languages/components/AllocateIOLocalVar.scala +++ b/shared/src/main/scala/io/kaitai/struct/languages/components/AllocateIOLocalVar.scala @@ -1,11 +1,13 @@ package io.kaitai.struct.languages.components -import io.kaitai.struct.format.{Identifier, RepeatSpec} +import io.kaitai.struct.format.{AttrSpec, Identifier, RepeatSpec} /** * Allocates new auxiliary IOs as local vars - no references saved and thus * probably garbage collector will deal with them. */ -trait AllocateIOLocalVar { +trait AllocateIOLocalVar extends ExtraAttrs { def allocateIO(varName: Identifier, rep: RepeatSpec): String + + override def extraAttrForIO(id: Identifier, rep: RepeatSpec): List[AttrSpec] = List() } diff --git a/shared/src/main/scala/io/kaitai/struct/languages/components/CommonReads.scala b/shared/src/main/scala/io/kaitai/struct/languages/components/CommonReads.scala index 350d1c091..d938b7b98 100644 --- a/shared/src/main/scala/io/kaitai/struct/languages/components/CommonReads.scala +++ b/shared/src/main/scala/io/kaitai/struct/languages/components/CommonReads.scala @@ -5,10 +5,8 @@ import io.kaitai.struct.datatype.DataType.{SwitchType, UserTypeFromBytes} import io.kaitai.struct.exprlang.Ast import io.kaitai.struct.format._ -import scala.collection.mutable.ListBuffer - trait CommonReads extends LanguageCompiler { - override def attrParse(attr: AttrLikeSpec, id: Identifier, extraAttrs: ListBuffer[AttrSpec], defEndian: Option[Endianness]): Unit = { + override def attrParse(attr: AttrLikeSpec, id: Identifier, defEndian: Option[Endianness]): Unit = { attrParseIfHeader(id, attr.cond.ifExpr) // Manage IO & seeking for ParseInstances @@ -28,22 +26,22 @@ trait CommonReads extends LanguageCompiler { normalIO } - if (debug) + if (config.readStoresPos) attrDebugStart(id, attr.dataType, Some(io), NoRepeat) defEndian match { case Some(_: CalcEndian) | Some(InheritedEndian) => attrParseHybrid( - () => attrParse0(id, attr, io, extraAttrs, Some(LittleEndian)), - () => attrParse0(id, attr, io, extraAttrs, Some(BigEndian)) + () => attrParse0(id, attr, io, Some(LittleEndian)), + () => attrParse0(id, attr, io, Some(BigEndian)) ) case None => - attrParse0(id, attr, io, extraAttrs, None) + attrParse0(id, attr, io, None) case Some(fe: FixedEndian) => - attrParse0(id, attr, io, extraAttrs, Some(fe)) + attrParse0(id, attr, io, Some(fe)) } - if (debug) + if (config.readStoresPos) attrDebugEnd(id, attr.dataType, io, NoRepeat) // More position management after parsing for ParseInstanceSpecs @@ -57,26 +55,26 @@ trait CommonReads extends LanguageCompiler { attrParseIfFooter(attr.cond.ifExpr) } - def attrParse0(id: Identifier, attr: AttrLikeSpec, io: String, extraAttrs: ListBuffer[AttrSpec], defEndian: Option[FixedEndian]): Unit = { + def attrParse0(id: Identifier, attr: AttrLikeSpec, io: String, defEndian: Option[FixedEndian]): Unit = { attr.cond.repeat match { case RepeatEos => condRepeatEosHeader(id, io, attr.dataType, needRaw(attr.dataType)) - attrParse2(id, attr.dataType, io, extraAttrs, attr.cond.repeat, false, defEndian) + attrParse2(id, attr.dataType, io, attr.cond.repeat, false, defEndian) condRepeatEosFooter case RepeatExpr(repeatExpr: Ast.expr) => condRepeatExprHeader(id, io, attr.dataType, needRaw(attr.dataType), repeatExpr) - attrParse2(id, attr.dataType, io, extraAttrs, attr.cond.repeat, false, defEndian) + attrParse2(id, attr.dataType, io, attr.cond.repeat, false, defEndian) condRepeatExprFooter case RepeatUntil(untilExpr: Ast.expr) => condRepeatUntilHeader(id, io, attr.dataType, needRaw(attr.dataType), untilExpr) - attrParse2(id, attr.dataType, io, extraAttrs, attr.cond.repeat, false, defEndian) + attrParse2(id, attr.dataType, io, attr.cond.repeat, false, defEndian) condRepeatUntilFooter(id, io, attr.dataType, needRaw(attr.dataType), untilExpr) case NoRepeat => - attrParse2(id, attr.dataType, io, extraAttrs, attr.cond.repeat, false, defEndian) + attrParse2(id, attr.dataType, io, attr.cond.repeat, false, defEndian) } } - def attrParse2(id: Identifier, dataType: DataType, io: String, extraAttrs: ListBuffer[AttrSpec], rep: RepeatSpec, isRaw: Boolean, defEndian: Option[FixedEndian], assignType: Option[DataType] = None): Unit + def attrParse2(id: Identifier, dataType: DataType, io: String, rep: RepeatSpec, isRaw: Boolean, defEndian: Option[FixedEndian], assignType: Option[DataType] = None): Unit def needRaw(dataType: DataType): Boolean = { dataType match { diff --git a/shared/src/main/scala/io/kaitai/struct/languages/components/EveryReadIsExpression.scala b/shared/src/main/scala/io/kaitai/struct/languages/components/EveryReadIsExpression.scala index fcee72026..08ee46d99 100644 --- a/shared/src/main/scala/io/kaitai/struct/languages/components/EveryReadIsExpression.scala +++ b/shared/src/main/scala/io/kaitai/struct/languages/components/EveryReadIsExpression.scala @@ -25,7 +25,6 @@ trait EveryReadIsExpression id: Identifier, dataType: DataType, io: String, - extraAttrs: ListBuffer[AttrSpec], rep: RepeatSpec, isRaw: Boolean, defEndian: Option[FixedEndian], @@ -33,16 +32,16 @@ trait EveryReadIsExpression ): Unit = { val assignType = assignTypeOpt.getOrElse(dataType) - if (debug && rep != NoRepeat) + if (config.readStoresPos && rep != NoRepeat) attrDebugStart(id, dataType, Some(io), rep) dataType match { case FixedBytesType(c, _) => attrFixedContentsParse(id, c) case t: UserType => - attrUserTypeParse(id, t, io, extraAttrs, rep, defEndian) + attrUserTypeParse(id, t, io, rep, defEndian) case t: BytesType => - attrBytesTypeParse(id, t, io, extraAttrs, rep, isRaw) + attrBytesTypeParse(id, t, io, rep, isRaw) case st: SwitchType => val isNullable = if (switchBytesOnlyAsRaw) { st.isNullableSwitchRaw @@ -50,7 +49,7 @@ trait EveryReadIsExpression st.isNullable } - attrSwitchTypeParse(id, st.on, st.cases, io, extraAttrs, rep, defEndian, isNullable, st.combinedType) + attrSwitchTypeParse(id, st.on, st.cases, io, rep, defEndian, isNullable, st.combinedType) case t: StrFromBytesType => val expr = translator.bytesToStr(parseExprBytes(t.bytes, io), Ast.expr.Str(t.encoding)) handleAssignment(id, expr, rep, isRaw) @@ -62,7 +61,7 @@ trait EveryReadIsExpression handleAssignment(id, expr, rep, isRaw) } - if (debug && rep != NoRepeat) + if (config.readStoresPos && rep != NoRepeat) attrDebugEnd(id, dataType, io, rep) } @@ -70,17 +69,13 @@ trait EveryReadIsExpression id: Identifier, dataType: BytesType, io: String, - extraAttrs: ListBuffer[AttrSpec], rep: RepeatSpec, isRaw: Boolean ): Unit = { // use intermediate variable name, if we'll be doing post-processing val rawId = dataType.process match { case None => id - case Some(_) => - val rawId = RawIdentifier(id) - Utils.addUniqueAttr(extraAttrs, AttrSpec(List(), rawId, dataType)) - rawId + case Some(_) => RawIdentifier(id) } val expr = parseExprBytes(dataType, io) @@ -104,25 +99,23 @@ trait EveryReadIsExpression } } - def attrUserTypeParse(id: Identifier, dataType: UserType, io: String, extraAttrs: ListBuffer[AttrSpec], rep: RepeatSpec, defEndian: Option[FixedEndian]): Unit = { + def attrUserTypeParse(id: Identifier, dataType: UserType, io: String, rep: RepeatSpec, defEndian: Option[FixedEndian]): Unit = { val newIO = dataType match { case knownSizeType: UserTypeFromBytes => // we have a fixed buffer, thus we shall create separate IO for it val rawId = RawIdentifier(id) val byteType = knownSizeType.bytes - attrParse2(rawId, byteType, io, extraAttrs, rep, true, defEndian) + attrParse2(rawId, byteType, io, rep, true, defEndian) val extraType = rep match { case NoRepeat => byteType case _ => ArrayType(byteType) } - Utils.addUniqueAttr(extraAttrs, AttrSpec(List(), rawId, extraType)) - this match { case thisStore: AllocateAndStoreIO => - thisStore.allocateIO(rawId, rep, extraAttrs) + thisStore.allocateIO(rawId, rep) case thisLocal: AllocateIOLocalVar => thisLocal.allocateIO(rawId, rep) } @@ -131,13 +124,13 @@ trait EveryReadIsExpression io } val expr = parseExpr(dataType, dataType, newIO, defEndian) - if (!debug) { + if (config.autoRead) { handleAssignment(id, expr, rep, false) } else { - // Debug mode requires one to actually call "_read" method on constructed user type, - // and this must be done as a separate statement - or else exception handler would - // blast the whole structure, not only this element. This, in turn, makes us assign - // constructed element to a temporary variable in case on repetitions + // Disabled autoRead requires one to actually call `_read` method on constructed + // user type, and this must be done as a separate statement - or else exception + // handler would blast the whole structure, not only this element. This, in turn, + // makes us assign constructed element to a temporary variable in case of arrays. rep match { case NoRepeat => handleAssignmentSimple(id, expr) @@ -156,7 +149,6 @@ trait EveryReadIsExpression on: Ast.expr, cases: Map[Ast.expr, DataType], io: String, - extraAttrs: ListBuffer[AttrSpec], rep: RepeatSpec, defEndian: Option[FixedEndian], isNullable: Boolean, @@ -169,17 +161,17 @@ trait EveryReadIsExpression (dataType) => { if (isNullable) condIfSetNonNull(id) - attrParse2(id, dataType, io, extraAttrs, rep, false, defEndian, Some(assignType)) + attrParse2(id, dataType, io, rep, false, defEndian, Some(assignType)) }, (dataType) => if (switchBytesOnlyAsRaw) { dataType match { case t: BytesType => - attrParse2(RawIdentifier(id), dataType, io, extraAttrs, rep, false, defEndian, Some(assignType)) + attrParse2(RawIdentifier(id), dataType, io, rep, false, defEndian, Some(assignType)) case _ => - attrParse2(id, dataType, io, extraAttrs, rep, false, defEndian, Some(assignType)) + attrParse2(id, dataType, io, rep, false, defEndian, Some(assignType)) } } else { - attrParse2(id, dataType, io, extraAttrs, rep, false, defEndian, Some(assignType)) + attrParse2(id, dataType, io, rep, false, defEndian, Some(assignType)) } ) } @@ -201,10 +193,10 @@ trait EveryReadIsExpression def parseExpr(dataType: DataType, assignType: DataType, io: String, defEndian: Option[FixedEndian]): String def bytesPadTermExpr(expr0: String, padRight: Option[Int], terminator: Option[Int], include: Boolean): String - def userTypeDebugRead(id: String): Unit = {} + def userTypeDebugRead(id: String): Unit = ??? def instanceCalculate(instName: Identifier, dataType: DataType, value: Ast.expr): Unit = { - if (debug) + if (config.readStoresPos) attrDebugStart(instName, dataType, None, NoRepeat) handleAssignmentSimple(instName, expression(value)) } diff --git a/shared/src/main/scala/io/kaitai/struct/languages/components/ExtraAttrs.scala b/shared/src/main/scala/io/kaitai/struct/languages/components/ExtraAttrs.scala index a1f9e6f5c..90f19d3a9 100644 --- a/shared/src/main/scala/io/kaitai/struct/languages/components/ExtraAttrs.scala +++ b/shared/src/main/scala/io/kaitai/struct/languages/components/ExtraAttrs.scala @@ -4,6 +4,14 @@ import io.kaitai.struct.datatype.DataType import io.kaitai.struct.datatype.DataType._ import io.kaitai.struct.format._ +/** + * Trait to be implemented by all [[LanguageCompiler]] compilers: supplies extra attributes + * when we'll be allocating new IOs. + */ +trait ExtraAttrs { + def extraAttrForIO(id: Identifier, rep: RepeatSpec): List[AttrSpec] +} + /** * Generates list of extra attributes required to store intermediate / * virtual stuff for every attribute like: @@ -13,21 +21,44 @@ import io.kaitai.struct.format._ * * unprocessed / postprocessed byte arrays */ object ExtraAttrs { - def forAttr(attr: AttrLikeSpec): List[AttrSpec] = - forAttr(attr.id, attr.dataType) + def forClassSpec(curClass: ClassSpec, compiler: ExtraAttrs): List[AttrSpec] = { + // We want only values of ParseInstances, which are AttrSpecLike. + // ValueInstances are ignored, as they can't currently generate + // any extra attributes (i.e. no `size`, no `process`, etc) + val parseInstances = curClass.instances.values.collect { + case inst: AttrLikeSpec => inst + } + + (curClass.seq ++ parseInstances).foldLeft(List[AttrSpec]())( + (attrs, attr) => attrs ++ ExtraAttrs.forAttr(attr, compiler) + ) + } + + def forAttr(attr: AttrLikeSpec, compiler: ExtraAttrs): Iterable[AttrSpec] = + forAttr(attr.id, attr.dataType, attr.cond, compiler) - def forAttr(id: Identifier, dataType: DataType): List[AttrSpec] = { + private + def forAttr(id: Identifier, dataType: DataType, condSpec: ConditionalSpec, compiler: ExtraAttrs): Iterable[AttrSpec] = { dataType match { case bt: BytesType => + // Byte array: only need extra attrs if `process` is used bt.process match { case None => List() case Some(_) => val rawId = RawIdentifier(id) - List(AttrSpec(List(), rawId, bt)) + List(AttrSpec(List(), rawId, bt, condSpec)) ++ + compiler.extraAttrForIO(id, condSpec.repeat) } case utb: UserTypeFromBytes => + // User type in a substream val rawId = RawIdentifier(id) - List(AttrSpec(List(), rawId, utb.bytes)) ++ forAttr(rawId, utb.bytes) + (List(AttrSpec(List(), rawId, utb.bytes, condSpec)) ++ + compiler.extraAttrForIO(rawId, condSpec.repeat) ++ + forAttr(rawId, utb.bytes, condSpec, compiler)).toList.distinct + case st: SwitchType => + st.cases.flatMap { case (_, caseType) => + forAttr(id, caseType, condSpec, compiler) + }.toList.distinct case _ => List() } diff --git a/shared/src/main/scala/io/kaitai/struct/languages/components/GoReads.scala b/shared/src/main/scala/io/kaitai/struct/languages/components/GoReads.scala index 876515960..5e70658a2 100644 --- a/shared/src/main/scala/io/kaitai/struct/languages/components/GoReads.scala +++ b/shared/src/main/scala/io/kaitai/struct/languages/components/GoReads.scala @@ -1,11 +1,11 @@ package io.kaitai.struct.languages.components import io.kaitai.struct.Utils -import io.kaitai.struct.datatype.{BigEndian, DataType, FixedEndian} import io.kaitai.struct.datatype.DataType._ +import io.kaitai.struct.datatype.{DataType, FixedEndian} import io.kaitai.struct.exprlang.Ast import io.kaitai.struct.format._ -import io.kaitai.struct.translators.{GoTranslator, TranslatorResult} +import io.kaitai.struct.translators.{GoTranslator, ResultString, TranslatorResult} import scala.collection.mutable.ListBuffer @@ -16,7 +16,6 @@ trait GoReads extends CommonReads with ObjectOrientedLanguage with SwitchOps { id: Identifier, dataType: DataType, io: String, - extraAttrs: ListBuffer[AttrSpec], rep: RepeatSpec, isRaw: Boolean, defEndian: Option[FixedEndian], @@ -26,7 +25,7 @@ trait GoReads extends CommonReads with ObjectOrientedLanguage with SwitchOps { case FixedBytesType(c, _) => attrFixedContentsParse(id, c) case t: UserType => - attrUserTypeParse(id, t, io, extraAttrs, rep, defEndian) + attrUserTypeParse(id, t, io, rep, defEndian) // case t: BytesType => // attrBytesTypeParse(id, t, io, extraAttrs, rep, isRaw) // case SwitchType(on, cases) => @@ -35,9 +34,16 @@ trait GoReads extends CommonReads with ObjectOrientedLanguage with SwitchOps { val r1 = translator.outVarCheckRes(parseExprBytes(t.bytes, io)) val expr = translator.bytesToStr(translator.resToStr(r1), Ast.expr.Str(t.encoding)) handleAssignment(id, expr, rep, isRaw) -// case t: EnumType => -// val expr = translator.doEnumById(t.enumSpec.get.name, parseExpr(t.basedOn, io)) -// handleAssignment(id, expr, rep, isRaw) + case t: EnumType => + val r1 = translator.outVarCheckRes(parseExpr(t.basedOn, io, defEndian)) + val enumSpec = t.enumSpec.get + val expr = translator.trEnumById(enumSpec.name, translator.resToStr(r1)) + handleAssignment(id, expr, rep, isRaw) + case BitsType1 => + val expr = parseExpr(dataType, io, defEndian) + val r1 = translator.outVarCheckRes(expr) + val r2 = ResultString(s"${translator.resToStr(r1)} != 0") + handleAssignment(id, r2, rep, isRaw) case _ => val expr = parseExpr(dataType, io, defEndian) val r = translator.outVarCheckRes(expr) @@ -60,25 +66,23 @@ trait GoReads extends CommonReads with ObjectOrientedLanguage with SwitchOps { expr } - def attrUserTypeParse(id: Identifier, dataType: UserType, io: String, extraAttrs: ListBuffer[AttrSpec], rep: RepeatSpec, defEndian: Option[FixedEndian]): Unit = { + def attrUserTypeParse(id: Identifier, dataType: UserType, io: String, rep: RepeatSpec, defEndian: Option[FixedEndian]): Unit = { val newIO = dataType match { case knownSizeType: UserTypeFromBytes => // we have a fixed buffer, thus we shall create separate IO for it val rawId = RawIdentifier(id) val byteType = knownSizeType.bytes - attrParse2(rawId, byteType, io, extraAttrs, rep, true, defEndian) + attrParse2(rawId, byteType, io, rep, true, defEndian) val extraType = rep match { case NoRepeat => byteType case _ => ArrayType(byteType) } - Utils.addUniqueAttr(extraAttrs, AttrSpec(List(), rawId, extraType)) - this match { case thisStore: AllocateAndStoreIO => - thisStore.allocateIO(rawId, rep, extraAttrs) + thisStore.allocateIO(rawId, rep) case thisLocal: AllocateIOLocalVar => thisLocal.allocateIO(rawId, rep) } diff --git a/shared/src/main/scala/io/kaitai/struct/languages/components/LanguageCompiler.scala b/shared/src/main/scala/io/kaitai/struct/languages/components/LanguageCompiler.scala index 77ec48c19..f41214463 100644 --- a/shared/src/main/scala/io/kaitai/struct/languages/components/LanguageCompiler.scala +++ b/shared/src/main/scala/io/kaitai/struct/languages/components/LanguageCompiler.scala @@ -10,8 +10,9 @@ import scala.collection.mutable.ListBuffer abstract class LanguageCompiler( typeProvider: ClassTypeProvider, - config: RuntimeConfig -) extends SwitchOps { + val config: RuntimeConfig +) extends SwitchOps + with ExtraAttrs { val translator: AbstractTranslator @@ -51,7 +52,7 @@ abstract class LanguageCompiler( */ def innerDocstrings: Boolean = false - def debug = config.debug + def debug: Boolean = !config.autoRead && config.readStoresPos def indent: String def outFileName(topClassName: String): String @@ -88,8 +89,9 @@ abstract class LanguageCompiler( def attributeReader(attrName: Identifier, attrType: DataType, isNullable: Boolean): Unit def attributeDoc(id: Identifier, doc: DocSpec): Unit = {} - def attrParse(attr: AttrLikeSpec, id: Identifier, extraAttrs: ListBuffer[AttrSpec], defEndian: Option[Endianness]): Unit + def attrParse(attr: AttrLikeSpec, id: Identifier, defEndian: Option[Endianness]): Unit def attrParseHybrid(leProc: () => Unit, beProc: () => Unit): Unit + def attrInit(attr: AttrLikeSpec): Unit = {} def attrDestructor(attr: AttrLikeSpec, id: Identifier): Unit = {} def attrFixedContentsParse(attrName: Identifier, contents: Array[Byte]): Unit @@ -117,13 +119,14 @@ abstract class LanguageCompiler( def popPos(io: String): Unit def alignToByte(io: String): Unit + def instanceDeclHeader(className: List[String]): Unit = {} def instanceClear(instName: InstanceIdentifier): Unit = {} def instanceSetCalculated(instName: InstanceIdentifier): Unit = {} def instanceDeclaration(attrName: InstanceIdentifier, attrType: DataType, isNullable: Boolean): Unit = attributeDeclaration(attrName, attrType, isNullable) def instanceHeader(className: List[String], instName: InstanceIdentifier, dataType: DataType, isNullable: Boolean): Unit def instanceFooter: Unit - def instanceCheckCacheAndReturn(instName: InstanceIdentifier): Unit - def instanceReturn(instName: InstanceIdentifier): Unit + def instanceCheckCacheAndReturn(instName: InstanceIdentifier, dataType: DataType): Unit + def instanceReturn(instName: InstanceIdentifier, attrType: DataType): Unit def instanceCalculate(instName: Identifier, dataType: DataType, value: Ast.expr) def enumDeclaration(curClass: List[String], enumName: String, enumColl: Seq[(Long, EnumValueSpec)]): Unit diff --git a/shared/src/main/scala/io/kaitai/struct/languages/components/LanguageCompilerStatic.scala b/shared/src/main/scala/io/kaitai/struct/languages/components/LanguageCompilerStatic.scala index 93d0dc023..bbb497c96 100644 --- a/shared/src/main/scala/io/kaitai/struct/languages/components/LanguageCompilerStatic.scala +++ b/shared/src/main/scala/io/kaitai/struct/languages/components/LanguageCompilerStatic.scala @@ -2,7 +2,6 @@ package io.kaitai.struct.languages.components import io.kaitai.struct._ import io.kaitai.struct.languages._ -import io.kaitai.struct.translators.{BaseTranslator, TypeProvider} trait LanguageCompilerStatic { def getCompiler(tp: ClassTypeProvider, config: RuntimeConfig): LanguageCompiler @@ -10,17 +9,20 @@ trait LanguageCompilerStatic { object LanguageCompilerStatic { val NAME_TO_CLASS: Map[String, LanguageCompilerStatic] = Map( + "construct" -> ConstructClassCompiler, "cpp_stl" -> CppCompiler, "csharp" -> CSharpCompiler, "graphviz" -> GraphvizClassCompiler, "go" -> GoCompiler, + "html" -> HtmlClassCompiler, "java" -> JavaCompiler, "javascript" -> JavaScriptCompiler, "lua" -> LuaCompiler, "perl" -> PerlCompiler, "php" -> PHPCompiler, "python" -> PythonCompiler, - "ruby" -> RubyCompiler + "ruby" -> RubyCompiler, + "rust" -> RustCompiler ) val CLASS_TO_NAME: Map[LanguageCompilerStatic, String] = NAME_TO_CLASS.map(_.swap) diff --git a/shared/src/main/scala/io/kaitai/struct/precompile/LoadImports.scala b/shared/src/main/scala/io/kaitai/struct/precompile/LoadImports.scala index 1cb0c0304..bd9684704 100644 --- a/shared/src/main/scala/io/kaitai/struct/precompile/LoadImports.scala +++ b/shared/src/main/scala/io/kaitai/struct/precompile/LoadImports.scala @@ -58,9 +58,18 @@ class LoadImports(specs: ClassSpecs) { } futureSpec.flatMap { case optSpec => + Log.importOps.info(() => { + val specNameAsStr = optSpec.map(_.nameAsStr).getOrElse("") + s".. LoadImports: loadImport($name, workDir = $workDir), got spec=$specNameAsStr" + }) optSpec match { case Some(spec) => val specName = spec.name.head + // Check if spec name does not match file name. If it doesn't match, + // it is probably already a serious error. + if (name != specName) + Log.importOps.warn(() => s"... expected to have type name $name, but got $specName") + // Check if we've already had this spec in our ClassSpecs. If we do, // don't do anything: we've already processed it and reprocessing it // might lead to infinite recursion. @@ -71,6 +80,7 @@ class LoadImports(specs: ClassSpecs) { specs(specName) = spec processClass(spec, ImportPath.updateWorkDir(workDir, impPath)) } else { + Log.importOps.warn(() => s"... we have that already, ignoring") Future { List() } } case None => diff --git a/shared/src/main/scala/io/kaitai/struct/precompile/ResolveTypes.scala b/shared/src/main/scala/io/kaitai/struct/precompile/ResolveTypes.scala index 1fe9eb0c6..179f46813 100644 --- a/shared/src/main/scala/io/kaitai/struct/precompile/ResolveTypes.scala +++ b/shared/src/main/scala/io/kaitai/struct/precompile/ResolveTypes.scala @@ -18,19 +18,21 @@ class ResolveTypes(specs: ClassSpecs, opaqueTypes: Boolean) { * @param curClass class to start from, might be top-level class */ def resolveUserTypes(curClass: ClassSpec): Unit = { - curClass.seq.foreach((attr) => resolveUserTypeForAttr(curClass, attr)) + curClass.seq.foreach((attr) => resolveUserTypeForMember(curClass, attr)) curClass.instances.foreach { case (_, instSpec) => instSpec match { case pis: ParseInstanceSpec => - resolveUserTypeForAttr(curClass, pis) + resolveUserTypeForMember(curClass, pis) case _: ValueInstanceSpec => // ignore all other types of instances } } + + curClass.params.foreach((paramDef) => resolveUserTypeForMember(curClass, paramDef)) } - def resolveUserTypeForAttr(curClass: ClassSpec, attr: AttrLikeSpec): Unit = + def resolveUserTypeForMember(curClass: ClassSpec, attr: MemberSpec): Unit = resolveUserType(curClass, attr.dataType, attr.path) def resolveUserType(curClass: ClassSpec, dataType: DataType, path: List[String]): Unit = { @@ -43,8 +45,8 @@ class ResolveTypes(specs: ClassSpecs, opaqueTypes: Boolean) { val err = new EnumNotFoundError(et.name.mkString("::"), curClass) throw new YAMLParseException(err.getMessage, path) } - case SwitchType(_, cases) => - cases.foreach { case (caseName, ut) => + case st: SwitchType => + st.cases.foreach { case (caseName, ut) => resolveUserType(curClass, ut, path ++ List("type", "cases", caseName.toString)) } case _ => @@ -105,21 +107,34 @@ class ResolveTypes(specs: ClassSpecs, opaqueTypes: Boolean) { } else { // Check if top-level specs has this name // If there's None => no luck at all - specs.get(firstName) + val resolvedTop = specs.get(firstName) + resolvedTop match { + case None => None + case Some(classSpec) => if (restNames.isEmpty) { + resolvedTop + } else { + resolveUserType(classSpec, restNames, path) + } + } } } } } def resolveEnumSpec(curClass: ClassSpec, typeName: List[String]): Option[EnumSpec] = { - // Console.println(s"resolveEnumSpec: at ${curClass.name} doing ${typeName.mkString("|")}") - val res = realResolveEnumSpec(curClass, typeName) - // Console.println(" => " + (res match { - // case None => "???" - // case Some(x) => x.name.mkString("|") - // })) + Log.enumResolve.info(() => s"resolveEnumSpec: at ${curClass.name} doing ${typeName.mkString("|")}") - res + val res = realResolveEnumSpec(curClass, typeName) + res match { + case None => { + Log.enumResolve.info(() => s" => ???") + res + } + case Some(x) => { + Log.enumResolve.info(() => s" => ${x.nameAsStr}") + res + } + } } private def realResolveEnumSpec(curClass: ClassSpec, typeName: List[String]): Option[EnumSpec] = { diff --git a/shared/src/main/scala/io/kaitai/struct/precompile/TypeValidator.scala b/shared/src/main/scala/io/kaitai/struct/precompile/TypeValidator.scala index 6bd2b1210..94d1a504f 100644 --- a/shared/src/main/scala/io/kaitai/struct/precompile/TypeValidator.scala +++ b/shared/src/main/scala/io/kaitai/struct/precompile/TypeValidator.scala @@ -43,7 +43,7 @@ class TypeValidator(specs: ClassSpecs, topClass: ClassSpec) { curClass.instances.foreach { case (_, inst) => inst match { case pis: ParseInstanceSpec => - validateAttr(pis) + validateParseInstance(pis) case vis: ValueInstanceSpec => // TODO } @@ -77,6 +77,22 @@ class TypeValidator(specs: ClassSpecs, topClass: ClassSpec) { validateDataType(attr.dataType, path) } + def validateParseInstance(pis: ParseInstanceSpec): Unit = { + validateAttr(pis) + + Log.typeValid.info(() => s"validateParseInstance(${pis.id.humanReadable})") + + pis.io match { + case Some(io) => checkAssertObject(io, KaitaiStreamType, "IO stream", pis.path, "io") + case None => // all good + } + + pis.pos match { + case Some(pos) => checkAssert[IntType](pos, "integer", pis.path, "pos") + case None => // all good + } + } + /** * Validates single non-composite data type, checking all expressions * inside data type definition. @@ -189,4 +205,47 @@ class TypeValidator(specs: ClassSpecs, topClass: ClassSpec) { throw new ErrorInInput(err, path ++ List(pathKey)) } } + + /** + * Checks that expression's type conforms to a given datatype, otherwise + * throw a human-readable exception, with some pointers that would help + * finding the expression in source .ksy. + * + * This version works with case objects. + * + * @param expr expression to check + * @param expectStr string to include + * @param path path to expression base + * @param pathKey key that contains expression in given path + */ + def checkAssertObject( + expr: Ast.expr, + expected: Object, + expectStr: String, + path: List[String], + pathKey: String + ): Unit = { + try { + val detected = detector.detectType(expr) + if (detected == expected) { + // good + } else { + detected match { + case st: SwitchType => + val combinedType = st.combinedType + if (combinedType == expected) { + // good + } else { + throw YAMLParseException.exprType(expectStr, combinedType, path ++ List(pathKey)) + } + case actual => throw YAMLParseException.exprType(expectStr, actual, path ++ List(pathKey)) + } + } + } catch { + case err: InvalidIdentifier => + throw new ErrorInInput(err, path ++ List(pathKey)) + case err: ExpressionError => + throw new ErrorInInput(err, path ++ List(pathKey)) + } + } } diff --git a/shared/src/main/scala/io/kaitai/struct/precompile/ValueTypesDeriver.scala b/shared/src/main/scala/io/kaitai/struct/precompile/ValueTypesDeriver.scala index 50cf82c30..eb1937025 100644 --- a/shared/src/main/scala/io/kaitai/struct/precompile/ValueTypesDeriver.scala +++ b/shared/src/main/scala/io/kaitai/struct/precompile/ValueTypesDeriver.scala @@ -24,7 +24,7 @@ class ValueTypesDeriver(specs: ClassSpecs, topClass: ClassSpec) { vi.dataType match { case None => try { - val viType = detector.detectType(vi.value) + val viType = detector.detectType(vi.value).asNonOwning vi.dataType = Some(viType) Log.typeProcValue.info(() => s"${instName.name} derived type: $viType") hasChanged = true diff --git a/shared/src/main/scala/io/kaitai/struct/translators/BaseTranslator.scala b/shared/src/main/scala/io/kaitai/struct/translators/BaseTranslator.scala index 2c2c808cb..9d8dedbd0 100644 --- a/shared/src/main/scala/io/kaitai/struct/translators/BaseTranslator.scala +++ b/shared/src/main/scala/io/kaitai/struct/translators/BaseTranslator.scala @@ -16,7 +16,8 @@ import io.kaitai.struct.precompile.TypeMismatchError * * Given that there are many of these abstract methods, to make it more * maintainable, they are grouped into several abstract traits: - * [[CommonLiterals]], [[CommonOps]]. + * [[CommonLiterals]], [[CommonOps]], [[CommonMethods]], + * [[CommonArraysAndCast]]. * * This translator implementation also handles user-defined types and * fields properly - it uses given [[TypeProvider]] to resolve these. @@ -28,6 +29,7 @@ abstract class BaseTranslator(val provider: TypeProvider) with AbstractTranslator with CommonLiterals with CommonOps + with CommonArraysAndCast[String] with CommonMethods[String] { /** @@ -52,11 +54,11 @@ abstract class BaseTranslator(val provider: TypeProvider) doStringLiteral(s) case Ast.expr.Bool(n) => doBoolLiteral(n) - case Ast.expr.EnumById(enumType, id) => - val enumSpec = provider.resolveEnum(enumType.name) + case Ast.expr.EnumById(enumType, id, inType) => + val enumSpec = provider.resolveEnum(inType, enumType.name) doEnumById(enumSpec.name, translate(id)) - case Ast.expr.EnumByLabel(enumType, label) => - val enumSpec = provider.resolveEnum(enumType.name) + case Ast.expr.EnumByLabel(enumType, label, inType) => + val enumSpec = provider.resolveEnum(inType, enumType.name) doEnumByLabel(enumSpec.name, label.name) case Ast.expr.Name(name: Ast.identifier) => doLocalName(name.name) @@ -83,9 +85,11 @@ abstract class BaseTranslator(val provider: TypeProvider) doStrCompareOp(left, op, right) case (_: BytesType, _: BytesType) => doBytesCompareOp(left, op, right) - case (EnumType(ltype, _), EnumType(rtype, _)) => - if (ltype != rtype) { - throw new TypeMismatchError(s"can't compare enums type $ltype and $rtype") + case (et1: EnumType, et2: EnumType) => + val et1Spec = et1.enumSpec.get + val et2Spec = et2.enumSpec.get + if (et1Spec != et2Spec) { + throw new TypeMismatchError(s"can't compare enums type ${et1Spec.nameAsStr} and ${et2Spec.nameAsStr}") } else { doEnumCompareOp(left, op, right) } @@ -117,34 +121,19 @@ abstract class BaseTranslator(val provider: TypeProvider) case call: Ast.expr.Call => translateCall(call) case Ast.expr.List(values: Seq[Ast.expr]) => - val t = detectArrayType(values) - t match { - case Int1Type(_) => - val literalBytes: Seq[Byte] = values.map { - case Ast.expr.IntNum(x) => - if (x < 0 || x > 0xff) { - throw new TypeMismatchError(s"got a weird byte value in byte array: $x") - } else { - x.toByte - } - case n => - throw new TypeMismatchError(s"got $n in byte array, unable to put it literally") - } - doByteArrayLiteral(literalBytes) - case _ => - doArrayLiteral(t, values) - } - case Ast.expr.CastToType(value, typeName) => - doCast(value, typeName.name) + doGuessArrayLiteral(values) + case ctt: Ast.expr.CastToType => + doCastOrArray(ctt) } } def doSubscript(container: Ast.expr, idx: Ast.expr): String def doIfExp(condition: Ast.expr, ifTrue: Ast.expr, ifFalse: Ast.expr): String - def doCast(value: Ast.expr, typeName: String): String = translate(value) + def doCast(value: Ast.expr, typeName: DataType): String = translate(value) def doArrayLiteral(t: DataType, value: Seq[Ast.expr]): String = "[" + value.map((v) => translate(v)).mkString(", ") + "]" def doByteArrayLiteral(arr: Seq[Byte]): String = "[" + arr.map(_ & 0xff).mkString(", ") + "]" + def doByteArrayNonLiteral(elts: Seq[Ast.expr]): String = ??? def doLocalName(s: String): String = doName(s) def doName(s: String): String diff --git a/shared/src/main/scala/io/kaitai/struct/translators/CSharpTranslator.scala b/shared/src/main/scala/io/kaitai/struct/translators/CSharpTranslator.scala index 26a030b89..31c948e1a 100644 --- a/shared/src/main/scala/io/kaitai/struct/translators/CSharpTranslator.scala +++ b/shared/src/main/scala/io/kaitai/struct/translators/CSharpTranslator.scala @@ -17,6 +17,8 @@ class CSharpTranslator(provider: TypeProvider, importList: ImportList) extends B override def doByteArrayLiteral(arr: Seq[Byte]): String = s"new byte[] { ${arr.map(_ & 0xff).mkString(", ")} }" + override def doByteArrayNonLiteral(elts: Seq[Ast.expr]): String = + s"new byte[] { ${elts.map(translate).mkString(", ")} }" override val asciiCharQuoteMap: Map[Char, String] = Map( '\t' -> "\\t", @@ -81,8 +83,8 @@ class CSharpTranslator(provider: TypeProvider, importList: ImportList) extends B s"${translate(container)}[${translate(idx)}]" override def doIfExp(condition: expr, ifTrue: expr, ifFalse: expr): String = s"(${translate(condition)} ? ${translate(ifTrue)} : ${translate(ifFalse)})" - override def doCast(value: Ast.expr, typeName: String): String = - s"((${Utils.upperCamelCase(typeName)}) (${translate(value)}))" + override def doCast(value: Ast.expr, typeName: DataType): String = + s"((${CSharpCompiler.kaitaiType2NativeType(typeName)}) (${translate(value)}))" // Predefined methods of various types override def strToInt(s: expr, base: expr): String = { diff --git a/shared/src/main/scala/io/kaitai/struct/translators/CommonArraysAndCast.scala b/shared/src/main/scala/io/kaitai/struct/translators/CommonArraysAndCast.scala new file mode 100644 index 000000000..945ec52b8 --- /dev/null +++ b/shared/src/main/scala/io/kaitai/struct/translators/CommonArraysAndCast.scala @@ -0,0 +1,101 @@ +package io.kaitai.struct.translators + +import io.kaitai.struct.datatype.DataType +import io.kaitai.struct.datatype.DataType._ +import io.kaitai.struct.exprlang.Ast +import io.kaitai.struct.precompile.TypeMismatchError + +/** + * Common implementation of arrays translations: + * + * * type guessing + * * type enforcing with a cast + * * rendering of byte arrays and true arrays + * * call to actual casting implementation + * + * @tparam T translation result type + */ +trait CommonArraysAndCast[T] extends TypeDetector { + /** + * Processes elements inside a given [[Ast.expr.List]] element to render them + * as either byte array literal or true array. + * @param values elements from a list + * @return translation result + */ + def doGuessArrayLiteral(values: Seq[Ast.expr]): T = { + val elementType = detectArrayType(values) + elementType match { + case Int1Type(_) => + val literalBytes: Seq[Byte] = values.map { + case Ast.expr.IntNum(x) => + if (x < 0 || x > 0xff) { + throw new TypeMismatchError(s"got a weird byte value in byte array: $x") + } else { + x.toByte + } + case n => + throw new TypeMismatchError(s"got $n in byte array, unable to put it literally") + } + doByteArrayLiteral(literalBytes) + case _ => + doArrayLiteral(elementType, values) + } + } + + /** + * Processes an [[Ast.expr.CastToType]] element, by checking if + * this is an literal array type enforcement cast first, and + * rendering it accordingly as proper literal, or invoking + * the normal [[doCast]] otherwise. + * @param v CastToType element + * @return translation result + */ + def doCastOrArray(v: Ast.expr.CastToType): T = { + val castToType = detectCastType(v.typeName) + + v.value match { + case array: Ast.expr.List => + // Special handling for literal arrays: if cast is present, + // then we don't need to guess the data type + castToType match { + case _: BytesType => + doByteArray(array.elts) + case ArrayType(elType) => + doArrayLiteral(elType, array.elts) + case _ => + // No luck, this is some kind of weird cast, not a type enforcement; + // Just do it and let real type casting deal with it. + doCast(v.value, castToType) + } + case _ => + doCast(v.value, castToType) + } + } + + def doCast(value: Ast.expr, typeName: DataType): T + def doArrayLiteral(t: DataType, value: Seq[Ast.expr]): T + def doByteArrayLiteral(arr: Seq[Byte]): T + def doByteArrayNonLiteral(elts: Seq[Ast.expr]): T + + private def doByteArray(elts: Seq[Ast.expr]): T = { + valuesAsByteArrayLiteral(elts) match { + case Some(arr) => + doByteArrayLiteral(arr) + case None => + doByteArrayNonLiteral(elts) + } + } + + private def valuesAsByteArrayLiteral(elts: Seq[Ast.expr]): Option[Seq[Byte]] = { + Some(elts.map { + case Ast.expr.IntNum(x) => + if (x < 0 || x > 0xff) { + return None + } else { + x.toByte + } + case _ => + return None + }) + } +} diff --git a/shared/src/main/scala/io/kaitai/struct/translators/CommonMethods.scala b/shared/src/main/scala/io/kaitai/struct/translators/CommonMethods.scala index b71faa868..511080510 100644 --- a/shared/src/main/scala/io/kaitai/struct/translators/CommonMethods.scala +++ b/shared/src/main/scala/io/kaitai/struct/translators/CommonMethods.scala @@ -5,6 +5,13 @@ import io.kaitai.struct.exprlang.Ast import io.kaitai.struct.precompile.TypeMismatchError abstract trait CommonMethods[T] extends TypeDetector { + /** + * Translates a certain attribute call (as in `foo.bar`) into a rendition + * of expression in certain target language. + * @note Must be kept in sync with [[TypeDetector.detectAttributeType]] + * @param call attribute call expression to translate + * @return result of translation as [[T]] + */ def translateAttribute(call: Ast.expr.Attribute): T = { val attr = call.attr val value = call.value @@ -12,6 +19,10 @@ abstract trait CommonMethods[T] extends TypeDetector { valType match { case ut: UserType => userTypeField(ut, value, attr.name) + case _: BytesType => + attr.name match { + case "length" => bytesLength(value) + } case _: StrType => attr.name match { case "length" => strLength(value) @@ -53,6 +64,13 @@ abstract trait CommonMethods[T] extends TypeDetector { } } + /** + * Translates a certain function call (as in `foo.bar(arg1, arg2)`) into a + * rendition of expression in certain target language. + * @note Must be kept in sync with [[TypeDetector.detectCallType]] + * @param call function call expression to translate + * @return result of translation as [[T]] + */ def translateCall(call: Ast.expr.Call): T = { val func = call.func val args = call.args @@ -72,6 +90,8 @@ abstract trait CommonMethods[T] extends TypeDetector { def userTypeField(ut: UserType, value: Ast.expr, name: String): T + def bytesLength(b: Ast.expr): T = ??? + def strLength(s: Ast.expr): T def strReverse(s: Ast.expr): T def strToInt(s: Ast.expr, base: Ast.expr): T diff --git a/shared/src/main/scala/io/kaitai/struct/translators/ConstructTranslator.scala b/shared/src/main/scala/io/kaitai/struct/translators/ConstructTranslator.scala new file mode 100644 index 000000000..1476d214c --- /dev/null +++ b/shared/src/main/scala/io/kaitai/struct/translators/ConstructTranslator.scala @@ -0,0 +1,31 @@ +package io.kaitai.struct.translators + +import io.kaitai.struct.ImportList +import io.kaitai.struct.exprlang.Ast +import io.kaitai.struct.format.Identifier + +class ConstructTranslator(provider: TypeProvider, importList: ImportList) extends PythonTranslator(provider, importList) { + override def doLocalName(s: String) = { + s match { + case Identifier.ITERATOR => "obj_" + case Identifier.INDEX => "i" + case Identifier.ROOT => "this._root" + case Identifier.IO => "_io" + case _ => s"this.${doName(s)}" + } + } + + override def doName(s: String) = { + s match { + case Identifier.PARENT => "_" + case _ => s + } + } + + override def kaitaiStreamSize(value: Ast.expr): String = + s"stream_size(${translate(value)})" + override def kaitaiStreamEof(value: Ast.expr): String = + s"stream_iseof(${translate(value)})" + override def kaitaiStreamPos(value: Ast.expr): String = + s"stream_tell(${translate(value)})" +} diff --git a/shared/src/main/scala/io/kaitai/struct/translators/CppTranslator.scala b/shared/src/main/scala/io/kaitai/struct/translators/CppTranslator.scala index 2ee150df3..fb4b64e37 100644 --- a/shared/src/main/scala/io/kaitai/struct/translators/CppTranslator.scala +++ b/shared/src/main/scala/io/kaitai/struct/translators/CppTranslator.scala @@ -2,17 +2,68 @@ package io.kaitai.struct.translators import java.nio.charset.Charset -import io.kaitai.struct.{ImportList, Utils} -import io.kaitai.struct.exprlang.Ast -import io.kaitai.struct.exprlang.Ast.expr +import io.kaitai.struct.CppRuntimeConfig.{RawPointers, SharedPointers, UniqueAndRawPointers} import io.kaitai.struct.datatype.DataType import io.kaitai.struct.datatype.DataType._ +import io.kaitai.struct.exprlang.Ast +import io.kaitai.struct.exprlang.Ast.expr import io.kaitai.struct.format.Identifier import io.kaitai.struct.languages.CppCompiler +import io.kaitai.struct.{ImportList, RuntimeConfig, Utils} -class CppTranslator(provider: TypeProvider, importListSrc: ImportList) extends BaseTranslator(provider) { +class CppTranslator(provider: TypeProvider, importListSrc: ImportList, config: RuntimeConfig) extends BaseTranslator(provider) { val CHARSET_UTF8 = Charset.forName("UTF-8") + /** + * Handles integer literals for C++ by appending relevant suffix to + * decimal notation. + * + * Note that suffixes essentially mean "long", "unsigned long", + * and "unsigned long long", which are not really guaranteed to match + * `int32_t`, `uint32_t` and `uint64_t`, but it would work for majority + * of current compilers. + * + * For reference, ranges of integers that are used in this conversion are: + * + * * int32_t (no suffix): –2147483648..2147483647 + * * uint32_t (UL): 0..4294967295 + * * int64_t (LL): -9223372036854775808..9223372036854775807 + * * uint64_t (ULL): 0..18446744073709551615 + * + * Merging all these ranges, we get the following decision tree: + * + * * -9223372036854775808..-2147483649 => LL + * * -2147483648..2147483647 => no suffix + * * 2147483648..4294967295 => UL + * * 4294967296..9223372036854775807 => LL + * * 9223372036854775808..18446744073709551615 => ULL + * + * Beyond these boundaries, C++ is unlikely to be able to represent + * these anyway, so we just drop the suffix and hope for the miracle. + * + * @param n integer to render + * @return rendered integer literal in C++ syntax as string + */ + override def doIntLiteral(n: BigInt): String = { + val suffix = if (n < -9223372036854775808L) { + "" // too low, no suffix would help anyway + } else if (n <= -2147483649L) { + "LL" // -9223372036854775808..–2147483649 + } else if (n <= 2147483647L) { + "" // -2147483648..2147483647 + } else if (n <= 4294967295L) { + "UL" // 2147483648..4294967295 + } else if (n <= 9223372036854775807L) { + "LL" // 4294967296..9223372036854775807 + } else if (n <= Utils.MAX_UINT64) { + "ULL" // 9223372036854775808..18446744073709551615 + } else { + "" // too high, no suffix would help anyway + } + + s"$n$suffix" + } + /** * Handles string literal for C++ by wrapping a C `const char*`-style string * into a std::string constructor. Note that normally std::string @@ -77,7 +128,8 @@ class CppTranslator(provider: TypeProvider, importListSrc: ImportList) extends B } override def doEnumByLabel(enumType: List[String], label: String): String = - (enumType.last + "_" + label).toUpperCase + CppCompiler.types2class(enumType.dropRight(1)) + "::" + + (enumType.last + "_" + label).toUpperCase override def doEnumById(enumType: List[String], id: String): String = s"static_cast<${CppCompiler.types2class(enumType)}>($id)" @@ -95,8 +147,20 @@ class CppTranslator(provider: TypeProvider, importListSrc: ImportList) extends B s"${translate(container)}->at(${translate(idx)})" override def doIfExp(condition: expr, ifTrue: expr, ifFalse: expr): String = s"((${translate(condition)}) ? (${translate(ifTrue)}) : (${translate(ifFalse)}))" - override def doCast(value: Ast.expr, typeName: String): String = - s"static_cast<${CppCompiler.types2class(List(typeName))}*>(${translate(value)})" + override def doCast(value: Ast.expr, typeName: DataType): String = + config.cppConfig.pointers match { + case RawPointers | UniqueAndRawPointers => + cppStaticCast(value, typeName) + case SharedPointers => + typeName match { + case ut: UserType => + s"std::static_pointer_cast<${CppCompiler.types2class(ut.classSpec.get.name)}>(${translate(value)})" + case _ => cppStaticCast(value, typeName) + } + } + + def cppStaticCast(value: Ast.expr, typeName: DataType): String = + s"static_cast<${CppCompiler.kaitaiType2NativeType(config.cppConfig, typeName)}>(${translate(value)})" // Predefined methods of various types override def strToInt(s: expr, base: expr): String = { @@ -124,6 +188,8 @@ class CppTranslator(provider: TypeProvider, importListSrc: ImportList) extends B } override def bytesToStr(bytesExpr: String, encoding: Ast.expr): String = s"${CppCompiler.kstreamName}::bytes_to_str($bytesExpr, ${translate(encoding)})" + override def bytesLength(b: Ast.expr): String = + s"${translate(b)}.length()" override def strLength(s: expr): String = s"${translate(s)}.length()" override def strReverse(s: expr): String = diff --git a/shared/src/main/scala/io/kaitai/struct/translators/GoTranslator.scala b/shared/src/main/scala/io/kaitai/struct/translators/GoTranslator.scala index b4be487eb..358c1e3b1 100644 --- a/shared/src/main/scala/io/kaitai/struct/translators/GoTranslator.scala +++ b/shared/src/main/scala/io/kaitai/struct/translators/GoTranslator.scala @@ -1,8 +1,9 @@ package io.kaitai.struct.translators +import io.kaitai.struct.datatype.DataType import io.kaitai.struct.datatype.DataType._ import io.kaitai.struct.exprlang.Ast -import io.kaitai.struct.format.{ClassSpec, Identifier} +import io.kaitai.struct.format.Identifier import io.kaitai.struct.languages.GoCompiler import io.kaitai.struct.precompile.TypeMismatchError import io.kaitai.struct.{ImportList, StringLanguageOutputWriter, Utils} @@ -16,8 +17,11 @@ class GoTranslator(out: StringLanguageOutputWriter, provider: TypeProvider, impo with AbstractTranslator with CommonLiterals with CommonOps + with CommonArraysAndCast[TranslatorResult] with CommonMethods[TranslatorResult] { + import io.kaitai.struct.languages.GoCompiler._ + var returnRes: Option[String] = None override def translate(v: Ast.expr): String = resToStr(translateExpr(v)) @@ -37,8 +41,8 @@ class GoTranslator(out: StringLanguageOutputWriter, provider: TypeProvider, impo trStringLiteral(s) case Ast.expr.Bool(n) => trBoolLiteral(n) - -// case Ast.expr.BoolOp(op, values) => + case Ast.expr.BoolOp(op, values) => + trBooleanOp(op, values) case Ast.expr.BinOp(left: Ast.expr, op: Ast.operator, right: Ast.expr) => (detectType(left), detectType(right), op) match { case (_: NumericType, _: NumericType, _) => @@ -48,16 +52,44 @@ class GoTranslator(out: StringLanguageOutputWriter, provider: TypeProvider, impo case (ltype, rtype, _) => throw new TypeMismatchError(s"can't do $ltype $op $rtype") } -// case Ast.expr.UnaryOp(op, operand) => -// case Ast.expr.IfExp(condition, ifTrue, ifFalse) => -// case Ast.expr.Compare(left, ops, right) => -// case Ast.expr.EnumByLabel(enumName, label) => -// case Ast.expr.EnumById(enumName, id) => -// case Ast.expr.CastToType(value, typeName) => -// case Ast.expr.Subscript(value, idx) => + case Ast.expr.UnaryOp(op, operand) => + ResultString(unaryOp(op) + (operand match { + case Ast.expr.IntNum(_) | Ast.expr.FloatNum(_) => + translate(operand) + case _ => + s"(${translate(operand)})" + })) + case Ast.expr.IfExp(condition, ifTrue, ifFalse) => + trIfExp(condition, ifTrue, ifFalse) + case Ast.expr.Compare(left, op, right) => + (detectType(left), detectType(right)) match { + case (_: NumericType, _: NumericType) => + trNumericCompareOp(left, op, right) + case (_: StrType, _: StrType) => + trStrCompareOp(left, op, right) + case (_: BytesType, _: BytesType) => + trBytesCompareOp(left, op, right) + case (_: BooleanType, _: BooleanType) => + trNumericCompareOp(left, op, right) + case (_: EnumType, _: EnumType) => + trNumericCompareOp(left, op, right) + case (ltype, rtype) => + throw new TypeMismatchError(s"can't do $ltype $op $rtype") + } + case Ast.expr.EnumById(enumType, id, inType) => + val enumSpec = provider.resolveEnum(inType, enumType.name) + trEnumById(enumSpec.name, translate(id)) + case Ast.expr.EnumByLabel(enumType, label, inType) => + val enumSpec = provider.resolveEnum(inType, enumType.name) + trEnumByLabel(enumSpec.name, label.name) + case ctt: Ast.expr.CastToType => + doCastOrArray(ctt) + case Ast.expr.Subscript(container, idx) => + trSubscript(container, idx) case Ast.expr.Name(name: Ast.identifier) => trLocalName(name.name) -// case Ast.expr.List(elts) => + case Ast.expr.List(elts) => + doGuessArrayLiteral(elts) case call: Ast.expr.Attribute => translateAttribute(call) case call: Ast.expr.Call => @@ -70,30 +102,68 @@ class GoTranslator(out: StringLanguageOutputWriter, provider: TypeProvider, impo def trStringLiteral(s: String): TranslatorResult = ResultString(doStringLiteral(s)) def trBoolLiteral(n: Boolean): TranslatorResult = ResultString(doBoolLiteral(n)) - def trNumericBinOp(left: Ast.expr, op: Ast.operator, right: Ast.expr) = - ResultString(numericBinOp(left, op, right)) + def trBooleanOp(op: Ast.boolop, values: Seq[Ast.expr]) = + ResultString(doBooleanOp(op, values)) + + def trNumericBinOp(left: Ast.expr, op: Ast.operator, right: Ast.expr): TranslatorResult = { + (detectType(left), detectType(right), op) match { + case (t1: IntType, t2: IntType, Ast.operator.Mod) => + val v1 = allocateLocalVar() + out.puts(s"${localVarName(v1)} := ${translate(left)} % ${translate(right)}") + out.puts(s"if ${localVarName(v1)} < 0 {") + out.inc + out.puts(s"${localVarName(v1)} += ${translate(right)}") + out.dec + out.puts("}") + ResultLocalVar(v1) + case _ => + ResultString(numericBinOp(left, op, right)) + } + } def trStrConcat(left: Ast.expr, right: Ast.expr): TranslatorResult = ResultString(translate(left) + " + " + translate(right)) -// override def doArrayLiteral(t: DataType, value: Seq[Ast.expr]): String = { -// val javaType = JavaCompiler.kaitaiType2JavaTypeBoxed(t) -// val commaStr = value.map((v) => translate(v)).mkString(", ") -// s"new ArrayList<$javaType>(Arrays.asList($commaStr))" -// } -// -// override def doByteArrayLiteral(arr: Seq[Byte]): String = -// s"new byte[] { ${arr.mkString(", ")} }" + def trNumericCompareOp(left: Ast.expr, op: Ast.cmpop, right: Ast.expr): TranslatorResult = + ResultString(doNumericCompareOp(left, op, right)) - override def numericBinOp(left: Ast.expr, op: Ast.operator, right: Ast.expr) = { - (detectType(left), detectType(right), op) match { - case (_: IntType, _: IntType, Ast.operator.Mod) => - s"${GoCompiler.kstreamName}.mod(${translate(left)}, ${translate(right)})" + def trStrCompareOp(left: Ast.expr, op: Ast.cmpop, right: Ast.expr): TranslatorResult = + ResultString(doStrCompareOp(left, op, right)) + + def trBytesCompareOp(left: Ast.expr, op: Ast.cmpop, right: Ast.expr): TranslatorResult = { + importList.add("bytes") + op match { + case Ast.cmpop.Eq => + ResultString(s"bytes.Equal(${translate(left)}, ${translate(right)})") case _ => - super.numericBinOp(left, op, right) + ResultString(s"(bytes.Compare(${translate(left)}, ${translate(right)}) ${cmpOp(op)} 0)") + } + } + + override def doIntLiteral(n: BigInt): String = { + if (n < -9223372036854775808L) { + s"$n" // too low, no type conversion would help anyway + } else if (n <= -2147483649L) { + s"int64($n)" // -9223372036854775808..–2147483649 + } else if (n <= 2147483647L) { + s"$n" // -2147483648..2147483647 + } else if (n <= 4294967295L) { + s"uint32($n)" // 2147483648..4294967295 + } else if (n <= 9223372036854775807L) { + s"int64($n)" // 4294967296..9223372036854775807 + } else if (n <= Utils.MAX_UINT64) { + s"uint64($n)" // 9223372036854775808..18446744073709551615 + } else { + s"$n" // too high, no type conversion would help anyway } } + override def unaryOp(op: Ast.unaryop): String = op match { + case Ast.unaryop.Invert => "^" + case Ast.unaryop.Minus => "-" + case Ast.unaryop.Not => "!" + } + def trLocalName(s: String): TranslatorResult = { s match { case Identifier.ROOT | @@ -124,25 +194,29 @@ class GoTranslator(out: StringLanguageOutputWriter, provider: TypeProvider, impo "_buf" } -// override def doEnumByLabel(enumTypeAbs: List[String], label: String): String = -// s"${enumClass(enumTypeAbs)}.${label.toUpperCase}" -// override def doEnumById(enumTypeAbs: List[String], id: String): String = -// s"${enumClass(enumTypeAbs)}.byId($id)" + def trSubscript(container: Ast.expr, idx: Ast.expr) = + ResultString(s"${translate(container)}[${translate(idx)}]") - def enumClass(enumTypeAbs: List[String]): String = { - val enumTypeRel = Utils.relClass(enumTypeAbs, provider.nowClass.name) - enumTypeRel.map((x) => Utils.upperCamelCase(x)).mkString(".") + def trIfExp(condition: Ast.expr, ifTrue: Ast.expr, ifFalse: Ast.expr): ResultLocalVar = { + val v1 = allocateLocalVar() + val typ = detectType(ifTrue) + out.puts(s"var ${localVarName(v1)} ${GoCompiler.kaitaiType2NativeType(typ)};") + out.puts(s"if (${translate(condition)}) {") + out.inc + out.puts(s"${localVarName(v1)} = ${translate(ifTrue)}") + out.dec + out.puts("} else {") + out.inc + out.puts(s"${localVarName(v1)} = ${translate(ifFalse)}") + out.dec + out.puts("}") + ResultLocalVar(v1) } - override def doStrCompareOp(left: Ast.expr, op: Ast.cmpop, right: Ast.expr) = { - if (op == Ast.cmpop.Eq) { - s"${translate(left)}.equals(${translate(right)})" - } else if (op == Ast.cmpop.NotEq) { - s"!(${translate(left)}).equals(${translate(right)})" - } else { - s"(${translate(left)}.compareTo(${translate(right)}) ${cmpOp(op)} 0)" - } - } + def trEnumByLabel(enumTypeAbs: List[String], label: String) = + ResultString(GoCompiler.enumToStr(enumTypeAbs, label)) + def trEnumById(enumTypeAbs: List[String], id: String) = + ResultString(s"${types2class(enumTypeAbs)}($id)") override def doBytesCompareOp(left: Ast.expr, op: Ast.cmpop, right: Ast.expr): String = { op match { @@ -155,20 +229,18 @@ class GoTranslator(out: StringLanguageOutputWriter, provider: TypeProvider, impo } } -// override def doSubscript(container: Ast.expr, idx: Ast.expr): String = -// s"${translate(container)}.get((int) ${translate(idx)})" -// override def doIfExp(condition: Ast.expr, ifTrue: Ast.expr, ifFalse: Ast.expr): String = -// s"(${translate(condition)} ? ${translate(ifTrue)} : ${translate(ifFalse)})" -// override def doCast(value: Ast.expr, typeName: String): String = -// s"((${Utils.upperCamelCase(typeName)}) (${translate(value)}))" + override def doCast(value: Ast.expr, typeName: DataType): TranslatorResult = ??? + + override def doArrayLiteral(t: DataType, value: Seq[Ast.expr]) = + ResultString(s"[]${GoCompiler.kaitaiType2NativeType(t)}{${value.map(translate).mkString(", ")}}") + + override def doByteArrayLiteral(arr: Seq[Byte]): TranslatorResult = + ResultString("[]uint8{" + arr.map(_ & 0xff).mkString(", ") + "}") + + override def doByteArrayNonLiteral(elts: Seq[Ast.expr]): TranslatorResult = + ResultString("[]uint8{" + elts.map(translate).mkString(", ") + "}") // Predefined methods of various types -// override def strToInt(s: Ast.expr, base: Ast.expr): String = -// s"Long.parseLong(${translate(s)}, ${translate(base)})" -// override def enumToInt(v: Ast.expr, et: EnumType): String = -// s"${translate(v)}.id()" -// override def intToStr(i: Ast.expr, base: Ast.expr): String = -// s"Long.toString(${translate(i)}, ${translate(base)})" val IMPORT_CHARMAP = "golang.org/x/text/encoding/charmap" @@ -182,6 +254,9 @@ class GoTranslator(out: StringLanguageOutputWriter, provider: TypeProvider, impo "big5" -> ("traditionalchinese.Big5", "golang.org/x/text/encoding/traditionalchinese") ) + override def bytesToStr(value: Ast.expr, expr: Ast.expr): TranslatorResult = + bytesToStr(translate(value), expr) + def bytesToStr(bytesExpr: String, encoding: Ast.expr): TranslatorResult = { val enc = encoding match { case Ast.expr.Str(s) => s @@ -204,21 +279,20 @@ class GoTranslator(out: StringLanguageOutputWriter, provider: TypeProvider, impo } } -// override def strLength(s: Ast.expr): String = -// s"${translate(s)}.length()" // override def strReverse(s: Ast.expr): String = // s"new StringBuilder(${translate(s)}).reverse().toString()" // override def strSubstring(s: Ast.expr, from: Ast.expr, to: Ast.expr): String = // s"${translate(s)}.substring(${translate(from)}, ${translate(to)})" -// override def arrayFirst(a: Ast.expr): String = -// s"${translate(a)}.get(0)" -// override def arrayLast(a: Ast.expr): String = { -// val v = translate(a) -// s"$v.get($v.size() - 1)" -// } -// override def arraySize(a: Ast.expr): String = -// s"${translate(a)}.size()" + override def arrayFirst(a: Ast.expr): TranslatorResult = + ResultString(s"${translate(a)}[0]") + override def arrayLast(a: Ast.expr): ResultString = { + val v = allocateLocalVar() + out.puts(s"${localVarName(v)} := ${translate(a)}") + ResultString(s"${localVarName(v)}[len(${localVarName(v)}) - 1]") + } + override def arraySize(a: Ast.expr): TranslatorResult = + ResultString(s"len(${translate(a)})") // override def arrayMin(a: Ast.expr): String = // s"Collections.min(${translate(a)})" // override def arrayMax(a: Ast.expr): String = @@ -248,38 +322,81 @@ class GoTranslator(out: StringLanguageOutputWriter, provider: TypeProvider, impo ResultString(s"utf8.RuneCountInString(${translate(s)})") } - override def strReverse(s: Ast.expr): TranslatorResult = ??? - - override def strToInt(s: Ast.expr, base: Ast.expr): TranslatorResult = ??? + override def strReverse(s: Ast.expr): TranslatorResult = { + ResultString(s"kaitai.StringReverse(${translate(s)})") + } - override def strSubstring(s: Ast.expr, from: Ast.expr, to: Ast.expr): TranslatorResult = ??? + override def strToInt(s: Ast.expr, base: Ast.expr): TranslatorResult = { + importList.add("strconv") + ResultString(s"strconv.ParseInt(${translate(s)}, ${translate(base)}, 0)") + } - override def bytesToStr(value: Ast.expr, expr: Ast.expr): TranslatorResult = ??? + override def strSubstring(s: Ast.expr, from: Ast.expr, to: Ast.expr): TranslatorResult = { + ResultString(s"${translate(s)}[${translate(from)}:${translate(to)}]") + } - override def intToStr(value: Ast.expr, num: Ast.expr): TranslatorResult = ??? + override def intToStr(value: Ast.expr, base: Ast.expr): TranslatorResult = { + importList.add("strconv") + ResultString(s"strconv.FormatInt(int64(${translate(value)}), ${translate(base)})") + } - override def floatToInt(value: Ast.expr): TranslatorResult = + override def floatToInt(value: Ast.expr) = ResultString(s"int(${translate(value)})") - override def kaitaiStreamSize(value: Ast.expr): TranslatorResult = ??? + override def kaitaiStreamSize(value: Ast.expr) = + outVarCheckRes(s"${translate(value)}.Size()") - override def kaitaiStreamEof(value: Ast.expr): TranslatorResult = ??? + override def kaitaiStreamEof(value: Ast.expr) = + outVarCheckRes(s"${translate(value)}.EOF()") - override def kaitaiStreamPos(value: Ast.expr): TranslatorResult = ??? + override def kaitaiStreamPos(value: Ast.expr) = + outVarCheckRes(s"${translate(value)}.Pos()") - override def arrayFirst(a: Ast.expr): TranslatorResult = ??? - - override def arrayLast(a: Ast.expr): TranslatorResult = ??? - - override def arraySize(a: Ast.expr): TranslatorResult = ??? - - override def arrayMin(a: Ast.expr): TranslatorResult = ??? + override def arrayMin(a: Ast.expr): ResultLocalVar = { + val min = allocateLocalVar() + val value = allocateLocalVar() + out.puts(s"${localVarName(min)} := ${translate(a)}[0]") + out.puts(s"for _, ${localVarName(value)} := range ${translate(a)} {") + out.inc + out.puts(s"if ${localVarName(min)} > ${localVarName(value)} {") + out.inc + out.puts(s"${localVarName(min)} = ${localVarName(value)}") + out.dec + out.puts("}") + out.dec + out.puts("}") + ResultLocalVar(min) + } - override def arrayMax(a: Ast.expr): TranslatorResult = ??? + override def arrayMax(a: Ast.expr): ResultLocalVar = { + val max = allocateLocalVar() + val value = allocateLocalVar() + out.puts(s"${localVarName(max)} := ${translate(a)}[0]") + out.puts(s"for _, ${localVarName(value)} := range ${translate(a)} {") + out.inc + out.puts(s"if ${localVarName(max)} < ${localVarName(value)} {") + out.inc + out.puts(s"${localVarName(max)} = ${localVarName(value)}") + out.dec + out.puts("}") + out.dec + out.puts("}") + ResultLocalVar(max) + } - override def enumToInt(value: Ast.expr, et: EnumType): TranslatorResult = ??? + override def enumToInt(value: Ast.expr, et: EnumType) = + translateExpr(value) - override def boolToInt(value: Ast.expr): TranslatorResult = ??? + override def boolToInt(value: Ast.expr): ResultLocalVar = { + val v = allocateLocalVar() + out.puts(s"${localVarName(v)} := 0") + out.puts(s"if ${translate(value)} {") + out.inc + out.puts(s"${localVarName(v)} = 1") + out.dec + out.puts("}") + ResultLocalVar(v) + } def userType(dataType: UserType, io: String) = { val v = allocateLocalVar() diff --git a/shared/src/main/scala/io/kaitai/struct/translators/JavaScriptTranslator.scala b/shared/src/main/scala/io/kaitai/struct/translators/JavaScriptTranslator.scala index 3605adea4..c938009e3 100644 --- a/shared/src/main/scala/io/kaitai/struct/translators/JavaScriptTranslator.scala +++ b/shared/src/main/scala/io/kaitai/struct/translators/JavaScriptTranslator.scala @@ -8,6 +8,9 @@ import io.kaitai.struct.format.Identifier import io.kaitai.struct.languages.JavaScriptCompiler class JavaScriptTranslator(provider: TypeProvider) extends BaseTranslator(provider) { + override def doByteArrayNonLiteral(elts: Seq[Ast.expr]): String = + s"new Uint8Array([${elts.map(translate).mkString(", ")}])" + /** * JavaScript rendition of common control character that would use hex form, * not octal. "Octal" control character string literals might be accepted diff --git a/shared/src/main/scala/io/kaitai/struct/translators/JavaTranslator.scala b/shared/src/main/scala/io/kaitai/struct/translators/JavaTranslator.scala index 9d7dea40e..84af611ad 100644 --- a/shared/src/main/scala/io/kaitai/struct/translators/JavaTranslator.scala +++ b/shared/src/main/scala/io/kaitai/struct/translators/JavaTranslator.scala @@ -10,46 +10,29 @@ import io.kaitai.struct.languages.JavaCompiler class JavaTranslator(provider: TypeProvider, importList: ImportList) extends BaseTranslator(provider) { override def doIntLiteral(n: BigInt): String = { - val literal = n.toString + val literal = if (n > Long.MaxValue) { + "0x" + n.toString(16) + } else { + n.toString + } val suffix = if (n > Int.MaxValue) "L" else "" - s"${literal}${suffix}" - } - - /** - * Wrapper for {@link #doIntLiteral(BigInt)} if {@code CalcIntType} is known to be needed. - *

- * {@link #doIntLiteral(BigInt)} doesn't work for statements like {@code new ArrayList(Arrays.asList(0, 1, 100500))} - * because it doesn't know that a {@code long} is always needed, even if the value of the number - * wouldn't need it. Java by default assumes {@code int} for numeric literals and would create an - * array with a different type than required. - *

- */ - def doIntLiteralCalcIntType(n: BigInt): String = { - val literal = doIntLiteral(n) - val isLong = JavaCompiler.kaitaiType2JavaTypePrim(CalcIntType) == "long" - val suffixNeeded = isLong && !literal.endsWith("L") - val suffix = if (suffixNeeded) "L" else "" - - s"${literal}${suffix}" + s"$literal$suffix" } override def doArrayLiteral(t: DataType, value: Seq[expr]): String = { val javaType = JavaCompiler.kaitaiType2JavaTypeBoxed(t) - val values = t match { - case CalcIntType => value.map((v) => v match { - case Ast.expr.IntNum(n) => doIntLiteralCalcIntType(n) - case _ => throw new UnsupportedOperationException("CalcIntType should only be used for numbers.") - }) - case _ => value.map((v) => translate(v)) - } - val commaStr = values.mkString(", ") + val commaStr = value.map((v) => translate(v)).mkString(", ") + importList.add("java.util.ArrayList") + importList.add("java.util.Arrays") s"new ArrayList<$javaType>(Arrays.asList($commaStr))" } override def doByteArrayLiteral(arr: Seq[Byte]): String = s"new byte[] { ${arr.mkString(", ")} }" + override def doByteArrayNonLiteral(elts: Seq[expr]): String = + s"new byte[] { ${elts.map(translate).mkString(", ")} }" override def numericBinOp(left: Ast.expr, op: Ast.operator, right: Ast.expr) = { (detectType(left), detectType(right), op) match { @@ -118,8 +101,8 @@ class JavaTranslator(provider: TypeProvider, importList: ImportList) extends Bas override def doIfExp(condition: expr, ifTrue: expr, ifFalse: expr): String = s"(${translate(condition)} ? ${translate(ifTrue)} : ${translate(ifFalse)})" - override def doCast(value: Ast.expr, typeName: String): String = - s"((${Utils.upperCamelCase(typeName)}) (${translate(value)}))" + override def doCast(value: Ast.expr, typeName: DataType): String = + s"((${JavaCompiler.kaitaiType2JavaType(typeName)}) (${translate(value)}))" // Predefined methods of various types override def strToInt(s: expr, base: expr): String = @@ -134,6 +117,8 @@ class JavaTranslator(provider: TypeProvider, importList: ImportList) extends Bas importList.add("java.nio.charset.Charset") s"new String($bytesExpr, Charset.forName(${translate(encoding)}))" } + override def bytesLength(b: Ast.expr): String = + s"${translate(b)}.length" override def strLength(s: expr): String = s"${translate(s)}.length()" override def strReverse(s: expr): String = diff --git a/shared/src/main/scala/io/kaitai/struct/translators/PHPTranslator.scala b/shared/src/main/scala/io/kaitai/struct/translators/PHPTranslator.scala index 4b0aa9699..cfd59a8e9 100644 --- a/shared/src/main/scala/io/kaitai/struct/translators/PHPTranslator.scala +++ b/shared/src/main/scala/io/kaitai/struct/translators/PHPTranslator.scala @@ -10,6 +10,8 @@ import io.kaitai.struct.{RuntimeConfig, Utils} class PHPTranslator(provider: TypeProvider, config: RuntimeConfig) extends BaseTranslator(provider) { override def doByteArrayLiteral(arr: Seq[Byte]): String = "\"" + Utils.hexEscapeByteArray(arr) + "\"" + override def doByteArrayNonLiteral(elts: Seq[Ast.expr]): String = + s"pack('C*', ${elts.map(translate).mkString(", ")})" // http://php.net/manual/en/language.types.string.php#language.types.string.syntax.double override val asciiCharQuoteMap: Map[Char, String] = Map( @@ -95,6 +97,8 @@ class PHPTranslator(provider: TypeProvider, config: RuntimeConfig) extends BaseT } override def bytesToStr(bytesExpr: String, encoding: Ast.expr): String = s"${PHPCompiler.kstreamName}::bytesToStr($bytesExpr, ${translate(encoding)})" + override def bytesLength(b: Ast.expr): String = + s"strlen(${translate(b)})" override def strLength(s: expr): String = s"strlen(${translate(s)})" override def strReverse(s: expr): String = diff --git a/shared/src/main/scala/io/kaitai/struct/translators/PerlTranslator.scala b/shared/src/main/scala/io/kaitai/struct/translators/PerlTranslator.scala index cc6c30ef7..92eb49cfd 100644 --- a/shared/src/main/scala/io/kaitai/struct/translators/PerlTranslator.scala +++ b/shared/src/main/scala/io/kaitai/struct/translators/PerlTranslator.scala @@ -49,6 +49,8 @@ class PerlTranslator(provider: TypeProvider, importList: ImportList) extends Bas override def doByteArrayLiteral(arr: Seq[Byte]): String = s"pack('C*', (${arr.map(_ & 0xff).mkString(", ")}))" + override def doByteArrayNonLiteral(elts: Seq[Ast.expr]): String = + s"pack('C*', (${elts.map(translate).mkString(", ")}))" override def anyField(value: Ast.expr, attrName: String): String = s"${translate(value)}->${doName(attrName)}" diff --git a/shared/src/main/scala/io/kaitai/struct/translators/PythonTranslator.scala b/shared/src/main/scala/io/kaitai/struct/translators/PythonTranslator.scala index ef6fa6b63..7f44219de 100644 --- a/shared/src/main/scala/io/kaitai/struct/translators/PythonTranslator.scala +++ b/shared/src/main/scala/io/kaitai/struct/translators/PythonTranslator.scala @@ -4,7 +4,7 @@ import io.kaitai.struct.{ImportList, Utils} import io.kaitai.struct.datatype.DataType._ import io.kaitai.struct.exprlang.Ast import io.kaitai.struct.format.Identifier -import io.kaitai.struct.languages.PythonCompiler +import io.kaitai.struct.languages.{PythonCompiler, RubyCompiler} class PythonTranslator(provider: TypeProvider, importList: ImportList) extends BaseTranslator(provider) { override def numericBinOp(left: Ast.expr, op: Ast.operator, right: Ast.expr) = { @@ -36,8 +36,11 @@ class PythonTranslator(provider: TypeProvider, importList: ImportList) extends B '\b' -> "\\b" ) - override def doByteArrayLiteral(arr: Seq[Byte]): String = { + override def doByteArrayLiteral(arr: Seq[Byte]): String = "b\"" + Utils.hexEscapeByteArray(arr) + "\"" + override def doByteArrayNonLiteral(elts: Seq[Ast.expr]): String = { + importList.add("import struct") + s"struct.pack('${elts.length}b', ${elts.map(translate).mkString(", ")})" } override def doLocalName(s: String) = { @@ -51,8 +54,8 @@ class PythonTranslator(provider: TypeProvider, importList: ImportList) extends B override def doEnumByLabel(enumTypeAbs: List[String], label: String): String = s"${PythonCompiler.types2class(enumTypeAbs)}.$label" - override def doEnumById(enumTypeAbs: List[String], id: String) = - s"${PythonCompiler.types2class(enumTypeAbs)}($id)" + override def doEnumById(enumTypeAbs: List[String], id: String): String = + s"${PythonCompiler.kstreamName}.resolve_enum(${PythonCompiler.types2class(enumTypeAbs)}, $id)" override def booleanOp(op: Ast.boolop) = op match { case Ast.boolop.Or => "or" @@ -98,6 +101,8 @@ class PythonTranslator(provider: TypeProvider, importList: ImportList) extends B } override def bytesToStr(bytesExpr: String, encoding: Ast.expr): String = s"($bytesExpr).decode(${translate(encoding)})" + override def bytesLength(value: Ast.expr): String = + s"len(${translate(value)})" override def strLength(value: Ast.expr): String = s"len(${translate(value)})" override def strReverse(value: Ast.expr): String = diff --git a/shared/src/main/scala/io/kaitai/struct/translators/RubyTranslator.scala b/shared/src/main/scala/io/kaitai/struct/translators/RubyTranslator.scala index 52975cee5..c2ee9a100 100644 --- a/shared/src/main/scala/io/kaitai/struct/translators/RubyTranslator.scala +++ b/shared/src/main/scala/io/kaitai/struct/translators/RubyTranslator.scala @@ -1,5 +1,6 @@ package io.kaitai.struct.translators +import io.kaitai.struct.Utils import io.kaitai.struct.datatype.DataType.EnumType import io.kaitai.struct.exprlang.Ast import io.kaitai.struct.exprlang.Ast.expr @@ -9,6 +10,8 @@ import io.kaitai.struct.languages.RubyCompiler class RubyTranslator(provider: TypeProvider) extends BaseTranslator(provider) { override def doByteArrayLiteral(arr: Seq[Byte]): String = s"${super.doByteArrayLiteral(arr)}.pack('C*')" + override def doByteArrayNonLiteral(elts: Seq[Ast.expr]): String = + s"[${elts.map(translate).mkString(", ")}].pack('C*')" // https://github.com/ruby/ruby/blob/trunk/doc/syntax/literals.rdoc#strings // https://github.com/ruby/ruby/blob/trunk/string.c - see "rb_str_inspect" @@ -29,6 +32,7 @@ class RubyTranslator(provider: TypeProvider) extends BaseTranslator(provider) { override def doName(s: String) = { s match { + case Identifier.ITERATOR => "_it" case Identifier.INDEX => "i" // FIXME: probably would clash with attribute named "i" case _ => s } @@ -37,7 +41,29 @@ class RubyTranslator(provider: TypeProvider) extends BaseTranslator(provider) { override def doEnumByLabel(enumTypeAbs: List[String], label: String): String = s":${enumTypeAbs.last}_$label" override def doEnumById(enumType: List[String], id: String): String = - s"${RubyCompiler.kstreamName}::resolve_enum(${enumType.last.toUpperCase}, $id)" + s"${RubyCompiler.kstreamName}::resolve_enum(${enumDirectMap(enumType)}, $id)" + + def enumDirectMap(enumTypeAndName: List[String]): String = { + val enumTypeAbs = enumTypeAndName.dropRight(1) + val enumTypeName = enumTypeAndName.last.toUpperCase + + val enumTypeRel = Utils.relClass(enumTypeAbs, provider.nowClass.name) + + if (enumTypeRel.nonEmpty) { + (enumTypeRel.map((x) => Utils.upperCamelCase(x)) ++ List(enumTypeName)).mkString("::") + } else { + enumTypeName + } + } + + def enumInverseMap(et: EnumType): String = { + val enumTypeAndName = et.enumSpec.get.name + val enumDirectMap = this.enumDirectMap(enumTypeAndName) + val enumNameDirect = enumTypeAndName.last.toUpperCase + val enumNameInverse = RubyCompiler.inverseEnumName(enumNameDirect) + + enumDirectMap.replace(enumNameDirect, enumNameInverse) + } override def doSubscript(container: Ast.expr, idx: Ast.expr): String = s"${translate(container)}[${translate(idx)}]" @@ -53,13 +79,15 @@ class RubyTranslator(provider: TypeProvider) extends BaseTranslator(provider) { }) } override def enumToInt(v: Ast.expr, et: EnumType): String = - s"${RubyCompiler.inverseEnumName(et.name.last.toUpperCase)}[${translate(v)}]" + s"${enumInverseMap(et)}[${translate(v)}]" override def floatToInt(v: Ast.expr): String = s"(${translate(v)}).to_i" override def intToStr(i: Ast.expr, base: Ast.expr): String = translate(i) + s".to_s(${translate(base)})" override def bytesToStr(bytesExpr: String, encoding: Ast.expr): String = s"($bytesExpr).force_encoding(${translate(encoding)})" + override def bytesLength(b: Ast.expr): String = + s"${translate(b)}.size" override def strLength(s: Ast.expr): String = s"${translate(s)}.size" override def strReverse(s: Ast.expr): String = diff --git a/shared/src/main/scala/io/kaitai/struct/translators/RustTranslator.scala b/shared/src/main/scala/io/kaitai/struct/translators/RustTranslator.scala new file mode 100644 index 000000000..267ea06b4 --- /dev/null +++ b/shared/src/main/scala/io/kaitai/struct/translators/RustTranslator.scala @@ -0,0 +1,129 @@ +package io.kaitai.struct.translators + +import io.kaitai.struct.datatype.DataType._ +import io.kaitai.struct.exprlang.Ast +import io.kaitai.struct.exprlang.Ast.expr +import io.kaitai.struct.format.Identifier +import io.kaitai.struct.languages.RustCompiler +import io.kaitai.struct.{RuntimeConfig, Utils} + +class RustTranslator(provider: TypeProvider, config: RuntimeConfig) extends BaseTranslator(provider) { + override def doByteArrayLiteral(arr: Seq[Byte]): String = + "vec!([" + arr.map((x) => + "%0#2x".format(x & 0xff) + ).mkString(", ") + "])" + override def doByteArrayNonLiteral(elts: Seq[Ast.expr]): String = + s"pack('C*', ${elts.map(translate).mkString(", ")})" + + override val asciiCharQuoteMap: Map[Char, String] = Map( + '\t' -> "\\t", + '\n' -> "\\n", + '\r' -> "\\r", + '"' -> "\\\"", + '\\' -> "\\\\" + ) + + override def strLiteralUnicode(code: Char): String = + "\\u{%x}".format(code.toInt) + + override def numericBinOp(left: Ast.expr, op: Ast.operator, right: Ast.expr) = { + (detectType(left), detectType(right), op) match { + case (_: IntType, _: IntType, Ast.operator.Div) => + s"${translate(left)} / ${translate(right)}" + case (_: IntType, _: IntType, Ast.operator.Mod) => + s"${translate(left)} % ${translate(right)}" + case _ => + super.numericBinOp(left, op, right) + } + } + + override def doLocalName(s: String) = { + s match { + case Identifier.ITERATOR => "tmpa" + case Identifier.ITERATOR2 => "tmpb" + case Identifier.INDEX => "i" + case _ => s"self.${doName(s)}" + } + } + + override def doName(s: String) = s + + override def doEnumByLabel(enumTypeAbs: List[String], label: String): String = { + val enumClass = types2classAbs(enumTypeAbs) + s"$enumClass::${label.toUpperCase}" + } + override def doEnumById(enumTypeAbs: List[String], id: String) = + // Just an integer, without any casts / resolutions - one would have to look up constants manually + id + + override def doSubscript(container: expr, idx: expr): String = + s"${translate(container)}[${translate(idx)}]" + override def doIfExp(condition: expr, ifTrue: expr, ifFalse: expr): String = + "if " + translate(condition) + + " { " + translate(ifTrue) + " } else { " + + translate(ifFalse) + "}" + + // Predefined methods of various types + override def strConcat(left: Ast.expr, right: Ast.expr): String = + "format!(\"{}{}\", " + translate(left) + ", " + translate(right) + ")" + + override def strToInt(s: expr, base: expr): String = + translate(base) match { + case "10" => + s"${translate(s)}.parse().unwrap()" + case _ => + "panic!(\"Converting from string to int in base {} is unimplemented\"" + translate(base) + ")" + } + + override def enumToInt(v: expr, et: EnumType): String = + translate(v) + + override def boolToInt(v: expr): String = + s"${translate(v)} as i32" + + override def floatToInt(v: expr): String = + s"${translate(v)} as i32" + + override def intToStr(i: expr, base: expr): String = { + val baseStr = translate(base) + baseStr match { + case "10" => + s"${translate(i)}.to_string()" + case _ => + s"base_convert(strval(${translate(i)}), 10, $baseStr)" + } + } + override def bytesToStr(bytesExpr: String, encoding: Ast.expr): String = + translate(encoding) match { + case "\"ASCII\"" => + s"String::from_utf8_lossy($bytesExpr)" + case _ => + "panic!(\"Unimplemented encoding for bytesToStr: {}\", " + + translate(encoding) + ")" + } + override def bytesLength(b: Ast.expr): String = + s"${translate(b)}.len()" + override def strLength(s: expr): String = + s"${translate(s)}.len()" + override def strReverse(s: expr): String = + s"${translate(s)}.graphemes(true).rev().flat_map(|g| g.chars()).collect()" + override def strSubstring(s: expr, from: expr, to: expr): String = + s"${translate(s)}.substring(${translate(from)}, ${translate(to)})" + + override def arrayFirst(a: expr): String = + s"${translate(a)}.first()" + override def arrayLast(a: expr): String = + s"${translate(a)}.last()" + override def arraySize(a: expr): String = + s"${translate(a)}.len()" + override def arrayMin(a: Ast.expr): String = + s"${translate(a)}.iter().min()" + override def arrayMax(a: Ast.expr): String = + s"${translate(a)}.iter().max()" + + def types2classAbs(names: List[String]) = + names match { + case List("kaitai_struct") => RustCompiler.kstructName + case _ => RustCompiler.types2classRel(names) + } +} diff --git a/shared/src/main/scala/io/kaitai/struct/translators/TypeDetector.scala b/shared/src/main/scala/io/kaitai/struct/translators/TypeDetector.scala index 89be684ba..5f9aed1b5 100644 --- a/shared/src/main/scala/io/kaitai/struct/translators/TypeDetector.scala +++ b/shared/src/main/scala/io/kaitai/struct/translators/TypeDetector.scala @@ -3,7 +3,6 @@ package io.kaitai.struct.translators import io.kaitai.struct.datatype.DataType import io.kaitai.struct.datatype.DataType._ import io.kaitai.struct.exprlang.Ast -import io.kaitai.struct.exprlang.Ast.expr import io.kaitai.struct.precompile.{TypeMismatchError, TypeUndecidedError} /** @@ -49,13 +48,13 @@ class TypeDetector(provider: TypeProvider) { case Ast.expr.FloatNum(_) => CalcFloatType case Ast.expr.Str(_) => CalcStrType case Ast.expr.Bool(_) => CalcBooleanType - case Ast.expr.EnumByLabel(enumType, _) => + case Ast.expr.EnumByLabel(enumType, _, inType) => val t = EnumType(List(enumType.name), CalcIntType) - t.enumSpec = Some(provider.resolveEnum(enumType.name)) + t.enumSpec = Some(provider.resolveEnum(inType, enumType.name)) t - case Ast.expr.EnumById(enumType, _) => + case Ast.expr.EnumById(enumType, _, inType) => val t = EnumType(List(enumType.name), CalcIntType) - t.enumSpec = Some(provider.resolveEnum(enumType.name)) + t.enumSpec = Some(provider.resolveEnum(inType, enumType.name)) t case Ast.expr.Name(name: Ast.identifier) => provider.determineType(name.name) case Ast.expr.UnaryOp(op: Ast.unaryop, v: Ast.expr) => @@ -90,7 +89,7 @@ class TypeDetector(provider: TypeProvider) { } }) CalcBooleanType - case Ast.expr.IfExp(condition: expr, ifTrue: expr, ifFalse: expr) => + case Ast.expr.IfExp(condition: Ast.expr, ifTrue: Ast.expr, ifFalse: Ast.expr) => detectType(condition) match { case _: BooleanType => val trueType = detectType(ifTrue) @@ -108,75 +107,105 @@ class TypeDetector(provider: TypeProvider) { case cntType => throw new TypeMismatchError(s"unable to apply operation [] to $cntType") } case Ast.expr.Attribute(value: Ast.expr, attr: Ast.identifier) => - val valType = detectType(value) - valType match { - case KaitaiStructType => - throw new TypeMismatchError(s"called attribute '${attr.name}' on generic struct expression '$value'") - case t: UserType => - t.classSpec match { - case Some(tt) => provider.determineType(tt, attr.name) - case None => throw new TypeUndecidedError(s"expression '$value' has undecided type '${t.name}' (while asking for attribute '${attr.name}')") - } - case _: StrType => - attr.name match { - case "length" => CalcIntType - case "reverse" => CalcStrType - case "to_i" => CalcIntType - case _ => throw new TypeMismatchError(s"called invalid attribute '${attr.name}' on expression of type $valType") - } - case _: IntType => - attr.name match { - case "to_s" => CalcStrType - case _ => throw new TypeMismatchError(s"called invalid attribute '${attr.name}' on expression of type $valType") - } - case _: FloatType => - attr.name match { - case "to_i" => CalcIntType - case _ => throw new TypeMismatchError(s"called invalid attribute '${attr.name}' on expression of type $valType") - } - case ArrayType(inType) => - attr.name match { - case "first" | "last" | "min" | "max" => inType - case "size" => CalcIntType - case _ => throw new TypeMismatchError(s"called invalid attribute '${attr.name}' on expression of type $valType") - } - case KaitaiStreamType => - attr.name match { - case "size" => CalcIntType - case "pos" => CalcIntType - case "eof" => CalcBooleanType - case _ => throw new TypeMismatchError(s"called invalid attribute '${attr.name}' on expression of type $valType") - } - case et: EnumType => - attr.name match { - case "to_i" => CalcIntType - case _ => throw new TypeMismatchError(s"called invalid attribute '${attr.name}' on expression of type $valType") - } - case _: BooleanType => - attr.name match { - case "to_i" => CalcIntType - case _ => throw new TypeMismatchError(s"called invalid attribute '${attr.name}' on expression of type $valType") - } - case _ => - throw new TypeMismatchError(s"don't know how to call anything on $valType") - } - case Ast.expr.Call(func: Ast.expr, args: Seq[Ast.expr]) => - func match { - case Ast.expr.Attribute(obj: Ast.expr, methodName: Ast.identifier) => - val objType = detectType(obj) - (objType, methodName.name) match { - case (_: StrType, "substring") => CalcStrType - case (_: StrType, "to_i") => CalcIntType - case _ => throw new RuntimeException(s"don't know how to call method '$methodName' of object type '$objType'") - } - } + detectAttributeType(value, attr) + case call: Ast.expr.Call => + detectCallType(call) case Ast.expr.List(values: Seq[Ast.expr]) => detectArrayType(values) match { case Int1Type(_) => CalcBytesType case t => ArrayType(t) } - case Ast.expr.CastToType(value, typeName) => - provider.resolveType(typeName.name) + case Ast.expr.CastToType(_, typeName) => + detectCastType(typeName) + } + } + + /** + * Detects resulting data type of a given attribute expression. + * + * @note Must be kept in sync with [[CommonMethods.translateAttribute]] + * @param value value part of attribute expression + * @param attr attribute identifier part of attribute expression + * @return data type + */ + def detectAttributeType(value: Ast.expr, attr: Ast.identifier) = { + val valType = detectType(value) + valType match { + case KaitaiStructType | CalcKaitaiStructType => + throw new TypeMismatchError(s"called attribute '${attr.name}' on generic struct expression '$value'") + case t: UserType => + t.classSpec match { + case Some(tt) => provider.determineType(tt, attr.name) + case None => throw new TypeUndecidedError(s"expression '$value' has undecided type '${t.name}' (while asking for attribute '${attr.name}')") + } + case _: BytesType => + attr.name match { + case "length" => CalcIntType + case _ => throw new TypeMismatchError(s"called invalid attribute '${attr.name}' on expression of type $valType") + } + case _: StrType => + attr.name match { + case "length" => CalcIntType + case "reverse" => CalcStrType + case "to_i" => CalcIntType + case _ => throw new TypeMismatchError(s"called invalid attribute '${attr.name}' on expression of type $valType") + } + case _: IntType => + attr.name match { + case "to_s" => CalcStrType + case _ => throw new TypeMismatchError(s"called invalid attribute '${attr.name}' on expression of type $valType") + } + case _: FloatType => + attr.name match { + case "to_i" => CalcIntType + case _ => throw new TypeMismatchError(s"called invalid attribute '${attr.name}' on expression of type $valType") + } + case ArrayType(inType) => + attr.name match { + case "first" | "last" | "min" | "max" => inType + case "size" => CalcIntType + case _ => throw new TypeMismatchError(s"called invalid attribute '${attr.name}' on expression of type $valType") + } + case KaitaiStreamType => + attr.name match { + case "size" => CalcIntType + case "pos" => CalcIntType + case "eof" => CalcBooleanType + case _ => throw new TypeMismatchError(s"called invalid attribute '${attr.name}' on expression of type $valType") + } + case et: EnumType => + attr.name match { + case "to_i" => CalcIntType + case _ => throw new TypeMismatchError(s"called invalid attribute '${attr.name}' on expression of type $valType") + } + case _: BooleanType => + attr.name match { + case "to_i" => CalcIntType + case _ => throw new TypeMismatchError(s"called invalid attribute '${attr.name}' on expression of type $valType") + } + case _ => + throw new TypeMismatchError(s"don't know how to call anything on $valType") + } + } + + /** + * Detects resulting data type of a given function call expression. Typical function + * call expression in KSY is `foo.bar(arg1, arg2)`, which is represented in AST as + * `Call(Attribute(foo, bar), Seq(arg1, arg2))`. + * @note Must be kept in sync with [[CommonMethods.translateCall]] + * @param call function call expression + * @return data type + */ + def detectCallType(call: Ast.expr.Call): DataType = { + call.func match { + case Ast.expr.Attribute(obj: Ast.expr, methodName: Ast.identifier) => + val objType = detectType(obj) + // TODO: check number and type of arguments in `call.args` + (objType, methodName.name) match { + case (_: StrType, "substring") => CalcStrType + case (_: StrType, "to_i") => CalcIntType + case _ => throw new RuntimeException(s"don't know how to call method '$methodName' of object type '$objType'") + } } } @@ -188,7 +217,7 @@ class TypeDetector(provider: TypeProvider) { * @param values * @return */ - def detectArrayType(values: Seq[expr]): DataType = { + def detectArrayType(values: Seq[Ast.expr]): DataType = { var t1o: Option[DataType] = None values.foreach { v => @@ -204,6 +233,35 @@ class TypeDetector(provider: TypeProvider) { case Some(t) => t } } + + /** + * Detects cast type determined by a typeId definition. + * @param typeName typeId definition to use + * @return data type + */ + def detectCastType(typeName: Ast.typeId): DataType = { + val singleType = if ((!typeName.absolute) && typeName.names.size == 1) { + // May be it's a reserved pure data type name? + DataType.pureFromString(Some(typeName.names(0))) match { + case _: UserType => + // No, it's a user type, let's try to resolve it through provider + provider.resolveType(typeName) + case primitiveType => + // Yes, it is! + primitiveType + } + } else { + // It's a complex type name, it can be only resolved through provider + provider.resolveType(typeName) + } + + // Wrap it in array type, if needed + if (typeName.isArray) { + ArrayType(singleType) + } else { + singleType + } + } } object TypeDetector { @@ -220,9 +278,11 @@ object TypeDetector { case (_: NumericType, _: NumericType) => // ok case (_: BooleanType, _: BooleanType) => // ok case (_: BytesType, _: BytesType) => // ok - case (EnumType(name1, _), EnumType(name2, _)) => - if (name1 != name2) { - throw new TypeMismatchError(s"can't compare different enums '$name1' and '$name2'") + case (et1: EnumType, et2: EnumType) => + val et1Spec = et1.enumSpec.get + val et2Spec = et2.enumSpec.get + if (et1Spec != et2Spec) { + throw new TypeMismatchError(s"can't compare different enums '${et1Spec.nameAsStr}' and '${et2Spec.nameAsStr}'") } op match { case Ast.cmpop.Eq | Ast.cmpop.NotEq => // ok @@ -278,19 +338,31 @@ object TypeDetector { if (t1.name == t2.name) { t1 } else { - KaitaiStructType + if (t1.isOwning || t2.isOwning) { + KaitaiStructType + } else { + CalcKaitaiStructType + } } case (Some(cs1), Some(cs2)) => if (cs1 == cs2) { t1 } else { - KaitaiStructType + if (t1.isOwning || t2.isOwning) { + KaitaiStructType + } else { + CalcKaitaiStructType + } } case (_, _) => - KaitaiStructType + if (t1.isOwning || t2.isOwning) { + KaitaiStructType + } else { + CalcKaitaiStructType + } } - case (_: UserType, KaitaiStructType) => KaitaiStructType - case (KaitaiStructType, _: UserType) => KaitaiStructType + case (_: UserType, _: ComplexDataType) => CalcKaitaiStructType + case (_: ComplexDataType, _: UserType) => CalcKaitaiStructType case _ => AnyType } } @@ -341,6 +413,7 @@ object TypeDetector { case (_: BooleanType, _: BooleanType) => true case (_: StrType, _: StrType) => true case (_: UserType, KaitaiStructType) => true + case (_: UserType, CalcKaitaiStructType) => true case (t1: UserType, t2: UserType) => (t1.classSpec, t2.classSpec) match { case (None, None) => @@ -352,6 +425,9 @@ object TypeDetector { case (_, _) => false } + case (t1: EnumType, t2: EnumType) => + // enums are assignable if their enumSpecs match + t1.enumSpec.get == t2.enumSpec.get case (_, _) => false } } diff --git a/shared/src/main/scala/io/kaitai/struct/translators/TypeProvider.scala b/shared/src/main/scala/io/kaitai/struct/translators/TypeProvider.scala index 74f9207d5..1964575ce 100644 --- a/shared/src/main/scala/io/kaitai/struct/translators/TypeProvider.scala +++ b/shared/src/main/scala/io/kaitai/struct/translators/TypeProvider.scala @@ -1,6 +1,7 @@ package io.kaitai.struct.translators import io.kaitai.struct.datatype.DataType +import io.kaitai.struct.exprlang.Ast import io.kaitai.struct.format.{ClassSpec, EnumSpec} /** @@ -12,8 +13,8 @@ trait TypeProvider { def nowClass: ClassSpec def determineType(attrName: String): DataType def determineType(inClass: ClassSpec, attrName: String): DataType - def resolveEnum(enumName: String): EnumSpec - def resolveType(typeName: String): DataType + def resolveEnum(typeName: Ast.typeId, enumName: String): EnumSpec + def resolveType(typeName: Ast.typeId): DataType def isLazy(attrName: String): Boolean def isLazy(inClass: ClassSpec, attrName: String): Boolean }