diff --git a/README.md b/README.md index 48d99ed88e9b..953f08bd451a 100644 --- a/README.md +++ b/README.md @@ -48,7 +48,7 @@ DataFusion offers SQL and Dataframe APIs, excellent [performance](https://benchm - Written in [Rust](https://www.rust-lang.org/), a modern system language with development productivity similar to Java or Golang, the performance of C++, and [loved by programmers everywhere](https://insights.stackoverflow.com/survey/2021#technology-most-loved-dreaded-and-wanted). -- Support for [Substrait](https://substrait.io/) for query plan serialization, making it easier to integrate DataFusion +- Support for [Substrait](https://substrait.io/) for query plan serialization, making it easier to integrate DataFusion with other projects, and to pass plans across language boundaries. ## Use Cases @@ -79,6 +79,13 @@ execution plans, file format support, etc. ## Comparisons with other projects +When compared to similar systems, DataFusion typically is: + +1. Targeted at developers, rather than end users / data scientists. +2. Designed to be embedded, rather than a complete file based SQL system. +3. Governed by the [Apache Software Foundation](https://www.apache.org/) process, rather than a single company or individual. +4. Implemented in `Rust`, rather than `C/C++` + Here is a comparison with similar projects that may help understand when DataFusion might be be suitable and unsuitable for your needs: diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index d01617ac76d5..6fe28cabee06 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -74,16 +74,16 @@ checksum = "f410d3907b6b3647b9e7bca4551274b2e3d716aa940afb67b7287257401da921" dependencies = [ "ahash", "arrow-arith", - "arrow-array", - "arrow-buffer", + "arrow-array 34.0.0", + "arrow-buffer 34.0.0", "arrow-cast", "arrow-csv", - "arrow-data", + "arrow-data 34.0.0", "arrow-ipc", "arrow-json", "arrow-ord", "arrow-row", - "arrow-schema", + "arrow-schema 34.0.0", "arrow-select", "arrow-string", "comfy-table", @@ -95,10 +95,10 @@ version = "34.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f87391cf46473c9bc53dab68cb8872c3a81d4dfd1703f1c8aa397dba9880a043" dependencies = [ - "arrow-array", - "arrow-buffer", - "arrow-data", - "arrow-schema", + "arrow-array 34.0.0", + "arrow-buffer 34.0.0", + "arrow-data 34.0.0", + "arrow-schema 34.0.0", "chrono", "half", "num", @@ -111,15 +111,32 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d35d5475e65c57cffba06d0022e3006b677515f99b54af33a7cd54f6cdd4a5b5" dependencies = [ "ahash", - "arrow-buffer", - "arrow-data", - "arrow-schema", + "arrow-buffer 34.0.0", + "arrow-data 34.0.0", + "arrow-schema 34.0.0", "chrono", "half", "hashbrown 0.13.2", "num", ] +[[package]] +name = "arrow-array" +version = "35.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43489bbff475545b78b0e20bde1d22abd6c99e54499839f9e815a2fa5134a51b" +dependencies = [ + "ahash", + "arrow-buffer 35.0.0", + "arrow-data 35.0.0", + "arrow-schema 35.0.0", + "chrono", + "chrono-tz", + "half", + "hashbrown 0.13.2", + "num", +] + [[package]] name = "arrow-buffer" version = "34.0.0" @@ -130,16 +147,26 @@ dependencies = [ "num", ] +[[package]] +name = "arrow-buffer" +version = "35.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3759e4a52c593281184787af5435671dc8b1e78333e5a30242b2e2d6e3c9d1f" +dependencies = [ + "half", + "num", +] + [[package]] name = "arrow-cast" version = "34.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0a7285272c9897321dfdba59de29f5b05aeafd3cdedf104a941256d155f6d304" dependencies = [ - "arrow-array", - "arrow-buffer", - "arrow-data", - "arrow-schema", + "arrow-array 34.0.0", + "arrow-buffer 34.0.0", + "arrow-data 34.0.0", + "arrow-schema 34.0.0", "arrow-select", "chrono", "lexical-core", @@ -152,11 +179,11 @@ version = "34.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "981ee4e7f6a120da04e00d0b39182e1eeacccb59c8da74511de753c56b7fddf7" dependencies = [ - "arrow-array", - "arrow-buffer", + "arrow-array 34.0.0", + "arrow-buffer 34.0.0", "arrow-cast", - "arrow-data", - "arrow-schema", + "arrow-data 34.0.0", + "arrow-schema 34.0.0", "chrono", "csv", "csv-core", @@ -171,8 +198,20 @@ version = "34.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "27cc673ee6989ea6e4b4e8c7d461f7e06026a096c8f0b1a7288885ff71ae1e56" dependencies = [ - "arrow-buffer", - "arrow-schema", + "arrow-buffer 34.0.0", + "arrow-schema 34.0.0", + "half", + "num", +] + +[[package]] +name = "arrow-data" +version = "35.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19c7787c6cdbf9539b1ffb860bfc18c5848926ec3d62cbd52dc3b1ea35c874fd" +dependencies = [ + "arrow-buffer 35.0.0", + "arrow-schema 35.0.0", "half", "num", ] @@ -183,11 +222,11 @@ version = "34.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e37b8b69d9e59116b6b538e8514e0ec63a30f08b617ce800d31cb44e3ef64c1a" dependencies = [ - "arrow-array", - "arrow-buffer", + "arrow-array 34.0.0", + "arrow-buffer 34.0.0", "arrow-cast", - "arrow-data", - "arrow-schema", + "arrow-data 34.0.0", + "arrow-schema 34.0.0", "flatbuffers", ] @@ -197,11 +236,11 @@ version = "34.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "80c3fa0bed7cfebf6d18e46b733f9cb8a1cb43ce8e6539055ca3e1e48a426266" dependencies = [ - "arrow-array", - "arrow-buffer", + "arrow-array 34.0.0", + "arrow-buffer 34.0.0", "arrow-cast", - "arrow-data", - "arrow-schema", + "arrow-data 34.0.0", + "arrow-schema 34.0.0", "chrono", "half", "indexmap", @@ -216,10 +255,10 @@ version = "34.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d247dce7bed6a8d6a3c6debfa707a3a2f694383f0c692a39d736a593eae5ef94" dependencies = [ - "arrow-array", - "arrow-buffer", - "arrow-data", - "arrow-schema", + "arrow-array 34.0.0", + "arrow-buffer 34.0.0", + "arrow-data 34.0.0", + "arrow-schema 34.0.0", "arrow-select", "num", ] @@ -231,10 +270,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8d609c0181f963cea5c70fddf9a388595b5be441f3aa1d1cdbf728ca834bbd3a" dependencies = [ "ahash", - "arrow-array", - "arrow-buffer", - "arrow-data", - "arrow-schema", + "arrow-array 34.0.0", + "arrow-buffer 34.0.0", + "arrow-data 34.0.0", + "arrow-schema 34.0.0", "half", "hashbrown 0.13.2", ] @@ -245,16 +284,22 @@ version = "34.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "64951898473bfb8e22293e83a44f02874d2257514d49cd95f9aa4afcff183fbc" +[[package]] +name = "arrow-schema" +version = "35.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf6b26f6a6f8410e3b9531cbd1886399b99842701da77d4b4cf2013f7708f20f" + [[package]] name = "arrow-select" version = "34.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2a513d89c2e1ac22b28380900036cf1f3992c6443efc5e079de631dcf83c6888" dependencies = [ - "arrow-array", - "arrow-buffer", - "arrow-data", - "arrow-schema", + "arrow-array 34.0.0", + "arrow-buffer 34.0.0", + "arrow-data 34.0.0", + "arrow-schema 34.0.0", "num", ] @@ -264,10 +309,10 @@ version = "34.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5288979b2705dae1114c864d73150629add9153b9b8f1d7ee3963db94c372ba5" dependencies = [ - "arrow-array", - "arrow-buffer", - "arrow-data", - "arrow-schema", + "arrow-array 34.0.0", + "arrow-buffer 34.0.0", + "arrow-data 34.0.0", + "arrow-schema 34.0.0", "arrow-select", "regex", "regex-syntax", @@ -440,9 +485,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "chrono" -version = "0.4.23" +version = "0.4.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16b0a3d9ed01224b22057780a37bb8c5dbfe1be8ba48678e7bf57ec4b385411f" +checksum = "4e3c5919066adf22df73762e50cffcde3a758f2a848b113b586d1f86728b673b" dependencies = [ "iana-time-zone", "num-integer", @@ -451,6 +496,28 @@ dependencies = [ "winapi", ] +[[package]] +name = "chrono-tz" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa48fa079165080f11d7753fd0bc175b7d391f276b965fe4b55bfad67856e463" +dependencies = [ + "chrono", + "chrono-tz-build", + "phf", +] + +[[package]] +name = "chrono-tz-build" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9998fb9f7e9b2111641485bf8beb32f92945f97f92a3d061f744cfef335f751" +dependencies = [ + "parse-zoneinfo", + "phf", + "phf_codegen", +] + [[package]] name = "clap" version = "3.2.23" @@ -546,9 +613,9 @@ dependencies = [ [[package]] name = "constant_time_eq" -version = "0.2.4" +version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3ad85c1f65dc7b37604eb0e89748faf0b9653065f2a8ef69f96a687ec1e9279" +checksum = "13418e745008f7349ec7e449155f419a61b92b58a99cc3616942b926825ec76b" [[package]] name = "core-foundation-sys" @@ -737,6 +804,7 @@ name = "datafusion-common" version = "20.0.0" dependencies = [ "arrow", + "arrow-array 35.0.0", "chrono", "num_cpus", "object_store", @@ -792,8 +860,8 @@ version = "20.0.0" dependencies = [ "ahash", "arrow", - "arrow-buffer", - "arrow-schema", + "arrow-buffer 34.0.0", + "arrow-schema 34.0.0", "blake2", "blake3", "chrono", @@ -829,7 +897,7 @@ dependencies = [ name = "datafusion-sql" version = "20.0.0" dependencies = [ - "arrow-schema", + "arrow-schema 34.0.0", "datafusion-common", "datafusion-expr", "log", @@ -1022,9 +1090,9 @@ dependencies = [ [[package]] name = "futures" -version = "0.3.26" +version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13e2792b0ff0340399d58445b88fd9770e3489eff258a4cbc1523418f12abf84" +checksum = "531ac96c6ff5fd7c62263c5e3c67a603af4fcaee2e1a0ae5565ba3a11e69e549" dependencies = [ "futures-channel", "futures-core", @@ -1037,9 +1105,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.26" +version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e5317663a9089767a1ec00a487df42e0ca174b61b4483213ac24448e4664df5" +checksum = "164713a5a0dcc3e7b4b1ed7d3b433cabc18025386f9339346e8daf15963cf7ac" dependencies = [ "futures-core", "futures-sink", @@ -1047,15 +1115,15 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.26" +version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec90ff4d0fe1f57d600049061dc6bb68ed03c7d2fbd697274c41805dcb3f8608" +checksum = "86d7a0c1aa76363dac491de0ee99faf6941128376f1cf96f07db7603b7de69dd" [[package]] name = "futures-executor" -version = "0.3.26" +version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8de0a35a6ab97ec8869e32a2473f4b1324459e14c29275d14b10cb1fd19b50e" +checksum = "1997dd9df74cdac935c76252744c1ed5794fac083242ea4fe77ef3ed60ba0f83" dependencies = [ "futures-core", "futures-task", @@ -1064,15 +1132,15 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.26" +version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfb8371b6fb2aeb2d280374607aeabfc99d95c72edfe51692e42d3d7f0d08531" +checksum = "89d422fa3cbe3b40dca574ab087abb5bc98258ea57eea3fd6f1fa7162c778b91" [[package]] name = "futures-macro" -version = "0.3.26" +version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95a73af87da33b5acf53acfebdc339fe592ecf5357ac7c0a7734ab9d8c876a70" +checksum = "3eb14ed937631bd8b8b8977f2c198443447a8355b6e3ca599f38c975e5a963b6" dependencies = [ "proc-macro2", "quote", @@ -1081,21 +1149,21 @@ dependencies = [ [[package]] name = "futures-sink" -version = "0.3.26" +version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f310820bb3e8cfd46c80db4d7fb8353e15dfff853a127158425f31e0be6c8364" +checksum = "ec93083a4aecafb2a80a885c9de1f0ccae9dbd32c2bb54b0c3a65690e0b8d2f2" [[package]] name = "futures-task" -version = "0.3.26" +version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dcf79a1bf610b10f42aea489289c5a2c478a786509693b80cd39c44ccd936366" +checksum = "fd65540d33b37b16542a0438c12e6aeead10d4ac5d05bd3f805b8f35ab592879" [[package]] name = "futures-util" -version = "0.3.26" +version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c1d6de3acfef38d2be4b1f543f553131788603495be83da675e180c8d6b7bd1" +checksum = "3ef6b17e481503ec85211fed8f39d1970f128935ca1f814cd32ac4a6842e84ab" dependencies = [ "futures-channel", "futures-core", @@ -1204,6 +1272,12 @@ dependencies = [ "libc", ] +[[package]] +name = "hermit-abi" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fed44880c466736ef9a5c5b5facefb5ed0785676d0c02d612db14e54f0d84286" + [[package]] name = "http" version = "0.2.9" @@ -1342,10 +1416,11 @@ checksum = "8bb03732005da905c88227371639bf1ad885cc712789c011c31c5fb3ab3ccf02" [[package]] name = "io-lifetimes" -version = "1.0.6" +version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cfa919a82ea574332e2de6e74b4c36e74d41982b335080fa59d4ef31be20fdf3" +checksum = "76e86b86ae312accbf05ade23ce76b625e0e47a255712b7414037385a1c05380" dependencies = [ + "hermit-abi 0.3.1", "libc", "windows-sys 0.45.0", ] @@ -1784,12 +1859,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7ac135ecf63ebb5f53dda0921b0b76d6048b3ef631a5f4760b9e8f863ff00cfa" dependencies = [ "ahash", - "arrow-array", - "arrow-buffer", + "arrow-array 34.0.0", + "arrow-buffer 34.0.0", "arrow-cast", - "arrow-data", + "arrow-data 34.0.0", "arrow-ipc", - "arrow-schema", + "arrow-schema 34.0.0", "arrow-select", "base64", "brotli", @@ -1810,6 +1885,15 @@ dependencies = [ "zstd 0.12.3+zstd.1.5.2", ] +[[package]] +name = "parse-zoneinfo" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c705f256449c60da65e11ff6626e0c16a0a0b96aaa348de61376b249bc340f41" +dependencies = [ + "regex", +] + [[package]] name = "paste" version = "1.0.12" @@ -1832,6 +1916,44 @@ dependencies = [ "indexmap", ] +[[package]] +name = "phf" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "928c6535de93548188ef63bb7c4036bd415cd8f36ad25af44b9789b2ee72a48c" +dependencies = [ + "phf_shared", +] + +[[package]] +name = "phf_codegen" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a56ac890c5e3ca598bbdeaa99964edb5b0258a583a9eb6ef4e89fc85d9224770" +dependencies = [ + "phf_generator", + "phf_shared", +] + +[[package]] +name = "phf_generator" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1181c94580fa345f50f19d738aaa39c0ed30a600d95cb2d3e23f94266f14fbf" +dependencies = [ + "phf_shared", + "rand", +] + +[[package]] +name = "phf_shared" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1fb5f6f826b772a8d4c0394209441e7d37cbbb967ae9c7e0e8134365c9ee676" +dependencies = [ + "siphasher", +] + [[package]] name = "pin-project-lite" version = "0.2.9" @@ -1888,9 +2010,9 @@ checksum = "dc375e1527247fe1a97d8b7156678dfe7c1af2fc075c9a4db3690ecd2a148068" [[package]] name = "proc-macro2" -version = "1.0.51" +version = "1.0.52" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d727cae5b39d21da60fa540906919ad737832fe0b1c165da3a34d6548c849d6" +checksum = "1d0e1ae9e836cc3beddd63db0df682593d7e2d3d891ae8c9083d2113e1744224" dependencies = [ "unicode-ident", ] @@ -1907,9 +2029,9 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.23" +version = "1.0.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8856d8364d252a14d474036ea1358d63c9e6965c8e5c1885c18f73d70bff9c7b" +checksum = "4424af4bf778aae2051a77b60283332f386554255d722233d09fbfc7e30da2fc" dependencies = [ "proc-macro2", ] @@ -2159,9 +2281,9 @@ dependencies = [ [[package]] name = "semver" -version = "1.0.16" +version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "58bc9567378fc7690d6b2addae4e60ac2eeea07becb2c64b9f218b53865cba2a" +checksum = "bebd363326d05ec3e2f532ab7660680f3b02130d780c299bca73469d521bc0ed" [[package]] name = "seq-macro" @@ -2171,18 +2293,18 @@ checksum = "e6b44e8fc93a14e66336d230954dda83d18b4605ccace8fe09bc7514a71ad0bc" [[package]] name = "serde" -version = "1.0.154" +version = "1.0.156" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8cdd151213925e7f1ab45a9bbfb129316bd00799784b174b7cc7bcd16961c49e" +checksum = "314b5b092c0ade17c00142951e50ced110ec27cea304b1037c6969246c2469a4" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.154" +version = "1.0.156" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4fc80d722935453bcafdc2c9a73cd6fac4dc1938f0346035d84bf99fa9e33217" +checksum = "d7e29c4601e36bcec74a223228dce795f4cd3616341a4af93520ca1a837c087d" dependencies = [ "proc-macro2", "quote", @@ -2223,6 +2345,12 @@ dependencies = [ "digest", ] +[[package]] +name = "siphasher" +version = "0.3.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7bd3e3206899af3f8b12af284fafc038cc1dc2b41d1b89dd17297221c5d225de" + [[package]] name = "slab" version = "0.4.8" @@ -2639,12 +2767,11 @@ checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" [[package]] name = "walkdir" -version = "2.3.2" +version = "2.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "808cf2735cd4b6866113f648b791c6adc5714537bc222d9347bb203386ffda56" +checksum = "36df944cda56c7d8d8b7496af378e6b16de9284591917d307c9b4d313c44e698" dependencies = [ "same-file", - "winapi", "winapi-util", ] @@ -2829,9 +2956,9 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.42.1" +version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e2522491fbfcd58cc84d47aeb2958948c4b8982e9a2d8a2a35bbaed431390e7" +checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071" dependencies = [ "windows_aarch64_gnullvm", "windows_aarch64_msvc", @@ -2844,45 +2971,45 @@ dependencies = [ [[package]] name = "windows_aarch64_gnullvm" -version = "0.42.1" +version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c9864e83243fdec7fc9c5444389dcbbfd258f745e7853198f365e3c4968a608" +checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" [[package]] name = "windows_aarch64_msvc" -version = "0.42.1" +version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c8b1b673ffc16c47a9ff48570a9d85e25d265735c503681332589af6253c6c7" +checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" [[package]] name = "windows_i686_gnu" -version = "0.42.1" +version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "de3887528ad530ba7bdbb1faa8275ec7a1155a45ffa57c37993960277145d640" +checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" [[package]] name = "windows_i686_msvc" -version = "0.42.1" +version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf4d1122317eddd6ff351aa852118a2418ad4214e6613a50e0191f7004372605" +checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" [[package]] name = "windows_x86_64_gnu" -version = "0.42.1" +version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1040f221285e17ebccbc2591ffdc2d44ee1f9186324dd3e84e99ac68d699c45" +checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" [[package]] name = "windows_x86_64_gnullvm" -version = "0.42.1" +version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "628bfdf232daa22b0d64fdb62b09fcc36bb01f05a3939e20ab73aaf9470d0463" +checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" [[package]] name = "windows_x86_64_msvc" -version = "0.42.1" +version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "447660ad36a13288b1db4d4248e857b510e8c3a225c822ba4fb748c0aafecffd" +checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" [[package]] name = "winreg" diff --git a/datafusion/CHANGELOG.md b/datafusion/CHANGELOG.md index 5b919a1f21d7..38fccecb8707 100644 --- a/datafusion/CHANGELOG.md +++ b/datafusion/CHANGELOG.md @@ -21,101 +21,144 @@ ## [20.0.0](https://github.com/apache/arrow-datafusion/tree/20.0.0) (2023-03-10) -**Merged PRs** - -Note that the changelog generator is failing due to https://github.com/apache/arrow-datafusion/issues/5549 so -the changelog for this release was produced manually by running `git log --oneline 19.0.0..main` and only shows -merged PRs. - -- 8c34ca4fa (origin/main, main) Add UserDefinedLogicalNodeCore (#5521) -- 7f84503bb revert accidently deleted size code in count_distinct (#5533) -- 6a8c4a601 fix: return schema of ExtensionPlan instead of its children's (#5514) -- 55c7c4d8e Minor: Move `ObjectStoreRegistry` to datafusion_execution crate (#5478) -- 07b07d52f Add dbbenchmark URL to db-benchmark readme (#5503) -- 9878ee011 minor: fix clippy problem in new version. (#5532) -- 8a1cb7e3c fixed hash_join tests after passing boolean by value (#5531) -- f5d23ff58 memory limited hash join (#5490) -- ff013e245 minor: improve error style (#5510) -- e46924d80 feat: add `arrow_cast` function to support supports arbitrary arrow types (#5166) -- 84530a2f5 build(deps): update sqlparser requirement from 0.30 to 0.32 w/ API update (#5457) -- 8a1b13390 Allow setting config extensions for TaskContext (#5497) -- 0ead640e8 Minor: Improve docs for UserDefinedLogicalNode `dyn_eq` and `dyn_hash` (#5515) -- 89ca58326 feat: interval add timestamp (#5491) -- 9e32de344 Pass booleans by value instead of by reference (#5487) -- deeaa5632 Minor: Move TableProviderFactories up out of RuntimeEnv and into SessionState (#5477) -- 50e9d78da feat: `ParquetExec` predicate preservation (#5495) -- d7b221213 feat: add optimization rules for bitwise operations (#5423) -- 928662bb1 chore: Remove references from SessionState from physical_plan (#5455) -- 8d8b765f4 Implement `Debug` for `ExecutionProps` and `VarProvider` (#5489) -- a6f4a7719 feat: Support bitwise operations for unsigned integer types (#5476) -- 99ef98931 Apply workaround for #5444 to DataFrame::describe (#5468) -- d0bd28eef feat: eliminate the duplicated sort keys in Order By clause (#5462) -- 21e33a3d1 Propagate timezone to created arrays (#5481) -- 978ec2dc8 refactor: make GeometricMean not to have update and merge (#5469) -- c56b94765 feat: add name() method to UserDefinedLogicalNode (#5450) -- eb9541380 Comment out description fields in issue templates (#5482) -- c37ddf72e feat: express unsigned literal in substrait (#5448) -- 53a638ef3 fix: build union schema with child has same column name but qualifier is different (#5452) -- 8d195a8c0 refactor: make sum_distinct not to have update and merge (#5474) -- 132ae42f4 Fix decimal dyn scalar op (#5465) -- e9852074b feat: `extensions_options` macro (#5442) -- ac618f034 enable hash joins on FixedSizeBinary columns (#5461) -- 8de7ea45e Fix is_distinct from for float NaN values (#5446) -- 61fc51446 Implement/fix Eq and Hash for Expr and LogicalPlan (#5421) -- be6efbc93 [feat]:fast check has column (#5328) -- 1455b025c Parquet sorting benchmark (#5433) -- d11820aa2 refactor count_distinct to not to have update and merge (#5408) -- ddd64e7ed build(deps): update zstd requirement from 0.11 to 0.12 (#5458) -- db4610e1d Upgrade bytes to 1.4 (#5460) -- 49473d68b add std,median result to describe method (#5445) -- a95e0ec2f minor: Port more window tests to sqlogictests (#5434) -- f68214dc6 Use compute_op_dyn_scalar for datatime (#5315) -- a4b47d8c8 add a unit test that cover cast bug. (#5443) -- 17b2f11b5 create new `datafusion-execution` crate, start splitting code out (#5432) -- f9f40bf70 minor: fix clippy in nightly. (#5440) -- 3c1e4c0fd Support for Sliding Windows Joins with Symmetric Hash Join (SHJ) (#5322) -- 03fbf9fec refactor: ParquetExec logical expr. => phys. expr. (#5419) -- c4b895822 Update README.md fix [DataFusion] links (#5438) -- 793feda23 add mean result for describe method (#5435) -- eef046441 add expr_fn::median (#5437) -- d076ab3b2 Bug/union wrong casting (#5342) -- 0000d4fce reimplement `push_down_projection` and `prune_column`. (#4465) -- 25b4f673d Add (#5409) -- d9669841c fix nested loop join with literal join filter (#5431) -- 96aa2a677 add a describe method on DataFrame like Polars (#5226) -- ea3b965dd Memory reservation & metrics for cross join (#5339) -- 20d08ab1f Optimize count_distinct.size (#5377) -- c477fc0ca Fix filter pushdown for extension plans (#5425) -- c676d1026 Also push down all filters in TableProvider (#5420) -- 06fecacb3 Update arrow 34 (#5375) -- 411807609 Parquet limit pushdown (#5404) (#5416) -- 58cd1bf0b Move file format config.rs to live with the rest of the datasource code (#5406) -- 8202a395a Support Zstd compressed files (#5397) -- 5ffa8d78f Add example of catalog API usage (#5291) (#5326) -- a7a824ad9 Add support for protobuf serialisation of Arrow Map type (#5359) -- cb09142c7 minor: port window tests to slt (part 2) (#5399) -- 69503ea0f fix(docs): fix typos (#5403) -- fad360df0 rebase (#5367) -- 38185cacd enhance: remove more projection. (#5402) -- 0b77ec27c refactor `push_down_filter` to fix dead-loop and use optimizer_recurse. (#5337) -- 8b92b9b6c feat: add eliminate_unnecessary_projection rule (#5366) -- 645428f59 minor: support forgotten large_utf8 (#5393) -- 248d6fc1b Minor: add tests for subquery to join (#5363) -- 85ed38673 bugfix: fix master `bors` problem. (#5395) -- a870e460d feat: rule ReplaceDistinctWithAggregate (#5354) -- 4d4c5678e chore: add known project ZincObserve (#5376) -- 1841736d9 refactor: parquet pruning simplifications (#5386) -- 722490121 Optimization of the function "intersect" (#5388) -- d64076b46 docs: clarify spark (#5391) -- 47bdda60b UDF zero params #5378 (#5380) -- 32a238c50 Added tests for "like_for_type_coercion" and "test_type_coercion_rewrite" (#5389) -- 9fd6efa51 minor: make table resolution an independent function ... (#5373) -- 05d8a2527 minor: port predicates tests to sqllogictests (#5374) -- 49237a21a Bug fix: Window frame range value outside the type range (#5384) -- d6ef46374 Fixed small typos in files of the optimizer (#5356) -- cef119da9 fix: misc phys. expression display bugs (#5387) -- ac33ebc8f Prepare for 19.0.0 release (#5381) -- 1309267e7 minor: disable tpcds-q41 due to not support decorrelate disjunction subquery. (#5369) +[Full Changelog](https://github.com/apache/arrow-datafusion/compare/19.0.0...20.0.0-rc1) + +**Breaking changes:** + +- Minor: Move TableProviderFactories up out of `RuntimeEnv` and into `SessionState` [#5477](https://github.com/apache/arrow-datafusion/pull/5477) (alamb) +- chore: Remove references from SessionState from physical_plan [#5455](https://github.com/apache/arrow-datafusion/pull/5455) (alamb) +- Implement `Debug` for `ExecutionProps` and `VarProvider` [#5489](https://github.com/apache/arrow-datafusion/pull/5489) (alamb) + +**Implemented enhancements:** + +- Add UserDefinedLogicalNodeCore [#5521](https://github.com/apache/arrow-datafusion/pull/5521) (mslapek) +- feat: add `arrow_cast` function to support supports arbitrary arrow types [#5166](https://github.com/apache/arrow-datafusion/pull/5166) (alamb) +- feat: interval add timestamp [#5491](https://github.com/apache/arrow-datafusion/pull/5491) (Weijun-H) +- feat: `ParquetExec` predicate preservation [#5495](https://github.com/apache/arrow-datafusion/pull/5495) (crepererum) +- feat: add optimization rules for bitwise operations [#5423](https://github.com/apache/arrow-datafusion/pull/5423) (izveigor) +- feat: Support bitwise operations for unsigned integer types [#5476](https://github.com/apache/arrow-datafusion/pull/5476) (izveigor) +- feat: eliminate the duplicated sort keys in Order By clause [#5462](https://github.com/apache/arrow-datafusion/pull/5462) (jackwener) +- feat: add name() method to UserDefinedLogicalNode [#5450](https://github.com/apache/arrow-datafusion/pull/5450) (waynexia) +- feat: express unsigned literal in substrait [#5448](https://github.com/apache/arrow-datafusion/pull/5448) (waynexia) +- feat: `extensions_options` macro [#5442](https://github.com/apache/arrow-datafusion/pull/5442) (crepererum) +- [feat]:fast check has column [#5328](https://github.com/apache/arrow-datafusion/pull/5328) (suxiaogang223) +- feat: eliminate unnecessary projection. [#5366](https://github.com/apache/arrow-datafusion/pull/5366) (jackwener) + +**Fixed bugs:** + +- revert accidently deleted size code in count_distinct [#5533](https://github.com/apache/arrow-datafusion/pull/5533) (comphead) +- fix: return schema of ExtensionPlan instead of its children's [#5514](https://github.com/apache/arrow-datafusion/pull/5514) (waynexia) +- fix: logical merge conflict -- hash_join tests with passing boolean by value [#5531](https://github.com/apache/arrow-datafusion/pull/5531) (korowa) +- fix: build union schema with child has same column name but qualifier… [#5452](https://github.com/apache/arrow-datafusion/pull/5452) (yukkit) +- Fix is_distinct from for float NaN values [#5446](https://github.com/apache/arrow-datafusion/pull/5446) (comphead) +- Bug/union wrong casting [#5342](https://github.com/apache/arrow-datafusion/pull/5342) (berkaysynnada) +- fix nested loop join with literal join filter [#5431](https://github.com/apache/arrow-datafusion/pull/5431) (ygf11) +- Fix filter pushdown for extension plans [#5425](https://github.com/apache/arrow-datafusion/pull/5425) (thinkharderdev) +- Bug fix: Window frame range value outside the type range [#5384](https://github.com/apache/arrow-datafusion/pull/5384) (mustafasrepo) +- fix: misc phys. expression display bugs [#5387](https://github.com/apache/arrow-datafusion/pull/5387) (crepererum) + +**Documentation updates:** + +- Minor: Improve docs for UserDefinedLogicalNode `dyn_eq` and `dyn_hash` [#5515](https://github.com/apache/arrow-datafusion/pull/5515) (alamb) +- chore: add known project ZincObserve [#5376](https://github.com/apache/arrow-datafusion/pull/5376) (hengfeiyang) +- docs: clarify spark [#5391](https://github.com/apache/arrow-datafusion/pull/5391) (hyoklee) + +**Merged pull requests:** + +- Manual changelog for 20.0.0 [#5551](https://github.com/apache/arrow-datafusion/pull/5551) (andygrove) +- Prepare for 20.0.0 release [Part 1] [#5539](https://github.com/apache/arrow-datafusion/pull/5539) (andygrove) +- chore: deduplicate workspace fields in Cargo.toml [#5519](https://github.com/apache/arrow-datafusion/pull/5519) (waynexia) +- Add necessary features to optimizer [#5540](https://github.com/apache/arrow-datafusion/pull/5540) (viirya) +- Minor: add the concise way for matching numerics [#5537](https://github.com/apache/arrow-datafusion/pull/5537) (izveigor) +- Add UserDefinedLogicalNodeCore [#5521](https://github.com/apache/arrow-datafusion/pull/5521) (mslapek) +- revert accidently deleted size code in count_distinct [#5533](https://github.com/apache/arrow-datafusion/pull/5533) (comphead) +- fix: return schema of ExtensionPlan instead of its children's [#5514](https://github.com/apache/arrow-datafusion/pull/5514) (waynexia) +- Minor: Move `ObjectStoreRegistry` to datafusion_execution crate [#5478](https://github.com/apache/arrow-datafusion/pull/5478) (alamb) +- Minor: Add db-benchmark URL to db-benchmark readme [#5503](https://github.com/apache/arrow-datafusion/pull/5503) (alamb) +- minor: fix clippy problem in new version. [#5532](https://github.com/apache/arrow-datafusion/pull/5532) (jackwener) +- fix: logical merge conflict -- hash_join tests with passing boolean by value [#5531](https://github.com/apache/arrow-datafusion/pull/5531) (korowa) +- Memory limited hash join [#5490](https://github.com/apache/arrow-datafusion/pull/5490) (korowa) +- minor: improve error style [#5510](https://github.com/apache/arrow-datafusion/pull/5510) (alamb) +- feat: add `arrow_cast` function to support supports arbitrary arrow types [#5166](https://github.com/apache/arrow-datafusion/pull/5166) (alamb) +- build(deps): update sqlparser requirement from 0.30 to 0.32 w/ API update [#5457](https://github.com/apache/arrow-datafusion/pull/5457) (alamb) +- Allow setting config extensions for TaskContext [#5497](https://github.com/apache/arrow-datafusion/pull/5497) (mpurins-coralogix) +- Minor: Improve docs for UserDefinedLogicalNode `dyn_eq` and `dyn_hash` [#5515](https://github.com/apache/arrow-datafusion/pull/5515) (alamb) +- feat: interval add timestamp [#5491](https://github.com/apache/arrow-datafusion/pull/5491) (Weijun-H) +- Pass booleans by value instead of by reference [#5487](https://github.com/apache/arrow-datafusion/pull/5487) (maxburke) +- Minor: Move TableProviderFactories up out of `RuntimeEnv` and into `SessionState` [#5477](https://github.com/apache/arrow-datafusion/pull/5477) (alamb) +- feat: `ParquetExec` predicate preservation [#5495](https://github.com/apache/arrow-datafusion/pull/5495) (crepererum) +- feat: add optimization rules for bitwise operations [#5423](https://github.com/apache/arrow-datafusion/pull/5423) (izveigor) +- chore: Remove references from SessionState from physical_plan [#5455](https://github.com/apache/arrow-datafusion/pull/5455) (alamb) +- Implement `Debug` for `ExecutionProps` and `VarProvider` [#5489](https://github.com/apache/arrow-datafusion/pull/5489) (alamb) +- feat: Support bitwise operations for unsigned integer types [#5476](https://github.com/apache/arrow-datafusion/pull/5476) (izveigor) +- Apply workaround for #5444 to `DataFrame::describe` [#5468](https://github.com/apache/arrow-datafusion/pull/5468) (jiangzhx) +- feat: eliminate the duplicated sort keys in Order By clause [#5462](https://github.com/apache/arrow-datafusion/pull/5462) (jackwener) +- Propagate timezone to created arrays [#5481](https://github.com/apache/arrow-datafusion/pull/5481) (maxburke) +- refactor: make GeometricMean not to have update and merge [#5469](https://github.com/apache/arrow-datafusion/pull/5469) (Weijun-H) +- feat: add name() method to UserDefinedLogicalNode [#5450](https://github.com/apache/arrow-datafusion/pull/5450) (waynexia) +- Comment out description text in issue templates [#5482](https://github.com/apache/arrow-datafusion/pull/5482) (Jefffrey) +- feat: express unsigned literal in substrait [#5448](https://github.com/apache/arrow-datafusion/pull/5448) (waynexia) +- fix: build union schema with child has same column name but qualifier… [#5452](https://github.com/apache/arrow-datafusion/pull/5452) (yukkit) +- refactor: make sum_distinct not to have update and merge [#5474](https://github.com/apache/arrow-datafusion/pull/5474) (Weijun-H) +- `compute_decimal_op_dyn_scalar` should not cast lhs array to decimal array [#5465](https://github.com/apache/arrow-datafusion/pull/5465) (viirya) +- feat: `extensions_options` macro [#5442](https://github.com/apache/arrow-datafusion/pull/5442) (crepererum) +- Enable hash joins on FixedSizeBinary columns [#5461](https://github.com/apache/arrow-datafusion/pull/5461) (maxburke) +- Fix is_distinct from for float NaN values [#5446](https://github.com/apache/arrow-datafusion/pull/5446) (comphead) +- Implement/fix Eq and Hash for Expr and LogicalPlan [#5421](https://github.com/apache/arrow-datafusion/pull/5421) (mslapek) +- [feat]:fast check has column [#5328](https://github.com/apache/arrow-datafusion/pull/5328) (suxiaogang223) +- Parquet sorting benchmark [#5433](https://github.com/apache/arrow-datafusion/pull/5433) (jaylmiller) +- refactor count_distinct to not to have update and merge [#5408](https://github.com/apache/arrow-datafusion/pull/5408) (Weijun-H) +- build(deps): update zstd requirement from 0.11 to 0.12 [#5458](https://github.com/apache/arrow-datafusion/pull/5458) (alamb) +- Upgrade bytes to 1.4 [#5460](https://github.com/apache/arrow-datafusion/pull/5460) (viirya) +- add std,median result to describe method [#5445](https://github.com/apache/arrow-datafusion/pull/5445) (jiangzhx) +- minor: Port more window tests to sqlogictests [#5434](https://github.com/apache/arrow-datafusion/pull/5434) (alamb) +- Use compute_op_dyn_scalar for datatime [#5315](https://github.com/apache/arrow-datafusion/pull/5315) (viirya) +- add a unit test that cover cast bug. [#5443](https://github.com/apache/arrow-datafusion/pull/5443) (jackwener) +- create new `datafusion-execution` crate, start splitting code out [#5432](https://github.com/apache/arrow-datafusion/pull/5432) (alamb) +- minor: fix clippy in nightly. [#5440](https://github.com/apache/arrow-datafusion/pull/5440) (jackwener) +- Support for Sliding Windows Joins with Symmetric Hash Join (SHJ) [#5322](https://github.com/apache/arrow-datafusion/pull/5322) (metesynnada) +- refactor: ParquetExec logical expr. => phys. expr. [#5419](https://github.com/apache/arrow-datafusion/pull/5419) (crepererum) +- Update README.md fix [DataFusion] links [#5438](https://github.com/apache/arrow-datafusion/pull/5438) (jiangzhx) +- add mean result for describe method [#5435](https://github.com/apache/arrow-datafusion/pull/5435) (jiangzhx) +- add expr_fn::median [#5437](https://github.com/apache/arrow-datafusion/pull/5437) (jiangzhx) +- Bug/union wrong casting [#5342](https://github.com/apache/arrow-datafusion/pull/5342) (berkaysynnada) +- reimplement `push_down_projection` and `prune_column`. [#4465](https://github.com/apache/arrow-datafusion/pull/4465) (jackwener) +- Add `expr_fn::stddev` [#5409](https://github.com/apache/arrow-datafusion/pull/5409) (jiangzhx) +- fix nested loop join with literal join filter [#5431](https://github.com/apache/arrow-datafusion/pull/5431) (ygf11) +- add a describe method on DataFrame like Polars [#5226](https://github.com/apache/arrow-datafusion/pull/5226) (jiangzhx) +- Memory reservation & metrics for cross join [#5339](https://github.com/apache/arrow-datafusion/pull/5339) (korowa) +- Optimize count_distinct.size [#5377](https://github.com/apache/arrow-datafusion/pull/5377) (comphead) +- Fix filter pushdown for extension plans [#5425](https://github.com/apache/arrow-datafusion/pull/5425) (thinkharderdev) +- Also push down all filters in TableProvider [#5420](https://github.com/apache/arrow-datafusion/pull/5420) (avantgardnerio) +- Update arrow 34 [#5375](https://github.com/apache/arrow-datafusion/pull/5375) (tustvold) +- Parquet limit pushdown (#5404) [#5416](https://github.com/apache/arrow-datafusion/pull/5416) (tustvold) +- Move file format config.rs to live with the rest of the datasource code [#5406](https://github.com/apache/arrow-datafusion/pull/5406) (alamb) +- Support Zstd compressed files [#5397](https://github.com/apache/arrow-datafusion/pull/5397) (dennybritz) +- Add example of catalog API usage (#5291) [#5326](https://github.com/apache/arrow-datafusion/pull/5326) (jaylmiller) +- Add support for protobuf serialisation of Arrow Map type [#5359](https://github.com/apache/arrow-datafusion/pull/5359) (ahmedriza) +- minor: port window tests to slt (part 2) [#5399](https://github.com/apache/arrow-datafusion/pull/5399) (alamb) +- fix(docs): fix typos [#5403](https://github.com/apache/arrow-datafusion/pull/5403) (WenyXu) +- Try to push down full filter before break-up [#5367](https://github.com/apache/arrow-datafusion/pull/5367) (avantgardnerio) +- enhance: remove more projection. [#5402](https://github.com/apache/arrow-datafusion/pull/5402) (jackwener) +- refactor `push_down_filter` to fix dead-loop and use optimizer_recurse. [#5337](https://github.com/apache/arrow-datafusion/pull/5337) (jackwener) +- feat: eliminate unnecessary projection. [#5366](https://github.com/apache/arrow-datafusion/pull/5366) (jackwener) +- minor: add forgotten large_utf8 [#5393](https://github.com/apache/arrow-datafusion/pull/5393) (jackwener) +- Minor: add tests for subquery to join [#5363](https://github.com/apache/arrow-datafusion/pull/5363) (ygf11) +- bugfix: fix master `bors` problem. [#5395](https://github.com/apache/arrow-datafusion/pull/5395) (jackwener) +- Rule ReplaceDistinctWithAggregate [#5354](https://github.com/apache/arrow-datafusion/pull/5354) (mingmwang) +- chore: add known project ZincObserve [#5376](https://github.com/apache/arrow-datafusion/pull/5376) (hengfeiyang) +- refactor: parquet pruning simplifications [#5386](https://github.com/apache/arrow-datafusion/pull/5386) (crepererum) +- Minor: intersect expressions optimization [#5388](https://github.com/apache/arrow-datafusion/pull/5388) (izveigor) +- docs: clarify spark [#5391](https://github.com/apache/arrow-datafusion/pull/5391) (hyoklee) +- UDF zero params #5378 [#5380](https://github.com/apache/arrow-datafusion/pull/5380) (jaylmiller) +- Minor: added some tests for coercion type [#5389](https://github.com/apache/arrow-datafusion/pull/5389) (izveigor) +- minor: make table resolution an independent function ... [#5373](https://github.com/apache/arrow-datafusion/pull/5373) (MichaelScofield) +- minor: port predicates tests to sqllogictests [#5374](https://github.com/apache/arrow-datafusion/pull/5374) (jackwener) +- Bug fix: Window frame range value outside the type range [#5384](https://github.com/apache/arrow-datafusion/pull/5384) (mustafasrepo) +- Fixed small typos in files of the optimizer [#5356](https://github.com/apache/arrow-datafusion/pull/5356) (izveigor) +- fix: misc phys. expression display bugs [#5387](https://github.com/apache/arrow-datafusion/pull/5387) (crepererum) +- Prepare for 19.0.0 release [#5381](https://github.com/apache/arrow-datafusion/pull/5381) (andygrove) +- minor: disable tpcds-q41 due to not support decorrelate disjunction subquery [#5369](https://github.com/apache/arrow-datafusion/pull/5369) (jackwener) ## [19.0.0](https://github.com/apache/arrow-datafusion/tree/19.0.0) (2023-02-24) diff --git a/datafusion/common/Cargo.toml b/datafusion/common/Cargo.toml index 8a0a7042fcba..7d78ed70eb35 100644 --- a/datafusion/common/Cargo.toml +++ b/datafusion/common/Cargo.toml @@ -41,6 +41,7 @@ pyarrow = ["pyo3", "arrow/pyarrow"] [dependencies] apache-avro = { version = "0.14", default-features = false, features = ["snappy"], optional = true } arrow = { workspace = true, default-features = false } +arrow-array = { version = "35.0.0", default-features = false, features = ["chrono-tz"] } chrono = { version = "0.4", default-features = false } cranelift-module = { version = "0.92.0", optional = true } num_cpus = "1.13.0" @@ -48,3 +49,6 @@ object_store = { version = "0.5.4", default-features = false, optional = true } parquet = { workspace = true, default-features = false, optional = true } pyo3 = { version = "0.18.0", optional = true } sqlparser = "0.32" + +[dev-dependencies] +rand = "0.8.4" diff --git a/datafusion/common/src/column.rs b/datafusion/common/src/column.rs index 4a0e392f6e76..78c1857238f9 100644 --- a/datafusion/common/src/column.rs +++ b/datafusion/common/src/column.rs @@ -17,7 +17,8 @@ //! Column -use crate::{DFSchema, DataFusionError, Result, SchemaError}; +use crate::utils::{parse_identifiers_normalized, quote_identifier}; +use crate::{DFSchema, DataFusionError, OwnedTableReference, Result, SchemaError}; use std::collections::HashSet; use std::convert::Infallible; use std::fmt; @@ -27,21 +28,37 @@ use std::sync::Arc; /// A named reference to a qualified field in a schema. #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] pub struct Column { - /// relation/table name. - pub relation: Option, + /// relation/table reference. + pub relation: Option, /// field/column name. pub name: String, } impl Column { - /// Create Column from optional qualifier and name - pub fn new(relation: Option>, name: impl Into) -> Self { + /// Create Column from optional qualifier and name. The optional qualifier, if present, + /// will be parsed and normalized by default. + /// + /// See full details on [`TableReference::parse_str`] + /// + /// [`TableReference::parse_str`]: crate::TableReference::parse_str + pub fn new( + relation: Option>, + name: impl Into, + ) -> Self { Self { relation: relation.map(|r| r.into()), name: name.into(), } } + /// Convenience method for when there is no qualifier + pub fn new_unqualified(name: impl Into) -> Self { + Self { + relation: None, + name: name.into(), + } + } + /// Create Column from unqualified name. pub fn from_name(name: impl Into) -> Self { Self { @@ -53,26 +70,36 @@ impl Column { /// Deserialize a fully qualified name string into a column pub fn from_qualified_name(flat_name: impl Into) -> Self { let flat_name = flat_name.into(); - use sqlparser::tokenizer::Token; - - let dialect = sqlparser::dialect::GenericDialect {}; - let mut tokenizer = sqlparser::tokenizer::Tokenizer::new(&dialect, &flat_name); - if let Ok(tokens) = tokenizer.tokenize() { - if let [Token::Word(relation), Token::Period, Token::Word(name)] = - tokens.as_slice() - { - return Column { - relation: Some(relation.value.clone()), - name: name.value.clone(), - }; - } - } - // any expression that's not in the form of `foo.bar` will be treated as unqualified column - // name - Column { - relation: None, - name: flat_name, - } + let mut idents = parse_identifiers_normalized(&flat_name); + + let (relation, name) = match idents.len() { + 1 => (None, idents.remove(0)), + 2 => ( + Some(OwnedTableReference::Bare { + table: idents.remove(0).into(), + }), + idents.remove(0), + ), + 3 => ( + Some(OwnedTableReference::Partial { + schema: idents.remove(0).into(), + table: idents.remove(0).into(), + }), + idents.remove(0), + ), + 4 => ( + Some(OwnedTableReference::Full { + catalog: idents.remove(0).into(), + schema: idents.remove(0).into(), + table: idents.remove(0).into(), + }), + idents.remove(0), + ), + // any expression that failed to parse or has more than 4 period delimited + // identifiers will be treated as an unqualified column name + _ => (None, flat_name), + }; + Self { relation, name } } /// Serialize column into a flat name string @@ -83,6 +110,18 @@ impl Column { } } + /// Serialize column into a quoted flat name string + pub fn quoted_flat_name(&self) -> String { + // TODO: quote identifiers only when special characters present + // see: https://github.com/apache/arrow-datafusion/issues/5523 + match &self.relation { + Some(r) => { + format!("{}.{}", r.to_quoted_string(), quote_identifier(&self.name)) + } + None => quote_identifier(&self.name), + } + } + /// Qualify column if not done yet. /// /// If this column already has a [relation](Self::relation), it will be returned as is and the given parameters are @@ -151,7 +190,7 @@ impl Column { } Err(DataFusionError::SchemaError(SchemaError::FieldNotFound { - field: Column::new(self.relation.clone(), self.name), + field: Box::new(Column::new(self.relation.clone(), self.name)), valid_fields: schemas .iter() .flat_map(|s| s.fields().iter().map(|f| f.qualified_column())) @@ -240,8 +279,7 @@ impl Column { // If not due to USING columns then due to ambiguous column name return Err(DataFusionError::SchemaError( SchemaError::AmbiguousReference { - qualifier: None, - name: self.name, + field: Column::new_unqualified(self.name), }, )); } @@ -249,7 +287,7 @@ impl Column { } Err(DataFusionError::SchemaError(SchemaError::FieldNotFound { - field: self, + field: Box::new(self), valid_fields: schemas .iter() .flat_map(|s| s.iter()) @@ -304,7 +342,12 @@ mod tests { let fields = names .iter() .map(|(qualifier, name)| { - DFField::new(qualifier.to_owned(), name, DataType::Boolean, true) + DFField::new( + qualifier.to_owned().map(|s| s.to_string()), + name, + DataType::Boolean, + true, + ) }) .collect::>(); DFSchema::new_with_metadata(fields, HashMap::new()) @@ -362,9 +405,7 @@ mod tests { &[], ) .expect_err("should've failed to find field"); - let expected = "Schema error: No field named 'z'. \ - Valid fields are 't1'.'a', 't1'.'b', 't2'.'c', \ - 't2'.'d', 't3'.'a', 't3'.'b', 't3'.'c', 't3'.'d', 't3'.'e'."; + let expected = r#"Schema error: No field named "z". Valid fields are "t1"."a", "t1"."b", "t2"."c", "t2"."d", "t3"."a", "t3"."b", "t3"."c", "t3"."d", "t3"."e"."#; assert_eq!(err.to_string(), expected); // ambiguous column reference @@ -375,7 +416,7 @@ mod tests { &[], ) .expect_err("should've found ambiguous field"); - let expected = "Schema error: Ambiguous reference to unqualified field 'a'"; + let expected = "Schema error: Ambiguous reference to unqualified field \"a\""; assert_eq!(err.to_string(), expected); Ok(()) diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index bb443f8aa564..67a367c5ef3f 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -23,8 +23,9 @@ use std::convert::TryFrom; use std::hash::Hash; use std::sync::Arc; -use crate::error::{DataFusionError, Result, SchemaError}; -use crate::{field_not_found, Column, TableReference}; +use crate::error::{unqualified_field_not_found, DataFusionError, Result, SchemaError}; +use crate::utils::quote_identifier; +use crate::{field_not_found, Column, OwnedTableReference, TableReference}; use arrow::compute::can_cast_types; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; @@ -70,7 +71,7 @@ impl DFSchema { if !qualified_names.insert((qualifier, field.name())) { return Err(DataFusionError::SchemaError( SchemaError::DuplicateQualifiedField { - qualifier: qualifier.to_string(), + qualifier: Box::new(qualifier.clone()), name: field.name().to_string(), }, )); @@ -90,18 +91,16 @@ impl DFSchema { let mut qualified_names = qualified_names .iter() .map(|(l, r)| (l.to_owned(), r.to_owned())) - .collect::>(); - qualified_names.sort_by(|a, b| { - let a = format!("{}.{}", a.0, a.1); - let b = format!("{}.{}", b.0, b.1); - a.cmp(&b) - }); + .collect::>(); + qualified_names.sort(); for (qualifier, name) in &qualified_names { if unqualified_names.contains(name) { return Err(DataFusionError::SchemaError( SchemaError::AmbiguousReference { - qualifier: Some(qualifier.to_string()), - name: name.to_string(), + field: Column { + relation: Some((*qualifier).clone()), + name: name.to_string(), + }, }, )); } @@ -109,13 +108,17 @@ impl DFSchema { Ok(Self { fields, metadata }) } - /// Create a `DFSchema` from an Arrow schema - pub fn try_from_qualified_schema(qualifier: &str, schema: &Schema) -> Result { + /// Create a `DFSchema` from an Arrow schema and a given qualifier + pub fn try_from_qualified_schema<'a>( + qualifier: impl Into>, + schema: &Schema, + ) -> Result { + let qualifier = qualifier.into(); Self::new_with_metadata( schema .fields() .iter() - .map(|f| DFField::from_qualified(qualifier, f.clone())) + .map(|f| DFField::from_qualified(qualifier.clone(), f.clone())) .collect(), schema.metadata().clone(), ) @@ -140,7 +143,7 @@ impl DFSchema { for field in other_schema.fields() { // skip duplicate columns let duplicated_field = match field.qualifier() { - Some(q) => self.field_with_name(Some(q.as_str()), field.name()).is_ok(), + Some(q) => self.field_with_name(Some(q), field.name()).is_ok(), // for unqualified columns, check as unqualified name None => self.field_with_unqualified_name(field.name()).is_ok(), }; @@ -173,7 +176,7 @@ impl DFSchema { // a fully qualified field name is provided. match &self.fields[i].qualifier { Some(qualifier) => { - if (qualifier.to_owned() + "." + self.fields[i].name()) == name { + if (qualifier.to_string() + "." + self.fields[i].name()) == name { return Err(DataFusionError::Plan(format!( "Fully qualified field name '{name}' was supplied to `index_of` \ which is deprecated. Please use `index_of_column_by_name` instead" @@ -185,12 +188,12 @@ impl DFSchema { } } - Err(field_not_found(None, name, self)) + Err(unqualified_field_not_found(name, self)) } pub fn index_of_column_by_name( &self, - qualifier: Option<&str>, + qualifier: Option<&TableReference>, name: &str, ) -> Result> { let mut matches = self @@ -201,19 +204,19 @@ impl DFSchema { // field to lookup is qualified. // current field is qualified and not shared between relations, compare both // qualifier and name. - (Some(q), Some(field_q)) => q == field_q && field.name() == name, + (Some(q), Some(field_q)) => { + q.resolved_eq(field_q) && field.name() == name + } // field to lookup is qualified but current field is unqualified. (Some(qq), None) => { // the original field may now be aliased with a name that matches the // original qualified name - let table_ref = TableReference::parse_str(field.name().as_str()); - match table_ref { - TableReference::Partial { schema, table } => { - schema == qq && table == name - } - TableReference::Full { schema, table, .. } => { - schema == qq && table == name - } + let column = Column::from_qualified_name(field.name()); + match column { + Column { + relation: Some(r), + name: column_name, + } => &r == qq && column_name == name, _ => false, } } @@ -227,9 +230,11 @@ impl DFSchema { None => Ok(Some(idx)), // found more than one matches Some(_) => Err(DataFusionError::Internal(format!( - "Ambiguous reference to qualified field named '{}.{}'", - qualifier.unwrap_or(""), - name + "Ambiguous reference to qualified field named {}.{}", + qualifier + .map(|q| q.to_quoted_string()) + .unwrap_or("".to_string()), + quote_identifier(name) ))), }, } @@ -237,23 +242,20 @@ impl DFSchema { /// Find the index of the column with the given qualifier and name pub fn index_of_column(&self, col: &Column) -> Result { - let qualifier = col.relation.as_deref(); - self.index_of_column_by_name(col.relation.as_deref(), &col.name)? - .ok_or_else(|| { - field_not_found(qualifier.map(|s| s.to_string()), &col.name, self) - }) + self.index_of_column_by_name(col.relation.as_ref(), &col.name)? + .ok_or_else(|| field_not_found(col.relation.clone(), &col.name, self)) } /// Check if the column is in the current schema pub fn is_column_from_schema(&self, col: &Column) -> Result { - self.index_of_column_by_name(col.relation.as_deref(), &col.name) + self.index_of_column_by_name(col.relation.as_ref(), &col.name) .map(|idx| idx.is_some()) } /// Find the field with the given name pub fn field_with_name( &self, - qualifier: Option<&str>, + qualifier: Option<&TableReference>, name: &str, ) -> Result<&DFField> { if let Some(qualifier) = qualifier { @@ -264,7 +266,7 @@ impl DFSchema { } /// Find all fields having the given qualifier - pub fn fields_with_qualified(&self, qualifier: &str) -> Vec<&DFField> { + pub fn fields_with_qualified(&self, qualifier: &TableReference) -> Vec<&DFField> { self.fields .iter() .filter(|field| field.qualifier().map(|q| q.eq(qualifier)).unwrap_or(false)) @@ -283,7 +285,7 @@ impl DFSchema { pub fn field_with_unqualified_name(&self, name: &str) -> Result<&DFField> { let matches = self.fields_with_unqualified_name(name); match matches.len() { - 0 => Err(field_not_found(None, name, self)), + 0 => Err(unqualified_field_not_found(name, self)), 1 => Ok(matches[0]), _ => { // When `matches` size > 1, it doesn't necessarily mean an `ambiguous name` problem. @@ -302,8 +304,10 @@ impl DFSchema { } else { Err(DataFusionError::SchemaError( SchemaError::AmbiguousReference { - qualifier: None, - name: name.to_string(), + field: Column { + relation: None, + name: name.to_string(), + }, }, )) } @@ -314,7 +318,7 @@ impl DFSchema { /// Find the field with the given qualified name pub fn field_with_qualified_name( &self, - qualifier: &str, + qualifier: &TableReference, name: &str, ) -> Result<&DFField> { let idx = self @@ -338,7 +342,11 @@ impl DFSchema { } /// Find if the field exists with the given qualified name - pub fn has_column_with_qualified_name(&self, qualifier: &str, name: &str) -> bool { + pub fn has_column_with_qualified_name( + &self, + qualifier: &TableReference, + name: &str, + ) -> bool { self.fields().iter().any(|field| { field.qualifier().map(|q| q.eq(qualifier)).unwrap_or(false) && field.name() == name @@ -452,12 +460,13 @@ impl DFSchema { } /// Replace all field qualifier with new value in schema - pub fn replace_qualifier(self, qualifier: &str) -> Self { + pub fn replace_qualifier(self, qualifier: impl Into) -> Self { + let qualifier = qualifier.into(); DFSchema { fields: self .fields .into_iter() - .map(|f| DFField::from_qualified(qualifier, f.field)) + .map(|f| DFField::from_qualified(qualifier.clone(), f.field)) .collect(), ..self } @@ -621,21 +630,29 @@ impl ExprSchema for DFSchema { #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct DFField { /// Optional qualifier (usually a table or relation name) - qualifier: Option, + qualifier: Option, /// Arrow field definition field: Field, } impl DFField { /// Creates a new `DFField` - pub fn new( - qualifier: Option<&str>, + pub fn new>( + qualifier: Option, name: &str, data_type: DataType, nullable: bool, ) -> Self { DFField { - qualifier: qualifier.map(|s| s.to_owned()), + qualifier: qualifier.map(|s| s.into()), + field: Field::new(name, data_type, nullable), + } + } + + /// Convenience method for creating new `DFField` without a qualifier + pub fn new_unqualified(name: &str, data_type: DataType, nullable: bool) -> Self { + DFField { + qualifier: None, field: Field::new(name, data_type, nullable), } } @@ -649,9 +666,12 @@ impl DFField { } /// Create a qualified field from an existing Arrow field - pub fn from_qualified(qualifier: &str, field: Field) -> Self { + pub fn from_qualified<'a>( + qualifier: impl Into>, + field: Field, + ) -> Self { Self { - qualifier: Some(qualifier.to_owned()), + qualifier: Some(qualifier.into().to_owned_reference()), field, } } @@ -697,7 +717,7 @@ impl DFField { } /// Get the optional qualifier - pub fn qualifier(&self) -> Option<&String> { + pub fn qualifier(&self) -> Option<&OwnedTableReference> { self.qualifier.as_ref() } @@ -723,9 +743,9 @@ mod tests { let col = Column::from_name("t1.c0"); let schema = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; // lookup with unqualified name "t1.c0" - let err = schema.index_of_column(&col).err().unwrap(); + let err = schema.index_of_column(&col).unwrap_err(); assert_eq!( - "Schema error: No field named 't1.c0'. Valid fields are 't1'.'c0', 't1'.'c1'.", + r#"Schema error: No field named "t1.c0". Valid fields are "t1"."c0", "t1"."c1"."#, &format!("{err}") ); Ok(()) @@ -781,8 +801,12 @@ mod tests { join.to_string() ); // test valid access - assert!(join.field_with_qualified_name("t1", "c0").is_ok()); - assert!(join.field_with_qualified_name("t2", "c0").is_ok()); + assert!(join + .field_with_qualified_name(&TableReference::bare("t1"), "c0") + .is_ok()); + assert!(join + .field_with_qualified_name(&TableReference::bare("t2"), "c0") + .is_ok()); // test invalid access assert!(join.field_with_unqualified_name("c0").is_err()); assert!(join.field_with_unqualified_name("t1.c0").is_err()); @@ -798,7 +822,7 @@ mod tests { assert!(join.is_err()); assert_eq!( "Schema error: Schema contains duplicate \ - qualified field name \'t1\'.\'c0\'", + qualified field name \"t1\".\"c0\"", &format!("{}", join.err().unwrap()) ); Ok(()) @@ -812,7 +836,7 @@ mod tests { assert!(join.is_err()); assert_eq!( "Schema error: Schema contains duplicate \ - unqualified field name \'c0\'", + unqualified field name \"c0\"", &format!("{}", join.err().unwrap()) ); Ok(()) @@ -828,14 +852,18 @@ mod tests { join.to_string() ); // test valid access - assert!(join.field_with_qualified_name("t1", "c0").is_ok()); + assert!(join + .field_with_qualified_name(&TableReference::bare("t1"), "c0") + .is_ok()); assert!(join.field_with_unqualified_name("c0").is_ok()); assert!(join.field_with_unqualified_name("c100").is_ok()); assert!(join.field_with_name(None, "c100").is_ok()); // test invalid access assert!(join.field_with_unqualified_name("t1.c0").is_err()); assert!(join.field_with_unqualified_name("t1.c100").is_err()); - assert!(join.field_with_qualified_name("", "c100").is_err()); + assert!(join + .field_with_qualified_name(&TableReference::bare(""), "c100") + .is_err()); Ok(()) } @@ -847,7 +875,7 @@ mod tests { assert!(join.is_err()); assert_eq!( "Schema error: Schema contains qualified \ - field name \'t1\'.\'c0\' and unqualified field name \'c0\' which would be ambiguous", + field name \"t1\".\"c0\" and unqualified field name \"c0\" which would be ambiguous", &format!("{}", join.err().unwrap()) ); Ok(()) @@ -857,11 +885,11 @@ mod tests { #[test] fn helpful_error_messages() -> Result<()> { let schema = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; - let expected_help = "Valid fields are \'t1\'.\'c0\', \'t1\'.\'c1\'."; + let expected_help = "Valid fields are \"t1\".\"c0\", \"t1\".\"c1\"."; // Pertinent message parts - let expected_err_msg = "Fully qualified field name \'t1.c0\'"; + let expected_err_msg = "Fully qualified field name 't1.c0'"; assert!(schema - .field_with_qualified_name("x", "y") + .field_with_qualified_name(&TableReference::bare("x"), "y") .unwrap_err() .to_string() .contains(expected_help)); @@ -889,12 +917,15 @@ mod tests { let col = Column::from_qualified_name("t1.c0"); let err = schema.index_of_column(&col).err().unwrap(); - assert_eq!("Schema error: No field named 't1'.'c0'.", &format!("{err}")); + assert_eq!( + r#"Schema error: No field named "t1"."c0"."#, + &format!("{err}") + ); // the same check without qualifier let col = Column::from_name("c0"); let err = schema.index_of_column(&col).err().unwrap(); - assert_eq!("Schema error: No field named 'c0'.", &format!("{err}")); + assert_eq!(r#"Schema error: No field named "c0"."#, &format!("{err}")); } #[test] @@ -1127,7 +1158,7 @@ mod tests { let arrow_schema_ref = Arc::new(arrow_schema.clone()); let df_schema = DFSchema::new_with_metadata( - vec![DFField::new(None, "c0", DataType::Int64, true)], + vec![DFField::new_unqualified("c0", DataType::Int64, true)], metadata, ) .unwrap(); diff --git a/datafusion/common/src/error.rs b/datafusion/common/src/error.rs index 3f10c1261cff..c1e11fe93bbb 100644 --- a/datafusion/common/src/error.rs +++ b/datafusion/common/src/error.rs @@ -23,7 +23,8 @@ use std::io; use std::result; use std::sync::Arc; -use crate::{Column, DFSchema}; +use crate::utils::quote_identifier; +use crate::{Column, DFSchema, OwnedTableReference}; #[cfg(feature = "avro")] use apache_avro::Error as AvroError; use arrow::error::ArrowError; @@ -97,10 +98,7 @@ pub enum DataFusionError { #[macro_export] macro_rules! context { ($desc:expr, $err:expr) => { - datafusion_common::DataFusionError::Context( - format!("{} at {}:{}", $desc, file!(), line!()), - Box::new($err), - ) + $err.context(format!("{} at {}:{}", $desc, file!(), line!())) }; } @@ -120,29 +118,41 @@ macro_rules! plan_err { #[derive(Debug)] pub enum SchemaError { /// Schema contains a (possibly) qualified and unqualified field with same unqualified name - AmbiguousReference { - qualifier: Option, + AmbiguousReference { field: Column }, + /// Schema contains duplicate qualified field name + DuplicateQualifiedField { + qualifier: Box, name: String, }, - /// Schema contains duplicate qualified field name - DuplicateQualifiedField { qualifier: String, name: String }, /// Schema contains duplicate unqualified field name DuplicateUnqualifiedField { name: String }, /// No field with this name FieldNotFound { - field: Column, + field: Box, valid_fields: Vec, }, } /// Create a "field not found" DataFusion::SchemaError -pub fn field_not_found( - qualifier: Option, +pub fn field_not_found>( + qualifier: Option, name: &str, schema: &DFSchema, ) -> DataFusionError { DataFusionError::SchemaError(SchemaError::FieldNotFound { - field: Column::new(qualifier, name), + field: Box::new(Column::new(qualifier, name)), + valid_fields: schema + .fields() + .iter() + .map(|f| f.qualified_column()) + .collect(), + }) +} + +/// Convenience wrapper over [`field_not_found`] for when there is no qualifier +pub fn unqualified_field_not_found(name: &str, schema: &DFSchema) -> DataFusionError { + DataFusionError::SchemaError(SchemaError::FieldNotFound { + field: Box::new(Column::new_unqualified(name)), valid_fields: schema .fields() .iter() @@ -158,25 +168,14 @@ impl Display for SchemaError { field, valid_fields, } => { - write!(f, "No field named ")?; - if let Some(q) = &field.relation { - write!(f, "'{}'.'{}'", q, field.name)?; - } else { - write!(f, "'{}'", field.name)?; - } + write!(f, "No field named {}", field.quoted_flat_name())?; if !valid_fields.is_empty() { write!( f, ". Valid fields are {}", valid_fields .iter() - .map(|field| { - if let Some(q) = &field.relation { - format!("'{}'.'{}'", q, field.name) - } else { - format!("'{}'", field.name) - } - }) + .map(|field| field.quoted_flat_name()) .collect::>() .join(", ") )?; @@ -186,20 +185,32 @@ impl Display for SchemaError { Self::DuplicateQualifiedField { qualifier, name } => { write!( f, - "Schema contains duplicate qualified field name '{qualifier}'.'{name}'" + "Schema contains duplicate qualified field name {}.{}", + qualifier.to_quoted_string(), + quote_identifier(name) ) } Self::DuplicateUnqualifiedField { name } => { write!( f, - "Schema contains duplicate unqualified field name '{name}'" + "Schema contains duplicate unqualified field name {}", + quote_identifier(name) ) } - Self::AmbiguousReference { qualifier, name } => { - if let Some(q) = qualifier { - write!(f, "Schema contains qualified field name '{q}'.'{name}' and unqualified field name '{name}' which would be ambiguous") + Self::AmbiguousReference { field } => { + if field.relation.is_some() { + write!( + f, + "Schema contains qualified field name {} and unqualified field name {} which would be ambiguous", + field.quoted_flat_name(), + quote_identifier(&field.name) + ) } else { - write!(f, "Ambiguous reference to unqualified field '{name}'") + write!( + f, + "Ambiguous reference to unqualified field {}", + field.quoted_flat_name() + ) } } } @@ -403,6 +414,11 @@ impl DataFusionError { // return last checkpoint (which may be the original error) last_datafusion_error } + + /// wraps self in Self::Context with a description + pub fn context(self, description: impl Into) -> Self { + Self::Context(description.into(), Box::new(self)) + } } #[cfg(test)] diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index 636feb21a489..4af8720b009a 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -34,7 +34,10 @@ pub mod utils; use arrow::compute::SortOptions; pub use column::Column; pub use dfschema::{DFField, DFSchema, DFSchemaRef, ExprSchema, ToDFSchema}; -pub use error::{field_not_found, DataFusionError, Result, SchemaError, SharedResult}; +pub use error::{ + field_not_found, unqualified_field_not_found, DataFusionError, Result, SchemaError, + SharedResult, +}; pub use parsers::parse_interval; pub use scalar::{ScalarType, ScalarValue}; pub use stats::{ColumnStatistics, Statistics}; diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 5dc0425869c2..92cdab3ebba3 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -43,7 +43,14 @@ use arrow::{ DECIMAL128_MAX_PRECISION, }, }; -use chrono::{Datelike, Duration, NaiveDate, NaiveDateTime}; +use arrow_array::timezone::Tz; +use chrono::{DateTime, Datelike, Duration, NaiveDate, NaiveDateTime, TimeZone}; + +// Constants we use throughout this file: +const MILLISECS_IN_ONE_DAY: i64 = 86_400_000; +const NANOSECS_IN_ONE_DAY: i64 = 86_400_000_000_000; +const MILLISECS_IN_ONE_MONTH: i64 = 2_592_000_000; // assuming 30 days. +const NANOSECS_IN_ONE_MONTH: i128 = 2_592_000_000_000_000; // assuming 30 days. /// Represents a dynamically typed, nullable single value. /// This is the single-valued counter-part to arrow's [`Array`]. @@ -199,10 +206,28 @@ impl PartialEq for ScalarValue { (TimestampNanosecond(v1, _), TimestampNanosecond(v2, _)) => v1.eq(v2), (TimestampNanosecond(_, _), _) => false, (IntervalYearMonth(v1), IntervalYearMonth(v2)) => v1.eq(v2), + (IntervalYearMonth(v1), IntervalDayTime(v2)) => { + ym_to_milli(v1).eq(&dt_to_milli(v2)) + } + (IntervalYearMonth(v1), IntervalMonthDayNano(v2)) => { + ym_to_nano(v1).eq(&mdn_to_nano(v2)) + } (IntervalYearMonth(_), _) => false, (IntervalDayTime(v1), IntervalDayTime(v2)) => v1.eq(v2), + (IntervalDayTime(v1), IntervalYearMonth(v2)) => { + dt_to_milli(v1).eq(&ym_to_milli(v2)) + } + (IntervalDayTime(v1), IntervalMonthDayNano(v2)) => { + dt_to_nano(v1).eq(&mdn_to_nano(v2)) + } (IntervalDayTime(_), _) => false, (IntervalMonthDayNano(v1), IntervalMonthDayNano(v2)) => v1.eq(v2), + (IntervalMonthDayNano(v1), IntervalYearMonth(v2)) => { + mdn_to_nano(v1).eq(&ym_to_nano(v2)) + } + (IntervalMonthDayNano(v1), IntervalDayTime(v2)) => { + mdn_to_nano(v1).eq(&dt_to_nano(v2)) + } (IntervalMonthDayNano(_), _) => false, (Struct(v1, t1), Struct(v2, t2)) => v1.eq(v2) && t1.eq(t2), (Struct(_, _), _) => false, @@ -304,10 +329,28 @@ impl PartialOrd for ScalarValue { } (TimestampNanosecond(_, _), _) => None, (IntervalYearMonth(v1), IntervalYearMonth(v2)) => v1.partial_cmp(v2), + (IntervalYearMonth(v1), IntervalDayTime(v2)) => { + ym_to_milli(v1).partial_cmp(&dt_to_milli(v2)) + } + (IntervalYearMonth(v1), IntervalMonthDayNano(v2)) => { + ym_to_nano(v1).partial_cmp(&mdn_to_nano(v2)) + } (IntervalYearMonth(_), _) => None, (IntervalDayTime(v1), IntervalDayTime(v2)) => v1.partial_cmp(v2), + (IntervalDayTime(v1), IntervalYearMonth(v2)) => { + dt_to_milli(v1).partial_cmp(&ym_to_milli(v2)) + } + (IntervalDayTime(v1), IntervalMonthDayNano(v2)) => { + dt_to_nano(v1).partial_cmp(&mdn_to_nano(v2)) + } (IntervalDayTime(_), _) => None, (IntervalMonthDayNano(v1), IntervalMonthDayNano(v2)) => v1.partial_cmp(v2), + (IntervalMonthDayNano(v1), IntervalYearMonth(v2)) => { + mdn_to_nano(v1).partial_cmp(&ym_to_nano(v2)) + } + (IntervalMonthDayNano(v1), IntervalDayTime(v2)) => { + mdn_to_nano(v1).partial_cmp(&dt_to_nano(v2)) + } (IntervalMonthDayNano(_), _) => None, (Struct(v1, t1), Struct(v2, t2)) => { if t1.eq(t2) { @@ -332,6 +375,52 @@ impl PartialOrd for ScalarValue { } } +/// This function computes the duration (in milliseconds) of the given +/// year-month-interval. +#[inline] +fn ym_to_milli(val: &Option) -> Option { + val.map(|value| (value as i64) * MILLISECS_IN_ONE_MONTH) +} + +/// This function computes the duration (in nanoseconds) of the given +/// year-month-interval. +#[inline] +fn ym_to_nano(val: &Option) -> Option { + val.map(|value| (value as i128) * NANOSECS_IN_ONE_MONTH) +} + +/// This function computes the duration (in milliseconds) of the given +/// daytime-interval. +#[inline] +fn dt_to_milli(val: &Option) -> Option { + val.map(|val| { + let (days, millis) = IntervalDayTimeType::to_parts(val); + (days as i64) * MILLISECS_IN_ONE_DAY + (millis as i64) + }) +} + +/// This function computes the duration (in nanoseconds) of the given +/// daytime-interval. +#[inline] +fn dt_to_nano(val: &Option) -> Option { + val.map(|val| { + let (days, millis) = IntervalDayTimeType::to_parts(val); + (days as i128) * (NANOSECS_IN_ONE_DAY as i128) + (millis as i128) * 1_000_000 + }) +} + +/// This function computes the duration (in nanoseconds) of the given +/// month-day-nano-interval. Assumes a month is 30 days long. +#[inline] +fn mdn_to_nano(val: &Option) -> Option { + val.map(|val| { + let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(val); + (months as i128) * NANOSECS_IN_ONE_MONTH + + (days as i128) * (NANOSECS_IN_ONE_DAY as i128) + + (nanos as i128) + }) +} + impl Eq for ScalarValue {} // TODO implement this in arrow-rs with simd @@ -464,6 +553,71 @@ macro_rules! unsigned_subtraction_error { } macro_rules! impl_op { + ($LHS:expr, $RHS:expr, +) => { + impl_op_arithmetic!($LHS, $RHS, +) + }; + ($LHS:expr, $RHS:expr, -) => { + match ($LHS, $RHS) { + ( + ScalarValue::TimestampSecond(Some(ts_lhs), tz_lhs), + ScalarValue::TimestampSecond(Some(ts_rhs), tz_rhs), + ) => { + let err = || { + DataFusionError::Execution( + "Overflow while converting seconds to milliseconds".to_string(), + ) + }; + ts_sub_to_interval( + ts_lhs.checked_mul(1_000).ok_or_else(err)?, + ts_rhs.checked_mul(1_000).ok_or_else(err)?, + &tz_lhs, + &tz_rhs, + IntervalMode::Milli, + ) + }, + ( + ScalarValue::TimestampMillisecond(Some(ts_lhs), tz_lhs), + ScalarValue::TimestampMillisecond(Some(ts_rhs), tz_rhs), + ) => ts_sub_to_interval( + *ts_lhs, + *ts_rhs, + tz_lhs, + tz_rhs, + IntervalMode::Milli, + ), + ( + ScalarValue::TimestampMicrosecond(Some(ts_lhs), tz_lhs), + ScalarValue::TimestampMicrosecond(Some(ts_rhs), tz_rhs), + ) => { + let err = || { + DataFusionError::Execution( + "Overflow while converting microseconds to nanoseconds".to_string(), + ) + }; + ts_sub_to_interval( + ts_lhs.checked_mul(1_000).ok_or_else(err)?, + ts_rhs.checked_mul(1_000).ok_or_else(err)?, + tz_lhs, + tz_rhs, + IntervalMode::Nano, + ) + }, + ( + ScalarValue::TimestampNanosecond(Some(ts_lhs), tz_lhs), + ScalarValue::TimestampNanosecond(Some(ts_rhs), tz_rhs), + ) => ts_sub_to_interval( + *ts_lhs, + *ts_rhs, + tz_lhs, + tz_rhs, + IntervalMode::Nano, + ), + _ => impl_op_arithmetic!($LHS, $RHS, -) + } + }; +} + +macro_rules! impl_op_arithmetic { ($LHS:expr, $RHS:expr, $OPERATION:tt) => { match ($LHS, $RHS) { // Binary operations on arguments with the same type: @@ -503,6 +657,40 @@ macro_rules! impl_op { (ScalarValue::Int8(lhs), ScalarValue::Int8(rhs)) => { primitive_op!(lhs, rhs, Int8, $OPERATION) } + ( + ScalarValue::IntervalYearMonth(Some(lhs)), + ScalarValue::IntervalYearMonth(Some(rhs)), + ) => Ok(ScalarValue::new_interval_ym( + 0, + lhs + rhs * get_sign!($OPERATION), + )), + ( + ScalarValue::IntervalDayTime(Some(lhs)), + ScalarValue::IntervalDayTime(Some(rhs)), + ) => { + let sign = get_sign!($OPERATION); + let (lhs_days, lhs_millis) = IntervalDayTimeType::to_parts(*lhs); + let (rhs_days, rhs_millis) = IntervalDayTimeType::to_parts(*rhs); + Ok(ScalarValue::new_interval_dt( + lhs_days + rhs_days * sign, + lhs_millis + rhs_millis * sign, + )) + } + ( + ScalarValue::IntervalMonthDayNano(Some(lhs)), + ScalarValue::IntervalMonthDayNano(Some(rhs)), + ) => { + let sign = get_sign!($OPERATION); + let (lhs_months, lhs_days, lhs_nanos) = + IntervalMonthDayNanoType::to_parts(*lhs); + let (rhs_months, rhs_days, rhs_nanos) = + IntervalMonthDayNanoType::to_parts(*rhs); + Ok(ScalarValue::new_interval_mdn( + lhs_months + rhs_months * sign, + lhs_days + rhs_days * sign, + lhs_nanos + rhs_nanos * (sign as i64), + )) + } // Binary operations on arguments with different types: (ScalarValue::Date32(Some(days)), _) => { let value = date32_add(*days, $RHS, get_sign!($OPERATION))?; @@ -544,6 +732,30 @@ macro_rules! impl_op { let value = nanoseconds_add(*ts_ns, $LHS, get_sign!($OPERATION))?; Ok(ScalarValue::TimestampNanosecond(Some(value), zone.clone())) } + ( + ScalarValue::IntervalYearMonth(Some(lhs)), + ScalarValue::IntervalDayTime(Some(rhs)), + ) => op_ym_dt(*lhs, *rhs, get_sign!($OPERATION), false), + ( + ScalarValue::IntervalYearMonth(Some(lhs)), + ScalarValue::IntervalMonthDayNano(Some(rhs)), + ) => op_ym_mdn(*lhs, *rhs, get_sign!($OPERATION), false), + ( + ScalarValue::IntervalDayTime(Some(lhs)), + ScalarValue::IntervalYearMonth(Some(rhs)), + ) => op_ym_dt(*rhs, *lhs, get_sign!($OPERATION), true), + ( + ScalarValue::IntervalDayTime(Some(lhs)), + ScalarValue::IntervalMonthDayNano(Some(rhs)), + ) => op_dt_mdn(*lhs, *rhs, get_sign!($OPERATION), false), + ( + ScalarValue::IntervalMonthDayNano(Some(lhs)), + ScalarValue::IntervalYearMonth(Some(rhs)), + ) => op_ym_mdn(*rhs, *lhs, get_sign!($OPERATION), true), + ( + ScalarValue::IntervalMonthDayNano(Some(lhs)), + ScalarValue::IntervalDayTime(Some(rhs)), + ) => op_dt_mdn(*rhs, *lhs, get_sign!($OPERATION), true), _ => Err(DataFusionError::Internal(format!( "Operator {} is not implemented for types {:?} and {:?}", stringify!($OPERATION), @@ -554,6 +766,68 @@ macro_rules! impl_op { }; } +/// This function adds/subtracts two "raw" intervals (`lhs` and `rhs`) of different +/// types ([`IntervalYearMonthType`] and [`IntervalDayTimeType`], respectively). +/// The argument `sign` chooses between addition and subtraction, the argument +/// `commute` swaps `lhs` and `rhs`. The return value is an interval [`ScalarValue`] +/// with type data type [`IntervalMonthDayNanoType`]. +#[inline] +fn op_ym_dt(mut lhs: i32, rhs: i64, sign: i32, commute: bool) -> Result { + let (mut days, millis) = IntervalDayTimeType::to_parts(rhs); + let mut nanos = (millis as i64) * 1_000_000; + if commute { + lhs *= sign; + } else { + days *= sign; + nanos *= sign as i64; + }; + Ok(ScalarValue::new_interval_mdn(lhs, days, nanos)) +} + +/// This function adds/subtracts two "raw" intervals (`lhs` and `rhs`) of different +/// types ([`IntervalYearMonthType`] and [`IntervalMonthDayNanoType`], respectively). +/// The argument `sign` chooses between addition and subtraction, the argument +/// `commute` swaps `lhs` and `rhs`. The return value is an interval [`ScalarValue`] +/// with type data type [`IntervalMonthDayNanoType`]. +#[inline] +fn op_ym_mdn(lhs: i32, rhs: i128, sign: i32, commute: bool) -> Result { + let (mut months, mut days, mut nanos) = IntervalMonthDayNanoType::to_parts(rhs); + if commute { + months += lhs * sign; + } else { + months = lhs + (months * sign); + days *= sign; + nanos *= sign as i64; + } + Ok(ScalarValue::new_interval_mdn(months, days, nanos)) +} + +/// This function adds/subtracts two "raw" intervals (`lhs` and `rhs`) of different +/// types ([`IntervalDayTimeType`] and [`IntervalMonthDayNanoType`], respectively). +/// The argument `sign` chooses between addition and subtraction, the argument +/// `commute` swaps `lhs` and `rhs`. The return value is an interval [`ScalarValue`] +/// with type data type [`IntervalMonthDayNanoType`]. +#[inline] +fn op_dt_mdn(lhs: i64, rhs: i128, sign: i32, commute: bool) -> Result { + let (lhs_days, lhs_millis) = IntervalDayTimeType::to_parts(lhs); + let (rhs_months, rhs_days, rhs_nanos) = IntervalMonthDayNanoType::to_parts(rhs); + + let result = if commute { + IntervalMonthDayNanoType::make_value( + rhs_months, + lhs_days * sign + rhs_days, + (lhs_millis * sign) as i64 * 1_000_000 + rhs_nanos, + ) + } else { + IntervalMonthDayNanoType::make_value( + rhs_months * sign, + lhs_days + rhs_days * sign, + (lhs_millis as i64) * 1_000_000 + rhs_nanos * (sign as i64), + ) + }; + Ok(ScalarValue::IntervalMonthDayNano(Some(result))) +} + macro_rules! get_sign { (+) => { 1 @@ -563,46 +837,138 @@ macro_rules! get_sign { }; } +#[derive(Clone, Copy)] +enum IntervalMode { + Milli, + Nano, +} + +/// This function computes subtracts `rhs_ts` from `lhs_ts`, taking timezones +/// into account when given. Units of the resulting interval is specified by +/// the argument `mode`. +/// The default behavior of Datafusion is the following: +/// - When subtracting timestamps at seconds/milliseconds precision, the output +/// interval will have the type [`IntervalDayTimeType`]. +/// - When subtracting timestamps at microseconds/nanoseconds precision, the +/// output interval will have the type [`IntervalMonthDayNanoType`]. +fn ts_sub_to_interval( + lhs_ts: i64, + rhs_ts: i64, + lhs_tz: &Option, + rhs_tz: &Option, + mode: IntervalMode, +) -> Result { + let lhs_dt = with_timezone_to_naive_datetime(lhs_ts, lhs_tz, mode)?; + let rhs_dt = with_timezone_to_naive_datetime(rhs_ts, rhs_tz, mode)?; + let delta_secs = lhs_dt.signed_duration_since(rhs_dt); + + match mode { + IntervalMode::Milli => { + let as_millisecs = delta_secs.num_milliseconds(); + Ok(ScalarValue::new_interval_dt( + (as_millisecs / MILLISECS_IN_ONE_DAY) as i32, + (as_millisecs % MILLISECS_IN_ONE_DAY) as i32, + )) + } + IntervalMode::Nano => { + let as_nanosecs = delta_secs.num_nanoseconds().ok_or_else(|| { + DataFusionError::Execution(String::from( + "Can not compute timestamp differences with nanosecond precision", + )) + })?; + Ok(ScalarValue::new_interval_mdn( + 0, + (as_nanosecs / NANOSECS_IN_ONE_DAY) as i32, + as_nanosecs % NANOSECS_IN_ONE_DAY, + )) + } + } +} + +/// This function creates the [`NaiveDateTime`] object corresponding to the +/// given timestamp using the units (tick size) implied by argument `mode`. +#[inline] +fn with_timezone_to_naive_datetime( + ts: i64, + tz: &Option, + mode: IntervalMode, +) -> Result { + let datetime = if let IntervalMode::Milli = mode { + ticks_to_naive_datetime::<1_000_000>(ts) + } else { + ticks_to_naive_datetime::<1>(ts) + }?; + + if let Some(tz) = tz { + let parsed_tz: Tz = FromStr::from_str(tz).map_err(|_| { + DataFusionError::Execution("cannot parse given timezone".to_string()) + })?; + let offset = parsed_tz + .offset_from_local_datetime(&datetime) + .single() + .ok_or_else(|| { + DataFusionError::Execution( + "error conversion result of timezone offset".to_string(), + ) + })?; + return Ok(DateTime::::from_local(datetime, offset).naive_utc()); + } + Ok(datetime) +} + +/// This function creates the [`NaiveDateTime`] object corresponding to the +/// given timestamp, whose tick size is specified by `UNIT_NANOS`. +#[inline] +fn ticks_to_naive_datetime(ticks: i64) -> Result { + NaiveDateTime::from_timestamp_opt( + (ticks * UNIT_NANOS) / 1_000_000_000, + ((ticks * UNIT_NANOS) % 1_000_000_000) as u32, + ) + .ok_or_else(|| { + DataFusionError::Execution( + "Can not convert given timestamp to a NaiveDateTime".to_string(), + ) + }) +} + #[inline] pub fn date32_add(days: i32, scalar: &ScalarValue, sign: i32) -> Result { let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); let prior = epoch.add(Duration::days(days as i64)); - let posterior = do_date_math(prior, scalar, sign)?; - Ok(posterior.sub(epoch).num_days() as i32) + do_date_math(prior, scalar, sign).map(|d| d.sub(epoch).num_days() as i32) } #[inline] pub fn date64_add(ms: i64, scalar: &ScalarValue, sign: i32) -> Result { let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); let prior = epoch.add(Duration::milliseconds(ms)); - let posterior = do_date_math(prior, scalar, sign)?; - Ok(posterior.sub(epoch).num_milliseconds()) + do_date_math(prior, scalar, sign).map(|d| d.sub(epoch).num_milliseconds()) } #[inline] pub fn seconds_add(ts_s: i64, scalar: &ScalarValue, sign: i32) -> Result { - Ok(do_date_time_math(ts_s, 0, scalar, sign)?.timestamp()) + do_date_time_math(ts_s, 0, scalar, sign).map(|dt| dt.timestamp()) } #[inline] pub fn milliseconds_add(ts_ms: i64, scalar: &ScalarValue, sign: i32) -> Result { let secs = ts_ms / 1000; let nsecs = ((ts_ms % 1000) * 1_000_000) as u32; - Ok(do_date_time_math(secs, nsecs, scalar, sign)?.timestamp_millis()) + do_date_time_math(secs, nsecs, scalar, sign).map(|dt| dt.timestamp_millis()) } #[inline] pub fn microseconds_add(ts_us: i64, scalar: &ScalarValue, sign: i32) -> Result { let secs = ts_us / 1_000_000; let nsecs = ((ts_us % 1_000_000) * 1000) as u32; - Ok(do_date_time_math(secs, nsecs, scalar, sign)?.timestamp_nanos() / 1000) + do_date_time_math(secs, nsecs, scalar, sign).map(|dt| dt.timestamp_nanos() / 1000) } #[inline] pub fn nanoseconds_add(ts_ns: i64, scalar: &ScalarValue, sign: i32) -> Result { let secs = ts_ns / 1_000_000_000; let nsecs = (ts_ns % 1_000_000_000) as u32; - Ok(do_date_time_math(secs, nsecs, scalar, sign)?.timestamp_nanos()) + do_date_time_math(secs, nsecs, scalar, sign).map(|dt| dt.timestamp_nanos()) } #[inline] @@ -2358,7 +2724,7 @@ impl ScalarValue { None => v.is_null(), } } - ScalarValue::Null => array.data().is_null(index), + ScalarValue::Null => array.is_null(index), } } @@ -2921,6 +3287,7 @@ mod tests { use arrow::compute::kernels; use arrow::datatypes::ArrowPrimitiveType; + use rand::Rng; use crate::cast::{as_string_array, as_uint32_array, as_uint64_array}; use crate::from_slice::FromSlice; @@ -3707,6 +4074,53 @@ mod tests { ])), None ); + // Different type of intervals can be compared. + assert!( + IntervalYearMonth(Some(IntervalYearMonthType::make_value(1, 2))) + < IntervalMonthDayNano(Some(IntervalMonthDayNanoType::make_value( + 14, 0, 1 + ))), + ); + assert!( + IntervalYearMonth(Some(IntervalYearMonthType::make_value(0, 4))) + >= IntervalDayTime(Some(IntervalDayTimeType::make_value(119, 1))) + ); + assert!( + IntervalDayTime(Some(IntervalDayTimeType::make_value(12, 86_399_999))) + >= IntervalDayTime(Some(IntervalDayTimeType::make_value(12, 0))) + ); + assert!( + IntervalYearMonth(Some(IntervalYearMonthType::make_value(2, 12))) + == IntervalMonthDayNano(Some(IntervalMonthDayNanoType::make_value( + 36, 0, 0 + ))), + ); + assert!( + IntervalYearMonth(Some(IntervalYearMonthType::make_value(0, 0))) + != IntervalDayTime(Some(IntervalDayTimeType::make_value(0, 1))) + ); + assert!( + IntervalYearMonth(Some(IntervalYearMonthType::make_value(1, 4))) + == IntervalYearMonth(Some(IntervalYearMonthType::make_value(0, 16))), + ); + assert!( + IntervalYearMonth(Some(IntervalYearMonthType::make_value(0, 3))) + > IntervalMonthDayNano(Some(IntervalMonthDayNanoType::make_value( + 2, + 28, + 999_999_999 + ))), + ); + assert!( + IntervalYearMonth(Some(IntervalYearMonthType::make_value(0, 1))) + > IntervalDayTime(Some(IntervalDayTimeType::make_value(29, 9_999))), + ); + assert!( + IntervalMonthDayNano(Some(IntervalMonthDayNanoType::make_value(1, 12, 34))) + > IntervalMonthDayNano(Some(IntervalMonthDayNanoType::make_value( + 0, 142, 34 + ))) + ); } #[test] @@ -4486,4 +4900,513 @@ mod tests { assert!(distance.is_none()); } } + + #[test] + fn test_scalar_interval_add() { + let cases = [ + ( + ScalarValue::new_interval_ym(1, 12), + ScalarValue::new_interval_ym(1, 12), + ScalarValue::new_interval_ym(2, 24), + ), + ( + ScalarValue::new_interval_dt(1, 999), + ScalarValue::new_interval_dt(1, 999), + ScalarValue::new_interval_dt(2, 1998), + ), + ( + ScalarValue::new_interval_mdn(12, 15, 123_456), + ScalarValue::new_interval_mdn(12, 15, 123_456), + ScalarValue::new_interval_mdn(24, 30, 246_912), + ), + ( + ScalarValue::new_interval_ym(0, 1), + ScalarValue::new_interval_dt(29, 86_390), + ScalarValue::new_interval_mdn(1, 29, 86_390_000_000), + ), + ( + ScalarValue::new_interval_ym(0, 1), + ScalarValue::new_interval_mdn(2, 10, 999_999_999), + ScalarValue::new_interval_mdn(3, 10, 999_999_999), + ), + ( + ScalarValue::new_interval_dt(400, 123_456), + ScalarValue::new_interval_ym(1, 1), + ScalarValue::new_interval_mdn(13, 400, 123_456_000_000), + ), + ( + ScalarValue::new_interval_dt(65, 321), + ScalarValue::new_interval_mdn(2, 5, 1_000_000), + ScalarValue::new_interval_mdn(2, 70, 322_000_000), + ), + ( + ScalarValue::new_interval_mdn(12, 15, 123_456), + ScalarValue::new_interval_ym(2, 0), + ScalarValue::new_interval_mdn(36, 15, 123_456), + ), + ( + ScalarValue::new_interval_mdn(12, 15, 100_000), + ScalarValue::new_interval_dt(370, 1), + ScalarValue::new_interval_mdn(12, 385, 1_100_000), + ), + ]; + for (lhs, rhs, expected) in cases.iter() { + let result = lhs.add(rhs).unwrap(); + let result_commute = rhs.add(lhs).unwrap(); + assert_eq!(*expected, result, "lhs:{:?} + rhs:{:?}", lhs, rhs); + assert_eq!(*expected, result_commute, "lhs:{:?} + rhs:{:?}", rhs, lhs); + } + } + + #[test] + fn test_scalar_interval_sub() { + let cases = [ + ( + ScalarValue::new_interval_ym(1, 12), + ScalarValue::new_interval_ym(1, 12), + ScalarValue::new_interval_ym(0, 0), + ), + ( + ScalarValue::new_interval_dt(1, 999), + ScalarValue::new_interval_dt(1, 999), + ScalarValue::new_interval_dt(0, 0), + ), + ( + ScalarValue::new_interval_mdn(12, 15, 123_456), + ScalarValue::new_interval_mdn(12, 15, 123_456), + ScalarValue::new_interval_mdn(0, 0, 0), + ), + ( + ScalarValue::new_interval_ym(0, 1), + ScalarValue::new_interval_dt(29, 999_999), + ScalarValue::new_interval_mdn(1, -29, -999_999_000_000), + ), + ( + ScalarValue::new_interval_ym(0, 1), + ScalarValue::new_interval_mdn(2, 10, 999_999_999), + ScalarValue::new_interval_mdn(-1, -10, -999_999_999), + ), + ( + ScalarValue::new_interval_dt(400, 123_456), + ScalarValue::new_interval_ym(1, 1), + ScalarValue::new_interval_mdn(-13, 400, 123_456_000_000), + ), + ( + ScalarValue::new_interval_dt(65, 321), + ScalarValue::new_interval_mdn(2, 5, 1_000_000), + ScalarValue::new_interval_mdn(-2, 60, 320_000_000), + ), + ( + ScalarValue::new_interval_mdn(12, 15, 123_456), + ScalarValue::new_interval_ym(2, 0), + ScalarValue::new_interval_mdn(-12, 15, 123_456), + ), + ( + ScalarValue::new_interval_mdn(12, 15, 100_000), + ScalarValue::new_interval_dt(370, 1), + ScalarValue::new_interval_mdn(12, -355, -900_000), + ), + ]; + for (lhs, rhs, expected) in cases.iter() { + let result = lhs.sub(rhs).unwrap(); + assert_eq!(*expected, result, "lhs:{:?} - rhs:{:?}", lhs, rhs); + } + } + + #[test] + fn timestamp_op_tests() { + // positive interval, edge cases + let test_data = get_timestamp_test_data(1); + for (lhs, rhs, expected) in test_data.into_iter() { + assert_eq!(expected, lhs.sub(rhs).unwrap()) + } + + // negative interval, edge cases + let test_data = get_timestamp_test_data(-1); + for (rhs, lhs, expected) in test_data.into_iter() { + assert_eq!(expected, lhs.sub(rhs).unwrap()); + } + } + #[test] + fn timestamp_op_random_tests() { + // timestamp1 + (or -) interval = timestamp2 + // timestamp2 - timestamp1 (or timestamp1 - timestamp2) = interval ? + let sample_size = 1000000; + let timestamps1 = get_random_timestamps(sample_size); + let intervals = get_random_intervals(sample_size); + // ts(sec) + interval(ns) = ts(sec); however, + // ts(sec) - ts(sec) cannot be = interval(ns). Therefore, + // timestamps are more precise than intervals in tests. + for (idx, ts1) in timestamps1.iter().enumerate() { + if idx % 2 == 0 { + let timestamp2 = ts1.add(intervals[idx].clone()).unwrap(); + assert_eq!( + intervals[idx], + timestamp2.sub(ts1).unwrap(), + "index:{}, operands: {:?} (-) {:?}", + idx, + timestamp2, + ts1 + ); + } else { + let timestamp2 = ts1.sub(intervals[idx].clone()).unwrap(); + assert_eq!( + intervals[idx], + ts1.sub(timestamp2.clone()).unwrap(), + "index:{}, operands: {:?} (-) {:?}", + idx, + ts1, + timestamp2 + ); + }; + } + } + + fn get_timestamp_test_data( + sign: i32, + ) -> Vec<(ScalarValue, ScalarValue, ScalarValue)> { + vec![ + ( + // 1st test case, having the same time but different with timezones + // Since they are timestamps with nanosecond precision, expected type is + // [`IntervalMonthDayNanoType`] + ScalarValue::TimestampNanosecond( + Some( + NaiveDate::from_ymd_opt(2023, 1, 1) + .unwrap() + .and_hms_nano_opt(12, 0, 0, 000_000_000) + .unwrap() + .timestamp_nanos(), + ), + Some("+12:00".to_string()), + ), + ScalarValue::TimestampNanosecond( + Some( + NaiveDate::from_ymd_opt(2023, 1, 1) + .unwrap() + .and_hms_nano_opt(0, 0, 0, 000_000_000) + .unwrap() + .timestamp_nanos(), + ), + Some("+00:00".to_string()), + ), + ScalarValue::new_interval_mdn(0, 0, 0), + ), + // 2nd test case, january with 31 days plus february with 28 days, with timezone + ( + ScalarValue::TimestampMicrosecond( + Some( + NaiveDate::from_ymd_opt(2023, 3, 1) + .unwrap() + .and_hms_micro_opt(2, 0, 0, 000_000) + .unwrap() + .timestamp_micros(), + ), + Some("+01:00".to_string()), + ), + ScalarValue::TimestampMicrosecond( + Some( + NaiveDate::from_ymd_opt(2023, 1, 1) + .unwrap() + .and_hms_micro_opt(0, 0, 0, 000_000) + .unwrap() + .timestamp_micros(), + ), + Some("-01:00".to_string()), + ), + ScalarValue::new_interval_mdn(0, sign * 59, 0), + ), + // 3rd test case, 29-days long february minus previous, year with timezone + ( + ScalarValue::TimestampMillisecond( + Some( + NaiveDate::from_ymd_opt(2024, 2, 29) + .unwrap() + .and_hms_milli_opt(10, 10, 0, 000) + .unwrap() + .timestamp_millis(), + ), + Some("+10:10".to_string()), + ), + ScalarValue::TimestampMillisecond( + Some( + NaiveDate::from_ymd_opt(2023, 12, 31) + .unwrap() + .and_hms_milli_opt(1, 0, 0, 000) + .unwrap() + .timestamp_millis(), + ), + Some("+01:00".to_string()), + ), + ScalarValue::new_interval_dt(sign * 60, 0), + ), + // 4th test case, leap years occur mostly every 4 years, but every 100 years + // we skip a leap year unless the year is divisible by 400, so 31 + 28 = 59 + ( + ScalarValue::TimestampSecond( + Some( + NaiveDate::from_ymd_opt(2100, 3, 1) + .unwrap() + .and_hms_opt(0, 0, 0) + .unwrap() + .timestamp(), + ), + Some("-11:59".to_string()), + ), + ScalarValue::TimestampSecond( + Some( + NaiveDate::from_ymd_opt(2100, 1, 1) + .unwrap() + .and_hms_opt(23, 58, 0) + .unwrap() + .timestamp(), + ), + Some("+11:59".to_string()), + ), + ScalarValue::new_interval_dt(sign * 59, 0), + ), + // 5th test case, without timezone positively seemed, but with timezone, + // negative resulting interval + ( + ScalarValue::TimestampMillisecond( + Some( + NaiveDate::from_ymd_opt(2023, 1, 1) + .unwrap() + .and_hms_milli_opt(6, 00, 0, 000) + .unwrap() + .timestamp_millis(), + ), + Some("+06:00".to_string()), + ), + ScalarValue::TimestampMillisecond( + Some( + NaiveDate::from_ymd_opt(2023, 1, 1) + .unwrap() + .and_hms_milli_opt(0, 0, 0, 000) + .unwrap() + .timestamp_millis(), + ), + Some("-12:00".to_string()), + ), + ScalarValue::new_interval_dt(0, sign * -43_200_000), + ), + // 6th test case, no problem before unix epoch beginning + ( + ScalarValue::TimestampMicrosecond( + Some( + NaiveDate::from_ymd_opt(1970, 1, 1) + .unwrap() + .and_hms_micro_opt(1, 2, 3, 15) + .unwrap() + .timestamp_micros(), + ), + None, + ), + ScalarValue::TimestampMicrosecond( + Some( + NaiveDate::from_ymd_opt(1969, 1, 1) + .unwrap() + .and_hms_micro_opt(0, 0, 0, 000_000) + .unwrap() + .timestamp_micros(), + ), + None, + ), + ScalarValue::new_interval_mdn( + 0, + 365 * sign, + sign as i64 * 3_723_000_015_000, + ), + ), + // 7th test case, no problem with big intervals + ( + ScalarValue::TimestampNanosecond( + Some( + NaiveDate::from_ymd_opt(2100, 1, 1) + .unwrap() + .and_hms_nano_opt(0, 0, 0, 0) + .unwrap() + .timestamp_nanos(), + ), + None, + ), + ScalarValue::TimestampNanosecond( + Some( + NaiveDate::from_ymd_opt(2000, 1, 1) + .unwrap() + .and_hms_nano_opt(0, 0, 0, 000_000_000) + .unwrap() + .timestamp_nanos(), + ), + None, + ), + ScalarValue::new_interval_mdn(0, sign * 36525, 0), + ), + // 8th test case, no problem detecting 366-days long years + ( + ScalarValue::TimestampSecond( + Some( + NaiveDate::from_ymd_opt(2041, 1, 1) + .unwrap() + .and_hms_opt(0, 0, 0) + .unwrap() + .timestamp(), + ), + None, + ), + ScalarValue::TimestampSecond( + Some( + NaiveDate::from_ymd_opt(2040, 1, 1) + .unwrap() + .and_hms_opt(0, 0, 0) + .unwrap() + .timestamp(), + ), + None, + ), + ScalarValue::new_interval_dt(sign * 366, 0), + ), + // 9th test case, no problem with unrealistic timezones + ( + ScalarValue::TimestampSecond( + Some( + NaiveDate::from_ymd_opt(2023, 1, 3) + .unwrap() + .and_hms_opt(0, 0, 0) + .unwrap() + .timestamp(), + ), + Some("+23:59".to_string()), + ), + ScalarValue::TimestampSecond( + Some( + NaiveDate::from_ymd_opt(2023, 1, 1) + .unwrap() + .and_hms_opt(0, 2, 0) + .unwrap() + .timestamp(), + ), + Some("-23:59".to_string()), + ), + ScalarValue::new_interval_dt(0, 0), + ), + // 10th test case, parsing different types of timezone input + ( + ScalarValue::TimestampSecond( + Some( + NaiveDate::from_ymd_opt(2023, 3, 17) + .unwrap() + .and_hms_opt(14, 10, 0) + .unwrap() + .timestamp(), + ), + Some("Europe/Istanbul".to_string()), + ), + ScalarValue::TimestampSecond( + Some( + NaiveDate::from_ymd_opt(2023, 3, 17) + .unwrap() + .and_hms_opt(4, 10, 0) + .unwrap() + .timestamp(), + ), + Some("America/Los_Angeles".to_string()), + ), + ScalarValue::new_interval_dt(0, 0), + ), + ] + } + + fn get_random_timestamps(sample_size: u64) -> Vec { + let vector_size = sample_size; + let mut timestamp = vec![]; + let mut rng = rand::thread_rng(); + for i in 0..vector_size { + let year = rng.gen_range(1995..=2050); + let month = rng.gen_range(1..=12); + let day = rng.gen_range(1..=28); // to exclude invalid dates + let hour = rng.gen_range(0..=23); + let minute = rng.gen_range(0..=59); + let second = rng.gen_range(0..=59); + if i % 4 == 0 { + timestamp.push(ScalarValue::TimestampSecond( + Some( + NaiveDate::from_ymd_opt(year, month, day) + .unwrap() + .and_hms_opt(hour, minute, second) + .unwrap() + .timestamp(), + ), + None, + )) + } else if i % 4 == 1 { + let millisec = rng.gen_range(0..=999); + timestamp.push(ScalarValue::TimestampMillisecond( + Some( + NaiveDate::from_ymd_opt(year, month, day) + .unwrap() + .and_hms_milli_opt(hour, minute, second, millisec) + .unwrap() + .timestamp_millis(), + ), + None, + )) + } else if i % 4 == 2 { + let microsec = rng.gen_range(0..=999_999); + timestamp.push(ScalarValue::TimestampMicrosecond( + Some( + NaiveDate::from_ymd_opt(year, month, day) + .unwrap() + .and_hms_micro_opt(hour, minute, second, microsec) + .unwrap() + .timestamp_micros(), + ), + None, + )) + } else if i % 4 == 3 { + let nanosec = rng.gen_range(0..=999_999_999); + timestamp.push(ScalarValue::TimestampNanosecond( + Some( + NaiveDate::from_ymd_opt(year, month, day) + .unwrap() + .and_hms_nano_opt(hour, minute, second, nanosec) + .unwrap() + .timestamp_nanos(), + ), + None, + )) + } + } + timestamp + } + + fn get_random_intervals(sample_size: u64) -> Vec { + let vector_size = sample_size; + let mut intervals = vec![]; + let mut rng = rand::thread_rng(); + const SECS_IN_ONE_DAY: i32 = 86_400; + const MICROSECS_IN_ONE_DAY: i64 = 86_400_000_000; + for i in 0..vector_size { + if i % 4 == 0 { + let days = rng.gen_range(0..5000); + // to not break second precision + let millis = rng.gen_range(0..SECS_IN_ONE_DAY) * 1000; + intervals.push(ScalarValue::new_interval_dt(days, millis)); + } else if i % 4 == 1 { + let days = rng.gen_range(0..5000); + let millisec = rng.gen_range(0..(MILLISECS_IN_ONE_DAY as i32)); + intervals.push(ScalarValue::new_interval_dt(days, millisec)); + } else if i % 4 == 2 { + let days = rng.gen_range(0..5000); + // to not break microsec precision + let nanosec = rng.gen_range(0..MICROSECS_IN_ONE_DAY) * 1000; + intervals.push(ScalarValue::new_interval_mdn(0, days, nanosec)); + } else { + let days = rng.gen_range(0..5000); + let nanosec = rng.gen_range(0..NANOSECS_IN_ONE_DAY); + intervals.push(ScalarValue::new_interval_mdn(0, days, nanosec)); + } + } + intervals + } } diff --git a/datafusion/common/src/table_reference.rs b/datafusion/common/src/table_reference.rs index 34656bc114a6..d0829702d07f 100644 --- a/datafusion/common/src/table_reference.rs +++ b/datafusion/common/src/table_reference.rs @@ -15,13 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::error::Result; -use sqlparser::{ - ast::Ident, - dialect::GenericDialect, - parser::{Parser, ParserError}, - tokenizer::{Token, TokenWithLocation}, -}; +use crate::utils::{parse_identifiers_normalized, quote_identifier}; use std::borrow::Cow; /// A resolved path to a table of the form "catalog.schema.table" @@ -41,8 +35,39 @@ impl<'a> std::fmt::Display for ResolvedTableReference<'a> { } } -/// Represents a path to a table that may require further resolution -#[derive(Debug, Clone, PartialEq, Eq)] +/// [`TableReference`]s represent a multi part identifier (path) to a +/// table that may require further resolution. +/// +/// # Creating [`TableReference`] +/// +/// When converting strings to [`TableReference`]s, the string is +/// parsed as though it were a SQL identifier, normalizing (convert to +/// lowercase) any unquoted identifiers. +/// +/// See [`TableReference::bare`] to create references without applying +/// normalization semantics +/// +/// # Examples +/// ``` +/// # use datafusion_common::TableReference; +/// // Get a table reference to 'mytable' +/// let table_reference = TableReference::from("mytable"); +/// assert_eq!(table_reference, TableReference::bare("mytable")); +/// +/// // Get a table reference to 'mytable' (note the capitalization) +/// let table_reference = TableReference::from("MyTable"); +/// assert_eq!(table_reference, TableReference::bare("mytable")); +/// +/// // Get a table reference to 'MyTable' (note the capitalization) using double quotes +/// // (programatically it is better to use `TableReference::bare` for this) +/// let table_reference = TableReference::from(r#""MyTable""#); +/// assert_eq!(table_reference, TableReference::bare("MyTable")); +/// +/// // Get a table reference to 'myschema.mytable' (note the capitalization) +/// let table_reference = TableReference::from("MySchema.MyTable"); +/// assert_eq!(table_reference, TableReference::partial("myschema", "mytable")); +///``` +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] pub enum TableReference<'a> { /// An unqualified table reference, e.g. "table" Bare { @@ -67,53 +92,76 @@ pub enum TableReference<'a> { }, } -/// Represents a path to a table that may require further resolution -/// that owns the underlying names -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub enum OwnedTableReference { - /// An unqualified table reference, e.g. "table" - Bare { - /// The table name - table: String, - }, - /// A partially resolved table reference, e.g. "schema.table" - Partial { - /// The schema containing the table - schema: String, - /// The table name - table: String, - }, - /// A fully resolved table reference, e.g. "catalog.schema.table" - Full { - /// The catalog (aka database) containing the table - catalog: String, - /// The schema containing the table - schema: String, - /// The table name - table: String, - }, -} +/// This is a [`TableReference`] that has 'static lifetime (aka it +/// owns the underlying string) +/// +/// To convert a [`TableReference`] to an [`OwnedTableReference`], use +/// +/// ``` +/// # use datafusion_common::{OwnedTableReference, TableReference}; +/// let table_reference = TableReference::from("mytable"); +/// let owned_reference = table_reference.to_owned_reference(); +/// ``` +pub type OwnedTableReference = TableReference<'static>; -impl OwnedTableReference { - /// Return a `TableReference` view of this `OwnedTableReference` - pub fn as_table_reference(&self) -> TableReference<'_> { +impl std::fmt::Display for TableReference<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Self::Bare { table } => TableReference::Bare { - table: table.into(), - }, - Self::Partial { schema, table } => TableReference::Partial { - schema: schema.into(), - table: table.into(), - }, - Self::Full { + TableReference::Bare { table } => write!(f, "{table}"), + TableReference::Partial { schema, table } => { + write!(f, "{schema}.{table}") + } + TableReference::Full { catalog, schema, table, - } => TableReference::Full { - catalog: catalog.into(), - schema: schema.into(), - table: table.into(), - }, + } => write!(f, "{catalog}.{schema}.{table}"), + } + } +} + +impl<'a> TableReference<'a> { + /// Convenience method for creating a typed none `None` + pub fn none() -> Option> { + None + } + + /// Convenience method for creating a [`TableReference::Bare`] + /// + /// As described on [`TableReference`] this does *NO* parsing at + /// all, so "Foo.Bar" stays as a reference to the table named + /// "Foo.Bar" (rather than "foo"."bar") + pub fn bare(table: impl Into>) -> TableReference<'a> { + TableReference::Bare { + table: table.into(), + } + } + + /// Convenience method for creating a [`TableReference::Partial`]. + /// + /// As described on [`TableReference`] this does *NO* parsing at all. + pub fn partial( + schema: impl Into>, + table: impl Into>, + ) -> TableReference<'a> { + TableReference::Partial { + schema: schema.into(), + table: table.into(), + } + } + + /// Convenience method for creating a [`TableReference::Full`] + /// + /// As described on [`TableReference`] this does *NO* parsing at all. + pub fn full( + catalog: impl Into>, + schema: impl Into>, + table: impl Into>, + ) -> TableReference<'a> { + TableReference::Full { + catalog: catalog.into(), + schema: schema.into(), + table: table.into(), } } @@ -125,39 +173,44 @@ impl OwnedTableReference { | Self::Bare { table } => table, } } -} -impl std::fmt::Display for OwnedTableReference { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + /// Retrieve the schema name if in the `Partial` or `Full` qualification + pub fn schema(&self) -> Option<&str> { match self { - OwnedTableReference::Bare { table } => write!(f, "{table}"), - OwnedTableReference::Partial { schema, table } => { - write!(f, "{schema}.{table}") - } - OwnedTableReference::Full { - catalog, - schema, - table, - } => write!(f, "{catalog}.{schema}.{table}"), + Self::Full { schema, .. } | Self::Partial { schema, .. } => Some(schema), + _ => None, } } -} -/// Convert `OwnedTableReference` into a `TableReference`. Somewhat -/// awkward to use but 'idiomatic': `(&table_ref).into()` -impl<'a> From<&'a OwnedTableReference> for TableReference<'a> { - fn from(r: &'a OwnedTableReference) -> Self { - r.as_table_reference() + /// Retrieve the catalog name if in the `Full` qualification + pub fn catalog(&self) -> Option<&str> { + match self { + Self::Full { catalog, .. } => Some(catalog), + _ => None, + } } -} -impl<'a> TableReference<'a> { - /// Retrieve the actual table name, regardless of qualification - pub fn table(&self) -> &str { + /// Compare with another [`TableReference`] as if both are resolved. + /// This allows comparing across variants, where if a field is not present + /// in both variants being compared then it is ignored in the comparison. + /// + /// e.g. this allows a [`TableReference::Bare`] to be considered equal to a + /// fully qualified [`TableReference::Full`] if the table names match. + pub fn resolved_eq(&self, other: &Self) -> bool { match self { - Self::Full { table, .. } - | Self::Partial { table, .. } - | Self::Bare { table } => table, + TableReference::Bare { table } => table == other.table(), + TableReference::Partial { schema, table } => { + table == other.table() && other.schema().map_or(true, |s| s == schema) + } + TableReference::Full { + catalog, + schema, + table, + } => { + table == other.table() + && other.schema().map_or(true, |s| s == schema) + && other.catalog().map_or(true, |c| c == catalog) + } } } @@ -190,23 +243,63 @@ impl<'a> TableReference<'a> { } } - /// Forms a [`TableReference`] by attempting to parse `s` as a multipart identifier, - /// failing that then taking the entire unnormalized input as the identifier itself. + /// Converts directly into an [`OwnedTableReference`] by cloning + /// the underlying data. + pub fn to_owned_reference(&self) -> OwnedTableReference { + match self { + Self::Full { + catalog, + schema, + table, + } => OwnedTableReference::Full { + catalog: catalog.to_string().into(), + schema: schema.to_string().into(), + table: table.to_string().into(), + }, + Self::Partial { schema, table } => OwnedTableReference::Partial { + schema: schema.to_string().into(), + table: table.to_string().into(), + }, + Self::Bare { table } => OwnedTableReference::Bare { + table: table.to_string().into(), + }, + } + } + + /// Forms a string where the identifiers are quoted /// - /// Will normalize (convert to lowercase) any unquoted identifiers. + /// # Example + /// ``` + /// # use datafusion_common::TableReference; + /// let table_reference = TableReference::partial("myschema", "mytable"); + /// assert_eq!(table_reference.to_quoted_string(), r#""myschema"."mytable""#); /// - /// e.g. `Foo` will be parsed as `foo`, and `"Foo"".bar"` will be parsed as - /// `Foo".bar` (note the preserved case and requiring two double quotes to represent - /// a single double quote in the identifier) + /// let table_reference = TableReference::partial("MySchema", "MyTable"); + /// assert_eq!(table_reference.to_quoted_string(), r#""MySchema"."MyTable""#); + /// ``` + pub fn to_quoted_string(&self) -> String { + match self { + TableReference::Bare { table } => quote_identifier(table), + TableReference::Partial { schema, table } => { + format!("{}.{}", quote_identifier(schema), quote_identifier(table)) + } + TableReference::Full { + catalog, + schema, + table, + } => format!( + "{}.{}.{}", + quote_identifier(catalog), + quote_identifier(schema), + quote_identifier(table) + ), + } + } + + /// Forms a [`TableReference`] by parsing `s` as a multipart SQL + /// identifier. See docs on [`TableReference`] for more details. pub fn parse_str(s: &'a str) -> Self { - let mut parts = parse_identifiers(s) - .unwrap_or_default() - .into_iter() - .map(|id| match id.quote_style { - Some(_) => id.value, - None => id.value.to_ascii_lowercase(), - }) - .collect::>(); + let mut parts = parse_identifiers_normalized(s); match parts.len() { 1 => Self::Bare { @@ -226,57 +319,34 @@ impl<'a> TableReference<'a> { } } -// TODO: remove when can use https://github.com/sqlparser-rs/sqlparser-rs/issues/805 -fn parse_identifiers(s: &str) -> Result> { - let dialect = GenericDialect; - let mut parser = Parser::new(&dialect).try_with_sql(s)?; - let mut idents = vec![]; - - // expecting at least one word for identifier - match parser.next_token_no_skip() { - Some(TokenWithLocation { - token: Token::Word(w), - .. - }) => idents.push(w.to_ident()), - Some(TokenWithLocation { token, .. }) => { - return Err(ParserError::ParserError(format!( - "Unexpected token in identifier: {token}" - )))? - } - None => { - return Err(ParserError::ParserError( - "Empty input when parsing identifier".to_string(), - ))? - } - }; +/// Parse a `String` into a OwnedTableReference as a multipart SQL identifier. +impl From for OwnedTableReference { + fn from(s: String) -> Self { + TableReference::parse_str(&s).to_owned_reference() + } +} - while let Some(TokenWithLocation { token, .. }) = parser.next_token_no_skip() { - match token { - // ensure that optional period is succeeded by another identifier - Token::Period => match parser.next_token_no_skip() { - Some(TokenWithLocation { - token: Token::Word(w), - .. - }) => idents.push(w.to_ident()), - Some(TokenWithLocation { token, .. }) => { - return Err(ParserError::ParserError(format!( - "Unexpected token following period in identifier: {token}" - )))? - } - None => { - return Err(ParserError::ParserError( - "Trailing period in identifier".to_string(), - ))? - } +impl<'a> From<&'a OwnedTableReference> for TableReference<'a> { + fn from(value: &'a OwnedTableReference) -> Self { + match value { + OwnedTableReference::Bare { table } => TableReference::Bare { + table: Cow::Borrowed(table), + }, + OwnedTableReference::Partial { schema, table } => TableReference::Partial { + schema: Cow::Borrowed(schema), + table: Cow::Borrowed(table), + }, + OwnedTableReference::Full { + catalog, + schema, + table, + } => TableReference::Full { + catalog: Cow::Borrowed(catalog), + schema: Cow::Borrowed(schema), + table: Cow::Borrowed(table), }, - _ => { - return Err(ParserError::ParserError(format!( - "Unexpected token in identifier: {token}" - )))? - } } } - Ok(idents) } /// Parse a string into a TableReference, normalizing where appropriate @@ -288,6 +358,12 @@ impl<'a> From<&'a str> for TableReference<'a> { } } +impl<'a> From<&'a String> for TableReference<'a> { + fn from(s: &'a String) -> Self { + Self::parse_str(s) + } +} + impl<'a> From> for TableReference<'a> { fn from(resolved: ResolvedTableReference<'a>) -> Self { Self::Full { @@ -302,64 +378,6 @@ impl<'a> From> for TableReference<'a> { mod tests { use super::*; - #[test] - fn test_parse_identifiers() -> Result<()> { - let s = "CATALOG.\"F(o)o. \"\"bar\".table"; - let actual = parse_identifiers(s)?; - let expected = vec![ - Ident { - value: "CATALOG".to_string(), - quote_style: None, - }, - Ident { - value: "F(o)o. \"bar".to_string(), - quote_style: Some('"'), - }, - Ident { - value: "table".to_string(), - quote_style: None, - }, - ]; - assert_eq!(expected, actual); - - let s = ""; - let err = parse_identifiers(s).expect_err("didn't fail to parse"); - assert_eq!( - "SQL(ParserError(\"Empty input when parsing identifier\"))", - format!("{err:?}") - ); - - let s = "*schema.table"; - let err = parse_identifiers(s).expect_err("didn't fail to parse"); - assert_eq!( - "SQL(ParserError(\"Unexpected token in identifier: *\"))", - format!("{err:?}") - ); - - let s = "schema.table*"; - let err = parse_identifiers(s).expect_err("didn't fail to parse"); - assert_eq!( - "SQL(ParserError(\"Unexpected token in identifier: *\"))", - format!("{err:?}") - ); - - let s = "schema.table."; - let err = parse_identifiers(s).expect_err("didn't fail to parse"); - assert_eq!( - "SQL(ParserError(\"Trailing period in identifier\"))", - format!("{err:?}") - ); - - let s = "schema.*"; - let err = parse_identifiers(s).expect_err("didn't fail to parse"); - assert_eq!( - "SQL(ParserError(\"Unexpected token following period in identifier: *\"))", - format!("{err:?}") - ); - - Ok(()) - } - #[test] fn test_table_reference_from_str_normalizes() { let expected = TableReference::Full { diff --git a/datafusion/common/src/utils.rs b/datafusion/common/src/utils.rs index 3c073015343c..a1226def8e56 100644 --- a/datafusion/common/src/utils.rs +++ b/datafusion/common/src/utils.rs @@ -20,6 +20,10 @@ use crate::{DataFusionError, Result, ScalarValue}; use arrow::array::ArrayRef; use arrow::compute::SortOptions; +use sqlparser::ast::Ident; +use sqlparser::dialect::GenericDialect; +use sqlparser::parser::{Parser, ParserError}; +use sqlparser::tokenizer::{Token, TokenWithLocation}; use std::cmp::Ordering; /// Given column vectors, returns row at `idx`. @@ -158,6 +162,78 @@ where Ok(low) } +/// Wraps identifier string in double quotes, escaping any double quotes in +/// the identifier by replacing it with two double quotes +/// +/// e.g. identifier `tab.le"name` becomes `"tab.le""name"` +pub fn quote_identifier(s: &str) -> String { + format!("\"{}\"", s.replace('"', "\"\"")) +} + +// TODO: remove when can use https://github.com/sqlparser-rs/sqlparser-rs/issues/805 +pub(crate) fn parse_identifiers(s: &str) -> Result> { + let dialect = GenericDialect; + let mut parser = Parser::new(&dialect).try_with_sql(s)?; + let mut idents = vec![]; + + // expecting at least one word for identifier + match parser.next_token_no_skip() { + Some(TokenWithLocation { + token: Token::Word(w), + .. + }) => idents.push(w.to_ident()), + Some(TokenWithLocation { token, .. }) => { + return Err(ParserError::ParserError(format!( + "Unexpected token in identifier: {token}" + )))? + } + None => { + return Err(ParserError::ParserError( + "Empty input when parsing identifier".to_string(), + ))? + } + }; + + while let Some(TokenWithLocation { token, .. }) = parser.next_token_no_skip() { + match token { + // ensure that optional period is succeeded by another identifier + Token::Period => match parser.next_token_no_skip() { + Some(TokenWithLocation { + token: Token::Word(w), + .. + }) => idents.push(w.to_ident()), + Some(TokenWithLocation { token, .. }) => { + return Err(ParserError::ParserError(format!( + "Unexpected token following period in identifier: {token}" + )))? + } + None => { + return Err(ParserError::ParserError( + "Trailing period in identifier".to_string(), + ))? + } + }, + _ => { + return Err(ParserError::ParserError(format!( + "Unexpected token in identifier: {token}" + )))? + } + } + } + Ok(idents) +} + +pub(crate) fn parse_identifiers_normalized(s: &str) -> Vec { + parse_identifiers(s) + .unwrap_or_default() + .into_iter() + .map(|id| match id.quote_style { + Some(_) => id.value, + None => id.value.to_ascii_lowercase(), + }) + .collect::>() +} + #[cfg(test)] mod tests { use arrow::array::Float64Array; @@ -330,4 +406,62 @@ mod tests { assert_eq!(res, 2); Ok(()) } + + #[test] + fn test_parse_identifiers() -> Result<()> { + let s = "CATALOG.\"F(o)o. \"\"bar\".table"; + let actual = parse_identifiers(s)?; + let expected = vec![ + Ident { + value: "CATALOG".to_string(), + quote_style: None, + }, + Ident { + value: "F(o)o. \"bar".to_string(), + quote_style: Some('"'), + }, + Ident { + value: "table".to_string(), + quote_style: None, + }, + ]; + assert_eq!(expected, actual); + + let s = ""; + let err = parse_identifiers(s).expect_err("didn't fail to parse"); + assert_eq!( + "SQL(ParserError(\"Empty input when parsing identifier\"))", + format!("{err:?}") + ); + + let s = "*schema.table"; + let err = parse_identifiers(s).expect_err("didn't fail to parse"); + assert_eq!( + "SQL(ParserError(\"Unexpected token in identifier: *\"))", + format!("{err:?}") + ); + + let s = "schema.table*"; + let err = parse_identifiers(s).expect_err("didn't fail to parse"); + assert_eq!( + "SQL(ParserError(\"Unexpected token in identifier: *\"))", + format!("{err:?}") + ); + + let s = "schema.table."; + let err = parse_identifiers(s).expect_err("didn't fail to parse"); + assert_eq!( + "SQL(ParserError(\"Trailing period in identifier\"))", + format!("{err:?}") + ); + + let s = "schema.*"; + let err = parse_identifiers(s).expect_err("didn't fail to parse"); + assert_eq!( + "SQL(ParserError(\"Unexpected token following period in identifier: *\"))", + format!("{err:?}") + ); + + Ok(()) + } } diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index e65cf99a4293..f0544db1d934 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -115,7 +115,7 @@ env_logger = "0.10" half = "2.2.1" postgres-protocol = "0.6.4" postgres-types = { version = "0.2.4", features = ["derive", "with-chrono-0_4"] } -rstest = "0.16.0" +rstest = "0.17.0" rust_decimal = { version = "1.27.0", features = ["tokio-pg"] } sqllogictest = "0.13.0" test-utils = { path = "../../test-utils" } diff --git a/datafusion/core/src/avro_to_arrow/arrow_array_reader.rs b/datafusion/core/src/avro_to_arrow/arrow_array_reader.rs index 1f06078e4627..313c2a1596ea 100644 --- a/datafusion/core/src/avro_to_arrow/arrow_array_reader.rs +++ b/datafusion/core/src/avro_to_arrow/arrow_array_reader.rs @@ -435,7 +435,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { }); let valid_len = cur_offset.to_usize().unwrap(); let array_data = match list_field.data_type() { - DataType::Null => NullArray::new(valid_len).data().clone(), + DataType::Null => NullArray::new(valid_len).into_data(), DataType::Boolean => { let num_bytes = bit_util::ceil(valid_len, 8); let mut bool_values = MutableBuffer::from_len_zeroed(num_bytes); @@ -496,13 +496,11 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { DataType::Utf8 => flatten_string_values(rows) .into_iter() .collect::() - .data() - .clone(), + .into_data(), DataType::LargeUtf8 => flatten_string_values(rows) .into_iter() .collect::() - .data() - .clone(), + .into_data(), DataType::List(field) => { let child = self.build_nested_list_array::(&flatten_values(rows), field)?; diff --git a/datafusion/core/src/catalog/information_schema.rs b/datafusion/core/src/catalog/information_schema.rs index 7119558e8763..7759011b54f4 100644 --- a/datafusion/core/src/catalog/information_schema.rs +++ b/datafusion/core/src/catalog/information_schema.rs @@ -644,10 +644,7 @@ impl PartitionStream for InformationSchemaDfSettings { // TODO: Stream this futures::stream::once(async move { // create a mem table with the names of tables - config.make_df_settings( - ctx.session_config().config_options(), - &mut builder, - ); + config.make_df_settings(ctx.session_config().options(), &mut builder); Ok(builder.finish()) }), )) diff --git a/datafusion/core/src/catalog/listing_schema.rs b/datafusion/core/src/catalog/listing_schema.rs index 32ee9f62ac3d..b39c3693c706 100644 --- a/datafusion/core/src/catalog/listing_schema.rs +++ b/datafusion/core/src/catalog/listing_schema.rs @@ -128,9 +128,7 @@ impl ListingSchemaProvider { if !self.table_exist(table_name) { let table_url = format!("{}/{}", self.authority, table_path); - let name = OwnedTableReference::Bare { - table: table_name.to_string(), - }; + let name = OwnedTableReference::bare(table_name.to_string()); let provider = self .factory .create( @@ -146,6 +144,7 @@ impl ListingSchemaProvider { if_not_exists: false, definition: None, file_compression_type: CompressionTypeVariant::UNCOMPRESSED, + order_exprs: vec![], options: Default::default(), }, ) diff --git a/datafusion/core/src/dataframe.rs b/datafusion/core/src/dataframe.rs index 7cdbaa43ca59..e3e2987d3dfe 100644 --- a/datafusion/core/src/dataframe.rs +++ b/datafusion/core/src/dataframe.rs @@ -30,8 +30,8 @@ use parquet::file::properties::WriterProperties; use datafusion_common::from_slice::FromSlice; use datafusion_common::{Column, DFSchema, ScalarValue}; use datafusion_expr::{ - avg, count, is_null, max, median, min, stddev, TableProviderFilterPushDown, - UNNAMED_TABLE, + avg, count, is_null, max, median, min, stddev, utils::COUNT_STAR_EXPANSION, + TableProviderFilterPushDown, UNNAMED_TABLE, }; use crate::arrow::datatypes::Schema; @@ -332,142 +332,128 @@ impl DataFrame { let original_schema_fields = self.schema().fields().iter(); //define describe column - let mut describe_schemas = original_schema_fields - .clone() - .map(|field| { - if field.data_type().is_numeric() { - Field::new(field.name(), DataType::Float64, true) - } else { - Field::new(field.name(), DataType::Utf8, true) - } - }) - .collect::>(); - describe_schemas.insert(0, Field::new("describe", DataType::Utf8, false)); + let mut describe_schemas = vec![Field::new("describe", DataType::Utf8, false)]; + describe_schemas.extend(original_schema_fields.clone().map(|field| { + if field.data_type().is_numeric() { + Field::new(field.name(), DataType::Float64, true) + } else { + Field::new(field.name(), DataType::Utf8, true) + } + })); - //count aggregation - let cnt = self.clone().aggregate( - vec![], - original_schema_fields - .clone() - .map(|f| count(col(f.name()))) - .collect::>(), - )?; - // The optimization of AggregateStatistics will rewrite the physical plan - // for the count function and ignore alias functions, - // as shown in https://github.com/apache/arrow-datafusion/issues/5444. - // This logic should be removed when #5444 is fixed. - let cnt = cnt.clone().select( - cnt.schema() - .fields() - .iter() - .zip(original_schema_fields.clone()) - .map(|(count_field, orgin_field)| { - col(count_field.name()).alias(orgin_field.name()) - }) - .collect::>(), - )?; - //should be removed when #5444 is fixed //collect recordBatch let describe_record_batch = vec![ // count aggregation - cnt.collect().await?, + self.clone().aggregate( + vec![], + original_schema_fields + .clone() + .map(|f| count(col(f.name())).alias(f.name())) + .collect::>(), + ), // null_count aggregation - self.clone() - .aggregate( - vec![], - original_schema_fields - .clone() - .map(|f| count(is_null(col(f.name()))).alias(f.name())) - .collect::>(), - )? - .collect() - .await?, + self.clone().aggregate( + vec![], + original_schema_fields + .clone() + .map(|f| count(is_null(col(f.name()))).alias(f.name())) + .collect::>(), + ), // mean aggregation - self.clone() - .aggregate( - vec![], - original_schema_fields - .clone() - .filter(|f| f.data_type().is_numeric()) - .map(|f| avg(col(f.name())).alias(f.name())) - .collect::>(), - )? - .collect() - .await?, + self.clone().aggregate( + vec![], + original_schema_fields + .clone() + .filter(|f| f.data_type().is_numeric()) + .map(|f| avg(col(f.name())).alias(f.name())) + .collect::>(), + ), // std aggregation - self.clone() - .aggregate( - vec![], - original_schema_fields - .clone() - .filter(|f| f.data_type().is_numeric()) - .map(|f| stddev(col(f.name())).alias(f.name())) - .collect::>(), - )? - .collect() - .await?, + self.clone().aggregate( + vec![], + original_schema_fields + .clone() + .filter(|f| f.data_type().is_numeric()) + .map(|f| stddev(col(f.name())).alias(f.name())) + .collect::>(), + ), // min aggregation - self.clone() - .aggregate( - vec![], - original_schema_fields - .clone() - .filter(|f| { - !matches!(f.data_type(), DataType::Binary | DataType::Boolean) - }) - .map(|f| min(col(f.name())).alias(f.name())) - .collect::>(), - )? - .collect() - .await?, + self.clone().aggregate( + vec![], + original_schema_fields + .clone() + .filter(|f| { + !matches!(f.data_type(), DataType::Binary | DataType::Boolean) + }) + .map(|f| min(col(f.name())).alias(f.name())) + .collect::>(), + ), // max aggregation - self.clone() - .aggregate( - vec![], - original_schema_fields - .clone() - .filter(|f| { - !matches!(f.data_type(), DataType::Binary | DataType::Boolean) - }) - .map(|f| max(col(f.name())).alias(f.name())) - .collect::>(), - )? - .collect() - .await?, + self.clone().aggregate( + vec![], + original_schema_fields + .clone() + .filter(|f| { + !matches!(f.data_type(), DataType::Binary | DataType::Boolean) + }) + .map(|f| max(col(f.name())).alias(f.name())) + .collect::>(), + ), // median aggregation - self.clone() - .aggregate( - vec![], - original_schema_fields - .clone() - .filter(|f| f.data_type().is_numeric()) - .map(|f| median(col(f.name())).alias(f.name())) - .collect::>(), - )? - .collect() - .await?, + self.clone().aggregate( + vec![], + original_schema_fields + .clone() + .filter(|f| f.data_type().is_numeric()) + .map(|f| median(col(f.name())).alias(f.name())) + .collect::>(), + ), ]; - let mut array_ref_vec: Vec = vec![]; + // first column with function names + let mut array_ref_vec: Vec = vec![Arc::new(StringArray::from_slice( + supported_describe_functions.clone(), + ))]; for field in original_schema_fields { let mut array_datas = vec![]; - for record_batch in describe_record_batch.iter() { - let column = record_batch.get(0).unwrap().column_by_name(field.name()); - match column { - Some(c) => { - if field.data_type().is_numeric() { - array_datas.push(cast(c, &DataType::Float64)?); - } else { - array_datas.push(cast(c, &DataType::Utf8)?); + for result in describe_record_batch.iter() { + let array_ref = match result { + Ok(df) => { + let batchs = df.clone().collect().await; + match batchs { + Ok(batchs) + if batchs.len() == 1 + && batchs[0] + .column_by_name(field.name()) + .is_some() => + { + let column = + batchs[0].column_by_name(field.name()).unwrap(); + if field.data_type().is_numeric() { + cast(column, &DataType::Float64)? + } else { + cast(column, &DataType::Utf8)? + } + } + _ => Arc::new(StringArray::from_slice(["null"])), } } - //if None mean the column cannot be min/max aggregation - None => { - array_datas.push(Arc::new(StringArray::from_slice(["null"]))); + //Handling error when only boolean/binary column, and in other cases + Err(err) + if err.to_string().contains( + "Error during planning: \ + Aggregate requires at least one grouping \ + or aggregate expression", + ) => + { + Arc::new(StringArray::from_slice(["null"])) } - } + Err(other_err) => { + panic!("{other_err}") + } + }; + array_datas.push(array_ref); } - array_ref_vec.push(concat( array_datas .iter() @@ -477,14 +463,6 @@ impl DataFrame { )?); } - //insert first column with function names - array_ref_vec.insert( - 0, - Arc::new(StringArray::from_slice( - supported_describe_functions.clone(), - )), - ); - let describe_record_batch = RecordBatch::try_new(Arc::new(Schema::new(describe_schemas)), array_ref_vec)?; @@ -651,7 +629,7 @@ impl DataFrame { let rows = self .aggregate( vec![], - vec![datafusion_expr::count(Expr::Literal(ScalarValue::Null))], + vec![datafusion_expr::count(Expr::Literal(COUNT_STAR_EXPANSION))], )? .collect() .await?; @@ -1393,7 +1371,7 @@ mod tests { let join = left .join_on(right, JoinType::Inner, [col("c1").eq(col("c1"))]) .expect_err("join didn't fail check"); - let expected = "Schema error: Ambiguous reference to unqualified field 'c1'"; + let expected = "Schema error: Ambiguous reference to unqualified field \"c1\""; assert_eq!(join.to_string(), expected); Ok(()) diff --git a/datafusion/core/src/datasource/datasource.rs b/datafusion/core/src/datasource/datasource.rs index 6277ce146adf..8db075a30a79 100644 --- a/datafusion/core/src/datasource/datasource.rs +++ b/datafusion/core/src/datasource/datasource.rs @@ -21,7 +21,7 @@ use std::any::Any; use std::sync::Arc; use async_trait::async_trait; -use datafusion_common::Statistics; +use datafusion_common::{DataFusionError, Statistics}; use datafusion_expr::{CreateExternalTable, LogicalPlan}; pub use datafusion_expr::{TableProviderFilterPushDown, TableType}; @@ -97,6 +97,16 @@ pub trait TableProvider: Sync + Send { fn statistics(&self) -> Option { None } + + /// Insert into this table + async fn insert_into( + &self, + _state: &SessionState, + _input: &LogicalPlan, + ) -> Result<()> { + let msg = "Insertion not implemented for this table".to_owned(); + Err(DataFusionError::NotImplemented(msg)) + } } /// A factory which creates [`TableProvider`]s at runtime given a URL. diff --git a/datafusion/core/src/datasource/file_format/file_type.rs b/datafusion/core/src/datasource/file_format/file_type.rs index 59c95962a992..e07eb8a3d7a6 100644 --- a/datafusion/core/src/datasource/file_format/file_type.rs +++ b/datafusion/core/src/datasource/file_format/file_type.rs @@ -30,10 +30,10 @@ use async_compression::tokio::bufread::{ }; use bytes::Bytes; #[cfg(feature = "compression")] -use bzip2::read::BzDecoder; +use bzip2::read::MultiBzDecoder; use datafusion_common::parsers::CompressionTypeVariant; #[cfg(feature = "compression")] -use flate2::read::GzDecoder; +use flate2::read::MultiGzDecoder; use futures::Stream; #[cfg(feature = "compression")] use futures::TryStreamExt; @@ -168,11 +168,11 @@ impl FileCompressionType { ) -> Result> { Ok(match self.variant { #[cfg(feature = "compression")] - GZIP => Box::new(GzDecoder::new(r)), + GZIP => Box::new(MultiGzDecoder::new(r)), #[cfg(feature = "compression")] - BZIP2 => Box::new(BzDecoder::new(r)), + BZIP2 => Box::new(MultiBzDecoder::new(r)), #[cfg(feature = "compression")] - XZ => Box::new(XzDecoder::new(r)), + XZ => Box::new(XzDecoder::new_multi_decoder(r)), #[cfg(feature = "compression")] ZSTD => match ZstdDecoder::new(r) { Ok(decoder) => Box::new(decoder), diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 53e94167d845..ba18e9f62c25 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -620,6 +620,7 @@ mod tests { use super::*; use crate::datasource::file_format::parquet::test_util::store_parquet; + use crate::physical_plan::file_format::get_scan_files; use crate::physical_plan::metrics::MetricValue; use crate::prelude::{SessionConfig, SessionContext}; use arrow::array::{Array, ArrayRef, StringArray}; @@ -1215,6 +1216,25 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_get_scan_files() -> Result<()> { + let session_ctx = SessionContext::new(); + let state = session_ctx.state(); + let projection = Some(vec![9]); + let exec = get_exec(&state, "alltypes_plain.parquet", projection, None).await?; + let scan_files = get_scan_files(exec)?; + assert_eq!(scan_files.len(), 1); + assert_eq!(scan_files[0].len(), 1); + assert_eq!(scan_files[0][0].len(), 1); + assert!(scan_files[0][0][0] + .object_meta + .location + .to_string() + .contains("alltypes_plain.parquet")); + + Ok(()) + } + fn check_page_index_validation( page_index: Option<&ParquetColumnIndex>, offset_index: Option<&ParquetOffsetIndex>, diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index fa7f7070c3bc..7ed6326a907a 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -81,6 +81,7 @@ impl ExpressionVisitor for ApplicabilityVisitor<'_> { } Expr::Literal(_) | Expr::Alias(_, _) + | Expr::OuterReferenceColumn(_, _) | Expr::ScalarVariable(_, _) | Expr::Not(_) | Expr::IsNotNull(_) diff --git a/datafusion/core/src/datasource/listing/mod.rs b/datafusion/core/src/datasource/listing/mod.rs index d5374d1edcea..1e6b40a3485a 100644 --- a/datafusion/core/src/datasource/listing/mod.rs +++ b/datafusion/core/src/datasource/listing/mod.rs @@ -55,7 +55,16 @@ pub struct FileRange { pub struct PartitionedFile { /// Path for the file (e.g. URL, filesystem path, etc) pub object_meta: ObjectMeta, - /// Values of partition columns to be appended to each row + /// Values of partition columns to be appended to each row. + /// + /// These MUST have the same count, order, and type than the [`table_partition_cols`]. + /// + /// You may use [`wrap_partition_value_in_dict`] to wrap them if you have used [`wrap_partition_type_in_dict`] to wrap the column type. + /// + /// + /// [`wrap_partition_type_in_dict`]: crate::physical_plan::file_format::wrap_partition_type_in_dict + /// [`wrap_partition_value_in_dict`]: crate::physical_plan::file_format::wrap_partition_value_in_dict + /// [`table_partition_cols`]: table::ListingOptions::table_partition_cols pub partition_values: Vec, /// An optional file range for a more fine-grained parallel execution pub range: Option, diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index f6d9c959eb1a..f85492d8c2ed 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -44,7 +44,6 @@ use crate::datasource::{ }; use crate::logical_expr::TableProviderFilterPushDown; use crate::physical_plan; -use crate::physical_plan::file_format::partition_type_wrap; use crate::{ error::{DataFusionError, Result}, execution::context::SessionState, @@ -213,10 +212,7 @@ pub struct ListingOptions { /// The file format pub format: Arc, /// The expected partition column names in the folder structure. - /// For example `Vec["a", "b"]` means that the two first levels of - /// partitioning expected should be named "a" and "b": - /// - If there is a third level of partitioning it will be ignored. - /// - Files that don't follow this partitioning will be ignored. + /// See [Self::with_table_partition_cols] for details pub table_partition_cols: Vec<(String, DataType)>, /// Set true to try to guess statistics from the files. /// This can add a lot of overhead as it will usually require files @@ -298,7 +294,45 @@ impl ListingOptions { self } - /// Set table partition column names on [`ListingOptions`] and returns self. + /// Set `table partition columns` on [`ListingOptions`] and returns self. + /// + /// "partition columns," used to support [Hive Partitioning], are + /// columns added to the data that is read, based on the folder + /// structure where the data resides. + /// + /// For example, give the following files in your filesystem: + /// + /// ```text + /// /mnt/nyctaxi/year=2022/month=01/tripdata.parquet + /// /mnt/nyctaxi/year=2021/month=12/tripdata.parquet + /// /mnt/nyctaxi/year=2021/month=11/tripdata.parquet + /// ``` + /// + /// A [`ListingTable`] created at `/mnt/nyctaxi/` with partition + /// columns "year" and "month" will include new `year` and `month` + /// columns while reading the files. The `year` column would have + /// value `2022` and the `month` column would have value `01` for + /// the rows read from + /// `/mnt/nyctaxi/year=2022/month=01/tripdata.parquet` + /// + ///# Notes + /// + /// - If only one level (e.g. `year` in the example above) is + /// specified, the other levels are ignored but the files are + /// still read. + /// + /// - Files that don't follow this partitioning scheme will be + /// ignored. + /// + /// - Since the columns have the same value for all rows read from + /// each individual file (such as dates), they are typically + /// dictionary encoded for efficiency. You may use + /// [`wrap_partition_type_in_dict`] to request a + /// dictionary-encoded type. + /// + /// - The partition columns are solely extracted from the file path. Especially they are NOT part of the parquet files itself. + /// + /// # Example /// /// ``` /// # use std::sync::Arc; @@ -306,6 +340,8 @@ impl ListingOptions { /// # use datafusion::prelude::col; /// # use datafusion::datasource::{listing::ListingOptions, file_format::parquet::ParquetFormat}; /// + /// // listing options for files with paths such as `/mnt/data/col_a=x/col_b=y/data.parquet` + /// // `col_a` and `col_b` will be included in the data read from those files /// let listing_options = ListingOptions::new(Arc::new( /// ParquetFormat::default() /// )) @@ -315,6 +351,9 @@ impl ListingOptions { /// assert_eq!(listing_options.table_partition_cols, vec![("col_a".to_string(), DataType::Utf8), /// ("col_b".to_string(), DataType::Utf8)]); /// ``` + /// + /// [Hive Partitioning]: https://docs.cloudera.com/HDPDocuments/HDP2/HDP-2.1.3/bk_system-admin-guide/content/hive_partitioned_tables.html + /// [`wrap_partition_type_in_dict`]: crate::physical_plan::file_format::wrap_partition_type_in_dict pub fn with_table_partition_cols( mut self, table_partition_cols: Vec<(String, DataType)>, @@ -538,11 +577,7 @@ impl ListingTable { // Add the partition columns to the file schema let mut table_fields = file_schema.fields().clone(); for (part_col_name, part_col_type) in &options.table_partition_cols { - table_fields.push(Field::new( - part_col_name, - partition_type_wrap(part_col_type.clone()), - false, - )); + table_fields.push(Field::new(part_col_name, part_col_type.clone(), false)); } let infinite_source = options.infinite_source; @@ -1012,10 +1047,7 @@ mod tests { let opt = ListingOptions::new(Arc::new(AvroFormat {})) .with_file_extension(FileType::AVRO.get_ext()) - .with_table_partition_cols(vec![( - String::from("p1"), - partition_type_wrap(DataType::Utf8), - )]) + .with_table_partition_cols(vec![(String::from("p1"), DataType::Utf8)]) .with_target_partitions(4); let table_path = ListingTableUrl::parse("test:///table/").unwrap(); diff --git a/datafusion/core/src/datasource/listing_table_factory.rs b/datafusion/core/src/datasource/listing_table_factory.rs index fe4393cb29a5..6c7b058520d8 100644 --- a/datafusion/core/src/datasource/listing_table_factory.rs +++ b/datafusion/core/src/datasource/listing_table_factory.rs @@ -86,7 +86,15 @@ impl TableProviderFactory for ListingTableFactory { None, cmd.table_partition_cols .iter() - .map(|x| (x.clone(), DataType::Utf8)) + .map(|x| { + ( + x.clone(), + DataType::Dictionary( + Box::new(DataType::UInt16), + Box::new(DataType::Utf8), + ), + ) + }) .collect::>(), ) } else { @@ -116,12 +124,18 @@ impl TableProviderFactory for ListingTableFactory { (Some(schema), table_partition_cols) }; + let file_sort_order = if cmd.order_exprs.is_empty() { + None + } else { + Some(cmd.order_exprs.clone()) + }; + let options = ListingOptions::new(file_format) .with_collect_stat(state.config().collect_statistics()) .with_file_extension(file_extension) .with_target_partitions(state.config().target_partitions()) .with_table_partition_cols(table_partition_cols) - .with_file_sort_order(None); + .with_file_sort_order(file_sort_order); let table_path = ListingTableUrl::parse(&cmd.location)?; let resolved_schema = match provided_schema { diff --git a/datafusion/core/src/datasource/memory.rs b/datafusion/core/src/datasource/memory.rs index ac1f4947f87d..b5fa33e38827 100644 --- a/datafusion/core/src/datasource/memory.rs +++ b/datafusion/core/src/datasource/memory.rs @@ -19,18 +19,22 @@ //! queried by DataFusion. This allows data to be pre-loaded into memory and then //! repeatedly queried without incurring additional file I/O overhead. -use futures::StreamExt; +use futures::{StreamExt, TryStreamExt}; use std::any::Any; use std::sync::Arc; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use async_trait::async_trait; +use datafusion_expr::LogicalPlan; +use tokio::sync::RwLock; +use tokio::task; use crate::datasource::{TableProvider, TableType}; use crate::error::{DataFusionError, Result}; use crate::execution::context::SessionState; use crate::logical_expr::Expr; +use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; use crate::physical_plan::common; use crate::physical_plan::common::AbortOnDropSingle; use crate::physical_plan::memory::MemoryExec; @@ -41,7 +45,7 @@ use crate::physical_plan::{repartition::RepartitionExec, Partitioning}; #[derive(Debug)] pub struct MemTable { schema: SchemaRef, - batches: Vec>, + batches: Arc>>>, } impl MemTable { @@ -54,7 +58,7 @@ impl MemTable { { Ok(Self { schema, - batches: partitions, + batches: Arc::new(RwLock::new(partitions)), }) } else { Err(DataFusionError::Plan( @@ -143,22 +147,102 @@ impl TableProvider for MemTable { _filters: &[Expr], _limit: Option, ) -> Result> { + let batches = &self.batches.read().await; Ok(Arc::new(MemoryExec::try_new( - &self.batches.clone(), + batches, self.schema(), projection.cloned(), )?)) } + + /// Inserts the execution results of a given [LogicalPlan] into this [MemTable]. + /// The `LogicalPlan` must have the same schema as this `MemTable`. + /// + /// # Arguments + /// + /// * `state` - The [SessionState] containing the context for executing the plan. + /// * `input` - The [LogicalPlan] to execute and insert. + /// + /// # Returns + /// + /// * A `Result` indicating success or failure. + async fn insert_into(&self, state: &SessionState, input: &LogicalPlan) -> Result<()> { + // Create a physical plan from the logical plan. + let plan = state.create_physical_plan(input).await?; + + // Check that the schema of the plan matches the schema of this table. + if !plan.schema().eq(&self.schema) { + return Err(DataFusionError::Plan( + "Inserting query must have the same schema with the table.".to_string(), + )); + } + + // Get the number of partitions in the plan and the table. + let plan_partition_count = plan.output_partitioning().partition_count(); + let table_partition_count = self.batches.read().await.len(); + + // Adjust the plan as necessary to match the number of partitions in the table. + let plan: Arc = if plan_partition_count + == table_partition_count + || table_partition_count == 0 + { + plan + } else if table_partition_count == 1 { + // If the table has only one partition, coalesce the partitions in the plan. + Arc::new(CoalescePartitionsExec::new(plan)) + } else { + // Otherwise, repartition the plan using a round-robin partitioning scheme. + Arc::new(RepartitionExec::try_new( + plan, + Partitioning::RoundRobinBatch(table_partition_count), + )?) + }; + + // Get the task context from the session state. + let task_ctx = state.task_ctx(); + + // Execute the plan and collect the results into batches. + let mut tasks = vec![]; + for idx in 0..plan.output_partitioning().partition_count() { + let stream = plan.execute(idx, task_ctx.clone())?; + let handle = task::spawn(async move { + stream.try_collect().await.map_err(DataFusionError::from) + }); + tasks.push(AbortOnDropSingle::new(handle)); + } + let results = futures::future::join_all(tasks) + .await + .into_iter() + .map(|result| { + result.map_err(|e| DataFusionError::Execution(format!("{e}")))? + }) + .collect::>>>()?; + + // Write the results into the table. + let mut all_batches = self.batches.write().await; + + if all_batches.is_empty() { + *all_batches = results + } else { + for (batches, result) in all_batches.iter_mut().zip(results.into_iter()) { + batches.extend(result); + } + } + + Ok(()) + } } #[cfg(test)] mod tests { use super::*; + use crate::datasource::provider_as_source; use crate::from_slice::FromSlice; use crate::prelude::SessionContext; use arrow::array::Int32Array; use arrow::datatypes::{DataType, Field, Schema}; use arrow::error::ArrowError; + use datafusion_expr::LogicalPlanBuilder; use futures::StreamExt; use std::collections::HashMap; @@ -388,4 +472,135 @@ mod tests { Ok(()) } + + fn create_mem_table_scan( + schema: SchemaRef, + data: Vec>, + ) -> Result> { + // Convert the table into a provider so that it can be used in a query + let provider = provider_as_source(Arc::new(MemTable::try_new(schema, data)?)); + // Create a table scan logical plan to read from the table + Ok(Arc::new( + LogicalPlanBuilder::scan("source", provider, None)?.build()?, + )) + } + + fn create_initial_ctx() -> Result<(SessionContext, SchemaRef, RecordBatch)> { + // Create a new session context + let session_ctx = SessionContext::new(); + // Create a new schema with one field called "a" of type Int32 + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + + // Create a new batch of data to insert into the table + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from_slice([1, 2, 3]))], + )?; + Ok((session_ctx, schema, batch)) + } + + #[tokio::test] + async fn test_insert_into_single_partition() -> Result<()> { + let (session_ctx, schema, batch) = create_initial_ctx()?; + let initial_table = Arc::new(MemTable::try_new( + schema.clone(), + vec![vec![batch.clone()]], + )?); + // Create a table scan logical plan to read from the table + let single_partition_table_scan = + create_mem_table_scan(schema.clone(), vec![vec![batch.clone()]])?; + // Insert the data from the provider into the table + initial_table + .insert_into(&session_ctx.state(), &single_partition_table_scan) + .await?; + // Ensure that the table now contains two batches of data in the same partition + assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 2); + + // Create a new provider with 2 partitions + let multi_partition_table_scan = create_mem_table_scan( + schema.clone(), + vec![vec![batch.clone()], vec![batch]], + )?; + + // Insert the data from the provider into the table. We expect coalescing partitions. + initial_table + .insert_into(&session_ctx.state(), &multi_partition_table_scan) + .await?; + // Ensure that the table now contains 4 batches of data with only 1 partition + assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 4); + assert_eq!(initial_table.batches.read().await.len(), 1); + Ok(()) + } + + #[tokio::test] + async fn test_insert_into_multiple_partition() -> Result<()> { + let (session_ctx, schema, batch) = create_initial_ctx()?; + // create a memory table with two partitions, each having one batch with the same data + let initial_table = Arc::new(MemTable::try_new( + schema.clone(), + vec![vec![batch.clone()], vec![batch.clone()]], + )?); + + // scan a data source provider from a memory table with a single partition + let single_partition_table_scan = create_mem_table_scan( + schema.clone(), + vec![vec![batch.clone(), batch.clone()]], + )?; + + // insert the data from the 1 partition data source provider into the initial table + initial_table + .insert_into(&session_ctx.state(), &single_partition_table_scan) + .await?; + + // We expect round robin repartition here, each partition gets 1 batch. + assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 2); + assert_eq!(initial_table.batches.read().await.get(1).unwrap().len(), 2); + + // scan a data source provider from a memory table with 2 partition + let multi_partition_table_scan = create_mem_table_scan( + schema.clone(), + vec![vec![batch.clone()], vec![batch]], + )?; + // We expect one-to-one partition mapping. + initial_table + .insert_into(&session_ctx.state(), &multi_partition_table_scan) + .await?; + // Ensure that the table now contains 3 batches of data with 2 partitions. + assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 3); + assert_eq!(initial_table.batches.read().await.get(1).unwrap().len(), 3); + Ok(()) + } + + #[tokio::test] + async fn test_insert_into_empty_table() -> Result<()> { + let (session_ctx, schema, batch) = create_initial_ctx()?; + // create empty memory table + let initial_table = Arc::new(MemTable::try_new(schema.clone(), vec![])?); + + // scan a data source provider from a memory table with a single partition + let single_partition_table_scan = create_mem_table_scan( + schema.clone(), + vec![vec![batch.clone(), batch.clone()]], + )?; + + // insert the data from the 1 partition data source provider into the initial table + initial_table + .insert_into(&session_ctx.state(), &single_partition_table_scan) + .await?; + + assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 2); + + // scan a data source provider from a memory table with 2 partition + let single_partition_table_scan = create_mem_table_scan( + schema.clone(), + vec![vec![batch.clone()], vec![batch]], + )?; + // We expect coalesce partitions here. + initial_table + .insert_into(&session_ctx.state(), &single_partition_table_scan) + .await?; + // Ensure that the table now contains 3 batches of data with 2 partitions. + assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 4); + Ok(()) + } } diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index c985ccdc5e3a..397f4109b62d 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -31,17 +31,13 @@ use crate::{ optimizer::PhysicalOptimizerRule, }, }; -use datafusion_expr::{DescribeTable, StringifiedPlan}; +use datafusion_expr::{DescribeTable, DmlStatement, StringifiedPlan, WriteOp}; pub use datafusion_physical_expr::execution_props::ExecutionProps; use datafusion_physical_expr::var_provider::is_system_variables; use parking_lot::RwLock; use std::collections::hash_map::Entry; +use std::string::String; use std::sync::Arc; -use std::{ - any::{Any, TypeId}, - hash::{BuildHasherDefault, Hasher}, - string::String, -}; use std::{ collections::{HashMap, HashSet}, fmt::Debug, @@ -87,7 +83,7 @@ use crate::physical_plan::PhysicalPlanner; use crate::variable::{VarProvider, VarType}; use async_trait::async_trait; use chrono::{DateTime, Utc}; -use datafusion_common::{config::Extensions, OwnedTableReference, ScalarValue}; +use datafusion_common::{config::Extensions, OwnedTableReference}; use datafusion_sql::{ parser::DFParser, planner::{ContextProvider, SqlToRel}, @@ -107,6 +103,9 @@ use datafusion_optimizer::OptimizerConfig; use datafusion_sql::planner::object_name_to_table_reference; use uuid::Uuid; +// backwards compatibility +pub use datafusion_execution::config::SessionConfig; + use super::options::{ AvroReadOptions, CsvReadOptions, NdJsonReadOptions, ParquetReadOptions, ReadOptions, }; @@ -282,7 +281,7 @@ impl SessionContext { self.session_id.clone() } - /// Return the [`TableFactoryProvider`] that is registered for the + /// Return the [`TableProviderFactory`] that is registered for the /// specified file type, if any. pub fn table_factory( &self, @@ -296,7 +295,7 @@ impl SessionContext { self.state .read() .config - .options + .options() .sql_parser .enable_ident_normalization } @@ -308,7 +307,8 @@ impl SessionContext { /// Creates a [`DataFrame`] that will execute a SQL query. /// - /// Note: This API implements DDL such as `CREATE TABLE` and `CREATE VIEW` with in-memory + /// Note: This API implements DDL statements such as `CREATE TABLE` and + /// `CREATE VIEW` and DML statements such as `INSERT INTO` with in-memory /// default implementations. /// /// If this is not desirable, consider using [`SessionState::create_logical_plan()`] which @@ -318,6 +318,24 @@ impl SessionContext { let plan = self.state().create_logical_plan(sql).await?; match plan { + LogicalPlan::Dml(DmlStatement { + table_name, + op: WriteOp::Insert, + input, + .. + }) => { + if self.table_exist(&table_name)? { + let name = table_name.table(); + let provider = self.table_provider(name).await?; + provider.insert_into(&self.state(), &input).await?; + } else { + return Err(DataFusionError::Execution(format!( + "Table '{}' does not exist", + table_name + ))); + } + self.return_empty_dataframe() + } LogicalPlan::CreateExternalTable(cmd) => { self.create_external_table(&cmd).await } @@ -423,8 +441,7 @@ impl SessionContext { variable, value, .. }) => { let mut state = self.state.write(); - let config_options = &mut state.config.options; - config_options.set(&variable, &value)?; + state.config.options_mut().set(&variable, &value)?; drop(state); self.return_empty_dataframe() @@ -445,7 +462,7 @@ impl SessionContext { let (catalog, schema_name) = match tokens.len() { 1 => { let state = self.state.read(); - let name = &state.config.options.catalog.default_catalog; + let name = &state.config.options().catalog.default_catalog; let catalog = state.catalog_list.catalog(name).ok_or_else(|| { DataFusionError::Execution(format!( @@ -1001,10 +1018,9 @@ impl SessionContext { table_ref: impl Into>, ) -> Result { let table_ref = table_ref.into(); - let table = table_ref.table().to_owned(); - let provider = self.table_provider(table_ref).await?; + let provider = self.table_provider(table_ref.to_owned_reference()).await?; let plan = LogicalPlanBuilder::scan( - &table, + table_ref.to_owned_reference(), provider_as_source(Arc::clone(&provider)), None, )? @@ -1018,7 +1034,7 @@ impl SessionContext { table_ref: impl Into>, ) -> Result> { let table_ref = table_ref.into(); - let table = table_ref.table().to_owned(); + let table = table_ref.table().to_string(); let schema = self.state.read().schema_for_ref(table_ref)?; match schema.table(&table).await { Some(ref provider) => Ok(Arc::clone(provider)), @@ -1161,344 +1177,6 @@ impl QueryPlanner for DefaultQueryPlanner { } } -/// Map that holds opaque objects indexed by their type. -/// -/// Data is wrapped into an [`Arc`] to enable [`Clone`] while still being [object safe]. -/// -/// [object safe]: https://doc.rust-lang.org/reference/items/traits.html#object-safety -type AnyMap = - HashMap, BuildHasherDefault>; - -/// Hasher for [`AnyMap`]. -/// -/// With [`TypeId`]s as keys, there's no need to hash them. They are already hashes themselves, coming from the compiler. -/// The [`IdHasher`] just holds the [`u64`] of the [`TypeId`], and then returns it, instead of doing any bit fiddling. -#[derive(Default)] -struct IdHasher(u64); - -impl Hasher for IdHasher { - fn write(&mut self, _: &[u8]) { - unreachable!("TypeId calls write_u64"); - } - - #[inline] - fn write_u64(&mut self, id: u64) { - self.0 = id; - } - - #[inline] - fn finish(&self) -> u64 { - self.0 - } -} - -/// Configuration options for session context -#[derive(Clone)] -pub struct SessionConfig { - /// Configuration options - options: ConfigOptions, - /// Opaque extensions. - extensions: AnyMap, -} - -impl Default for SessionConfig { - fn default() -> Self { - Self { - options: ConfigOptions::new(), - // Assume no extensions by default. - extensions: HashMap::with_capacity_and_hasher( - 0, - BuildHasherDefault::default(), - ), - } - } -} - -impl SessionConfig { - /// Create an execution config with default setting - pub fn new() -> Self { - Default::default() - } - - /// Create an execution config with config options read from the environment - pub fn from_env() -> Result { - Ok(ConfigOptions::from_env()?.into()) - } - - /// Set a configuration option - pub fn set(mut self, key: &str, value: ScalarValue) -> Self { - self.options.set(key, &value.to_string()).unwrap(); - self - } - - /// Set a boolean configuration option - pub fn set_bool(self, key: &str, value: bool) -> Self { - self.set(key, ScalarValue::Boolean(Some(value))) - } - - /// Set a generic `u64` configuration option - pub fn set_u64(self, key: &str, value: u64) -> Self { - self.set(key, ScalarValue::UInt64(Some(value))) - } - - /// Set a generic `usize` configuration option - pub fn set_usize(self, key: &str, value: usize) -> Self { - let value: u64 = value.try_into().expect("convert usize to u64"); - self.set(key, ScalarValue::UInt64(Some(value))) - } - - /// Set a generic `str` configuration option - pub fn set_str(self, key: &str, value: &str) -> Self { - self.set(key, ScalarValue::Utf8(Some(value.to_string()))) - } - - /// Customize batch size - pub fn with_batch_size(mut self, n: usize) -> Self { - // batch size must be greater than zero - assert!(n > 0); - self.options.execution.batch_size = n; - self - } - - /// Customize [`target_partitions`] - /// - /// [`target_partitions`]: crate::config::ExecutionOptions::target_partitions - pub fn with_target_partitions(mut self, n: usize) -> Self { - // partition count must be greater than zero - assert!(n > 0); - self.options.execution.target_partitions = n; - self - } - - /// Get [`target_partitions`] - /// - /// [`target_partitions`]: crate::config::ExecutionOptions::target_partitions - pub fn target_partitions(&self) -> usize { - self.options.execution.target_partitions - } - - /// Is the information schema enabled? - pub fn information_schema(&self) -> bool { - self.options.catalog.information_schema - } - - /// Should the context create the default catalog and schema? - pub fn create_default_catalog_and_schema(&self) -> bool { - self.options.catalog.create_default_catalog_and_schema - } - - /// Are joins repartitioned during execution? - pub fn repartition_joins(&self) -> bool { - self.options.optimizer.repartition_joins - } - - /// Are aggregates repartitioned during execution? - pub fn repartition_aggregations(&self) -> bool { - self.options.optimizer.repartition_aggregations - } - - /// Are window functions repartitioned during execution? - pub fn repartition_window_functions(&self) -> bool { - self.options.optimizer.repartition_windows - } - - /// Do we execute sorts in a per-partition fashion and merge afterwards, - /// or do we coalesce partitions first and sort globally? - pub fn repartition_sorts(&self) -> bool { - self.options.optimizer.repartition_sorts - } - - /// Are statistics collected during execution? - pub fn collect_statistics(&self) -> bool { - self.options.execution.collect_statistics - } - - /// Selects a name for the default catalog and schema - pub fn with_default_catalog_and_schema( - mut self, - catalog: impl Into, - schema: impl Into, - ) -> Self { - self.options.catalog.default_catalog = catalog.into(); - self.options.catalog.default_schema = schema.into(); - self - } - - /// Controls whether the default catalog and schema will be automatically created - pub fn with_create_default_catalog_and_schema(mut self, create: bool) -> Self { - self.options.catalog.create_default_catalog_and_schema = create; - self - } - - /// Enables or disables the inclusion of `information_schema` virtual tables - pub fn with_information_schema(mut self, enabled: bool) -> Self { - self.options.catalog.information_schema = enabled; - self - } - - /// Enables or disables the use of repartitioning for joins to improve parallelism - pub fn with_repartition_joins(mut self, enabled: bool) -> Self { - self.options.optimizer.repartition_joins = enabled; - self - } - - /// Enables or disables the use of repartitioning for aggregations to improve parallelism - pub fn with_repartition_aggregations(mut self, enabled: bool) -> Self { - self.options.optimizer.repartition_aggregations = enabled; - self - } - - /// Sets minimum file range size for repartitioning scans - pub fn with_repartition_file_min_size(mut self, size: usize) -> Self { - self.options.optimizer.repartition_file_min_size = size; - self - } - - /// Enables or disables the use of repartitioning for file scans - pub fn with_repartition_file_scans(mut self, enabled: bool) -> Self { - self.options.optimizer.repartition_file_scans = enabled; - self - } - - /// Enables or disables the use of repartitioning for window functions to improve parallelism - pub fn with_repartition_windows(mut self, enabled: bool) -> Self { - self.options.optimizer.repartition_windows = enabled; - self - } - - /// Enables or disables the use of per-partition sorting to improve parallelism - pub fn with_repartition_sorts(mut self, enabled: bool) -> Self { - self.options.optimizer.repartition_sorts = enabled; - self - } - - /// Enables or disables the use of pruning predicate for parquet readers to skip row groups - pub fn with_parquet_pruning(mut self, enabled: bool) -> Self { - self.options.execution.parquet.pruning = enabled; - self - } - - /// Returns true if pruning predicate should be used to skip parquet row groups - pub fn parquet_pruning(&self) -> bool { - self.options.execution.parquet.pruning - } - - /// Enables or disables the collection of statistics after listing files - pub fn with_collect_statistics(mut self, enabled: bool) -> Self { - self.options.execution.collect_statistics = enabled; - self - } - - /// Get the currently configured batch size - pub fn batch_size(&self) -> usize { - self.options.execution.batch_size - } - - /// Convert configuration options to name-value pairs with values - /// converted to strings. - /// - /// Note that this method will eventually be deprecated and - /// replaced by [`config_options`]. - /// - /// [`config_options`]: Self::config_options - pub fn to_props(&self) -> HashMap { - let mut map = HashMap::new(); - // copy configs from config_options - for entry in self.options.entries() { - map.insert(entry.key, entry.value.unwrap_or_default()); - } - - map - } - - /// Return a handle to the configuration options. - pub fn config_options(&self) -> &ConfigOptions { - &self.options - } - - /// Return a mutable handle to the configuration options. - pub fn config_options_mut(&mut self) -> &mut ConfigOptions { - &mut self.options - } - - /// Add extensions. - /// - /// Extensions can be used to attach extra data to the session config -- e.g. tracing information or caches. - /// Extensions are opaque and the types are unknown to DataFusion itself, which makes them extremely flexible. [^1] - /// - /// Extensions are stored within an [`Arc`] so they do NOT require [`Clone`]. The are immutable. If you need to - /// modify their state over their lifetime -- e.g. for caches -- you need to establish some for of interior mutability. - /// - /// Extensions are indexed by their type `T`. If multiple values of the same type are provided, only the last one - /// will be kept. - /// - /// You may use [`get_extension`](Self::get_extension) to retrieve extensions. - /// - /// # Example - /// ``` - /// use std::sync::Arc; - /// use datafusion::execution::context::SessionConfig; - /// - /// // application-specific extension types - /// struct Ext1(u8); - /// struct Ext2(u8); - /// struct Ext3(u8); - /// - /// let ext1a = Arc::new(Ext1(10)); - /// let ext1b = Arc::new(Ext1(11)); - /// let ext2 = Arc::new(Ext2(2)); - /// - /// let cfg = SessionConfig::default() - /// // will only remember the last Ext1 - /// .with_extension(Arc::clone(&ext1a)) - /// .with_extension(Arc::clone(&ext1b)) - /// .with_extension(Arc::clone(&ext2)); - /// - /// let ext1_received = cfg.get_extension::().unwrap(); - /// assert!(!Arc::ptr_eq(&ext1_received, &ext1a)); - /// assert!(Arc::ptr_eq(&ext1_received, &ext1b)); - /// - /// let ext2_received = cfg.get_extension::().unwrap(); - /// assert!(Arc::ptr_eq(&ext2_received, &ext2)); - /// - /// assert!(cfg.get_extension::().is_none()); - /// ``` - /// - /// [^1]: Compare that to [`ConfigOptions`] which only supports [`ScalarValue`] payloads. - pub fn with_extension(mut self, ext: Arc) -> Self - where - T: Send + Sync + 'static, - { - let ext = ext as Arc; - let id = TypeId::of::(); - self.extensions.insert(id, ext); - self - } - - /// Get extension, if any for the specified type `T` exists. - /// - /// See [`with_extension`](Self::with_extension) on how to add attach extensions. - pub fn get_extension(&self) -> Option> - where - T: Send + Sync + 'static, - { - let id = TypeId::of::(); - self.extensions - .get(&id) - .cloned() - .map(|ext| Arc::downcast(ext).expect("TypeId unique")) - } -} - -impl From for SessionConfig { - fn from(options: ConfigOptions) -> Self { - Self { - options, - ..Default::default() - } - } -} - /// Execution context for registering data sources and executing queries #[derive(Clone)] pub struct SessionState { @@ -1575,7 +1253,7 @@ impl SessionState { default_catalog .register_schema( - &config.config_options().catalog.default_schema, + &config.options().catalog.default_schema, Arc::new(MemorySchemaProvider::new()), ) .expect("memory catalog provider can register schema"); @@ -1588,7 +1266,7 @@ impl SessionState { ); catalog_list.register_catalog( - config.config_options().catalog.default_catalog.clone(), + config.options().catalog.default_catalog.clone(), Arc::new(default_catalog), ); } @@ -1664,8 +1342,8 @@ impl SessionState { runtime: &Arc, default_catalog: &MemoryCatalogProvider, ) { - let url = config.options.catalog.location.as_ref(); - let format = config.options.catalog.format.as_ref(); + let url = config.options().catalog.location.as_ref(); + let format = config.options().catalog.format.as_ref(); let (url, format) = match (url, format) { (Some(url), Some(format)) => (url, format), _ => return, @@ -1673,7 +1351,7 @@ impl SessionState { let url = url.to_string(); let format = format.to_string(); - let has_header = config.options.catalog.has_header; + let has_header = config.options().catalog.has_header; let url = Url::parse(url.as_str()).expect("Invalid default catalog location!"); let authority = match url.host_str() { Some(host) => format!("{}://{}", url.scheme(), host), @@ -1889,7 +1567,7 @@ impl SessionState { } let enable_ident_normalization = - self.config.options.sql_parser.enable_ident_normalization; + self.config.options().sql_parser.enable_ident_normalization; relations .into_iter() .map(|x| object_name_to_table_reference(x, enable_ident_normalization)) @@ -1909,12 +1587,12 @@ impl SessionState { }; let enable_ident_normalization = - self.config.options.sql_parser.enable_ident_normalization; + self.config.options().sql_parser.enable_ident_normalization; let parse_float_as_decimal = - self.config.options.sql_parser.parse_float_as_decimal; + self.config.options().sql_parser.parse_float_as_decimal; for reference in references { let table = reference.table(); - let resolved = self.resolve_table_ref(reference.as_table_reference()); + let resolved = self.resolve_table_ref(&reference); if let Entry::Vacant(v) = provider.tables.entry(resolved.to_string()) { if let Ok(schema) = self.schema_for_ref(resolved) { if let Some(table) = schema.table(table).await { @@ -2020,7 +1698,7 @@ impl SessionState { /// return the configuration options pub fn config_options(&self) -> &ConfigOptions { - self.config.config_options() + self.config.options() } /// Get a new TaskContext to run in this session @@ -2958,7 +2636,7 @@ mod tests { let test = task_context .session_config() - .config_options() + .options() .extensions .get::(); assert!(test.is_some()); diff --git a/datafusion/core/src/execution/mod.rs b/datafusion/core/src/execution/mod.rs index 5586c2ce3ce7..ad9b9ce2125b 100644 --- a/datafusion/core/src/execution/mod.rs +++ b/datafusion/core/src/execution/mod.rs @@ -43,12 +43,12 @@ pub mod context; // backwards compatibility pub use crate::datasource::file_format::options; -pub mod runtime_env; // backwards compatibility pub use datafusion_execution::disk_manager; pub use datafusion_execution::memory_pool; pub use datafusion_execution::registry; +pub use datafusion_execution::runtime_env; pub use disk_manager::DiskManager; pub use registry::FunctionRegistry; diff --git a/datafusion/core/src/physical_optimizer/coalesce_batches.rs b/datafusion/core/src/physical_optimizer/coalesce_batches.rs index a1566c37cd2d..9fd29cf89674 100644 --- a/datafusion/core/src/physical_optimizer/coalesce_batches.rs +++ b/datafusion/core/src/physical_optimizer/coalesce_batches.rs @@ -24,7 +24,7 @@ use crate::{ physical_optimizer::PhysicalOptimizerRule, physical_plan::{ coalesce_batches::CoalesceBatchesExec, filter::FilterExec, joins::HashJoinExec, - repartition::RepartitionExec, rewrite::TreeNodeRewritable, Partitioning, + repartition::RepartitionExec, tree_node::TreeNodeRewritable, Partitioning, }, }; use std::sync::Arc; diff --git a/datafusion/core/src/physical_optimizer/dist_enforcement.rs b/datafusion/core/src/physical_optimizer/dist_enforcement.rs index 919273af7467..95d4427e6dc4 100644 --- a/datafusion/core/src/physical_optimizer/dist_enforcement.rs +++ b/datafusion/core/src/physical_optimizer/dist_enforcement.rs @@ -28,8 +28,8 @@ use crate::physical_plan::joins::{ }; use crate::physical_plan::projection::ProjectionExec; use crate::physical_plan::repartition::RepartitionExec; -use crate::physical_plan::rewrite::TreeNodeRewritable; use crate::physical_plan::sorts::sort::SortOptions; +use crate::physical_plan::tree_node::TreeNodeRewritable; use crate::physical_plan::windows::WindowAggExec; use crate::physical_plan::Partitioning; use crate::physical_plan::{with_new_children_if_necessary, Distribution, ExecutionPlan}; diff --git a/datafusion/core/src/physical_optimizer/global_sort_selection.rs b/datafusion/core/src/physical_optimizer/global_sort_selection.rs index 81b4b59e3a14..647558fbff12 100644 --- a/datafusion/core/src/physical_optimizer/global_sort_selection.rs +++ b/datafusion/core/src/physical_optimizer/global_sort_selection.rs @@ -22,9 +22,9 @@ use std::sync::Arc; use crate::config::ConfigOptions; use crate::error::Result; use crate::physical_optimizer::PhysicalOptimizerRule; -use crate::physical_plan::rewrite::TreeNodeRewritable; use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; +use crate::physical_plan::tree_node::TreeNodeRewritable; use crate::physical_plan::ExecutionPlan; /// Currently for a sort operator, if diff --git a/datafusion/core/src/physical_optimizer/join_selection.rs b/datafusion/core/src/physical_optimizer/join_selection.rs index d9787881f161..2308b2c85dda 100644 --- a/datafusion/core/src/physical_optimizer/join_selection.rs +++ b/datafusion/core/src/physical_optimizer/join_selection.rs @@ -32,7 +32,7 @@ use crate::physical_plan::{ExecutionPlan, PhysicalExpr}; use super::optimizer::PhysicalOptimizerRule; use crate::error::Result; -use crate::physical_plan::rewrite::TreeNodeRewritable; +use crate::physical_plan::tree_node::TreeNodeRewritable; /// For hash join with the partition mode [PartitionMode::Auto], JoinSelection rule will make /// a cost based decision to select which PartitionMode mode(Partitioned/CollectLeft) is optimal diff --git a/datafusion/core/src/physical_optimizer/pipeline_checker.rs b/datafusion/core/src/physical_optimizer/pipeline_checker.rs index 8a6b0e003adf..f097196cd1fc 100644 --- a/datafusion/core/src/physical_optimizer/pipeline_checker.rs +++ b/datafusion/core/src/physical_optimizer/pipeline_checker.rs @@ -22,7 +22,7 @@ use crate::config::ConfigOptions; use crate::error::Result; use crate::physical_optimizer::PhysicalOptimizerRule; -use crate::physical_plan::rewrite::TreeNodeRewritable; +use crate::physical_plan::tree_node::TreeNodeRewritable; use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; use std::sync::Arc; diff --git a/datafusion/core/src/physical_optimizer/pipeline_fixer.rs b/datafusion/core/src/physical_optimizer/pipeline_fixer.rs index 7e85f0c0d5f7..7532914c1258 100644 --- a/datafusion/core/src/physical_optimizer/pipeline_fixer.rs +++ b/datafusion/core/src/physical_optimizer/pipeline_fixer.rs @@ -34,7 +34,7 @@ use crate::physical_plan::joins::{ convert_sort_expr_with_filter_schema, HashJoinExec, PartitionMode, SymmetricHashJoinExec, }; -use crate::physical_plan::rewrite::TreeNodeRewritable; +use crate::physical_plan::tree_node::TreeNodeRewritable; use crate::physical_plan::ExecutionPlan; use datafusion_common::DataFusionError; use datafusion_expr::logical_plan::JoinType; diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index 80b72e68f6ea..9185bf04df82 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -46,7 +46,7 @@ use arrow::{ record_batch::RecordBatch, }; use datafusion_common::{downcast_value, ScalarValue}; -use datafusion_physical_expr::rewrite::{TreeNodeRewritable, TreeNodeRewriter}; +use datafusion_physical_expr::rewrite::TreeNodeRewritable; use datafusion_physical_expr::utils::collect_columns; use datafusion_physical_expr::{expressions as phys_expr, PhysicalExprRef}; use log::trace; @@ -643,28 +643,15 @@ fn rewrite_column_expr( column_old: &phys_expr::Column, column_new: &phys_expr::Column, ) -> Result> { - let mut rewriter = RewriteColumnExpr { - column_old, - column_new, - }; - e.transform_using(&mut rewriter) -} - -struct RewriteColumnExpr<'a> { - column_old: &'a phys_expr::Column, - column_new: &'a phys_expr::Column, -} - -impl<'a> TreeNodeRewriter> for RewriteColumnExpr<'a> { - fn mutate(&mut self, expr: Arc) -> Result> { + e.transform(&|expr| { if let Some(column) = expr.as_any().downcast_ref::() { - if column == self.column_old { - return Ok(Arc::new(self.column_new.clone())); + if column == column_old { + return Ok(Some(Arc::new(column_new.clone()))); } } - Ok(expr) - } + Ok(None) + }) } fn reverse_operator(op: Operator) -> Result { diff --git a/datafusion/core/src/physical_optimizer/sort_enforcement.rs b/datafusion/core/src/physical_optimizer/sort_enforcement.rs index 70880b750552..261c19600c81 100644 --- a/datafusion/core/src/physical_optimizer/sort_enforcement.rs +++ b/datafusion/core/src/physical_optimizer/sort_enforcement.rs @@ -39,9 +39,9 @@ use crate::physical_optimizer::utils::add_sort_above; use crate::physical_optimizer::PhysicalOptimizerRule; use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; use crate::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; -use crate::physical_plan::rewrite::TreeNodeRewritable; use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; +use crate::physical_plan::tree_node::TreeNodeRewritable; use crate::physical_plan::union::UnionExec; use crate::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec}; use crate::physical_plan::{with_new_children_if_necessary, Distribution, ExecutionPlan}; diff --git a/datafusion/core/src/physical_plan/common.rs b/datafusion/core/src/physical_plan/common.rs index 2f02aaa272c8..7fb67d758f3d 100644 --- a/datafusion/core/src/physical_plan/common.rs +++ b/datafusion/core/src/physical_plan/common.rs @@ -42,11 +42,6 @@ use tokio::task::JoinHandle; /// [`MemoryReservation`] used across query execution streams pub(crate) type SharedMemoryReservation = Arc>; -/// [`MemoryReservation`] used at query operator level -/// `Option` wrapper allows to initialize empty reservation in operator constructor, -/// and set it to actual reservation at stream level. -pub(crate) type OperatorMemoryReservation = Arc>>; - /// Stream of record batches pub struct SizedRecordBatchStream { schema: SchemaRef, diff --git a/datafusion/core/src/physical_plan/file_format/avro.rs b/datafusion/core/src/physical_plan/file_format/avro.rs index 9ac1d23bac83..e4fc570feb52 100644 --- a/datafusion/core/src/physical_plan/file_format/avro.rs +++ b/datafusion/core/src/physical_plan/file_format/avro.rs @@ -213,7 +213,6 @@ mod tests { use crate::datasource::listing::PartitionedFile; use crate::datasource::object_store::ObjectStoreUrl; use crate::physical_plan::file_format::chunked_store::ChunkedStore; - use crate::physical_plan::file_format::partition_type_wrap; use crate::prelude::SessionContext; use crate::scalar::ScalarValue; use crate::test::object_store::local_unpartitioned_file; @@ -409,10 +408,7 @@ mod tests { file_schema, statistics: Statistics::default(), limit: None, - table_partition_cols: vec![( - "date".to_owned(), - partition_type_wrap(DataType::Utf8), - )], + table_partition_cols: vec![("date".to_owned(), DataType::Utf8)], output_ordering: None, infinite_source: false, }); diff --git a/datafusion/core/src/physical_plan/file_format/csv.rs b/datafusion/core/src/physical_plan/file_format/csv.rs index 346c9915815d..3fc7df4f1047 100644 --- a/datafusion/core/src/physical_plan/file_format/csv.rs +++ b/datafusion/core/src/physical_plan/file_format/csv.rs @@ -326,7 +326,6 @@ mod tests { use super::*; use crate::datasource::file_format::file_type::FileType; use crate::physical_plan::file_format::chunked_store::ChunkedStore; - use crate::physical_plan::file_format::partition_type_wrap; use crate::prelude::*; use crate::test::{partitioned_csv_config, partitioned_file_groups}; use crate::test_util::{aggr_test_schema_with_missing_col, arrow_test_data}; @@ -580,8 +579,7 @@ mod tests { let mut config = partitioned_csv_config(file_schema, file_groups)?; // Add partition columns - config.table_partition_cols = - vec![("date".to_owned(), partition_type_wrap(DataType::Utf8))]; + config.table_partition_cols = vec![("date".to_owned(), DataType::Utf8)]; config.file_groups[0][0].partition_values = vec![ScalarValue::Utf8(Some("2021-10-26".to_owned()))]; diff --git a/datafusion/core/src/physical_plan/file_format/json.rs b/datafusion/core/src/physical_plan/file_format/json.rs index 6f26080122c6..146227335c99 100644 --- a/datafusion/core/src/physical_plan/file_format/json.rs +++ b/datafusion/core/src/physical_plan/file_format/json.rs @@ -73,6 +73,11 @@ impl NdJsonExec { file_compression_type, } } + + /// Ref to the base configs + pub fn base_config(&self) -> &FileScanConfig { + &self.base_config + } } impl ExecutionPlan for NdJsonExec { diff --git a/datafusion/core/src/physical_plan/file_format/mod.rs b/datafusion/core/src/physical_plan/file_format/mod.rs index 97b091cedfef..d4ef60a41588 100644 --- a/datafusion/core/src/physical_plan/file_format/mod.rs +++ b/datafusion/core/src/physical_plan/file_format/mod.rs @@ -30,9 +30,9 @@ pub use self::csv::CsvExec; pub(crate) use self::parquet::plan_to_parquet; pub use self::parquet::{ParquetExec, ParquetFileMetrics, ParquetFileReaderFactory}; use arrow::{ - array::{ArrayData, ArrayRef, DictionaryArray}, + array::{ArrayData, ArrayRef, BufferBuilder, DictionaryArray}, buffer::Buffer, - datatypes::{DataType, Field, Schema, SchemaRef, UInt16Type}, + datatypes::{ArrowNativeType, DataType, Field, Schema, SchemaRef, UInt16Type}, record_batch::RecordBatch, }; pub use avro::AvroExec; @@ -45,29 +45,95 @@ use crate::datasource::{ listing::{FileRange, PartitionedFile}, object_store::ObjectStoreUrl, }; +use crate::physical_plan::tree_node::{ + TreeNodeVisitable, TreeNodeVisitor, VisitRecursion, +}; +use crate::physical_plan::ExecutionPlan; use crate::{ error::{DataFusionError, Result}, scalar::ScalarValue, }; -use arrow::array::{new_null_array, UInt16BufferBuilder}; +use arrow::array::new_null_array; use arrow::record_batch::RecordBatchOptions; -use log::{debug, info}; +use log::{debug, info, warn}; use object_store::path::Path; use object_store::ObjectMeta; use std::{ + borrow::Cow, collections::HashMap, fmt::{Display, Formatter, Result as FmtResult}, + marker::PhantomData, sync::Arc, vec, }; use super::{ColumnStatistics, Statistics}; -/// Convert logical type of partition column to physical type: `Dictionary(UInt16, val_type)` -pub fn partition_type_wrap(val_type: DataType) -> DataType { +/// Convert type to a type suitable for use as a [`ListingTable`] +/// partition column. Returns `Dictionary(UInt16, val_type)`, which is +/// a reasonable trade off between a reasonable number of partition +/// values and space efficiency. +/// +/// This use this to specify types for partition columns. However +/// you MAY also choose not to dictionary-encode the data or to use a +/// different dictionary type. +/// +/// Use [`wrap_partition_value_in_dict`] to wrap a [`ScalarValue`] in the same say. +pub fn wrap_partition_type_in_dict(val_type: DataType) -> DataType { DataType::Dictionary(Box::new(DataType::UInt16), Box::new(val_type)) } +/// Convert a [`ScalarValue`] of partition columns to a type, as +/// decribed in the documentation of [`wrap_partition_type_in_dict`], +/// which can wrap the types. +pub fn wrap_partition_value_in_dict(val: ScalarValue) -> ScalarValue { + ScalarValue::Dictionary(Box::new(DataType::UInt16), Box::new(val)) +} + +/// Get all of the [`PartitionedFile`] to be scanned for an [`ExecutionPlan`] +pub fn get_scan_files( + plan: Arc, +) -> Result>>> { + let mut collector = FileScanCollector::new(); + plan.accept(&mut collector)?; + Ok(collector.file_groups) +} + +struct FileScanCollector { + file_groups: Vec>>, +} + +impl FileScanCollector { + fn new() -> Self { + Self { + file_groups: vec![], + } + } +} + +impl TreeNodeVisitor for FileScanCollector { + type N = Arc; + + fn pre_visit(&mut self, node: &Self::N) -> Result { + let plan_any = node.as_any(); + let file_groups = + if let Some(parquet_exec) = plan_any.downcast_ref::() { + parquet_exec.base_config().file_groups.clone() + } else if let Some(avro_exec) = plan_any.downcast_ref::() { + avro_exec.base_config().file_groups.clone() + } else if let Some(json_exec) = plan_any.downcast_ref::() { + json_exec.base_config().file_groups.clone() + } else if let Some(csv_exec) = plan_any.downcast_ref::() { + csv_exec.base_config().file_groups.clone() + } else { + return Ok(VisitRecursion::Continue); + }; + + self.file_groups.push(file_groups); + Ok(VisitRecursion::Stop) + } +} + /// The base configurations to provide when creating a physical plan for /// any given file format. #[derive(Debug, Clone)] @@ -346,7 +412,7 @@ struct PartitionColumnProjector { /// An Arrow buffer initialized to zeros that represents the key array of all partition /// columns (partition columns are materialized by dictionary arrays with only one /// value in the dictionary, thus all the keys are equal to zero). - key_buffer_cache: Option, + key_buffer_cache: ZeroBufferGenerators, /// Mapping between the indexes in the list of partition columns and the target /// schema. Sorted by index in the target schema so that we can iterate on it to /// insert the partition columns in the target record batch. @@ -372,7 +438,7 @@ impl PartitionColumnProjector { Self { projected_partition_indexes, - key_buffer_cache: None, + key_buffer_cache: Default::default(), projected_schema, } } @@ -398,11 +464,27 @@ impl PartitionColumnProjector { } let mut cols = file_batch.columns().to_vec(); for &(pidx, sidx) in &self.projected_partition_indexes { + let mut partition_value = Cow::Borrowed(&partition_values[pidx]); + + // check if user forgot to dict-encode the partition value + let field = self.projected_schema.field(sidx); + let expected_data_type = field.data_type(); + let actual_data_type = partition_value.get_datatype(); + if let DataType::Dictionary(key_type, _) = expected_data_type { + if !matches!(actual_data_type, DataType::Dictionary(_, _)) { + warn!("Partition value for column {} was not dictionary-encoded, applied auto-fix.", field.name()); + partition_value = Cow::Owned(ScalarValue::Dictionary( + key_type.clone(), + Box::new(partition_value.as_ref().clone()), + )); + } + } + cols.insert( sidx, - create_dict_array( + create_output_array( &mut self.key_buffer_cache, - &partition_values[pidx], + partition_value.as_ref(), file_batch.num_rows(), ), ) @@ -411,26 +493,60 @@ impl PartitionColumnProjector { } } -fn create_dict_array( - key_buffer_cache: &mut Option, - val: &ScalarValue, - len: usize, -) -> ArrayRef { - // build value dictionary - let dict_vals = val.to_array(); - - // build keys array - let sliced_key_buffer = match key_buffer_cache { - Some(buf) if buf.len() >= len * 2 => buf.slice(buf.len() - len * 2), - _ => { - let mut key_buffer_builder = UInt16BufferBuilder::new(len * 2); - key_buffer_builder.advance(len * 2); // keys are all 0 - key_buffer_cache.insert(key_buffer_builder.finish()).clone() +#[derive(Debug, Default)] +struct ZeroBufferGenerators { + gen_i8: ZeroBufferGenerator, + gen_i16: ZeroBufferGenerator, + gen_i32: ZeroBufferGenerator, + gen_i64: ZeroBufferGenerator, + gen_u8: ZeroBufferGenerator, + gen_u16: ZeroBufferGenerator, + gen_u32: ZeroBufferGenerator, + gen_u64: ZeroBufferGenerator, +} + +/// Generate a arrow [`Buffer`] that contains zero values. +#[derive(Debug, Default)] +struct ZeroBufferGenerator +where + T: ArrowNativeType, +{ + cache: Option, + _t: PhantomData, +} + +impl ZeroBufferGenerator +where + T: ArrowNativeType, +{ + const SIZE: usize = std::mem::size_of::(); + + fn get_buffer(&mut self, n_vals: usize) -> Buffer { + match &mut self.cache { + Some(buf) if buf.len() >= n_vals * Self::SIZE => { + buf.slice_with_length(0, n_vals * Self::SIZE) + } + _ => { + let mut key_buffer_builder = BufferBuilder::::new(n_vals); + key_buffer_builder.advance(n_vals); // keys are all 0 + self.cache.insert(key_buffer_builder.finish()).clone() + } } - }; + } +} + +fn create_dict_array( + buffer_gen: &mut ZeroBufferGenerator, + dict_val: &ScalarValue, + len: usize, + data_type: DataType, +) -> ArrayRef +where + T: ArrowNativeType, +{ + let dict_vals = dict_val.to_array(); - // create data type - let data_type = partition_type_wrap(val.get_datatype()); + let sliced_key_buffer = buffer_gen.get_buffer(len); // assemble pieces together let mut builder = ArrayData::builder(data_type) @@ -442,6 +558,84 @@ fn create_dict_array( )) } +fn create_output_array( + key_buffer_cache: &mut ZeroBufferGenerators, + val: &ScalarValue, + len: usize, +) -> ArrayRef { + if let ScalarValue::Dictionary(key_type, dict_val) = &val { + match key_type.as_ref() { + DataType::Int8 => { + return create_dict_array( + &mut key_buffer_cache.gen_i8, + dict_val, + len, + val.get_datatype(), + ); + } + DataType::Int16 => { + return create_dict_array( + &mut key_buffer_cache.gen_i16, + dict_val, + len, + val.get_datatype(), + ); + } + DataType::Int32 => { + return create_dict_array( + &mut key_buffer_cache.gen_i32, + dict_val, + len, + val.get_datatype(), + ); + } + DataType::Int64 => { + return create_dict_array( + &mut key_buffer_cache.gen_i64, + dict_val, + len, + val.get_datatype(), + ); + } + DataType::UInt8 => { + return create_dict_array( + &mut key_buffer_cache.gen_u8, + dict_val, + len, + val.get_datatype(), + ); + } + DataType::UInt16 => { + return create_dict_array( + &mut key_buffer_cache.gen_u16, + dict_val, + len, + val.get_datatype(), + ); + } + DataType::UInt32 => { + return create_dict_array( + &mut key_buffer_cache.gen_u32, + dict_val, + len, + val.get_datatype(), + ); + } + DataType::UInt64 => { + return create_dict_array( + &mut key_buffer_cache.gen_u64, + dict_val, + len, + val.get_datatype(), + ); + } + _ => {} + } + } + + val.to_array_of_size(len) +} + /// A single file or part of a file that should be read, along with its schema, statistics pub struct FileMeta { /// Path for the file (e.g. URL, filesystem path, etc) @@ -559,7 +753,10 @@ mod tests { Arc::clone(&file_schema), None, Statistics::default(), - vec![("date".to_owned(), partition_type_wrap(DataType::Utf8))], + vec![( + "date".to_owned(), + wrap_partition_type_in_dict(DataType::Utf8), + )], ); let (proj_schema, proj_statistics) = conf.project(); @@ -605,7 +802,10 @@ mod tests { ), ..Default::default() }, - vec![("date".to_owned(), partition_type_wrap(DataType::Utf8))], + vec![( + "date".to_owned(), + wrap_partition_type_in_dict(DataType::Utf8), + )], ); let (proj_schema, proj_statistics) = conf.project(); @@ -636,9 +836,18 @@ mod tests { ("c", &vec![10, 11, 12]), ); let partition_cols = vec![ - ("year".to_owned(), partition_type_wrap(DataType::Utf8)), - ("month".to_owned(), partition_type_wrap(DataType::Utf8)), - ("day".to_owned(), partition_type_wrap(DataType::Utf8)), + ( + "year".to_owned(), + wrap_partition_type_in_dict(DataType::Utf8), + ), + ( + "month".to_owned(), + wrap_partition_type_in_dict(DataType::Utf8), + ), + ( + "day".to_owned(), + wrap_partition_type_in_dict(DataType::Utf8), + ), ]; // create a projected schema let conf = config_for_projection( @@ -670,9 +879,15 @@ mod tests { // file_batch is ok here because we kept all the file cols in the projection file_batch, &[ - ScalarValue::Utf8(Some("2021".to_owned())), - ScalarValue::Utf8(Some("10".to_owned())), - ScalarValue::Utf8(Some("26".to_owned())), + wrap_partition_value_in_dict(ScalarValue::Utf8(Some( + "2021".to_owned(), + ))), + wrap_partition_value_in_dict(ScalarValue::Utf8(Some( + "10".to_owned(), + ))), + wrap_partition_value_in_dict(ScalarValue::Utf8(Some( + "26".to_owned(), + ))), ], ) .expect("Projection of partition columns into record batch failed"); @@ -698,9 +913,15 @@ mod tests { // file_batch is ok here because we kept all the file cols in the projection file_batch, &[ - ScalarValue::Utf8(Some("2021".to_owned())), - ScalarValue::Utf8(Some("10".to_owned())), - ScalarValue::Utf8(Some("27".to_owned())), + wrap_partition_value_in_dict(ScalarValue::Utf8(Some( + "2021".to_owned(), + ))), + wrap_partition_value_in_dict(ScalarValue::Utf8(Some( + "10".to_owned(), + ))), + wrap_partition_value_in_dict(ScalarValue::Utf8(Some( + "27".to_owned(), + ))), ], ) .expect("Projection of partition columns into record batch failed"); @@ -728,9 +949,15 @@ mod tests { // file_batch is ok here because we kept all the file cols in the projection file_batch, &[ - ScalarValue::Utf8(Some("2021".to_owned())), - ScalarValue::Utf8(Some("10".to_owned())), - ScalarValue::Utf8(Some("28".to_owned())), + wrap_partition_value_in_dict(ScalarValue::Utf8(Some( + "2021".to_owned(), + ))), + wrap_partition_value_in_dict(ScalarValue::Utf8(Some( + "10".to_owned(), + ))), + wrap_partition_value_in_dict(ScalarValue::Utf8(Some( + "28".to_owned(), + ))), ], ) .expect("Projection of partition columns into record batch failed"); @@ -744,6 +971,34 @@ mod tests { "+---+---+---+------+-----+", ]; crate::assert_batches_eq!(expected, &[projected_batch]); + + // forgot to dictionary-wrap the scalar value + let file_batch = build_table_i32( + ("a", &vec![0, 1, 2]), + ("b", &vec![-2, -1, 0]), + ("c", &vec![10, 11, 12]), + ); + let projected_batch = proj + .project( + // file_batch is ok here because we kept all the file cols in the projection + file_batch, + &[ + ScalarValue::Utf8(Some("2021".to_owned())), + ScalarValue::Utf8(Some("10".to_owned())), + ScalarValue::Utf8(Some("26".to_owned())), + ], + ) + .expect("Projection of partition columns into record batch failed"); + let expected = vec![ + "+---+----+----+------+-----+", + "| a | b | c | year | day |", + "+---+----+----+------+-----+", + "| 0 | -2 | 10 | 2021 | 26 |", + "| 1 | -1 | 11 | 2021 | 26 |", + "| 2 | 0 | 12 | 2021 | 26 |", + "+---+----+----+------+-----+", + ]; + crate::assert_batches_eq!(expected, &[projected_batch]); } #[test] diff --git a/datafusion/core/src/physical_plan/file_format/parquet.rs b/datafusion/core/src/physical_plan/file_format/parquet.rs index 2c1bfc9caa51..92be32f47649 100644 --- a/datafusion/core/src/physical_plan/file_format/parquet.rs +++ b/datafusion/core/src/physical_plan/file_format/parquet.rs @@ -204,6 +204,8 @@ impl ParquetExec { /// `ParquetRecordBatchStream`. These filters are applied by the /// parquet decoder to skip unecessairly decoding other columns /// which would not pass the predicate. Defaults to false + /// + /// [`Expr`]: datafusion_expr::Expr pub fn with_pushdown_filters(mut self, pushdown_filters: bool) -> Self { self.pushdown_filters = Some(pushdown_filters); self @@ -219,6 +221,8 @@ impl ParquetExec { /// minimize the cost of filter evaluation by reordering the /// predicate [`Expr`]s. If false, the predicates are applied in /// the same order as specified in the query. Defaults to false. + /// + /// [`Expr`]: datafusion_expr::Expr pub fn with_reorder_filters(mut self, reorder_filters: bool) -> Self { self.reorder_filters = Some(reorder_filters); self @@ -372,7 +376,7 @@ impl ExecutionPlan for ParquetExec { }) })?; - let config_options = ctx.session_config().config_options(); + let config_options = ctx.session_config().options(); let opener = ParquetOpener { partition_index, @@ -810,7 +814,6 @@ mod tests { use crate::execution::context::SessionState; use crate::execution::options::CsvReadOptions; use crate::physical_plan::displayable; - use crate::physical_plan::file_format::partition_type_wrap; use crate::prelude::{ParquetReadOptions, SessionConfig, SessionContext}; use crate::test::object_store::local_unpartitioned_file; use crate::{ @@ -1656,26 +1659,50 @@ mod tests { object_meta: meta, partition_values: vec![ ScalarValue::Utf8(Some("2021".to_owned())), - ScalarValue::Utf8(Some("10".to_owned())), - ScalarValue::Utf8(Some("26".to_owned())), + ScalarValue::UInt8(Some(10)), + ScalarValue::Dictionary( + Box::new(DataType::UInt16), + Box::new(ScalarValue::Utf8(Some("26".to_owned()))), + ), ], range: None, extensions: None, }; + let expected_schema = Schema::new(vec![ + Field::new("id", DataType::Int32, true), + Field::new("bool_col", DataType::Boolean, true), + Field::new("tinyint_col", DataType::Int32, true), + Field::new("month", DataType::UInt8, false), + Field::new( + "day", + DataType::Dictionary( + Box::new(DataType::UInt16), + Box::new(DataType::Utf8), + ), + false, + ), + ]); + let parquet_exec = ParquetExec::new( FileScanConfig { object_store_url, file_groups: vec![vec![partitioned_file]], file_schema: schema, statistics: Statistics::default(), - // file has 10 cols so index 12 should be month - projection: Some(vec![0, 1, 2, 12]), + // file has 10 cols so index 12 should be month and 13 should be day + projection: Some(vec![0, 1, 2, 12, 13]), limit: None, table_partition_cols: vec![ - ("year".to_owned(), partition_type_wrap(DataType::Utf8)), - ("month".to_owned(), partition_type_wrap(DataType::Utf8)), - ("day".to_owned(), partition_type_wrap(DataType::Utf8)), + ("year".to_owned(), DataType::Utf8), + ("month".to_owned(), DataType::UInt8), + ( + "day".to_owned(), + DataType::Dictionary( + Box::new(DataType::UInt16), + Box::new(DataType::Utf8), + ), + ), ], output_ordering: None, infinite_source: false, @@ -1684,22 +1711,24 @@ mod tests { None, ); assert_eq!(parquet_exec.output_partitioning().partition_count(), 1); + assert_eq!(parquet_exec.schema().as_ref(), &expected_schema); let mut results = parquet_exec.execute(0, task_ctx)?; let batch = results.next().await.unwrap()?; + assert_eq!(batch.schema().as_ref(), &expected_schema); let expected = vec![ - "+----+----------+-------------+-------+", - "| id | bool_col | tinyint_col | month |", - "+----+----------+-------------+-------+", - "| 4 | true | 0 | 10 |", - "| 5 | false | 1 | 10 |", - "| 6 | true | 0 | 10 |", - "| 7 | false | 1 | 10 |", - "| 2 | true | 0 | 10 |", - "| 3 | false | 1 | 10 |", - "| 0 | true | 0 | 10 |", - "| 1 | false | 1 | 10 |", - "+----+----------+-------------+-------+", + "+----+----------+-------------+-------+-----+", + "| id | bool_col | tinyint_col | month | day |", + "+----+----------+-------------+-------+-----+", + "| 4 | true | 0 | 10 | 26 |", + "| 5 | false | 1 | 10 | 26 |", + "| 6 | true | 0 | 10 | 26 |", + "| 7 | false | 1 | 10 | 26 |", + "| 2 | true | 0 | 10 | 26 |", + "| 3 | false | 1 | 10 | 26 |", + "| 0 | true | 0 | 10 | 26 |", + "| 1 | false | 1 | 10 | 26 |", + "+----+----------+-------------+-------+-----+", ]; crate::assert_batches_eq!(expected, &[batch]); diff --git a/datafusion/core/src/physical_plan/file_format/parquet/row_filter.rs b/datafusion/core/src/physical_plan/file_format/parquet/row_filter.rs index e1feafec1588..91ae8afe6f29 100644 --- a/datafusion/core/src/physical_plan/file_format/parquet/row_filter.rs +++ b/datafusion/core/src/physical_plan/file_format/parquet/row_filter.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{Array, BooleanArray}; +use arrow::array::BooleanArray; use arrow::datatypes::{DataType, Schema}; use arrow::error::{ArrowError, Result as ArrowResult}; use arrow::record_batch::RecordBatch; @@ -131,8 +131,7 @@ impl ArrowPredicate for DatafusionArrowPredicate { .map(|v| v.into_array(batch.num_rows())) { Ok(array) => { - let mask = as_boolean_array(&array)?; - let bool_arr = BooleanArray::from(mask.data().clone()); + let bool_arr = as_boolean_array(&array)?.clone(); let num_filtered = bool_arr.len() - bool_arr.true_count(); self.rows_filtered.add(num_filtered); timer.stop(); @@ -219,13 +218,16 @@ impl<'a> TreeNodeRewriter> for FilterCandidateBuilder<'a> if DataType::is_nested(self.file_schema.field(idx).data_type()) { self.non_primitive_columns = true; + return Ok(RewriteRecursion::Stop); } } else if self.table_schema.index_of(column.name()).is_err() { // If the column does not exist in the (un-projected) table schema then // it must be a projected column. self.projected_columns = true; + return Ok(RewriteRecursion::Stop); } } + Ok(RewriteRecursion::Continue) } diff --git a/datafusion/core/src/physical_plan/joins/cross_join.rs b/datafusion/core/src/physical_plan/joins/cross_join.rs index d4933b9d6e0e..8492e5e6b20d 100644 --- a/datafusion/core/src/physical_plan/joins/cross_join.rs +++ b/datafusion/core/src/physical_plan/joins/cross_join.rs @@ -26,8 +26,7 @@ use arrow::datatypes::{Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use crate::execution::context::TaskContext; -use crate::execution::memory_pool::MemoryConsumer; -use crate::physical_plan::common::{OperatorMemoryReservation, SharedMemoryReservation}; +use crate::execution::memory_pool::{SharedOptionalMemoryReservation, TryGrow}; use crate::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet}; use crate::physical_plan::{ coalesce_batches::concat_batches, coalesce_partitions::CoalescePartitionsExec, @@ -38,7 +37,6 @@ use crate::physical_plan::{ use crate::{error::Result, scalar::ScalarValue}; use async_trait::async_trait; use datafusion_common::DataFusionError; -use parking_lot::Mutex; use super::utils::{ adjust_right_output_partitioning, cross_join_equivalence_properties, @@ -61,7 +59,7 @@ pub struct CrossJoinExec { /// Build-side data left_fut: OnceAsync, /// Memory reservation for build-side data - reservation: OperatorMemoryReservation, + reservation: SharedOptionalMemoryReservation, /// Execution plan metrics metrics: ExecutionPlanMetricsSet, } @@ -106,7 +104,7 @@ async fn load_left_input( left: Arc, context: Arc, metrics: BuildProbeJoinMetrics, - reservation: SharedMemoryReservation, + reservation: SharedOptionalMemoryReservation, ) -> Result { // merge all left parts into a single stream let merge = { @@ -125,7 +123,7 @@ async fn load_left_input( |mut acc, batch| async { let batch_size = batch.get_array_memory_size(); // Reserve memory for incoming batch - acc.3.lock().try_grow(batch_size)?; + acc.3.try_grow(batch_size)?; // Update metrics acc.2.build_mem_used.add(batch_size); acc.2.build_input_batches.add(1); @@ -226,27 +224,15 @@ impl ExecutionPlan for CrossJoinExec { let join_metrics = BuildProbeJoinMetrics::new(partition, &self.metrics); // Initialization of operator-level reservation - { - let mut reservation_lock = self.reservation.lock(); - if reservation_lock.is_none() { - *reservation_lock = Some(Arc::new(Mutex::new( - MemoryConsumer::new("CrossJoinExec").register(context.memory_pool()), - ))); - }; - } - - let reservation = self.reservation.lock().clone().ok_or_else(|| { - DataFusionError::Internal( - "Operator-level memory reservation is not initialized".to_string(), - ) - })?; + self.reservation + .initialize("CrossJoinExec", context.memory_pool()); let left_fut = self.left_fut.once(|| { load_left_input( self.left.clone(), context, join_metrics.clone(), - reservation, + self.reservation.clone(), ) }); diff --git a/datafusion/core/src/physical_plan/joins/hash_join.rs b/datafusion/core/src/physical_plan/joins/hash_join.rs index 0d2b897dd267..39acffa203ac 100644 --- a/datafusion/core/src/physical_plan/joins/hash_join.rs +++ b/datafusion/core/src/physical_plan/joins/hash_join.rs @@ -58,7 +58,6 @@ use hashbrown::raw::RawTable; use crate::physical_plan::{ coalesce_batches::concat_batches, coalesce_partitions::CoalescePartitionsExec, - common::{OperatorMemoryReservation, SharedMemoryReservation}, expressions::Column, expressions::PhysicalSortExpr, hash_utils::create_hashes, @@ -78,7 +77,12 @@ use crate::logical_expr::JoinType; use crate::arrow::array::BooleanBufferBuilder; use crate::arrow::datatypes::TimeUnit; -use crate::execution::{context::TaskContext, memory_pool::MemoryConsumer}; +use crate::execution::{ + context::TaskContext, + memory_pool::{ + MemoryConsumer, SharedMemoryReservation, SharedOptionalMemoryReservation, TryGrow, + }, +}; use super::{ utils::{OnceAsync, OnceFut}, @@ -88,7 +92,6 @@ use crate::physical_plan::joins::utils::{ adjust_indices_by_join_type, apply_join_filter_to_indices, build_batch_from_indices, get_final_indices_from_bit_map, need_produce_result_in_final, JoinSide, }; -use parking_lot::Mutex; use std::fmt; use std::task::Poll; @@ -137,7 +140,7 @@ pub struct HashJoinExec { /// Build-side data left_fut: OnceAsync, /// Operator-level memory reservation for left data - reservation: OperatorMemoryReservation, + reservation: SharedOptionalMemoryReservation, /// Shares the `RandomState` for the hashing algorithm random_state: RandomState, /// Partitioning mode to use @@ -378,26 +381,14 @@ impl ExecutionPlan for HashJoinExec { let join_metrics = BuildProbeJoinMetrics::new(partition, &self.metrics); // Initialization of operator-level reservation - { - let mut operator_reservation_lock = self.reservation.lock(); - if operator_reservation_lock.is_none() { - *operator_reservation_lock = Some(Arc::new(Mutex::new( - MemoryConsumer::new("HashJoinExec").register(context.memory_pool()), - ))); - }; - } - - let operator_reservation = self.reservation.lock().clone().ok_or_else(|| { - DataFusionError::Internal( - "Operator-level memory reservation is not initialized".to_string(), - ) - })?; + self.reservation + .initialize("HashJoinExec", context.memory_pool()); // Inititalization of stream-level reservation - let reservation = Arc::new(Mutex::new( + let reservation = SharedMemoryReservation::from( MemoryConsumer::new(format!("HashJoinStream[{partition}]")) .register(context.memory_pool()), - )); + ); // Memory reservation for left-side data depends on PartitionMode: // - operator-level for `CollectLeft` mode @@ -415,7 +406,7 @@ impl ExecutionPlan for HashJoinExec { on_left.clone(), context.clone(), join_metrics.clone(), - operator_reservation.clone(), + Arc::new(self.reservation.clone()), ) }), PartitionMode::Partitioned => OnceFut::new(collect_left_input( @@ -425,7 +416,7 @@ impl ExecutionPlan for HashJoinExec { on_left.clone(), context.clone(), join_metrics.clone(), - reservation.clone(), + Arc::new(reservation.clone()), )), PartitionMode::Auto => { return Err(DataFusionError::Plan(format!( @@ -497,7 +488,7 @@ async fn collect_left_input( on_left: Vec, context: Arc, metrics: BuildProbeJoinMetrics, - reservation: SharedMemoryReservation, + reservation: Arc, ) -> Result { let schema = left.schema(); @@ -526,7 +517,7 @@ async fn collect_left_input( .try_fold(initial, |mut acc, batch| async { let batch_size = batch.get_array_memory_size(); // Reserve memory for incoming batch - acc.3.lock().try_grow(batch_size)?; + acc.3.try_grow(batch_size)?; // Update metrics acc.2.build_mem_used.add(batch_size); acc.2.build_input_batches.add(1); @@ -555,7 +546,7 @@ async fn collect_left_input( // + 16 bytes fixed let estimated_hastable_size = 32 * estimated_buckets + estimated_buckets + 16; - reservation.lock().try_grow(estimated_hastable_size)?; + reservation.try_grow(estimated_hastable_size)?; metrics.build_mem_used.add(estimated_hastable_size); let mut hashmap = JoinHashMap(RawTable::with_capacity(num_rows)); @@ -1157,7 +1148,7 @@ impl HashJoinStream { // TODO: Replace `ceil` wrapper with stable `div_cell` after // https://github.com/rust-lang/rust/issues/88581 let visited_bitmap_size = bit_util::ceil(left_data.1.num_rows(), 8); - self.reservation.lock().try_grow(visited_bitmap_size)?; + self.reservation.try_grow(visited_bitmap_size)?; self.join_metrics.build_mem_used.add(visited_bitmap_size); } diff --git a/datafusion/core/src/physical_plan/joins/nested_loop_join.rs b/datafusion/core/src/physical_plan/joins/nested_loop_join.rs index c283b11f8f6a..e04e86d0d3d7 100644 --- a/datafusion/core/src/physical_plan/joins/nested_loop_join.rs +++ b/datafusion/core/src/physical_plan/joins/nested_loop_join.rs @@ -24,8 +24,10 @@ use crate::physical_plan::joins::utils::{ build_batch_from_indices, build_join_schema, check_join_is_valid, combine_join_equivalence_properties, estimate_join_statistics, get_anti_indices, get_anti_u64_indices, get_final_indices_from_bit_map, get_semi_indices, - get_semi_u64_indices, ColumnIndex, JoinFilter, JoinSide, OnceAsync, OnceFut, + get_semi_u64_indices, BuildProbeJoinMetrics, ColumnIndex, JoinFilter, JoinSide, + OnceAsync, OnceFut, }; +use crate::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet}; use crate::physical_plan::{ DisplayFormatType, Distribution, ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream, @@ -35,19 +37,21 @@ use arrow::array::{ }; use arrow::datatypes::{Schema, SchemaRef}; use arrow::record_batch::RecordBatch; +use arrow::util::bit_util; use datafusion_common::{DataFusionError, Statistics}; use datafusion_expr::JoinType; use datafusion_physical_expr::{EquivalenceProperties, PhysicalSortExpr}; use futures::{ready, Stream, StreamExt, TryStreamExt}; -use log::debug; use std::any::Any; use std::fmt::Formatter; use std::sync::Arc; use std::task::Poll; -use std::time::Instant; use crate::error::Result; use crate::execution::context::TaskContext; +use crate::execution::memory_pool::{ + MemoryConsumer, SharedMemoryReservation, SharedOptionalMemoryReservation, TryGrow, +}; use crate::physical_plan::coalesce_batches::concat_batches; /// Data of the inner table side @@ -87,6 +91,10 @@ pub struct NestedLoopJoinExec { inner_table: OnceAsync, /// Information of index and left / right placement of columns column_indices: Vec, + /// Operator-level memory reservation for left data + reservation: SharedOptionalMemoryReservation, + /// Execution metrics + metrics: ExecutionPlanMetricsSet, } impl NestedLoopJoinExec { @@ -110,6 +118,8 @@ impl NestedLoopJoinExec { schema: Arc::new(schema), inner_table: Default::default(), column_indices, + reservation: Default::default(), + metrics: Default::default(), }) } } @@ -189,17 +199,41 @@ impl ExecutionPlan for NestedLoopJoinExec { partition: usize, context: Arc, ) -> Result { + let join_metrics = BuildProbeJoinMetrics::new(partition, &self.metrics); + + // Initialization of operator-level reservation + self.reservation + .initialize("NestedLoopJoinExec", context.memory_pool()); + + // Inititalization of stream-level reservation + let reservation = SharedMemoryReservation::from( + MemoryConsumer::new(format!("NestedLoopJoinStream[{partition}]")) + .register(context.memory_pool()), + ); + let (outer_table, inner_table) = if left_is_build_side(self.join_type) { // left must be single partition let inner_table = self.inner_table.once(|| { - load_specified_partition_of_input(0, self.left.clone(), context.clone()) + load_specified_partition_of_input( + 0, + self.left.clone(), + context.clone(), + join_metrics.clone(), + Arc::new(self.reservation.clone()), + ) }); let outer_table = self.right.execute(partition, context)?; (outer_table, inner_table) } else { // right must be single partition let inner_table = self.inner_table.once(|| { - load_specified_partition_of_input(0, self.right.clone(), context.clone()) + load_specified_partition_of_input( + 0, + self.right.clone(), + context.clone(), + join_metrics.clone(), + Arc::new(self.reservation.clone()), + ) }); let outer_table = self.left.execute(partition, context)?; (outer_table, inner_table) @@ -214,6 +248,8 @@ impl ExecutionPlan for NestedLoopJoinExec { is_exhausted: false, visited_left_side: None, column_indices: self.column_indices.clone(), + join_metrics, + reservation, })) } @@ -233,6 +269,10 @@ impl ExecutionPlan for NestedLoopJoinExec { } } + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + fn statistics(&self) -> Statistics { estimate_join_statistics( self.left.clone(), @@ -273,28 +313,34 @@ async fn load_specified_partition_of_input( partition: usize, input: Arc, context: Arc, + join_metrics: BuildProbeJoinMetrics, + reservation: Arc, ) -> Result { - let start = Instant::now(); let stream = input.execute(partition, context)?; // Load all batches and count the rows - let (batches, num_rows) = stream - .try_fold((Vec::new(), 0usize), |mut acc, batch| async { - acc.1 += batch.num_rows(); - acc.0.push(batch); - Ok(acc) - }) + let (batches, num_rows, _, _) = stream + .try_fold( + (Vec::new(), 0usize, join_metrics, reservation), + |mut acc, batch| async { + let batch_size = batch.get_array_memory_size(); + // Reserve memory for incoming batch + acc.3.try_grow(batch_size)?; + // Update metrics + acc.2.build_mem_used.add(batch_size); + acc.2.build_input_batches.add(1); + acc.2.build_input_rows.add(batch.num_rows()); + // Update rowcount + acc.1 += batch.num_rows(); + // Push batch to output + acc.0.push(batch); + Ok(acc) + }, + ) .await?; let merged_batch = concat_batches(&input.schema(), &batches, num_rows)?; - debug!( - "Built input of nested loop join containing {} rows in {} ms for partition {}", - num_rows, - start.elapsed().as_millis(), - partition - ); - Ok(merged_batch) } @@ -326,6 +372,10 @@ struct NestedLoopJoinStream { column_indices: Vec, // TODO: support null aware equal // null_equals_null: bool + /// Join execution metrics + join_metrics: BuildProbeJoinMetrics, + /// Memory reservation for visited_left_side + reservation: SharedMemoryReservation, } fn build_join_indices( @@ -362,10 +412,20 @@ impl NestedLoopJoinStream { cx: &mut std::task::Context<'_>, ) -> Poll>> { // all left row + let build_timer = self.join_metrics.build_time.timer(); let left_data = match ready!(self.inner_table.get(cx)) { Ok(data) => data, Err(e) => return Poll::Ready(Some(Err(e))), }; + build_timer.done(); + + if self.visited_left_side.is_none() && self.join_type == JoinType::Full { + // TODO: Replace `ceil` wrapper with stable `div_cell` after + // https://github.com/rust-lang/rust/issues/88581 + let visited_bitmap_size = bit_util::ceil(left_data.num_rows(), 8); + self.reservation.try_grow(visited_bitmap_size)?; + self.join_metrics.build_mem_used.add(visited_bitmap_size); + } // add a bitmap for full join. let visited_left_side = self.visited_left_side.get_or_insert_with(|| { @@ -384,6 +444,11 @@ impl NestedLoopJoinStream { .poll_next_unpin(cx) .map(|maybe_batch| match maybe_batch { Some(Ok(right_batch)) => { + // Setting up timer & updating input metrics + self.join_metrics.input_batches.add(1); + self.join_metrics.input_rows.add(right_batch.num_rows()); + let timer = self.join_metrics.join_time.timer(); + let result = join_left_and_right_batch( left_data, &right_batch, @@ -393,11 +458,22 @@ impl NestedLoopJoinStream { &self.schema, visited_left_side, ); + + // Recording time & updating output metrics + if let Ok(batch) = &result { + timer.done(); + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(batch.num_rows()); + } + Some(result) } Some(err) => Some(err), None => { if self.join_type == JoinType::Full && !self.is_exhausted { + // Only setting up timer, input is exhausted + let timer = self.join_metrics.join_time.timer(); + // use the global left bitmap to produce the left indices and right indices let (left_side, right_side) = get_final_indices_from_bit_map( visited_left_side, @@ -416,6 +492,14 @@ impl NestedLoopJoinStream { JoinSide::Left, ); self.is_exhausted = true; + + // Recording time & updating output metrics + if let Ok(batch) = &result { + timer.done(); + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(batch.num_rows()); + } + Some(result) } else { // end of the join loop @@ -431,10 +515,12 @@ impl NestedLoopJoinStream { cx: &mut std::task::Context<'_>, ) -> Poll>> { // all right row + let build_timer = self.join_metrics.build_time.timer(); let right_data = match ready!(self.inner_table.get(cx)) { Ok(data) => data, Err(e) => return Poll::Ready(Some(Err(e))), }; + build_timer.done(); // for build right, bitmap is not needed. let mut empty_visited_left_side = BooleanBufferBuilder::new(0); @@ -442,6 +528,12 @@ impl NestedLoopJoinStream { .poll_next_unpin(cx) .map(|maybe_batch| match maybe_batch { Some(Ok(left_batch)) => { + // Setting up timer & updating input metrics + self.join_metrics.input_batches.add(1); + self.join_metrics.input_rows.add(left_batch.num_rows()); + let timer = self.join_metrics.join_time.timer(); + + // Actual join execution let result = join_left_and_right_batch( &left_batch, right_data, @@ -451,6 +543,14 @@ impl NestedLoopJoinStream { &self.schema, &mut empty_visited_left_side, ); + + // Recording time & updating output metrics + if let Ok(batch) = &result { + timer.done(); + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(batch.num_rows()); + } + Some(result) } Some(err) => Some(err), @@ -633,6 +733,11 @@ mod tests { use crate::physical_expr::expressions::BinaryExpr; use crate::{ assert_batches_sorted_eq, + common::assert_contains, + execution::{ + context::SessionConfig, + runtime_env::{RuntimeConfig, RuntimeEnv}, + }, physical_plan::{ common, expressions::Column, memory::MemoryExec, repartition::RepartitionExec, }, @@ -1016,4 +1121,56 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_overallocation() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]), + ("b1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]), + ("c1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]), + ); + let right = build_table( + ("a2", &vec![10, 11]), + ("b2", &vec![12, 13]), + ("c2", &vec![14, 15]), + ); + let filter = prepare_join_filter(); + + let join_types = vec![ + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::Full, + JoinType::LeftSemi, + JoinType::LeftAnti, + JoinType::RightSemi, + JoinType::RightAnti, + ]; + + for join_type in join_types { + let runtime_config = RuntimeConfig::new().with_memory_limit(100, 1.0); + let runtime = Arc::new(RuntimeEnv::new(runtime_config)?); + let session_ctx = + SessionContext::with_config_rt(SessionConfig::default(), runtime); + let task_ctx = session_ctx.task_ctx(); + + let err = multi_partitioned_join_collect( + left.clone(), + right.clone(), + &join_type, + Some(filter.clone()), + task_ctx, + ) + .await + .unwrap_err(); + + assert_contains!( + err.to_string(), + "External error: Resources exhausted: Failed to allocate additional" + ); + assert_contains!(err.to_string(), "NestedLoopJoinExec"); + } + + Ok(()) + } } diff --git a/datafusion/core/src/physical_plan/joins/sort_merge_join.rs b/datafusion/core/src/physical_plan/joins/sort_merge_join.rs index 8fa5145938c4..50179896f648 100644 --- a/datafusion/core/src/physical_plan/joins/sort_merge_join.rs +++ b/datafusion/core/src/physical_plan/joins/sort_merge_join.rs @@ -24,6 +24,7 @@ use std::any::Any; use std::cmp::Ordering; use std::collections::VecDeque; use std::fmt::Formatter; +use std::mem; use std::ops::Range; use std::pin::Pin; use std::sync::Arc; @@ -39,6 +40,7 @@ use futures::{Stream, StreamExt}; use crate::error::DataFusionError; use crate::error::Result; use crate::execution::context::TaskContext; +use crate::execution::memory_pool::{MemoryConsumer, MemoryReservation}; use crate::logical_expr::JoinType; use crate::physical_plan::expressions::Column; use crate::physical_plan::expressions::PhysicalSortExpr; @@ -305,6 +307,10 @@ impl ExecutionPlan for SortMergeJoinExec { // create output buffer let batch_size = context.session_config().batch_size(); + // create memory reservation + let reservation = MemoryConsumer::new(format!("SMJStream[{partition}]")) + .register(context.memory_pool()); + // create join stream Ok(Box::pin(SMJStream::try_new( self.schema.clone(), @@ -317,6 +323,7 @@ impl ExecutionPlan for SortMergeJoinExec { self.join_type, batch_size, SortMergeJoinMetrics::new(partition, &self.metrics), + reservation, )?)) } @@ -362,6 +369,9 @@ struct SortMergeJoinMetrics { output_batches: metrics::Count, /// Number of rows produced by this operator output_rows: metrics::Count, + /// Peak memory used for buffered data. + /// Calculated as sum of peak memory values across partitions + peak_mem_used: metrics::Gauge, } impl SortMergeJoinMetrics { @@ -374,6 +384,7 @@ impl SortMergeJoinMetrics { let output_batches = MetricBuilder::new(metrics).counter("output_batches", partition); let output_rows = MetricBuilder::new(metrics).output_rows(partition); + let peak_mem_used = MetricBuilder::new(metrics).gauge("peak_mem_used", partition); Self { join_time, @@ -381,6 +392,7 @@ impl SortMergeJoinMetrics { input_rows, output_batches, output_rows, + peak_mem_used, } } } @@ -505,15 +517,34 @@ struct BufferedBatch { pub join_arrays: Vec, /// Buffered joined index (null joining buffered) pub null_joined: Vec, + /// Size estimation used for reserving / releasing memory + pub size_estimation: usize, } impl BufferedBatch { fn new(batch: RecordBatch, range: Range, on_column: &[Column]) -> Self { let join_arrays = join_arrays(&batch, on_column); + + // Estimation is calculated as + // inner batch size + // + join keys size + // + worst case null_joined (as vector capacity * element size) + // + Range size + // + size of this estimation + let size_estimation = batch.get_array_memory_size() + + join_arrays + .iter() + .map(|arr| arr.get_array_memory_size()) + .sum::() + + batch.num_rows().next_power_of_two() * mem::size_of::() + + mem::size_of::>() + + mem::size_of::(); + BufferedBatch { batch, range, join_arrays, null_joined: vec![], + size_estimation, } } } @@ -565,6 +596,8 @@ struct SMJStream { pub join_type: JoinType, /// Metrics pub join_metrics: SortMergeJoinMetrics, + /// Memory reservation + pub reservation: MemoryReservation, } impl RecordBatchStream for SMJStream { @@ -682,6 +715,7 @@ impl SMJStream { join_type: JoinType, batch_size: usize, join_metrics: SortMergeJoinMetrics, + reservation: MemoryReservation, ) -> Result { let streamed_schema = streamed.schema(); let buffered_schema = buffered.schema(); @@ -708,6 +742,7 @@ impl SMJStream { batch_size, join_type, join_metrics, + reservation, }) } @@ -763,7 +798,11 @@ impl SMJStream { let head_batch = self.buffered_data.head_batch(); if head_batch.range.end == head_batch.batch.num_rows() { self.freeze_dequeuing_buffered()?; - self.buffered_data.batches.pop_front(); + if let Some(buffered_batch) = + self.buffered_data.batches.pop_front() + { + self.reservation.shrink(buffered_batch.size_estimation); + } } else { break; } @@ -789,11 +828,14 @@ impl SMJStream { self.join_metrics.input_batches.add(1); self.join_metrics.input_rows.add(batch.num_rows()); if batch.num_rows() > 0 { - self.buffered_data.batches.push_back(BufferedBatch::new( - batch, - 0..1, - &self.on_buffered, - )); + let buffered_batch = + BufferedBatch::new(batch, 0..1, &self.on_buffered); + self.reservation.try_grow(buffered_batch.size_estimation)?; + self.join_metrics + .peak_mem_used + .set_max(self.reservation.size()); + + self.buffered_data.batches.push_back(buffered_batch); self.buffered_state = BufferedState::PollingRest; } } @@ -827,15 +869,19 @@ impl SMJStream { } Poll::Ready(Some(batch)) => { self.join_metrics.input_batches.add(1); + self.join_metrics.input_rows.add(batch.num_rows()); if batch.num_rows() > 0 { - self.join_metrics.input_rows.add(batch.num_rows()); - self.buffered_data.batches.push_back( - BufferedBatch::new( - batch, - 0..0, - &self.on_buffered, - ), + let buffered_batch = BufferedBatch::new( + batch, + 0..0, + &self.on_buffered, ); + self.reservation + .try_grow(buffered_batch.size_estimation)?; + self.join_metrics + .peak_mem_used + .set_max(self.reservation.size()); + self.buffered_data.batches.push_back(buffered_batch); } } } @@ -1315,7 +1361,9 @@ mod tests { use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; + use crate::common::assert_contains; use crate::error::Result; + use crate::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use crate::logical_expr::JoinType; use crate::physical_plan::expressions::Column; use crate::physical_plan::joins::utils::JoinOn; @@ -2212,4 +2260,135 @@ mod tests { assert_batches_sorted_eq!(expected, &batches); Ok(()) } + + #[tokio::test] + async fn overallocation_single_batch() -> Result<()> { + let left = build_table( + ("a1", &vec![0, 1, 2, 3, 4, 5]), + ("b1", &vec![1, 2, 3, 4, 5, 6]), + ("c1", &vec![4, 5, 6, 7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![0, 10, 20, 30, 40]), + ("b2", &vec![1, 3, 4, 6, 8]), + ("c2", &vec![50, 60, 70, 80, 90]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b2", &right.schema())?, + )]; + let sort_options = vec![SortOptions::default(); on.len()]; + + let join_types = vec![ + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::Full, + JoinType::LeftSemi, + JoinType::LeftAnti, + ]; + + for join_type in join_types { + let runtime_config = RuntimeConfig::new().with_memory_limit(100, 1.0); + let runtime = Arc::new(RuntimeEnv::new(runtime_config)?); + let session_config = SessionConfig::default().with_batch_size(50); + let session_ctx = SessionContext::with_config_rt(session_config, runtime); + let task_ctx = session_ctx.task_ctx(); + let join = join_with_options( + left.clone(), + right.clone(), + on.clone(), + join_type, + sort_options.clone(), + false, + )?; + + let stream = join.execute(0, task_ctx)?; + let err = common::collect(stream).await.unwrap_err(); + + assert_contains!( + err.to_string(), + "Resources exhausted: Failed to allocate additional" + ); + assert_contains!(err.to_string(), "SMJStream[0]"); + } + + Ok(()) + } + + #[tokio::test] + async fn overallocation_multi_batch() -> Result<()> { + let left_batch_1 = build_table_i32( + ("a1", &vec![0, 1]), + ("b1", &vec![1, 1]), + ("c1", &vec![4, 5]), + ); + let left_batch_2 = build_table_i32( + ("a1", &vec![2, 3]), + ("b1", &vec![1, 1]), + ("c1", &vec![6, 7]), + ); + let left_batch_3 = build_table_i32( + ("a1", &vec![4, 5]), + ("b1", &vec![1, 1]), + ("c1", &vec![8, 9]), + ); + let right_batch_1 = build_table_i32( + ("a2", &vec![0, 10]), + ("b2", &vec![1, 1]), + ("c2", &vec![50, 60]), + ); + let right_batch_2 = build_table_i32( + ("a2", &vec![20, 30]), + ("b2", &vec![1, 1]), + ("c2", &vec![70, 80]), + ); + let right_batch_3 = + build_table_i32(("a2", &vec![40]), ("b2", &vec![1]), ("c2", &vec![90])); + let left = + build_table_from_batches(vec![left_batch_1, left_batch_2, left_batch_3]); + let right = + build_table_from_batches(vec![right_batch_1, right_batch_2, right_batch_3]); + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b2", &right.schema())?, + )]; + let sort_options = vec![SortOptions::default(); on.len()]; + + let join_types = vec![ + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::Full, + JoinType::LeftSemi, + JoinType::LeftAnti, + ]; + + for join_type in join_types { + let runtime_config = RuntimeConfig::new().with_memory_limit(100, 1.0); + let runtime = Arc::new(RuntimeEnv::new(runtime_config)?); + let session_config = SessionConfig::default().with_batch_size(50); + let session_ctx = SessionContext::with_config_rt(session_config, runtime); + let task_ctx = session_ctx.task_ctx(); + let join = join_with_options( + left.clone(), + right.clone(), + on.clone(), + join_type, + sort_options.clone(), + false, + )?; + + let stream = join.execute(0, task_ctx)?; + let err = common::collect(stream).await.unwrap_err(); + + assert_contains!( + err.to_string(), + "Resources exhausted: Failed to allocate additional" + ); + assert_contains!(err.to_string(), "SMJStream[0]"); + } + + Ok(()) + } } diff --git a/datafusion/core/src/physical_plan/joins/utils.rs b/datafusion/core/src/physical_plan/joins/utils.rs index a756d2ba8938..f10ad7a15ae3 100644 --- a/datafusion/core/src/physical_plan/joins/utils.rs +++ b/datafusion/core/src/physical_plan/joins/utils.rs @@ -18,11 +18,11 @@ //! Join related functionality used both on logical and physical plans use arrow::array::{ - new_null_array, Array, BooleanBufferBuilder, PrimitiveArray, UInt32Array, + downcast_array, new_null_array, Array, BooleanBufferBuilder, UInt32Array, UInt32Builder, UInt64Array, }; use arrow::compute; -use arrow::datatypes::{Field, Schema, UInt32Type, UInt64Type}; +use arrow::datatypes::{Field, Schema}; use arrow::record_batch::{RecordBatch, RecordBatchOptions}; use futures::future::{BoxFuture, Shared}; use futures::{ready, FutureExt}; @@ -783,8 +783,8 @@ pub(crate) fn apply_join_filter_to_indices( filter.schema(), build_input_buffer, probe_batch, - PrimitiveArray::from(build_indices.data().clone()), - PrimitiveArray::from(probe_indices.data().clone()), + build_indices.clone(), + probe_indices.clone(), filter.column_indices(), build_side, )?; @@ -794,13 +794,12 @@ pub(crate) fn apply_join_filter_to_indices( .into_array(intermediate_batch.num_rows()); let mask = as_boolean_array(&filter_result)?; - let left_filtered = PrimitiveArray::::from( - compute::filter(&build_indices, mask)?.data().clone(), - ); - let right_filtered = PrimitiveArray::::from( - compute::filter(&probe_indices, mask)?.data().clone(), - ); - Ok((left_filtered, right_filtered)) + let left_filtered = compute::filter(&build_indices, mask)?; + let right_filtered = compute::filter(&probe_indices, mask)?; + Ok(( + downcast_array(left_filtered.as_ref()), + downcast_array(right_filtered.as_ref()), + )) } /// Returns a new [RecordBatch] by combining the `left` and `right` according to `indices`. diff --git a/datafusion/core/src/physical_plan/metrics/value.rs b/datafusion/core/src/physical_plan/metrics/value.rs index 4df4e7567536..59b012f25a27 100644 --- a/datafusion/core/src/physical_plan/metrics/value.rs +++ b/datafusion/core/src/physical_plan/metrics/value.rs @@ -129,6 +129,11 @@ impl Gauge { self.value.fetch_sub(n, Ordering::Relaxed); } + /// Set metric's value to maximum of `n` and current value + pub fn set_max(&self, n: usize) { + self.value.fetch_max(n, Ordering::Relaxed); + } + /// Set the metric's value to `n` and return the previous value pub fn set(&self, n: usize) -> usize { // relaxed ordering for operations on `value` poses no issues diff --git a/datafusion/core/src/physical_plan/mod.rs b/datafusion/core/src/physical_plan/mod.rs index dbd1024ae482..9e0e03a77d04 100644 --- a/datafusion/core/src/physical_plan/mod.rs +++ b/datafusion/core/src/physical_plan/mod.rs @@ -653,10 +653,10 @@ pub mod metrics; pub mod planner; pub mod projection; pub mod repartition; -pub mod rewrite; pub mod sorts; pub mod stream; pub mod streaming; +pub mod tree_node; pub mod udaf; pub mod union; pub mod unnest; diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs index 38830c6856b0..51653450a699 100644 --- a/datafusion/core/src/physical_plan/planner.rs +++ b/datafusion/core/src/physical_plan/planner.rs @@ -346,6 +346,9 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { Expr::Placeholder { .. } => Err(DataFusionError::Internal( "Create physical name does not support placeholder".to_string(), )), + Expr::OuterReferenceColumn(_, _) => Err(DataFusionError::Internal( + "Create physical name does not support OuterReferenceColumn".to_string(), + )), } } @@ -1876,7 +1879,7 @@ mod tests { use arrow::array::{ArrayRef, DictionaryArray, Int32Array}; use arrow::datatypes::{DataType, Field, Int32Type, SchemaRef}; use arrow::record_batch::RecordBatch; - use datafusion_common::assert_contains; + use datafusion_common::{assert_contains, TableReference}; use datafusion_common::{DFField, DFSchema, DFSchemaRef}; use datafusion_expr::{ col, lit, sum, Extension, GroupingSet, LogicalPlanBuilder, @@ -2388,7 +2391,7 @@ Internal error: Optimizer rule 'type_coercion' failed due to unexpected error: E Self { schema: DFSchemaRef::new( DFSchema::new_with_metadata( - vec![DFField::new(None, "a", DataType::Int32, false)], + vec![DFField::new_unqualified("a", DataType::Int32, false)], HashMap::new(), ) .unwrap(), @@ -2518,12 +2521,13 @@ Internal error: Optimizer rule 'type_coercion' failed due to unexpected error: E match ctx.read_csv(path, options).await?.into_optimized_plan()? { LogicalPlan::TableScan(ref scan) => { let mut scan = scan.clone(); - scan.table_name = name.to_string(); + let table_reference = TableReference::from(name).to_owned_reference(); + scan.table_name = table_reference; let new_schema = scan .projected_schema .as_ref() .clone() - .replace_qualifier(name); + .replace_qualifier(name.to_string()); scan.projected_schema = Arc::new(new_schema); LogicalPlan::TableScan(scan) } diff --git a/datafusion/core/src/physical_plan/tree_node/mod.rs b/datafusion/core/src/physical_plan/tree_node/mod.rs new file mode 100644 index 000000000000..327d938d4ec4 --- /dev/null +++ b/datafusion/core/src/physical_plan/tree_node/mod.rs @@ -0,0 +1,214 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This module provides common traits for visiting or rewriting tree nodes easily. + +pub mod rewritable; +pub mod visitable; + +use datafusion_common::Result; + +/// Implements the [visitor +/// pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for recursively walking [`TreeNodeVisitable`]s. +/// +/// [`TreeNodeVisitor`] allows keeping the algorithms +/// separate from the code to traverse the structure of the `TreeNodeVisitable` +/// tree and makes it easier to add new types of tree node and +/// algorithms by. +/// +/// When passed to[`TreeNodeVisitable::accept`], [`TreeNodeVisitor::pre_visit`] +/// and [`TreeNodeVisitor::post_visit`] are invoked recursively +/// on an node tree. +/// +/// If an [`Err`] result is returned, recursion is stopped +/// immediately. +/// +/// If [`Recursion::Stop`] is returned on a call to pre_visit, no +/// children of that tree node are visited, nor is post_visit +/// called on that tree node +pub trait TreeNodeVisitor: Sized { + /// The node type which is visitable. + type N: TreeNodeVisitable; + + /// Invoked before any children of `node` are visited. + fn pre_visit(&mut self, node: &Self::N) -> Result; + + /// Invoked after all children of `node` are visited. Default + /// implementation does nothing. + fn post_visit(&mut self, _node: &Self::N) -> Result<()> { + Ok(()) + } +} + +/// Trait for types that can be visited by [`TreeNodeVisitor`] +pub trait TreeNodeVisitable: Sized { + /// Return the children of this tree node + fn get_children(&self) -> Vec; + + /// Accept a visitor, calling `visit` on all children of this + fn accept>(&self, visitor: &mut V) -> Result<()> { + match visitor.pre_visit(self)? { + VisitRecursion::Continue => {} + // If the recursion should stop, do not visit children + VisitRecursion::Stop => return Ok(()), + }; + + for child in self.get_children() { + child.accept(visitor)?; + } + + visitor.post_visit(self) + } +} + +/// Controls how the visitor recursion should proceed. +pub enum VisitRecursion { + /// Attempt to visit all the children, recursively. + Continue, + /// Do not visit the children of this tree node, though the walk + /// of parents of this tree node will not be affected + Stop, +} + +/// Trait for marking tree node as rewritable +pub trait TreeNodeRewritable: Clone { + /// Convenience utils for writing optimizers rule: recursively apply the given `op` to the node tree. + /// When `op` does not apply to a given node, it is left unchanged. + /// The default tree traversal direction is transform_up(Postorder Traversal). + fn transform(self, op: &F) -> Result + where + F: Fn(Self) -> Result>, + { + self.transform_up(op) + } + + /// Convenience utils for writing optimizers rule: recursively apply the given 'op' to the node and all of its + /// children(Preorder Traversal). + /// When the `op` does not apply to a given node, it is left unchanged. + fn transform_down(self, op: &F) -> Result + where + F: Fn(Self) -> Result>, + { + let node_cloned = self.clone(); + let after_op = match op(node_cloned)? { + Some(value) => value, + None => self, + }; + after_op.map_children(|node| node.transform_down(op)) + } + + /// Convenience utils for writing optimizers rule: recursively apply the given 'op' first to all of its + /// children and then itself(Postorder Traversal). + /// When the `op` does not apply to a given node, it is left unchanged. + fn transform_up(self, op: &F) -> Result + where + F: Fn(Self) -> Result>, + { + let after_op_children = self.map_children(|node| node.transform_up(op))?; + + let after_op_children_clone = after_op_children.clone(); + let new_node = match op(after_op_children)? { + Some(value) => value, + None => after_op_children_clone, + }; + Ok(new_node) + } + + /// Transform the tree node using the given [TreeNodeRewriter] + /// It performs a depth first walk of an node and its children. + /// + /// For an node tree such as + /// ```text + /// ParentNode + /// left: ChildNode1 + /// right: ChildNode2 + /// ``` + /// + /// The nodes are visited using the following order + /// ```text + /// pre_visit(ParentNode) + /// pre_visit(ChildNode1) + /// mutate(ChildNode1) + /// pre_visit(ChildNode2) + /// mutate(ChildNode2) + /// mutate(ParentNode) + /// ``` + /// + /// If an Err result is returned, recursion is stopped immediately + /// + /// If [`false`] is returned on a call to pre_visit, no + /// children of that node are visited, nor is mutate + /// called on that node + /// + fn transform_using>( + self, + rewriter: &mut R, + ) -> Result { + let need_mutate = match rewriter.pre_visit(&self)? { + RewriteRecursion::Mutate => return rewriter.mutate(self), + RewriteRecursion::Stop => return Ok(self), + RewriteRecursion::Continue => true, + RewriteRecursion::Skip => false, + }; + + let after_op_children = + self.map_children(|node| node.transform_using(rewriter))?; + + // now rewrite this node itself + if need_mutate { + rewriter.mutate(after_op_children) + } else { + Ok(after_op_children) + } + } + + /// Apply transform `F` to the node's children, the transform `F` might have a direction(Preorder or Postorder) + fn map_children(self, transform: F) -> Result + where + F: FnMut(Self) -> Result; +} + +/// Trait for potentially recursively transform an [`TreeNodeRewritable`] node +/// tree. When passed to `TreeNodeRewritable::transform_using`, `TreeNodeRewriter::mutate` is +/// invoked recursively on all nodes of a tree. +pub trait TreeNodeRewriter: Sized { + /// The node type which is rewritable. + type N: TreeNodeRewritable; + + /// Invoked before (Preorder) any children of `node` are rewritten / + /// visited. Default implementation returns `Ok(RewriteRecursion::Continue)` + fn pre_visit(&mut self, _node: &Self::N) -> Result { + Ok(RewriteRecursion::Continue) + } + + /// Invoked after (Postorder) all children of `node` have been mutated and + /// returns a potentially modified node. + fn mutate(&mut self, node: Self::N) -> Result; +} + +/// Controls how the [TreeNodeRewriter] recursion should proceed. +#[allow(dead_code)] +pub enum RewriteRecursion { + /// Continue rewrite / visit this node tree. + Continue, + /// Call 'op' immediately and return. + Mutate, + /// Do not rewrite / visit the children of this node. + Stop, + /// Keep recursive but skip apply op on this node + Skip, +} diff --git a/datafusion/core/src/physical_plan/tree_node/rewritable.rs b/datafusion/core/src/physical_plan/tree_node/rewritable.rs new file mode 100644 index 000000000000..004fc47fd7ac --- /dev/null +++ b/datafusion/core/src/physical_plan/tree_node/rewritable.rs @@ -0,0 +1,39 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tree node rewritable implementations + +use crate::physical_plan::tree_node::TreeNodeRewritable; +use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; +use datafusion_common::Result; +use std::sync::Arc; + +impl TreeNodeRewritable for Arc { + fn map_children(self, transform: F) -> Result + where + F: FnMut(Self) -> Result, + { + let children = self.children(); + if !children.is_empty() { + let new_children: Result> = + children.into_iter().map(transform).collect(); + with_new_children_if_necessary(self, new_children?) + } else { + Ok(self) + } + } +} diff --git a/datafusion/core/src/physical_plan/tree_node/visitable.rs b/datafusion/core/src/physical_plan/tree_node/visitable.rs new file mode 100644 index 000000000000..935c8adb7ea7 --- /dev/null +++ b/datafusion/core/src/physical_plan/tree_node/visitable.rs @@ -0,0 +1,28 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tree node visitable implementations + +use crate::physical_plan::tree_node::TreeNodeVisitable; +use crate::physical_plan::ExecutionPlan; +use std::sync::Arc; + +impl TreeNodeVisitable for Arc { + fn get_children(&self) -> Vec { + self.children() + } +} diff --git a/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs b/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs index 0ac836cf28ed..fa276c423879 100644 --- a/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs +++ b/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs @@ -353,12 +353,12 @@ impl PartitionByHandler for SortedPartitionByBoundedWindowStream { // For instance, if `n_out` number of rows are calculated, we can remove // first `n_out` rows from `self.input_buffer_record_batch`. fn prune_state(&mut self, n_out: usize) -> Result<()> { + // Prune `self.window_agg_states`: + self.prune_out_columns(n_out)?; // Prune `self.partition_batches`: self.prune_partition_batches()?; // Prune `self.input_buffer_record_batch`: self.prune_input_batch(n_out)?; - // Prune `self.window_agg_states`: - self.prune_out_columns(n_out)?; Ok(()) } @@ -548,9 +548,9 @@ impl SortedPartitionByBoundedWindowStream { for (partition_row, WindowState { state: value, .. }) in window_agg_state { let n_prune = min(value.window_frame_range.start, value.last_calculated_index); - if let Some(state) = n_prune_each_partition.get_mut(partition_row) { - if n_prune < *state { - *state = n_prune; + if let Some(current) = n_prune_each_partition.get_mut(partition_row) { + if n_prune < *current { + *current = n_prune; } } else { n_prune_each_partition.insert(partition_row.clone(), n_prune); @@ -571,15 +571,7 @@ impl SortedPartitionByBoundedWindowStream { // Update state indices since we have pruned some rows from the beginning: for window_agg_state in self.window_agg_states.iter_mut() { - let window_state = - window_agg_state.get_mut(partition_row).ok_or_else(err)?; - let mut state = &mut window_state.state; - state.window_frame_range = Range { - start: state.window_frame_range.start - n_prune, - end: state.window_frame_range.end - n_prune, - }; - state.last_calculated_index -= n_prune; - state.offset_pruned_rows += n_prune; + window_agg_state[partition_row].state.prune_state(*n_prune); } } Ok(()) diff --git a/datafusion/core/src/test_util/mod.rs b/datafusion/core/src/test_util/mod.rs index 61cdc188f823..8c570cff8fd0 100644 --- a/datafusion/core/src/test_util/mod.rs +++ b/datafusion/core/src/test_util/mod.rs @@ -39,7 +39,7 @@ use crate::prelude::{CsvReadOptions, SessionContext}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use async_trait::async_trait; -use datafusion_common::{DataFusionError, Statistics}; +use datafusion_common::{DataFusionError, Statistics, TableReference}; use datafusion_expr::{CreateExternalTable, Expr, TableType}; use datafusion_physical_expr::PhysicalSortExpr; use futures::Stream; @@ -219,11 +219,8 @@ pub fn scan_empty( ) -> Result { let table_schema = Arc::new(table_schema.clone()); let provider = Arc::new(EmptyTable::new(table_schema)); - LogicalPlanBuilder::scan( - name.unwrap_or(UNNAMED_TABLE), - provider_as_source(provider), - projection, - ) + let name = TableReference::bare(name.unwrap_or(UNNAMED_TABLE).to_string()); + LogicalPlanBuilder::scan(name, provider_as_source(provider), projection) } /// Scan an empty data source with configured partition, mainly used in tests. @@ -235,11 +232,8 @@ pub fn scan_empty_with_partitions( ) -> Result { let table_schema = Arc::new(table_schema.clone()); let provider = Arc::new(EmptyTable::new(table_schema).with_partitions(partitions)); - LogicalPlanBuilder::scan( - name.unwrap_or(UNNAMED_TABLE), - provider_as_source(provider), - projection, - ) + let name = TableReference::bare(name.unwrap_or(UNNAMED_TABLE).to_string()); + LogicalPlanBuilder::scan(name, provider_as_source(provider), projection) } /// Get the schema for the aggregate_test_* csv files diff --git a/datafusion/core/tests/dataframe.rs b/datafusion/core/tests/dataframe.rs index b619262e07eb..9da5bf2b86ad 100644 --- a/datafusion/core/tests/dataframe.rs +++ b/datafusion/core/tests/dataframe.rs @@ -16,6 +16,7 @@ // under the License. use arrow::datatypes::{DataType, Field, Schema}; +use arrow::util::pretty::pretty_format_batches; use arrow::{ array::{ ArrayRef, Int32Array, Int32Builder, ListBuilder, StringArray, StringBuilder, @@ -35,18 +36,77 @@ use datafusion::{assert_batches_eq, assert_batches_sorted_eq}; use datafusion_expr::expr::{GroupingSet, Sort}; use datafusion_expr::{avg, col, count, lit, max, sum, Expr, ExprSchemable}; +#[tokio::test] +async fn count_wildcard() -> Result<()> { + let ctx = SessionContext::new(); + let testdata = datafusion::test_util::parquet_test_data(); + + ctx.register_parquet( + "alltypes_tiny_pages", + &format!("{testdata}/alltypes_tiny_pages.parquet"), + ParquetReadOptions::default(), + ) + .await?; + + let sql_results = ctx + .sql("select count(*) from alltypes_tiny_pages") + .await? + .explain(false, false)? + .collect() + .await?; + + let df_results = ctx + .table("alltypes_tiny_pages") + .await? + .aggregate(vec![], vec![count(Expr::Wildcard)])? + .explain(false, false) + .unwrap() + .collect() + .await?; + + //make sure sql plan same with df plan + assert_eq!( + pretty_format_batches(&sql_results)?.to_string(), + pretty_format_batches(&df_results)?.to_string() + ); + + let results = ctx + .table("alltypes_tiny_pages") + .await? + .aggregate(vec![], vec![count(Expr::Wildcard)])? + .collect() + .await?; + + let expected = vec![ + "+-----------------+", + "| COUNT(UInt8(1)) |", + "+-----------------+", + "| 7300 |", + "+-----------------+", + ]; + assert_batches_sorted_eq!(expected, &results); + + Ok(()) +} #[tokio::test] async fn describe() -> Result<()> { let ctx = SessionContext::new(); let testdata = datafusion::test_util::parquet_test_data(); + ctx.register_parquet( + "alltypes_tiny_pages", + &format!("{testdata}/alltypes_tiny_pages.parquet"), + ParquetReadOptions::default(), + ) + .await?; - let df = ctx - .read_parquet( - &format!("{testdata}/alltypes_tiny_pages.parquet"), - ParquetReadOptions::default(), - ) + let describe_record_batch = ctx + .table("alltypes_tiny_pages") + .await? + .describe() + .await? + .collect() .await?; - let describe_record_batch = df.describe().await.unwrap().collect().await.unwrap(); + #[rustfmt::skip] let expected = vec![ "+------------+-------------------+----------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+-----------------+------------+-------------------------+--------------------+-------------------+", @@ -63,6 +123,30 @@ async fn describe() -> Result<()> { ]; assert_batches_eq!(expected, &describe_record_batch); + //add test case for only boolean boolean/binary column + let result = ctx + .sql("select 'a' as a,true as b") + .await? + .describe() + .await? + .collect() + .await?; + #[rustfmt::skip] + let expected = vec![ + "+------------+------+------+", + "| describe | a | b |", + "+------------+------+------+", + "| count | 1 | 1 |", + "| null_count | 1 | 1 |", + "| mean | null | null |", + "| std | null | null |", + "| min | a | null |", + "| max | a | null |", + "| median | null | null |", + "+------------+------+------+", + ]; + assert_batches_eq!(expected, &result); + Ok(()) } @@ -245,7 +329,7 @@ async fn sort_on_ambiguous_column() -> Result<()> { .sort(vec![col("b").sort(true, true)]) .unwrap_err(); - let expected = "Schema error: Ambiguous reference to unqualified field 'b'"; + let expected = "Schema error: Ambiguous reference to unqualified field \"b\""; assert_eq!(err.to_string(), expected); Ok(()) } @@ -264,7 +348,7 @@ async fn group_by_ambiguous_column() -> Result<()> { .aggregate(vec![col("b")], vec![max(col("a"))]) .unwrap_err(); - let expected = "Schema error: Ambiguous reference to unqualified field 'b'"; + let expected = "Schema error: Ambiguous reference to unqualified field \"b\""; assert_eq!(err.to_string(), expected); Ok(()) } @@ -283,7 +367,7 @@ async fn filter_on_ambiguous_column() -> Result<()> { .filter(col("b").eq(lit(1))) .unwrap_err(); - let expected = "Schema error: Ambiguous reference to unqualified field 'b'"; + let expected = "Schema error: Ambiguous reference to unqualified field \"b\""; assert_eq!(err.to_string(), expected); Ok(()) } @@ -302,7 +386,7 @@ async fn select_ambiguous_column() -> Result<()> { .select(vec![col("b")]) .unwrap_err(); - let expected = "Schema error: Ambiguous reference to unqualified field 'b'"; + let expected = "Schema error: Ambiguous reference to unqualified field \"b\""; assert_eq!(err.to_string(), expected); Ok(()) } diff --git a/datafusion/core/tests/memory_limit.rs b/datafusion/core/tests/memory_limit.rs index 392e5494151a..034505f52c96 100644 --- a/datafusion/core/tests/memory_limit.rs +++ b/datafusion/core/tests/memory_limit.rs @@ -20,6 +20,7 @@ use std::sync::Arc; use datafusion::datasource::MemTable; +use datafusion::execution::context::SessionState; use datafusion::execution::disk_manager::DiskManagerConfig; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion_common::assert_contains; @@ -37,7 +38,9 @@ fn init() { async fn oom_sort() { run_limit_test( "select * from t order by host DESC", - "Resources exhausted: Memory Exhausted while Sorting (DiskManager is disabled)", + vec![ + "Resources exhausted: Memory Exhausted while Sorting (DiskManager is disabled)", + ], 200_000, ) .await @@ -47,7 +50,10 @@ async fn oom_sort() { async fn group_by_none() { run_limit_test( "select median(image) from t", - "Resources exhausted: Failed to allocate additional", + vec![ + "Resources exhausted: Failed to allocate additional", + "AggregateStream", + ], 20_000, ) .await @@ -57,7 +63,10 @@ async fn group_by_none() { async fn group_by_row_hash() { run_limit_test( "select count(*) from t GROUP BY response_bytes", - "Resources exhausted: Failed to allocate additional", + vec![ + "Resources exhausted: Failed to allocate additional", + "GroupedHashAggregateStream", + ], 2_000, ) .await @@ -68,17 +77,53 @@ async fn group_by_hash() { run_limit_test( // group by dict column "select count(*) from t GROUP BY service, host, pod, container", - "Resources exhausted: Failed to allocate additional", + vec![ + "Resources exhausted: Failed to allocate additional", + "GroupedHashAggregateStream", + ], 1_000, ) .await } #[tokio::test] -async fn join_by_key() { - run_limit_test( +async fn join_by_key_multiple_partitions() { + let config = SessionConfig::new().with_target_partitions(2); + run_limit_test_with_config( "select t1.* from t t1 JOIN t t2 ON t1.service = t2.service", - "Resources exhausted: Failed to allocate additional", + vec![ + "Resources exhausted: Failed to allocate additional", + "HashJoinStream", + ], + 1_000, + config, + ) + .await +} + +#[tokio::test] +async fn join_by_key_single_partition() { + let config = SessionConfig::new().with_target_partitions(1); + run_limit_test_with_config( + "select t1.* from t t1 JOIN t t2 ON t1.service = t2.service", + vec![ + "Resources exhausted: Failed to allocate additional", + "HashJoinExec", + ], + 1_000, + config, + ) + .await +} + +#[tokio::test] +async fn join_by_expression() { + run_limit_test( + "select t1.* from t t1 JOIN t t2 ON t1.service != t2.service", + vec![ + "Resources exhausted: Failed to allocate additional", + "NestedLoopJoinExec", + ], 1_000, ) .await @@ -88,8 +133,30 @@ async fn join_by_key() { async fn cross_join() { run_limit_test( "select t1.* from t t1 CROSS JOIN t t2", - "Resources exhausted: Failed to allocate additional", + vec![ + "Resources exhausted: Failed to allocate additional", + "CrossJoinExec", + ], + 1_000, + ) + .await +} + +#[tokio::test] +async fn merge_join() { + // Planner chooses MergeJoin only if number of partitions > 1 + let config = SessionConfig::new() + .with_target_partitions(2) + .set_bool("datafusion.optimizer.prefer_hash_join", false); + + run_limit_test_with_config( + "select t1.* from t t1 JOIN t t2 ON t1.pod = t2.pod AND t1.time = t2.time", + vec![ + "Resources exhausted: Failed to allocate additional", + "SMJStream", + ], 1_000, + config, ) .await } @@ -98,8 +165,26 @@ async fn cross_join() { const MEMORY_FRACTION: f64 = 0.95; /// runs the specified query against 1000 rows with a 50 -/// byte memory limit and no disk manager enabled. -async fn run_limit_test(query: &str, expected_error: &str, memory_limit: usize) { +/// byte memory limit and no disk manager enabled +/// with default SessionConfig. +async fn run_limit_test( + query: &str, + expected_error_contains: Vec<&str>, + memory_limit: usize, +) { + let config = SessionConfig::new(); + run_limit_test_with_config(query, expected_error_contains, memory_limit, config).await +} + +/// runs the specified query against 1000 rows with a 50 +/// byte memory limit and no disk manager enabled +/// with specified SessionConfig instance +async fn run_limit_test_with_config( + query: &str, + expected_error_contains: Vec<&str>, + memory_limit: usize, + config: SessionConfig, +) { let batches: Vec<_> = AccessLogGenerator::new() .with_row_limit(1000) .with_max_batch_size(50) @@ -115,11 +200,12 @@ async fn run_limit_test(query: &str, expected_error: &str, memory_limit: usize) let runtime = RuntimeEnv::new(rt_config).unwrap(); - let ctx = SessionContext::with_config_rt( - // do NOT re-partition (since RepartitionExec has also has a memory budget which we'll likely hit first) - SessionConfig::new().with_target_partitions(1), - Arc::new(runtime), - ); + // Disabling physical optimizer rules to avoid sorts / repartitions + // (since RepartitionExec / SortExec also has a memory budget which we'll likely hit first) + let state = SessionState::with_config_rt(config, Arc::new(runtime)) + .with_physical_optimizer_rules(vec![]); + + let ctx = SessionContext::with_state(state); ctx.register_table("t", Arc::new(table)) .expect("registering table"); @@ -130,7 +216,9 @@ async fn run_limit_test(query: &str, expected_error: &str, memory_limit: usize) panic!("Unexpected success when running, expected memory limit failure") } Err(e) => { - assert_contains!(e.to_string(), expected_error); + for error_substring in expected_error_contains { + assert_contains!(e.to_string(), error_substring); + } } } } diff --git a/datafusion/core/tests/parquet/mod.rs b/datafusion/core/tests/parquet/mod.rs index 123095186ed2..9ce5b182af10 100644 --- a/datafusion/core/tests/parquet/mod.rs +++ b/datafusion/core/tests/parquet/mod.rs @@ -139,7 +139,7 @@ impl ContextWithParquet { let file = match unit { Unit::RowGroup => make_test_file_rg(scenario).await, Unit::Page => { - let config = config.config_options_mut(); + let config = config.options_mut(); config.execution.parquet.enable_page_index = true; make_test_file_page(scenario).await } diff --git a/datafusion/core/tests/path_partition.rs b/datafusion/core/tests/path_partition.rs index 670b508c3347..6c7f1431afb3 100644 --- a/datafusion/core/tests/path_partition.rs +++ b/datafusion/core/tests/path_partition.rs @@ -59,7 +59,13 @@ async fn parquet_distinct_partition_col() -> Result<()> { ], &[ ("year", DataType::Int32), - ("month", DataType::Utf8), + ( + "month", + DataType::Dictionary( + Box::new(DataType::UInt16), + Box::new(DataType::Utf8), + ), + ), ("day", DataType::Utf8), ], "mirror:///", diff --git a/datafusion/core/tests/sql/aggregates.rs b/datafusion/core/tests/sql/aggregates.rs index 3eaa2fd5d141..48dd53830ae3 100644 --- a/datafusion/core/tests/sql/aggregates.rs +++ b/datafusion/core/tests/sql/aggregates.rs @@ -70,194 +70,6 @@ async fn csv_query_array_agg_distinct() -> Result<()> { Ok(()) } -#[tokio::test] -async fn aggregate_timestamps_sum() -> Result<()> { - let ctx = SessionContext::new(); - ctx.register_table("t", table_with_timestamps()).unwrap(); - - let results = plan_and_collect( - &ctx, - "SELECT sum(nanos), sum(micros), sum(millis), sum(secs) FROM t", - ) - .await - .unwrap_err(); - - assert_eq!(results.to_string(), "Error during planning: The function Sum does not support inputs of type Timestamp(Nanosecond, None)."); - - Ok(()) -} - -#[tokio::test] -async fn aggregate_timestamps_count() -> Result<()> { - let ctx = SessionContext::new(); - ctx.register_table("t", table_with_timestamps()).unwrap(); - - let results = execute_to_batches( - &ctx, - "SELECT count(nanos), count(micros), count(millis), count(secs) FROM t", - ) - .await; - - let expected = vec![ - "+----------------+-----------------+-----------------+---------------+", - "| COUNT(t.nanos) | COUNT(t.micros) | COUNT(t.millis) | COUNT(t.secs) |", - "+----------------+-----------------+-----------------+---------------+", - "| 3 | 3 | 3 | 3 |", - "+----------------+-----------------+-----------------+---------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn aggregate_timestamps_min() -> Result<()> { - let ctx = SessionContext::new(); - ctx.register_table("t", table_with_timestamps()).unwrap(); - - let results = execute_to_batches( - &ctx, - "SELECT min(nanos), min(micros), min(millis), min(secs) FROM t", - ) - .await; - - let expected = vec![ - "+----------------------------+----------------------------+-------------------------+---------------------+", - "| MIN(t.nanos) | MIN(t.micros) | MIN(t.millis) | MIN(t.secs) |", - "+----------------------------+----------------------------+-------------------------+---------------------+", - "| 2011-12-13T11:13:10.123450 | 2011-12-13T11:13:10.123450 | 2011-12-13T11:13:10.123 | 2011-12-13T11:13:10 |", - "+----------------------------+----------------------------+-------------------------+---------------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn aggregate_timestamps_max() -> Result<()> { - let ctx = SessionContext::new(); - ctx.register_table("t", table_with_timestamps()).unwrap(); - - let results = execute_to_batches( - &ctx, - "SELECT max(nanos), max(micros), max(millis), max(secs) FROM t", - ) - .await; - - let expected = vec![ - "+-------------------------+-------------------------+-------------------------+---------------------+", - "| MAX(t.nanos) | MAX(t.micros) | MAX(t.millis) | MAX(t.secs) |", - "+-------------------------+-------------------------+-------------------------+---------------------+", - "| 2021-01-01T05:11:10.432 | 2021-01-01T05:11:10.432 | 2021-01-01T05:11:10.432 | 2021-01-01T05:11:10 |", - "+-------------------------+-------------------------+-------------------------+---------------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn aggregate_times_sum() -> Result<()> { - let ctx = SessionContext::new(); - ctx.register_table("t", table_with_times()).unwrap(); - - let results = plan_and_collect( - &ctx, - "SELECT sum(nanos), sum(micros), sum(millis), sum(secs) FROM t", - ) - .await - .unwrap_err(); - - assert_eq!(results.to_string(), "Error during planning: The function Sum does not support inputs of type Time64(Nanosecond)."); - - Ok(()) -} - -#[tokio::test] -async fn aggregate_times_count() -> Result<()> { - let ctx = SessionContext::new(); - ctx.register_table("t", table_with_times()).unwrap(); - - let results = execute_to_batches( - &ctx, - "SELECT count(nanos), count(micros), count(millis), count(secs) FROM t", - ) - .await; - - let expected = vec![ - "+----------------+-----------------+-----------------+---------------+", - "| COUNT(t.nanos) | COUNT(t.micros) | COUNT(t.millis) | COUNT(t.secs) |", - "+----------------+-----------------+-----------------+---------------+", - "| 4 | 4 | 4 | 4 |", - "+----------------+-----------------+-----------------+---------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn aggregate_times_min() -> Result<()> { - let ctx = SessionContext::new(); - ctx.register_table("t", table_with_times()).unwrap(); - - let results = execute_to_batches( - &ctx, - "SELECT min(nanos), min(micros), min(millis), min(secs) FROM t", - ) - .await; - - let expected = vec![ - "+--------------------+-----------------+---------------+-------------+", - "| MIN(t.nanos) | MIN(t.micros) | MIN(t.millis) | MIN(t.secs) |", - "+--------------------+-----------------+---------------+-------------+", - "| 18:06:30.243620451 | 18:06:30.243620 | 18:06:30.243 | 18:06:30 |", - "+--------------------+-----------------+---------------+-------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn aggregate_times_max() -> Result<()> { - let ctx = SessionContext::new(); - ctx.register_table("t", table_with_times()).unwrap(); - - let results = execute_to_batches( - &ctx, - "SELECT max(nanos), max(micros), max(millis), max(secs) FROM t", - ) - .await; - - let expected = vec![ - "+--------------------+-----------------+---------------+-------------+", - "| MAX(t.nanos) | MAX(t.micros) | MAX(t.millis) | MAX(t.secs) |", - "+--------------------+-----------------+---------------+-------------+", - "| 21:06:28.247821084 | 21:06:28.247821 | 21:06:28.247 | 21:06:28 |", - "+--------------------+-----------------+---------------+-------------+", - ]; - assert_batches_sorted_eq!(expected, &results); - - Ok(()) -} - -#[tokio::test] -async fn aggregate_timestamps_avg() -> Result<()> { - let ctx = SessionContext::new(); - ctx.register_table("t", table_with_timestamps()).unwrap(); - - let results = plan_and_collect( - &ctx, - "SELECT avg(nanos), avg(micros), avg(millis), avg(secs) FROM t", - ) - .await - .unwrap_err(); - - assert_eq!(results.to_string(), "Error during planning: The function Avg does not support inputs of type Timestamp(Nanosecond, None)."); - Ok(()) -} - #[tokio::test] async fn aggregate_decimal_sum() -> Result<()> { let ctx = SessionContext::new(); diff --git a/datafusion/core/tests/sql/expr.rs b/datafusion/core/tests/sql/expr.rs index 24017f9cd274..3d5f134a8783 100644 --- a/datafusion/core/tests/sql/expr.rs +++ b/datafusion/core/tests/sql/expr.rs @@ -1313,6 +1313,24 @@ async fn test_extract_date_part() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_extract_epoch() -> Result<()> { + test_expression!( + "extract(epoch from '1870-01-01T07:29:10.256'::timestamp)", + "-3155646649.744" + ); + test_expression!( + "extract(epoch from '2000-01-01T00:00:00.000'::timestamp)", + "946684800.0" + ); + test_expression!( + "extract(epoch from to_timestamp('2000-01-01T00:00:00+00:00'))", + "946684800.0" + ); + test_expression!("extract(epoch from NULL::timestamp)", "NULL"); + Ok(()) +} + #[tokio::test] async fn test_extract_date_part_func() -> Result<()> { test_expression!( diff --git a/datafusion/core/tests/sql/idenfifers.rs b/datafusion/core/tests/sql/idenfifers.rs index a305f23b4944..1b57f60bd435 100644 --- a/datafusion/core/tests/sql/idenfifers.rs +++ b/datafusion/core/tests/sql/idenfifers.rs @@ -211,28 +211,28 @@ async fn case_insensitive_in_sql_errors() { .await .unwrap_err() .to_string(); - assert_contains!(actual, "No field named 'column1'"); + assert_contains!(actual, r#"No field named "column1""#); let actual = ctx .sql("SELECT Column1 from test") .await .unwrap_err() .to_string(); - assert_contains!(actual, "No field named 'column1'"); + assert_contains!(actual, r#"No field named "column1""#); let actual = ctx .sql("SELECT column1 from test") .await .unwrap_err() .to_string(); - assert_contains!(actual, "No field named 'column1'"); + assert_contains!(actual, r#"No field named "column1""#); let actual = ctx .sql(r#"SELECT "column1" from test"#) .await .unwrap_err() .to_string(); - assert_contains!(actual, "No field named 'column1'"); + assert_contains!(actual, r#"No field named "column1""#); // This should pass (note the quotes) ctx.sql(r#"SELECT "Column1" from test"#).await.unwrap(); diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index b9b4c5cf31d5..dc6f741c5b47 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -20,240 +20,21 @@ use datafusion::from_slice::FromSlice; #[tokio::test] async fn equijoin() -> Result<()> { - let test_repartition_joins = vec![true, false]; - for repartition_joins in test_repartition_joins { - let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; - let equivalent_sql = [ - "SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t1_id = t2_id ORDER BY t1_id", - "SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t2_id = t1_id ORDER BY t1_id", - ]; - let expected = vec![ - "+-------+---------+---------+", - "| t1_id | t1_name | t2_name |", - "+-------+---------+---------+", - "| 11 | a | z |", - "| 22 | b | y |", - "| 44 | d | x |", - "+-------+---------+---------+", - ]; - for sql in equivalent_sql.iter() { - let actual = execute_to_batches(&ctx, sql).await; - assert_batches_eq!(expected, &actual); - } - - let ctx = create_join_context_qualified("t1", "t2")?; - let equivalent_sql = [ - "SELECT t1.a, t2.b FROM t1 INNER JOIN t2 ON t1.a = t2.a ORDER BY t1.a", - "SELECT t1.a, t2.b FROM t1 INNER JOIN t2 ON t2.a = t1.a ORDER BY t1.a", - ]; - let expected = vec![ - "+---+-----+", - "| a | b |", - "+---+-----+", - "| 1 | 100 |", - "| 2 | 200 |", - "| 4 | 400 |", - "+---+-----+", - ]; - for sql in equivalent_sql.iter() { - let actual = execute_to_batches(&ctx, sql).await; - assert_batches_eq!(expected, &actual); - } - } - Ok(()) -} - -#[tokio::test] -async fn equijoin_multiple_condition_ordering() -> Result<()> { - let test_repartition_joins = vec![true, false]; - for repartition_joins in test_repartition_joins { - let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; - let equivalent_sql = [ - "SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t1_id = t2_id AND t1_name <> t2_name ORDER BY t1_id", - "SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t1_id = t2_id AND t2_name <> t1_name ORDER BY t1_id", - "SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t2_id = t1_id AND t1_name <> t2_name ORDER BY t1_id", - "SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t2_id = t1_id AND t2_name <> t1_name ORDER BY t1_id", - ]; - let expected = vec![ - "+-------+---------+---------+", - "| t1_id | t1_name | t2_name |", - "+-------+---------+---------+", - "| 11 | a | z |", - "| 22 | b | y |", - "| 44 | d | x |", - "+-------+---------+---------+", - ]; - for sql in equivalent_sql.iter() { - let actual = execute_to_batches(&ctx, sql).await; - assert_batches_eq!(expected, &actual); - } - } - Ok(()) -} - -#[tokio::test] -async fn equijoin_and_other_condition() -> Result<()> { - let test_repartition_joins = vec![true, false]; - for repartition_joins in test_repartition_joins { - let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; - let sql = - "SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t1_id = t2_id AND t2_name >= 'y' ORDER BY t1_id"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ - "+-------+---------+---------+", - "| t1_id | t1_name | t2_name |", - "+-------+---------+---------+", - "| 11 | a | z |", - "| 22 | b | y |", - "+-------+---------+---------+", - ]; - assert_batches_eq!(expected, &actual); - } - Ok(()) -} - -#[tokio::test] -async fn equijoin_left_and_condition_from_right() -> Result<()> { - let test_repartition_joins = vec![true, false]; - for repartition_joins in test_repartition_joins { - let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; - let sql = - "SELECT t1_id, t1_name, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id AND t2_name >= 'y' ORDER BY t1_id"; - let dataframe = ctx.sql(sql).await.unwrap(); - let actual = dataframe.collect().await.unwrap(); - let expected = vec![ - "+-------+---------+---------+", - "| t1_id | t1_name | t2_name |", - "+-------+---------+---------+", - "| 11 | a | z |", - "| 22 | b | y |", - "| 33 | c | |", - "| 44 | d | |", - "+-------+---------+---------+", - ]; - assert_batches_eq!(expected, &actual); - } - - Ok(()) -} - -#[tokio::test] -async fn equijoin_left_and_not_null_condition_from_right() -> Result<()> { - let test_repartition_joins = vec![true, false]; - for repartition_joins in test_repartition_joins { - let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; - let sql = - "SELECT t1_id, t1_name, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id AND t2_name is not null ORDER BY t1_id"; - let dataframe = ctx.sql(sql).await.unwrap(); - let actual = dataframe.collect().await.unwrap(); - let expected = vec![ - "+-------+---------+---------+", - "| t1_id | t1_name | t2_name |", - "+-------+---------+---------+", - "| 11 | a | z |", - "| 22 | b | y |", - "| 33 | c | |", - "| 44 | d | x |", - "+-------+---------+---------+", - ]; - assert_batches_eq!(expected, &actual); - } - - Ok(()) -} - -#[tokio::test] -async fn full_join_sub_query() -> Result<()> { - let test_repartition_joins = vec![true, false]; - for repartition_joins in test_repartition_joins { - let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; - let sql = " - SELECT t1_id, t1_name, t2_name FROM (SELECT * from (t1) AS t1) FULL JOIN (SELECT * from (t2) AS t2) ON t1_id = t2_id AND t2_name >= 'y' - ORDER BY t1_id, t2_name"; - let dataframe = ctx.sql(sql).await.unwrap(); - let actual = dataframe.collect().await.unwrap(); - let expected = vec![ - "+-------+---------+---------+", - "| t1_id | t1_name | t2_name |", - "+-------+---------+---------+", - "| 11 | a | z |", - "| 22 | b | y |", - "| 33 | c | |", - "| 44 | d | |", - "| | | w |", - "| | | x |", - "+-------+---------+---------+", - ]; - assert_batches_eq!(expected, &actual); - } - - Ok(()) -} - -#[tokio::test] -async fn equijoin_right_and_condition_from_left() -> Result<()> { - let test_repartition_joins = vec![true, false]; - for repartition_joins in test_repartition_joins { - let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; - let sql = - "SELECT t1_id, t1_name, t2_name FROM t1 RIGHT JOIN t2 ON t1_id = t2_id AND t1_id >= 22 ORDER BY t2_name"; - let dataframe = ctx.sql(sql).await.unwrap(); - let actual = dataframe.collect().await.unwrap(); - let expected = vec![ - "+-------+---------+---------+", - "| t1_id | t1_name | t2_name |", - "+-------+---------+---------+", - "| | | w |", - "| 44 | d | x |", - "| 22 | b | y |", - "| | | z |", - "+-------+---------+---------+", - ]; - assert_batches_eq!(expected, &actual); - } - Ok(()) -} - -#[tokio::test] -async fn equijoin_left_and_condition_from_left() -> Result<()> { - let test_repartition_joins = vec![true, false]; - for repartition_joins in test_repartition_joins { - let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; - let sql = - "SELECT t1_id, t1_name, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id AND t1_id >= 44 ORDER BY t1_id"; - let expected = vec![ - "+-------+---------+---------+", - "| t1_id | t1_name | t2_name |", - "+-------+---------+---------+", - "| 11 | a | |", - "| 22 | b | |", - "| 33 | c | |", - "| 44 | d | x |", - "+-------+---------+---------+", - ]; - let actual = execute_to_batches(&ctx, sql).await; - assert_batches_eq!(expected, &actual); - } - Ok(()) -} - -#[tokio::test] -async fn equijoin_left_and_condition_from_both() -> Result<()> { - let test_repartition_joins = vec![true, false]; - for repartition_joins in test_repartition_joins { - let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; - let sql = - "SELECT t1_id, t1_int, t2_int FROM t1 LEFT JOIN t2 ON t1_id = t2_id AND t1_int >= t2_int ORDER BY t1_id"; - let expected = vec![ - "+-------+--------+--------+", - "| t1_id | t1_int | t2_int |", - "+-------+--------+--------+", - "| 11 | 1 | |", - "| 22 | 2 | 1 |", - "| 33 | 3 | |", - "| 44 | 4 | 3 |", - "+-------+--------+--------+", - ]; + let ctx = create_join_context_qualified("t1", "t2")?; + let equivalent_sql = [ + "SELECT t1.a, t2.b FROM t1 INNER JOIN t2 ON t1.a = t2.a ORDER BY t1.a", + "SELECT t1.a, t2.b FROM t1 INNER JOIN t2 ON t2.a = t1.a ORDER BY t1.a", + ]; + let expected = vec![ + "+---+-----+", + "| a | b |", + "+---+-----+", + "| 1 | 100 |", + "| 2 | 200 |", + "| 4 | 400 |", + "+---+-----+", + ]; + for sql in equivalent_sql.iter() { let actual = execute_to_batches(&ctx, sql).await; assert_batches_eq!(expected, &actual); } @@ -261,81 +42,6 @@ async fn equijoin_left_and_condition_from_both() -> Result<()> { Ok(()) } -#[tokio::test] -async fn equijoin_right_and_condition_from_right() -> Result<()> { - let test_repartition_joins = vec![true, false]; - for repartition_joins in test_repartition_joins { - let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; - let sql = - "SELECT t1_id, t1_name, t2_name FROM t1 RIGHT JOIN t2 ON t1_id = t2_id AND t2_id >= 22 ORDER BY t2_name"; - let dataframe = ctx.sql(sql).await.unwrap(); - let actual = dataframe.collect().await.unwrap(); - let expected = vec![ - "+-------+---------+---------+", - "| t1_id | t1_name | t2_name |", - "+-------+---------+---------+", - "| | | w |", - "| 44 | d | x |", - "| 22 | b | y |", - "| | | z |", - "+-------+---------+---------+", - ]; - assert_batches_eq!(expected, &actual); - } - Ok(()) -} - -#[tokio::test] -async fn equijoin_right_and_condition_from_both() -> Result<()> { - let test_repartition_joins = vec![true, false]; - for repartition_joins in test_repartition_joins { - let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; - let sql = - "SELECT t1_int, t2_int, t2_id FROM t1 RIGHT JOIN t2 ON t1_id = t2_id AND t2_int <= t1_int ORDER BY t2_id"; - let dataframe = ctx.sql(sql).await.unwrap(); - let actual = dataframe.collect().await.unwrap(); - let expected = vec![ - "+--------+--------+-------+", - "| t1_int | t2_int | t2_id |", - "+--------+--------+-------+", - "| | 3 | 11 |", - "| 2 | 1 | 22 |", - "| 4 | 3 | 44 |", - "| | 3 | 55 |", - "+--------+--------+-------+", - ]; - assert_batches_eq!(expected, &actual); - } - Ok(()) -} - -#[tokio::test] -async fn left_join() -> Result<()> { - let test_repartition_joins = vec![true, false]; - for repartition_joins in test_repartition_joins { - let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; - let equivalent_sql = [ - "SELECT t1_id, t1_name, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id ORDER BY t1_id", - "SELECT t1_id, t1_name, t2_name FROM t1 LEFT JOIN t2 ON t2_id = t1_id ORDER BY t1_id", - ]; - let expected = vec![ - "+-------+---------+---------+", - "| t1_id | t1_name | t2_name |", - "+-------+---------+---------+", - "| 11 | a | z |", - "| 22 | b | y |", - "| 33 | c | |", - "| 44 | d | x |", - "+-------+---------+---------+", - ]; - for sql in equivalent_sql.iter() { - let actual = execute_to_batches(&ctx, sql).await; - assert_batches_eq!(expected, &actual); - } - } - Ok(()) -} - #[tokio::test] async fn left_join_unbalanced() -> Result<()> { // the t1_id is larger than t2_id so the join_selection optimizer should kick in @@ -576,71 +282,6 @@ async fn full_join_not_null_filter() -> Result<()> { Ok(()) } -#[tokio::test] -async fn right_join() -> Result<()> { - let test_repartition_joins = vec![true, false]; - for repartition_joins in test_repartition_joins { - let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; - let equivalent_sql = [ - "SELECT t1_id, t1_name, t2_name FROM t1 RIGHT JOIN t2 ON t1_id = t2_id ORDER BY t1_id", - "SELECT t1_id, t1_name, t2_name FROM t1 RIGHT JOIN t2 ON t2_id = t1_id ORDER BY t1_id" - ]; - let expected = vec![ - "+-------+---------+---------+", - "| t1_id | t1_name | t2_name |", - "+-------+---------+---------+", - "| 11 | a | z |", - "| 22 | b | y |", - "| 44 | d | x |", - "| | | w |", - "+-------+---------+---------+", - ]; - for sql in equivalent_sql.iter() { - let actual = execute_to_batches(&ctx, sql).await; - assert_batches_eq!(expected, &actual); - } - } - Ok(()) -} - -#[tokio::test] -async fn full_join() -> Result<()> { - let test_repartition_joins = vec![true, false]; - for repartition_joins in test_repartition_joins { - let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; - let equivalent_sql = [ - "SELECT t1_id, t1_name, t2_name FROM t1 FULL JOIN t2 ON t1_id = t2_id ORDER BY t1_id", - "SELECT t1_id, t1_name, t2_name FROM t1 FULL JOIN t2 ON t2_id = t1_id ORDER BY t1_id", - ]; - let expected = vec![ - "+-------+---------+---------+", - "| t1_id | t1_name | t2_name |", - "+-------+---------+---------+", - "| 11 | a | z |", - "| 22 | b | y |", - "| 33 | c | |", - "| 44 | d | x |", - "| | | w |", - "+-------+---------+---------+", - ]; - for sql in equivalent_sql.iter() { - let actual = execute_to_batches(&ctx, sql).await; - assert_batches_eq!(expected, &actual); - } - - let equivalent_sql = [ - "SELECT t1_id, t1_name, t2_name FROM t1 FULL OUTER JOIN t2 ON t1_id = t2_id ORDER BY t1_id", - "SELECT t1_id, t1_name, t2_name FROM t1 FULL OUTER JOIN t2 ON t2_id = t1_id ORDER BY t1_id", - ]; - for sql in equivalent_sql.iter() { - let actual = execute_to_batches(&ctx, sql).await; - assert_batches_eq!(expected, &actual); - } - } - - Ok(()) -} - #[tokio::test] async fn left_join_using() -> Result<()> { let test_repartition_joins = vec![true, false]; @@ -664,80 +305,6 @@ async fn left_join_using() -> Result<()> { Ok(()) } -#[tokio::test] -async fn equijoin_implicit_syntax() -> Result<()> { - let test_repartition_joins = vec![true, false]; - for repartition_joins in test_repartition_joins { - let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; - let equivalent_sql = [ - "SELECT t1_id, t1_name, t2_name FROM t1, t2 WHERE t1_id = t2_id ORDER BY t1_id", - "SELECT t1_id, t1_name, t2_name FROM t1, t2 WHERE t2_id = t1_id ORDER BY t1_id", - ]; - let expected = vec![ - "+-------+---------+---------+", - "| t1_id | t1_name | t2_name |", - "+-------+---------+---------+", - "| 11 | a | z |", - "| 22 | b | y |", - "| 44 | d | x |", - "+-------+---------+---------+", - ]; - for sql in equivalent_sql.iter() { - let actual = execute_to_batches(&ctx, sql).await; - assert_batches_eq!(expected, &actual); - } - } - Ok(()) -} - -#[tokio::test] -async fn equijoin_implicit_syntax_with_filter() -> Result<()> { - let test_repartition_joins = vec![true, false]; - for repartition_joins in test_repartition_joins { - let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; - let sql = "SELECT t1_id, t1_name, t2_name \ - FROM t1, t2 \ - WHERE t1_id > 0 \ - AND t1_id = t2_id \ - AND t2_id < 99 \ - ORDER BY t1_id"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ - "+-------+---------+---------+", - "| t1_id | t1_name | t2_name |", - "+-------+---------+---------+", - "| 11 | a | z |", - "| 22 | b | y |", - "| 44 | d | x |", - "+-------+---------+---------+", - ]; - assert_batches_eq!(expected, &actual); - } - Ok(()) -} - -#[tokio::test] -async fn equijoin_implicit_syntax_reversed() -> Result<()> { - let test_repartition_joins = vec![true, false]; - for repartition_joins in test_repartition_joins { - let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; - let sql = - "SELECT t1_id, t1_name, t2_name FROM t1, t2 WHERE t2_id = t1_id ORDER BY t1_id"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ - "+-------+---------+---------+", - "| t1_id | t1_name | t2_name |", - "+-------+---------+---------+", - "| 11 | a | z |", - "| 22 | b | y |", - "| 44 | d | x |", - "+-------+---------+---------+", - ]; - assert_batches_eq!(expected, &actual); - } - Ok(()) -} - #[tokio::test] async fn cross_join() { let test_repartition_joins = vec![true, false]; @@ -1496,17 +1063,14 @@ async fn hash_join_with_dictionary() -> Result<()> { #[tokio::test] async fn reduce_left_join_1() -> Result<()> { - let test_repartition_joins = vec![true, false]; - for repartition_joins in test_repartition_joins { - let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; + let ctx = create_join_context("t1_id", "t2_id", false)?; - // reduce to inner join - let sql = - "select * from t1 left join t2 on t1.t1_id = t2.t2_id where t2.t2_id < 100"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); - let plan = dataframe.into_optimized_plan()?; - let expected = vec![ + // reduce to inner join + let sql = "select * from t1 left join t2 on t1.t1_id = t2.t2_id where t2.t2_id < 100"; + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); + let plan = dataframe.into_optimized_plan()?; + let expected = vec![ "Explain [plan_type:Utf8, plan:Utf8]", " Inner Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", " Filter: t1.t1_id < UInt32(100) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", @@ -1514,46 +1078,31 @@ async fn reduce_left_join_1() -> Result<()> { " Filter: t2.t2_id < UInt32(100) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - let expected = vec![ - "+-------+---------+--------+-------+---------+--------+", - "| t1_id | t1_name | t1_int | t2_id | t2_name | t2_int |", - "+-------+---------+--------+-------+---------+--------+", - "| 11 | a | 1 | 11 | z | 3 |", - "| 22 | b | 2 | 22 | y | 1 |", - "| 44 | d | 4 | 44 | x | 3 |", - "+-------+---------+--------+-------+---------+--------+", - ]; - - let results = execute_to_batches(&ctx, sql).await; - assert_batches_sorted_eq!(expected, &results); - } + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); Ok(()) } #[tokio::test] async fn reduce_left_join_2() -> Result<()> { - let test_repartition_joins = vec![true, false]; - for repartition_joins in test_repartition_joins { - let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; + let ctx = create_join_context("t1_id", "t2_id", false)?; - // reduce to inner join - let sql = "select * from t1 left join t2 on t1.t1_id = t2.t2_id where t2.t2_int < 10 or (t1.t1_int > 2 and t2.t2_name != 'w')"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); - let plan = dataframe.into_optimized_plan()?; + // reduce to inner join + let sql = "select * from t1 left join t2 on t1.t1_id = t2.t2_id where t2.t2_int < 10 or (t1.t1_int > 2 and t2.t2_name != 'w')"; + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); + let plan = dataframe.into_optimized_plan()?; - // filter expr: `t2.t2_int < 10 or (t1.t1_int > 2 and t2.t2_name != 'w')` - // could be write to: `(t1.t1_int > 2 or t2.t2_int < 10) and (t2.t2_name != 'w' or t2.t2_int < 10)` - // the right part `(t2.t2_name != 'w' or t2.t2_int < 10)` could be push down left join side and remove in filter. + // filter expr: `t2.t2_int < 10 or (t1.t1_int > 2 and t2.t2_name != 'w')` + // could be write to: `(t1.t1_int > 2 or t2.t2_int < 10) and (t2.t2_name != 'w' or t2.t2_int < 10)` + // the right part `(t2.t2_name != 'w' or t2.t2_int < 10)` could be push down left join side and remove in filter. - let expected = vec![ + let expected = vec![ "Explain [plan_type:Utf8, plan:Utf8]", " Filter: t2.t2_int < UInt32(10) OR t1.t1_int > UInt32(2) AND t2.t2_name != Utf8(\"w\") [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", " Inner Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", @@ -1561,41 +1110,26 @@ async fn reduce_left_join_2() -> Result<()> { " Filter: t2.t2_int < UInt32(10) OR t2.t2_name != Utf8(\"w\") [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - let expected = vec![ - "+-------+---------+--------+-------+---------+--------+", - "| t1_id | t1_name | t1_int | t2_id | t2_name | t2_int |", - "+-------+---------+--------+-------+---------+--------+", - "| 11 | a | 1 | 11 | z | 3 |", - "| 22 | b | 2 | 22 | y | 1 |", - "| 44 | d | 4 | 44 | x | 3 |", - "+-------+---------+--------+-------+---------+--------+", - ]; - - let results = execute_to_batches(&ctx, sql).await; - assert_batches_sorted_eq!(expected, &results); - } + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); Ok(()) } #[tokio::test] async fn reduce_left_join_3() -> Result<()> { - let test_repartition_joins = vec![true, false]; - for repartition_joins in test_repartition_joins { - let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; + let ctx = create_join_context("t1_id", "t2_id", false)?; - // reduce subquery to inner join - let sql = "select * from (select t1.* from t1 left join t2 on t1.t1_id = t2.t2_id where t2.t2_int < 3) t3 left join t2 on t3.t1_int = t2.t2_int where t3.t1_id < 100"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); - let plan = dataframe.into_optimized_plan()?; - let expected = vec![ + // reduce subquery to inner join + let sql = "select * from (select t1.* from t1 left join t2 on t1.t1_id = t2.t2_id where t2.t2_int < 3) t3 left join t2 on t3.t1_int = t2.t2_int where t3.t1_id < 100"; + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); + let plan = dataframe.into_optimized_plan()?; + let expected = vec![ "Explain [plan_type:Utf8, plan:Utf8]", " Left Join: t3.t1_int = t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", " SubqueryAlias: t3 [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", @@ -1608,204 +1142,131 @@ async fn reduce_left_join_3() -> Result<()> { " TableScan: t2 projection=[t2_id, t2_int] [t2_id:UInt32;N, t2_int:UInt32;N]", " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - let expected = vec![ - "+-------+---------+--------+-------+---------+--------+", - "| t1_id | t1_name | t1_int | t2_id | t2_name | t2_int |", - "+-------+---------+--------+-------+---------+--------+", - "| 22 | b | 2 | | | |", - "+-------+---------+--------+-------+---------+--------+", - ]; - - let results = execute_to_batches(&ctx, sql).await; - assert_batches_sorted_eq!(expected, &results); - } + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); Ok(()) } #[tokio::test] async fn reduce_right_join_1() -> Result<()> { - let test_repartition_joins = vec![true, false]; - for repartition_joins in test_repartition_joins { - let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; + let ctx = create_join_context("t1_id", "t2_id", false)?; - // reduce to inner join - let sql = "select * from t1 right join t2 on t1.t1_id = t2.t2_id where t1.t1_int is not null"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); - let plan = dataframe.into_optimized_plan()?; - let expected = vec![ + // reduce to inner join + let sql = "select * from t1 right join t2 on t1.t1_id = t2.t2_id where t1.t1_int is not null"; + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); + let plan = dataframe.into_optimized_plan()?; + let expected = vec![ "Explain [plan_type:Utf8, plan:Utf8]", " Inner Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", " Filter: t1.t1_int IS NOT NULL [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - let expected = vec![ - "+-------+---------+--------+-------+---------+--------+", - "| t1_id | t1_name | t1_int | t2_id | t2_name | t2_int |", - "+-------+---------+--------+-------+---------+--------+", - "| 11 | a | 1 | 11 | z | 3 |", - "| 22 | b | 2 | 22 | y | 1 |", - "| 44 | d | 4 | 44 | x | 3 |", - "+-------+---------+--------+-------+---------+--------+", - ]; - - let results = execute_to_batches(&ctx, sql).await; - assert_batches_sorted_eq!(expected, &results); - } + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); Ok(()) } #[tokio::test] async fn reduce_right_join_2() -> Result<()> { - let test_repartition_joins = vec![true, false]; - for repartition_joins in test_repartition_joins { - let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; - - // reduce to inner join - let sql = "select * from t1 right join t2 on t1.t1_id = t2.t2_id where not(t1.t1_int = t2.t2_int)"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); - let plan = dataframe.into_optimized_plan()?; - let expected = vec![ + let ctx = create_join_context("t1_id", "t2_id", false)?; + + // reduce to inner join + let sql = "select * from t1 right join t2 on t1.t1_id = t2.t2_id where not(t1.t1_int = t2.t2_int)"; + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); + let plan = dataframe.into_optimized_plan()?; + let expected = vec![ "Explain [plan_type:Utf8, plan:Utf8]", " Filter: t1.t1_int != t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", " Inner Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - let expected = vec![ - "+-------+---------+--------+-------+---------+--------+", - "| t1_id | t1_name | t1_int | t2_id | t2_name | t2_int |", - "+-------+---------+--------+-------+---------+--------+", - "| 11 | a | 1 | 11 | z | 3 |", - "| 22 | b | 2 | 22 | y | 1 |", - "| 44 | d | 4 | 44 | x | 3 |", - "+-------+---------+--------+-------+---------+--------+", - ]; - - let results = execute_to_batches(&ctx, sql).await; - assert_batches_sorted_eq!(expected, &results); - } + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); Ok(()) } #[tokio::test] async fn reduce_full_join_to_right_join() -> Result<()> { - let test_repartition_joins = vec![true, false]; - for repartition_joins in test_repartition_joins { - let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; + let ctx = create_join_context("t1_id", "t2_id", false)?; - // reduce to right join - let sql = "select * from t1 full join t2 on t1.t1_id = t2.t2_id where t2.t2_name is not null"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); - let plan = dataframe.into_optimized_plan()?; - let expected = vec![ + // reduce to right join + let sql = "select * from t1 full join t2 on t1.t1_id = t2.t2_id where t2.t2_name is not null"; + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); + let plan = dataframe.into_optimized_plan()?; + let expected = vec![ "Explain [plan_type:Utf8, plan:Utf8]", " Right Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", " Filter: t2.t2_name IS NOT NULL [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - let expected = vec![ - "+-------+---------+--------+-------+---------+--------+", - "| t1_id | t1_name | t1_int | t2_id | t2_name | t2_int |", - "+-------+---------+--------+-------+---------+--------+", - "| | | | 55 | w | 3 |", - "| 11 | a | 1 | 11 | z | 3 |", - "| 22 | b | 2 | 22 | y | 1 |", - "| 44 | d | 4 | 44 | x | 3 |", - "+-------+---------+--------+-------+---------+--------+", - ]; - - let results = execute_to_batches(&ctx, sql).await; - assert_batches_sorted_eq!(expected, &results); - } + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); Ok(()) } #[tokio::test] async fn reduce_full_join_to_left_join() -> Result<()> { - let test_repartition_joins = vec![true, false]; - for repartition_joins in test_repartition_joins { - let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; + let ctx = create_join_context("t1_id", "t2_id", false)?; - // reduce to left join - let sql = - "select * from t1 full join t2 on t1.t1_id = t2.t2_id where t1.t1_name != 'b'"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); - let plan = dataframe.into_optimized_plan()?; - let expected = vec![ + // reduce to left join + let sql = + "select * from t1 full join t2 on t1.t1_id = t2.t2_id where t1.t1_name != 'b'"; + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); + let plan = dataframe.into_optimized_plan()?; + let expected = vec![ "Explain [plan_type:Utf8, plan:Utf8]", " Left Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", " Filter: t1.t1_name != Utf8(\"b\") [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - let expected = vec![ - "+-------+---------+--------+-------+---------+--------+", - "| t1_id | t1_name | t1_int | t2_id | t2_name | t2_int |", - "+-------+---------+--------+-------+---------+--------+", - "| 11 | a | 1 | 11 | z | 3 |", - "| 33 | c | 3 | | | |", - "| 44 | d | 4 | 44 | x | 3 |", - "+-------+---------+--------+-------+---------+--------+", - ]; + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); - let results = execute_to_batches(&ctx, sql).await; - assert_batches_sorted_eq!(expected, &results); - } Ok(()) } #[tokio::test] async fn reduce_full_join_to_inner_join() -> Result<()> { - let test_repartition_joins = vec![true, false]; - for repartition_joins in test_repartition_joins { - let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; + let ctx = create_join_context("t1_id", "t2_id", false)?; - // reduce to inner join - let sql = "select * from t1 full join t2 on t1.t1_id = t2.t2_id where t1.t1_name != 'b' and t2.t2_name = 'x'"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); - let plan = dataframe.into_optimized_plan()?; - let expected = vec![ + // reduce to inner join + let sql = "select * from t1 full join t2 on t1.t1_id = t2.t2_id where t1.t1_name != 'b' and t2.t2_name = 'x'"; + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); + let plan = dataframe.into_optimized_plan()?; + let expected = vec![ "Explain [plan_type:Utf8, plan:Utf8]", " Inner Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", " Filter: t1.t1_name != Utf8(\"b\") [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", @@ -1813,23 +1274,12 @@ async fn reduce_full_join_to_inner_join() -> Result<()> { " Filter: t2.t2_name = Utf8(\"x\") [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - let expected = vec![ - "+-------+---------+--------+-------+---------+--------+", - "| t1_id | t1_name | t1_int | t2_id | t2_name | t2_int |", - "+-------+---------+--------+-------+---------+--------+", - "| 44 | d | 4 | 44 | x | 3 |", - "+-------+---------+--------+-------+---------+--------+", - ]; - - let results = execute_to_batches(&ctx, sql).await; - assert_batches_sorted_eq!(expected, &results); - } + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); Ok(()) } @@ -2324,84 +1774,42 @@ async fn right_semi_join() -> Result<()> { Ok(()) } -#[tokio::test] -async fn left_join_with_nonequal_condition() -> Result<()> { - let test_repartition_joins = vec![true, false]; - for repartition_joins in test_repartition_joins { - let ctx = create_join_context("t1_id", "t2_id", repartition_joins).unwrap(); - - let sql = "SELECT t1_id, t1_name, t2_name FROM t1 LEFT JOIN t2 ON (t1_id != t2_id and t2_id >= 100) ORDER BY t1_id"; - let actual = execute_to_batches(&ctx, sql).await; - let expected = vec![ - "+-------+---------+---------+", - "| t1_id | t1_name | t2_name |", - "+-------+---------+---------+", - "| 11 | a | |", - "| 22 | b | |", - "| 33 | c | |", - "| 44 | d | |", - "+-------+---------+---------+", - ]; - - assert_batches_eq!(expected, &actual); - } - - Ok(()) -} - #[tokio::test] async fn reduce_cross_join_with_expr_join_key_all() -> Result<()> { - let test_repartition_joins = vec![true, false]; - for repartition_joins in test_repartition_joins { - let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; + let ctx = create_join_context("t1_id", "t2_id", false)?; - // reduce to inner join - let sql = "select * from t1 cross join t2 where t1.t1_id + 12 = t2.t2_id + 1"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); - let plan = dataframe.into_optimized_plan()?; - let expected = vec![ + // reduce to inner join + let sql = "select * from t1 cross join t2 where t1.t1_id + 12 = t2.t2_id + 1"; + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); + let plan = dataframe.into_optimized_plan()?; + let expected = vec![ "Explain [plan_type:Utf8, plan:Utf8]", " Inner Join: CAST(t1.t1_id AS Int64) + Int64(12) = CAST(t2.t2_id AS Int64) + Int64(1) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - let expected = vec![ - "+-------+---------+--------+-------+---------+--------+", - "| t1_id | t1_name | t1_int | t2_id | t2_name | t2_int |", - "+-------+---------+--------+-------+---------+--------+", - "| 11 | a | 1 | 22 | y | 1 |", - "| 33 | c | 3 | 44 | x | 3 |", - "| 44 | d | 4 | 55 | w | 3 |", - "+-------+---------+--------+-------+---------+--------+", - ]; - - let results = execute_to_batches(&ctx, sql).await; - assert_batches_sorted_eq!(expected, &results); - } + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); Ok(()) } #[tokio::test] async fn reduce_cross_join_with_cast_expr_join_key() -> Result<()> { - let test_repartition_joins = vec![true, false]; - for repartition_joins in test_repartition_joins { - let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; + let ctx = create_join_context("t1_id", "t2_id", false)?; - let sql = + let sql = "select t1.t1_id, t2.t2_id, t1.t1_name from t1 cross join t2 where t1.t1_id + 11 = cast(t2.t2_id as BIGINT)"; - let msg = format!("Creating logical plan for '{sql}'"); - let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); - let plan = dataframe.into_optimized_plan()?; - let expected = vec![ + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); + let plan = dataframe.into_optimized_plan()?; + let expected = vec![ "Explain [plan_type:Utf8, plan:Utf8]", " Projection: t1.t1_id, t2.t2_id, t1.t1_name [t1_id:UInt32;N, t2_id:UInt32;N, t1_name:Utf8;N]", " Inner Join: CAST(t1.t1_id AS Int64) + Int64(11) = CAST(t2.t2_id AS Int64) [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]", @@ -2409,25 +1817,12 @@ async fn reduce_cross_join_with_cast_expr_join_key() -> Result<()> { " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]", ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - let expected = vec![ - "+-------+-------+---------+", - "| t1_id | t2_id | t1_name |", - "+-------+-------+---------+", - "| 11 | 22 | a |", - "| 33 | 44 | c |", - "| 44 | 55 | d |", - "+-------+-------+---------+", - ]; - - let results = execute_to_batches(&ctx, sql).await; - assert_batches_sorted_eq!(expected, &results); - } + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); Ok(()) } @@ -2502,20 +1897,6 @@ async fn reduce_cross_join_with_wildcard_and_expr() -> Result<()> { expected, actual, "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" ); - - // assert execution result - let expected = vec![ - "+-------+---------+--------+-------+---------+--------+----------------------+", - "| t1_id | t1_name | t1_int | t2_id | t2_name | t2_int | t1.t1_id + Int64(11) |", - "+-------+---------+--------+-------+---------+--------+----------------------+", - "| 11 | a | 1 | 22 | y | 1 | 22 |", - "| 33 | c | 3 | 44 | x | 3 | 44 |", - "| 44 | d | 4 | 55 | w | 3 | 55 |", - "+-------+---------+--------+-------+---------+--------+----------------------+", - ]; - - let results = execute_to_batches(&ctx, sql).await; - assert_batches_sorted_eq!(expected, &results); } Ok(()) @@ -2574,19 +1955,6 @@ async fn both_side_expr_key_inner_join() -> Result<()> { expected, actual, "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" ); - - let expected = vec![ - "+-------+-------+---------+", - "| t1_id | t2_id | t1_name |", - "+-------+-------+---------+", - "| 11 | 22 | a |", - "| 33 | 44 | c |", - "| 44 | 55 | d |", - "+-------+-------+---------+", - ]; - - let results = execute_to_batches(&ctx, sql).await; - assert_batches_sorted_eq!(expected, &results); } Ok(()) @@ -2642,19 +2010,6 @@ async fn left_side_expr_key_inner_join() -> Result<()> { expected, actual, "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" ); - - let expected = vec![ - "+-------+-------+---------+", - "| t1_id | t2_id | t1_name |", - "+-------+-------+---------+", - "| 11 | 22 | a |", - "| 33 | 44 | c |", - "| 44 | 55 | d |", - "+-------+-------+---------+", - ]; - - let results = execute_to_batches(&ctx, sql).await; - assert_batches_sorted_eq!(expected, &results); } Ok(()) @@ -2709,19 +2064,6 @@ async fn right_side_expr_key_inner_join() -> Result<()> { expected, actual, "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" ); - - let expected = vec![ - "+-------+-------+---------+", - "| t1_id | t2_id | t1_name |", - "+-------+-------+---------+", - "| 11 | 22 | a |", - "| 33 | 44 | c |", - "| 44 | 55 | d |", - "+-------+-------+---------+", - ]; - - let results = execute_to_batches(&ctx, sql).await; - assert_batches_sorted_eq!(expected, &results); } Ok(()) @@ -2774,19 +2116,6 @@ async fn select_wildcard_with_expr_key_inner_join() -> Result<()> { expected, actual, "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" ); - - let expected = vec![ - "+-------+---------+--------+-------+---------+--------+", - "| t1_id | t1_name | t1_int | t2_id | t2_name | t2_int |", - "+-------+---------+--------+-------+---------+--------+", - "| 11 | a | 1 | 22 | y | 1 |", - "| 33 | c | 3 | 44 | x | 3 |", - "| 44 | d | 4 | 55 | w | 3 |", - "+-------+---------+--------+-------+---------+--------+", - ]; - - let results = execute_to_batches(&ctx, sql).await; - assert_batches_sorted_eq!(expected, &results); } Ok(()) @@ -2817,19 +2146,6 @@ async fn join_with_type_coercion_for_equi_expr() -> Result<()> { "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" ); - let expected = vec![ - "+-------+---------+-------+", - "| t1_id | t1_name | t2_id |", - "+-------+---------+-------+", - "| 11 | a | 22 |", - "| 33 | c | 44 |", - "| 44 | d | 55 |", - "+-------+---------+-------+", - ]; - - let results = execute_to_batches(&ctx, sql).await; - assert_batches_sorted_eq!(expected, &results); - Ok(()) } @@ -2858,17 +2174,6 @@ async fn join_only_with_filter() -> Result<()> { "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" ); - let expected = vec![ - "+-------+---------+-------+", - "| t1_id | t1_name | t2_id |", - "+-------+---------+-------+", - "| 11 | a | 55 |", - "+-------+---------+-------+", - ]; - - let results = execute_to_batches(&ctx, sql).await; - assert_batches_sorted_eq!(expected, &results); - Ok(()) } @@ -2900,17 +2205,6 @@ async fn type_coercion_join_with_filter_and_equi_expr() -> Result<()> { "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" ); - let expected = vec![ - "+-------+---------+-------+", - "| t1_id | t1_name | t2_id |", - "+-------+---------+-------+", - "| 11 | a | 55 |", - "+-------+---------+-------+", - ]; - - let results = execute_to_batches(&ctx, sql).await; - assert_batches_sorted_eq!(expected, &results); - Ok(()) } diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index 456dc12a1031..65cbcdfe06d9 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -231,6 +231,82 @@ fn create_join_context( Ok(ctx) } +fn create_sub_query_join_context( + column_outer: &str, + column_inner_left: &str, + column_inner_right: &str, + repartition_joins: bool, +) -> Result { + let ctx = SessionContext::with_config( + SessionConfig::new() + .with_repartition_joins(repartition_joins) + .with_target_partitions(2) + .with_batch_size(4096), + ); + + let t0_schema = Arc::new(Schema::new(vec![ + Field::new(column_outer, DataType::UInt32, true), + Field::new("t0_name", DataType::Utf8, true), + Field::new("t0_int", DataType::UInt32, true), + ])); + let t0_data = RecordBatch::try_new( + t0_schema, + vec![ + Arc::new(UInt32Array::from_slice([11, 22, 33, 44])), + Arc::new(StringArray::from(vec![ + Some("a"), + Some("b"), + Some("c"), + Some("d"), + ])), + Arc::new(UInt32Array::from_slice([1, 2, 3, 4])), + ], + )?; + ctx.register_batch("t0", t0_data)?; + + let t1_schema = Arc::new(Schema::new(vec![ + Field::new(column_inner_left, DataType::UInt32, true), + Field::new("t1_name", DataType::Utf8, true), + Field::new("t1_int", DataType::UInt32, true), + ])); + let t1_data = RecordBatch::try_new( + t1_schema, + vec![ + Arc::new(UInt32Array::from_slice([11, 22, 33, 44])), + Arc::new(StringArray::from(vec![ + Some("a"), + Some("b"), + Some("c"), + Some("d"), + ])), + Arc::new(UInt32Array::from_slice([1, 2, 3, 4])), + ], + )?; + ctx.register_batch("t1", t1_data)?; + + let t2_schema = Arc::new(Schema::new(vec![ + Field::new(column_inner_right, DataType::UInt32, true), + Field::new("t2_name", DataType::Utf8, true), + Field::new("t2_int", DataType::UInt32, true), + ])); + let t2_data = RecordBatch::try_new( + t2_schema, + vec![ + Arc::new(UInt32Array::from_slice([11, 22, 44, 55])), + Arc::new(StringArray::from(vec![ + Some("z"), + Some("y"), + Some("x"), + Some("w"), + ])), + Arc::new(UInt32Array::from_slice([3, 1, 3, 3])), + ], + )?; + ctx.register_batch("t2", t2_data)?; + + Ok(ctx) +} + fn create_left_semi_anti_join_context_with_null_ids( column_left: &str, column_right: &str, @@ -1438,99 +1514,6 @@ pub fn make_timestamps() -> RecordBatch { .unwrap() } -/// Return a new table provider containing all of the supported timestamp types -pub fn table_with_times() -> Arc { - let batch = make_times(); - let schema = batch.schema(); - let partitions = vec![vec![batch]]; - Arc::new(MemTable::try_new(schema, partitions).unwrap()) -} - -/// Return record batch with all of the supported time types -/// values -/// -/// Columns are named: -/// "nanos" --> Time64NanosecondArray -/// "micros" --> Time64MicrosecondArray -/// "millis" --> Time32MillisecondArray -/// "secs" --> Time32SecondArray -/// "names" --> StringArray -pub fn make_times() -> RecordBatch { - let ts_strings = vec![ - Some("18:06:30.243620451"), - Some("20:08:28.161121654"), - Some("19:11:04.156423842"), - Some("21:06:28.247821084"), - ]; - - let ts_nanos = ts_strings - .into_iter() - .map(|t| { - t.map(|t| { - let integer_sec = t - .parse::() - .unwrap() - .num_seconds_from_midnight() as i64; - let extra_nano = - t.parse::().unwrap().nanosecond() as i64; - // Total time in nanoseconds given by integer number of seconds multiplied by 10^9 - // plus number of nanoseconds corresponding to the extra fraction of second - integer_sec * 1_000_000_000 + extra_nano - }) - }) - .collect::>(); - - let ts_micros = ts_nanos - .iter() - .map(|t| t.as_ref().map(|ts_nanos| ts_nanos / 1000)) - .collect::>(); - - let ts_millis = ts_nanos - .iter() - .map(|t| t.as_ref().map(|ts_nanos| { ts_nanos / 1000000 } as i32)) - .collect::>(); - - let ts_secs = ts_nanos - .iter() - .map(|t| t.as_ref().map(|ts_nanos| { ts_nanos / 1000000000 } as i32)) - .collect::>(); - - let names = ts_nanos - .iter() - .enumerate() - .map(|(i, _)| format!("Row {i}")) - .collect::>(); - - let arr_nanos = Time64NanosecondArray::from(ts_nanos); - let arr_micros = Time64MicrosecondArray::from(ts_micros); - let arr_millis = Time32MillisecondArray::from(ts_millis); - let arr_secs = Time32SecondArray::from(ts_secs); - - let names = names.iter().map(|s| s.as_str()).collect::>(); - let arr_names = StringArray::from(names); - - let schema = Schema::new(vec![ - Field::new("nanos", arr_nanos.data_type().clone(), true), - Field::new("micros", arr_micros.data_type().clone(), true), - Field::new("millis", arr_millis.data_type().clone(), true), - Field::new("secs", arr_secs.data_type().clone(), true), - Field::new("name", arr_names.data_type().clone(), true), - ]); - let schema = Arc::new(schema); - - RecordBatch::try_new( - schema, - vec![ - Arc::new(arr_nanos), - Arc::new(arr_micros), - Arc::new(arr_millis), - Arc::new(arr_secs), - Arc::new(arr_names), - ], - ) - .unwrap() -} - #[tokio::test] async fn nyc() -> Result<()> { // schema for nyxtaxi csv files diff --git a/datafusion/core/tests/sql/order.rs b/datafusion/core/tests/sql/order.rs index 2388eebef931..da7663cb8b5c 100644 --- a/datafusion/core/tests/sql/order.rs +++ b/datafusion/core/tests/sql/order.rs @@ -16,6 +16,9 @@ // under the License. use super::*; +use datafusion::datasource::datasource::TableProviderFactory; +use datafusion::datasource::listing::ListingTable; +use datafusion::datasource::listing_table_factory::ListingTableFactory; use test_utils::{batches_to_vec, partitions_to_sorted_vec}; #[tokio::test] @@ -39,3 +42,136 @@ async fn sort_with_lots_of_repetition_values() -> Result<()> { } Ok(()) } + +#[tokio::test] +async fn create_external_table_with_order() -> Result<()> { + let ctx = SessionContext::new(); + let sql = "CREATE EXTERNAL TABLE dt (a_id integer, a_str string, a_bool boolean) STORED AS CSV WITH ORDER (a_id ASC) LOCATION 'file://path/to/table';"; + if let LogicalPlan::CreateExternalTable(cmd) = + ctx.state().create_logical_plan(sql).await? + { + let listing_table_factory = Arc::new(ListingTableFactory::new()); + let table_dyn = listing_table_factory.create(&ctx.state(), &cmd).await?; + let table = table_dyn.as_any().downcast_ref::().unwrap(); + assert_eq!(cmd.order_exprs.len(), 1); + assert_eq!( + &cmd.order_exprs, + table.options().file_sort_order.as_ref().unwrap() + ) + } else { + panic!("Wrong command") + } + Ok(()) +} + +#[tokio::test] +async fn create_external_table_with_ddl_ordered_non_cols() -> Result<()> { + let ctx = SessionContext::new(); + let sql = "CREATE EXTERNAL TABLE dt (a_id integer, a_str string, a_bool boolean) STORED AS CSV WITH ORDER (a ASC) LOCATION 'file://path/to/table';"; + match ctx.state().create_logical_plan(sql).await { + Ok(_) => panic!("Expecting error."), + Err(e) => { + assert_eq!( + e.to_string(), + "Error during planning: Column a is not in schema" + ) + } + } + Ok(()) +} + +#[tokio::test] +async fn create_external_table_with_ddl_ordered_without_schema() -> Result<()> { + let ctx = SessionContext::new(); + let sql = "CREATE EXTERNAL TABLE dt STORED AS CSV WITH ORDER (a ASC) LOCATION 'file://path/to/table';"; + match ctx.state().create_logical_plan(sql).await { + Ok(_) => panic!("Expecting error."), + Err(e) => { + assert_eq!(e.to_string(), "Error during planning: Provide a schema before specifying the order while creating a table.") + } + } + Ok(()) +} + +#[tokio::test] +async fn sort_with_duplicate_sort_exprs() -> Result<()> { + let ctx = SessionContext::new(); + + let t1_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, true), + Field::new("name", DataType::Utf8, true), + ])); + + let t1_data = RecordBatch::try_new( + t1_schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![2, 4, 9, 3, 4])), + Arc::new(StringArray::from_slice(["a", "b", "c", "d", "e"])), + ], + )?; + ctx.register_batch("t1", t1_data)?; + + let sql = "select * from t1 order by id desc, id, name, id asc"; + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(sql).await.expect(&msg); + let plan = dataframe.into_optimized_plan().unwrap(); + let expected = vec![ + "Sort: t1.id DESC NULLS FIRST, t1.name ASC NULLS LAST [id:Int32;N, name:Utf8;N]", + " TableScan: t1 projection=[id, name] [id:Int32;N, name:Utf8;N]", + ]; + + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + let expected = vec![ + "+----+------+", + "| id | name |", + "+----+------+", + "| 9 | c |", + "| 4 | b |", + "| 4 | e |", + "| 3 | d |", + "| 2 | a |", + "+----+------+", + ]; + + let results = execute_to_batches(&ctx, sql).await; + assert_batches_eq!(expected, &results); + + let sql = "select * from t1 order by id asc, id, name, id desc;"; + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(sql).await.expect(&msg); + let plan = dataframe.into_optimized_plan().unwrap(); + let expected = vec![ + "Sort: t1.id ASC NULLS LAST, t1.name ASC NULLS LAST [id:Int32;N, name:Utf8;N]", + " TableScan: t1 projection=[id, name] [id:Int32;N, name:Utf8;N]", + ]; + + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + let expected = vec![ + "+----+------+", + "| id | name |", + "+----+------+", + "| 2 | a |", + "| 3 | d |", + "| 4 | b |", + "| 4 | e |", + "| 9 | c |", + "+----+------+", + ]; + + let results = execute_to_batches(&ctx, sql).await; + assert_batches_eq!(expected, &results); + + Ok(()) +} diff --git a/datafusion/core/tests/sql/references.rs b/datafusion/core/tests/sql/references.rs index f006cbb45984..335bc630861c 100644 --- a/datafusion/core/tests/sql/references.rs +++ b/datafusion/core/tests/sql/references.rs @@ -67,7 +67,7 @@ async fn qualified_table_references_and_fields() -> Result<()> { let error = ctx.sql(sql).await.unwrap_err(); assert_contains!( error.to_string(), - "No field named 'f1'.'c1'. Valid fields are 'test'.'f.c1', 'test'.'test.c2'" + r#"No field named "f1"."c1". Valid fields are "test"."f.c1", "test"."test.c2""# ); // however, enclosing it in double quotes is ok diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs index 06d08d2e9f96..6711384d17b8 100644 --- a/datafusion/core/tests/sql/subqueries.rs +++ b/datafusion/core/tests/sql/subqueries.rs @@ -179,3 +179,262 @@ async fn in_subquery_with_same_table() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn invalid_scalar_subquery() -> Result<()> { + let ctx = create_join_context("t1_id", "t2_id", true)?; + + let sql = "SELECT t1_id, t1_name, t1_int, (select t2_id, t2_name FROM t2 WHERE t2.t2_id = t1.t1_int) FROM t1"; + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(sql).await.expect(&msg); + let err = dataframe.into_optimized_plan().err().unwrap(); + assert_eq!( + "Plan(\"Scalar subquery should only return one column\")", + &format!("{err:?}") + ); + + Ok(()) +} + +#[tokio::test] +async fn subquery_not_allowed() -> Result<()> { + let ctx = create_join_context("t1_id", "t2_id", true)?; + + // In/Exist Subquery is not allowed in ORDER BY clause. + let sql = "SELECT t1_id, t1_name, t1_int FROM t1 order by t1_int in (SELECT t2_int FROM t2 WHERE t1.t1_id > t1.t1_int)"; + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(sql).await.expect(&msg); + let err = dataframe.into_optimized_plan().err().unwrap(); + + assert_eq!( + "Plan(\"In/Exist subquery can not be used in Sort plan nodes\")", + &format!("{err:?}") + ); + + Ok(()) +} + +#[tokio::test] +async fn support_agg_correlated_columns() -> Result<()> { + let ctx = create_join_context("t1_id", "t2_id", true)?; + + let sql = "SELECT t1_id, t1_name FROM t1 WHERE EXISTS (SELECT sum(t1.t1_int + t2.t2_id) FROM t2 WHERE t1.t1_name = t2.t2_name)"; + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(sql).await.expect(&msg); + let plan = dataframe.into_optimized_plan()?; + + let expected = vec![ + "Filter: EXISTS () [t1_id:UInt32;N, t1_name:Utf8;N]", + " Subquery: [SUM(outer_ref(t1.t1_int) + t2.t2_id):UInt64;N]", + " Projection: SUM(outer_ref(t1.t1_int) + t2.t2_id) [SUM(outer_ref(t1.t1_int) + t2.t2_id):UInt64;N]", + " Aggregate: groupBy=[[]], aggr=[[SUM(outer_ref(t1.t1_int) + t2.t2_id)]] [SUM(outer_ref(t1.t1_int) + t2.t2_id):UInt64;N]", + " Filter: outer_ref(t1.t1_name) = t2.t2_name [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " TableScan: t2 [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " TableScan: t1 projection=[t1_id, t1_name] [t1_id:UInt32;N, t1_name:Utf8;N]", + ]; + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + Ok(()) +} + +#[tokio::test] +async fn support_agg_correlated_columns2() -> Result<()> { + let ctx = create_join_context("t1_id", "t2_id", true)?; + + let sql = "SELECT t1_id, t1_name FROM t1 WHERE EXISTS (SELECT count(*) FROM t2 WHERE t1.t1_name = t2.t2_name having sum(t1_int + t2_id) >0)"; + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(sql).await.expect(&msg); + let plan = dataframe.into_optimized_plan()?; + + let expected = vec![ + "Filter: EXISTS () [t1_id:UInt32;N, t1_name:Utf8;N]", + " Subquery: [COUNT(UInt8(1)):Int64;N]", + " Projection: COUNT(UInt8(1)) [COUNT(UInt8(1)):Int64;N]", + " Filter: CAST(SUM(outer_ref(t1.t1_int) + t2.t2_id) AS Int64) > Int64(0) [COUNT(UInt8(1)):Int64;N, SUM(outer_ref(t1.t1_int) + t2.t2_id):UInt64;N]", + " Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)), SUM(outer_ref(t1.t1_int) + t2.t2_id)]] [COUNT(UInt8(1)):Int64;N, SUM(outer_ref(t1.t1_int) + t2.t2_id):UInt64;N]", + " Filter: outer_ref(t1.t1_name) = t2.t2_name [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " TableScan: t2 [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " TableScan: t1 projection=[t1_id, t1_name] [t1_id:UInt32;N, t1_name:Utf8;N]", + ]; + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + Ok(()) +} + +#[tokio::test] +async fn support_join_correlated_columns() -> Result<()> { + let ctx = create_sub_query_join_context("t0_id", "t1_id", "t2_id", true)?; + let sql = "SELECT t0_id, t0_name FROM t0 WHERE EXISTS (SELECT 1 FROM t1 INNER JOIN t2 ON(t1.t1_id = t2.t2_id and t1.t1_name = t0.t0_name))"; + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(sql).await.expect(&msg); + let plan = dataframe.into_optimized_plan()?; + + let expected = vec![ + "Filter: EXISTS () [t0_id:UInt32;N, t0_name:Utf8;N]", + " Subquery: [Int64(1):Int64]", + " Projection: Int64(1) [Int64(1):Int64]", + " Inner Join: Filter: t1.t1_id = t2.t2_id AND t1.t1_name = outer_ref(t0.t0_name) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " TableScan: t1 [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " TableScan: t2 [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " TableScan: t0 projection=[t0_id, t0_name] [t0_id:UInt32;N, t0_name:Utf8;N]", + ]; + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + Ok(()) +} + +#[tokio::test] +async fn support_join_correlated_columns2() -> Result<()> { + let ctx = create_sub_query_join_context("t0_id", "t1_id", "t2_id", true)?; + let sql = "SELECT t0_id, t0_name FROM t0 WHERE EXISTS (SELECT 1 FROM t1 INNER JOIN (select * from t2 where t2.t2_name = t0.t0_name) as t2 ON(t1.t1_id = t2.t2_id ))"; + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(sql).await.expect(&msg); + let plan = dataframe.into_optimized_plan()?; + + let expected = vec![ + "Filter: EXISTS () [t0_id:UInt32;N, t0_name:Utf8;N]", + " Subquery: [Int64(1):Int64]", + " Projection: Int64(1) [Int64(1):Int64]", + " Inner Join: Filter: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " TableScan: t1 [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " SubqueryAlias: t2 [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " Projection: t2.t2_id, t2.t2_name, t2.t2_int [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " Filter: t2.t2_name = outer_ref(t0.t0_name) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " TableScan: t2 [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " TableScan: t0 projection=[t0_id, t0_name] [t0_id:UInt32;N, t0_name:Utf8;N]", + ]; + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + Ok(()) +} + +#[tokio::test] +async fn support_order_by_correlated_columns() -> Result<()> { + let ctx = create_join_context("t1_id", "t2_id", true)?; + + let sql = "SELECT t1_id, t1_name FROM t1 WHERE EXISTS (SELECT * FROM t2 WHERE t2_id >= t1_id order by t1_id)"; + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(sql).await.expect(&msg); + let plan = dataframe.into_optimized_plan()?; + + let expected = vec![ + "Filter: EXISTS () [t1_id:UInt32;N, t1_name:Utf8;N]", + " Subquery: [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " Sort: outer_ref(t1.t1_id) ASC NULLS LAST [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " Projection: t2.t2_id, t2.t2_name, t2.t2_int [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " Filter: t2.t2_id >= outer_ref(t1.t1_id) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " TableScan: t2 [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " TableScan: t1 projection=[t1_id, t1_name] [t1_id:UInt32;N, t1_name:Utf8;N]", + ]; + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + Ok(()) +} + +#[tokio::test] +async fn support_limit_subquery() -> Result<()> { + let ctx = create_join_context("t1_id", "t2_id", true)?; + + let sql = "SELECT t1_id, t1_name FROM t1 WHERE EXISTS (SELECT * FROM t2 WHERE t2_id = t1_id limit 1)"; + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(sql).await.expect(&msg); + let plan = dataframe.into_optimized_plan()?; + + let expected = vec![ + "Filter: EXISTS () [t1_id:UInt32;N, t1_name:Utf8;N]", + " Subquery: [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " Limit: skip=0, fetch=1 [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " Projection: t2.t2_id, t2.t2_name, t2.t2_int [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " Filter: t2.t2_id = outer_ref(t1.t1_id) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " TableScan: t2 [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " TableScan: t1 projection=[t1_id, t1_name] [t1_id:UInt32;N, t1_name:Utf8;N]", + ]; + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + let sql = "SELECT t1_id, t1_name FROM t1 WHERE t1_id in (SELECT t2_id FROM t2 where t1_name = t2_name limit 10)"; + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(sql).await.expect(&msg); + let plan = dataframe.into_optimized_plan()?; + + let expected = vec![ + "Filter: t1.t1_id IN () [t1_id:UInt32;N, t1_name:Utf8;N]", + " Subquery: [t2_id:UInt32;N]", + " Limit: skip=0, fetch=10 [t2_id:UInt32;N]", + " Projection: t2.t2_id [t2_id:UInt32;N]", + " Filter: outer_ref(t1.t1_name) = t2.t2_name [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " TableScan: t2 [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " TableScan: t1 projection=[t1_id, t1_name] [t1_id:UInt32;N, t1_name:Utf8;N]", + ]; + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + Ok(()) +} + +#[tokio::test] +async fn support_union_subquery() -> Result<()> { + let ctx = create_join_context("t1_id", "t2_id", true)?; + + let sql = "SELECT t1_id, t1_name FROM t1 WHERE EXISTS \ + (SELECT * FROM t2 WHERE t2_id = t1_id UNION ALL \ + SELECT * FROM t2 WHERE upper(t2_name) = upper(t1.t1_name))"; + + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(sql).await.expect(&msg); + let plan = dataframe.into_optimized_plan()?; + + let expected = vec![ + "Filter: EXISTS () [t1_id:UInt32;N, t1_name:Utf8;N]", + " Subquery: [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " Union [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " Projection: t2.t2_id, t2.t2_name, t2.t2_int [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " Filter: t2.t2_id = outer_ref(t1.t1_id) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " TableScan: t2 [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " Projection: t2.t2_id, t2.t2_name, t2.t2_int [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " Filter: upper(t2.t2_name) = upper(outer_ref(t1.t1_name)) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " TableScan: t2 [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " TableScan: t1 projection=[t1_id, t1_name] [t1_id:UInt32;N, t1_name:Utf8;N]", + ]; + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + Ok(()) +} diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs index 8f95eba572ba..535e1d89170f 100644 --- a/datafusion/core/tests/sql/window.rs +++ b/datafusion/core/tests/sql/window.rs @@ -61,7 +61,20 @@ async fn window_frame_creation_type_checking() -> Result<()> { ).await } -fn write_test_data_to_parquet(tmpdir: &TempDir, n_file: usize) -> Result<()> { +fn split_record_batch(batch: RecordBatch, n_split: usize) -> Vec { + let n_chunk = batch.num_rows() / n_split; + let mut res = vec![]; + for i in 0..n_split - 1 { + let chunk = batch.slice(i * n_chunk, n_chunk); + res.push(chunk); + } + let start = (n_split - 1) * n_chunk; + let len = batch.num_rows() - start; + res.push(batch.slice(start, len)); + res +} + +fn get_test_data(n_split: usize) -> Result> { let ts_field = Field::new("ts", DataType::Int32, false); let inc_field = Field::new("inc_col", DataType::Int32, false); let desc_field = Field::new("desc_col", DataType::Int32, false); @@ -100,19 +113,19 @@ fn write_test_data_to_parquet(tmpdir: &TempDir, n_file: usize) -> Result<()> { ])), ], )?; - let n_chunk = batch.num_rows() / n_file; - for i in 0..n_file { + Ok(split_record_batch(batch, n_split)) +} + +fn write_test_data_to_parquet(tmpdir: &TempDir, n_split: usize) -> Result<()> { + let batches = get_test_data(n_split)?; + for (i, batch) in batches.into_iter().enumerate() { let target_file = tmpdir.path().join(format!("{i}.parquet")); let file = File::create(target_file).unwrap(); // Default writer properties let props = WriterProperties::builder().build(); - let chunks_start = i * n_chunk; - let cur_batch = batch.slice(chunks_start, n_chunk); - // let chunks_end = chunks_start + n_chunk; - let mut writer = - ArrowWriter::try_new(file, cur_batch.schema(), Some(props)).unwrap(); + let mut writer = ArrowWriter::try_new(file, batch.schema(), Some(props)).unwrap(); - writer.write(&cur_batch).expect("Writing batch"); + writer.write(&batch).expect("Writing batch"); // writer must be closed to write footer writer.close().unwrap(); @@ -120,12 +133,11 @@ fn write_test_data_to_parquet(tmpdir: &TempDir, n_file: usize) -> Result<()> { Ok(()) } -async fn get_test_context(tmpdir: &TempDir) -> Result { +async fn get_test_context(tmpdir: &TempDir, n_batch: usize) -> Result { let session_config = SessionConfig::new().with_target_partitions(1); let ctx = SessionContext::with_config(session_config); let parquet_read_options = ParquetReadOptions::default(); - // The sort order is specified (not actually correct in this case) let file_sort_order = [col("ts")] .into_iter() .map(|e| { @@ -139,7 +151,7 @@ async fn get_test_context(tmpdir: &TempDir) -> Result { .to_listing_options(&ctx.copied_config()) .with_file_sort_order(Some(file_sort_order)); - write_test_data_to_parquet(tmpdir, 1)?; + write_test_data_to_parquet(tmpdir, n_batch)?; let provided_schema = None; let sql_definition = None; ctx.register_listing_table( @@ -160,7 +172,7 @@ mod tests { #[tokio::test] async fn test_source_sorted_aggregate() -> Result<()> { let tmpdir = TempDir::new().unwrap(); - let ctx = get_test_context(&tmpdir).await?; + let ctx = get_test_context(&tmpdir, 1).await?; let sql = "SELECT SUM(inc_col) OVER(ORDER BY ts RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING) as sum1, @@ -235,7 +247,7 @@ mod tests { #[tokio::test] async fn test_source_sorted_builtin() -> Result<()> { let tmpdir = TempDir::new().unwrap(); - let ctx = get_test_context(&tmpdir).await?; + let ctx = get_test_context(&tmpdir, 1).await?; let sql = "SELECT FIRST_VALUE(inc_col) OVER(ORDER BY ts RANGE BETWEEN 10 PRECEDING and 1 FOLLOWING) as fv1, @@ -309,7 +321,7 @@ mod tests { #[tokio::test] async fn test_source_sorted_unbounded_preceding() -> Result<()> { let tmpdir = TempDir::new().unwrap(); - let ctx = get_test_context(&tmpdir).await?; + let ctx = get_test_context(&tmpdir, 1).await?; let sql = "SELECT SUM(inc_col) OVER(ORDER BY ts ASC RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING) as sum1, @@ -368,7 +380,7 @@ mod tests { #[tokio::test] async fn test_source_sorted_unbounded_preceding_builtin() -> Result<()> { let tmpdir = TempDir::new().unwrap(); - let ctx = get_test_context(&tmpdir).await?; + let ctx = get_test_context(&tmpdir, 1).await?; let sql = "SELECT FIRST_VALUE(inc_col) OVER(ORDER BY ts ASC ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING) as first_value1, diff --git a/datafusion/core/tests/sqllogictests/src/engines/datafusion/create_table.rs b/datafusion/core/tests/sqllogictests/src/engines/datafusion/create_table.rs index 981dd75b56d7..2d2e37f4b7e2 100644 --- a/datafusion/core/tests/sqllogictests/src/engines/datafusion/create_table.rs +++ b/datafusion/core/tests/sqllogictests/src/engines/datafusion/create_table.rs @@ -66,12 +66,9 @@ fn create_new_table( let sql_to_rel = SqlToRel::new_with_options( &LogicTestContextProvider {}, ParserOptions { - parse_float_as_decimal: config - .config_options() - .sql_parser - .parse_float_as_decimal, + parse_float_as_decimal: config.options().sql_parser.parse_float_as_decimal, enable_ident_normalization: config - .config_options() + .options() .sql_parser .enable_ident_normalization, }, diff --git a/datafusion/core/tests/sqllogictests/src/engines/datafusion/insert.rs b/datafusion/core/tests/sqllogictests/src/engines/datafusion/insert.rs deleted file mode 100644 index a8fca3b16c06..000000000000 --- a/datafusion/core/tests/sqllogictests/src/engines/datafusion/insert.rs +++ /dev/null @@ -1,93 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use super::error::Result; -use crate::engines::datafusion::util::LogicTestContextProvider; -use crate::engines::output::DFOutput; -use arrow::record_batch::RecordBatch; -use datafusion::datasource::MemTable; -use datafusion::prelude::SessionContext; -use datafusion_common::{DFSchema, DataFusionError}; -use datafusion_expr::Expr as DFExpr; -use datafusion_sql::planner::{object_name_to_table_reference, PlannerContext, SqlToRel}; -use sqllogictest::DBOutput; -use sqlparser::ast::{Expr, SetExpr, Statement as SQLStatement}; -use std::sync::Arc; - -pub async fn insert(ctx: &SessionContext, insert_stmt: SQLStatement) -> Result { - // First, use sqlparser to get table name and insert values - let table_reference; - let insert_values: Vec>; - match insert_stmt { - SQLStatement::Insert { - table_name, source, .. - } => { - table_reference = object_name_to_table_reference( - table_name, - ctx.enable_ident_normalization(), - )?; - - // Todo: check columns match table schema - match *source.body { - SetExpr::Values(values) => { - insert_values = values.rows; - } - _ => { - // Directly panic: make it easy to find the location of the error. - panic!() - } - } - } - _ => unreachable!(), - } - - // Second, get batches in table and destroy the old table - let mut origin_batches = ctx.table(&table_reference).await?.collect().await?; - let schema = ctx.table_provider(&table_reference).await?.schema(); - ctx.deregister_table(&table_reference)?; - - // Third, transfer insert values to `RecordBatch` - // Attention: schema info can be ignored. (insert values don't contain schema info) - let sql_to_rel = SqlToRel::new(&LogicTestContextProvider {}); - let num_rows = insert_values.len(); - for row in insert_values.into_iter() { - let logical_exprs = row - .into_iter() - .map(|expr| { - sql_to_rel.sql_to_expr( - expr, - &DFSchema::empty(), - &mut PlannerContext::new(), - ) - }) - .collect::, DataFusionError>>()?; - // Directly use `select` to get `RecordBatch` - let dataframe = ctx.read_empty()?; - origin_batches.extend(dataframe.select(logical_exprs)?.collect().await?) - } - - // Replace new batches schema to old schema - for batch in origin_batches.iter_mut() { - *batch = RecordBatch::try_new(schema.clone(), batch.columns().to_vec())?; - } - - // Final, create new memtable with same schema. - let new_provider = MemTable::try_new(schema, vec![origin_batches])?; - ctx.register_table(&table_reference, Arc::new(new_provider))?; - - Ok(DBOutput::StatementComplete(num_rows as u64)) -} diff --git a/datafusion/core/tests/sqllogictests/src/engines/datafusion/mod.rs b/datafusion/core/tests/sqllogictests/src/engines/datafusion/mod.rs index 1f8f7feb36e5..cdd6663a5e0b 100644 --- a/datafusion/core/tests/sqllogictests/src/engines/datafusion/mod.rs +++ b/datafusion/core/tests/sqllogictests/src/engines/datafusion/mod.rs @@ -26,13 +26,11 @@ use create_table::create_table; use datafusion::arrow::record_batch::RecordBatch; use datafusion::prelude::SessionContext; use datafusion_sql::parser::{DFParser, Statement}; -use insert::insert; use sqllogictest::DBOutput; use sqlparser::ast::Statement as SQLStatement; mod create_table; mod error; -mod insert; mod normalize; mod util; @@ -85,7 +83,6 @@ async fn run_query(ctx: &SessionContext, sql: impl Into) -> Result return insert(ctx, statement).await, SQLStatement::CreateTable { query, constraints, diff --git a/datafusion/core/tests/sqllogictests/test_files/aggregate.slt b/datafusion/core/tests/sqllogictests/test_files/aggregate.slt index 694b062c7f78..0fffb2a5db07 100644 --- a/datafusion/core/tests/sqllogictests/test_files/aggregate.slt +++ b/datafusion/core/tests/sqllogictests/test_files/aggregate.slt @@ -1013,9 +1013,9 @@ select c2, sum(c3) sum_c3, avg(c3) avg_c3, max(c3) max_c3, min(c3) min_c3, count 4 29 1.260869565217 123 -117 23 5 -194 -13.857142857143 118 -101 14 -# TODO: csv_query_array_agg_unsupported -# statement error -# SELECT array_agg(c13 ORDER BY c1) FROM aggregate_test_100; +# csv_query_array_agg_unsupported +statement error This feature is not implemented: ORDER BY not supported in ARRAY_AGG: c1 +SELECT array_agg(c13 ORDER BY c1) FROM aggregate_test_100; # csv_query_array_cube_agg_with_overflow query TIIRIII @@ -1058,42 +1058,15 @@ NULL 4 29 1.260869565217 123 -117 23 NULL 5 -194 -13.857142857143 118 -101 14 NULL NULL 781 7.81 125 -117 100 +# TODO this querys output is non determinisitic (the order of the elements +# differs run to run +# # csv_query_array_agg_distinct # query T # SELECT array_agg(distinct c2) FROM aggregate_test_100 # ---- # [4, 2, 3, 5, 1] -# TODO: aggregate_timestamps_sum - -# aggregate_timestamps_count -# query IIII -# SELECT count(nanos), count(micros), count(millis), count(secs) FROM t -# ---- -# 3 3 3 3 - -# aggregate_timestamps_min -# query TTTT -# SELECT min(nanos), min(micros), min(millis), min(secs) FROM t -# ---- -# 2011-12-13T11:13:10.123450 2011-12-13T11:13:10.123450 2011-12-13T11:13:10.123 2011-12-13T11:13:10 - -# # aggregate_timestamps_max -# query TTTT -# SELECT max(nanos), max(micros), max(millis), max(secs) FROM t -# ---- -# 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10 - -# TODO: aggregate_times_sum - -# TODO: aggregate_times_count - -# TODO: aggregate_times_min - -# TODO: aggregate_times_max - -# TODO: aggregate_timestamps_avg - # aggregate_time_min_and_max query TT select min(t), max(t) from (select '00:00:00' as t union select '00:00:01' union select '00:00:02') @@ -1101,30 +1074,28 @@ select min(t), max(t) from (select '00:00:00' as t union select '00:00:01' unio 00:00:00 00:00:02 # aggregate_decimal_min -query R -select min(c1) from d_table +query RT +select min(c1), arrow_typeof(min(c1)) from d_table ---- --100.009 +-100.009 Decimal128(10, 3) # aggregate_decimal_max -query R -select max(c1) from d_table +query RT +select max(c1), arrow_typeof(max(c1)) from d_table ---- -110.009 +110.009 Decimal128(10, 3) -# FIX: doesn't check datatype # aggregate_decimal_sum -query R -select sum(c1) from d_table +query RT +select sum(c1), arrow_typeof(sum(c1)) from d_table ---- -100 +100 Decimal128(20, 3) -# FIX: doesn't check datatype # aggregate_decimal_avg -query R -select avg(c1) from d_table +query RT +select avg(c1), arrow_typeof(avg(c1)) from d_table ---- -5 +5 Decimal128(14, 7) # FIX: different test table # aggregate @@ -1191,19 +1162,17 @@ SELECT COUNT(DISTINCT c1) FROM test # TODO: aggregate_with_alias -# FIX: CSV Writer error # array_agg_zero -# query I -# SELECT ARRAY_AGG([]) -# ---- -# [] +query ? +SELECT ARRAY_AGG([]) +---- +[] -# FIX: CSV Writer error # array_agg_one -# query I -# SELECT ARRAY_AGG([1]) -# ---- -# [[1]] +query ? +SELECT ARRAY_AGG([1]) +---- +[[1]] # test_approx_percentile_cont_decimal_support query TI @@ -1216,7 +1185,6 @@ d 4 e 4 - # array_agg_zero query ? SELECT ARRAY_AGG([]); @@ -1284,3 +1252,155 @@ NULL 2 statement ok drop table the_nulls; + +# All supported timestamp types + +# "nanos" --> TimestampNanosecondArray +# "micros" --> TimestampMicrosecondArray +# "millis" --> TimestampMillisecondArray +# "secs" --> TimestampSecondArray +# "names" --> StringArray + +statement ok +create table t_source +as values + ('2018-11-13T17:11:10.011375885995', 'Row 0'), + ('2011-12-13T11:13:10.12345', 'Row 1'), + (null, 'Row 2'), + ('2021-1-1T05:11:10.432', 'Row 3'); + + +statement ok +create table t as +select + arrow_cast(column1, 'Timestamp(Nanosecond, None)') as nanos, + arrow_cast(column1, 'Timestamp(Microsecond, None)') as micros, + arrow_cast(column1, 'Timestamp(Millisecond, None)') as millis, + arrow_cast(column1, 'Timestamp(Second, None)') as secs, + column2 as names +from t_source; + +# Demonstate the contents +query PPPPT +select * from t; +---- +2018-11-13T17:11:10.011375885 2018-11-13T17:11:10.011375 2018-11-13T17:11:10.011 2018-11-13T17:11:10 Row 0 +2011-12-13T11:13:10.123450 2011-12-13T11:13:10.123450 2011-12-13T11:13:10.123 2011-12-13T11:13:10 Row 1 +NULL NULL NULL NULL Row 2 +2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10 Row 3 + + +# aggregate_timestamps_sum +statement error Error during planning: The function Sum does not support inputs of type Timestamp\(Nanosecond, None\) +SELECT sum(nanos), sum(micros), sum(millis), sum(secs) FROM t; + +# aggregate_timestamps_count +query IIII +SELECT count(nanos), count(micros), count(millis), count(secs) FROM t; +---- +3 3 3 3 + + +# aggregate_timestamps_min +query PPPP +SELECT min(nanos), min(micros), min(millis), min(secs) FROM t; +---- +2011-12-13T11:13:10.123450 2011-12-13T11:13:10.123450 2011-12-13T11:13:10.123 2011-12-13T11:13:10 + +# aggregate_timestamps_max +query PPPP +SELECT max(nanos), max(micros), max(millis), max(secs) FROM t; +---- +2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10.432 2021-01-01T05:11:10 + + + +# aggregate_timestamps_avg +statement error Error during planning: The function Avg does not support inputs of type Timestamp\(Nanosecond, None\). +SELECT avg(nanos), avg(micros), avg(millis), avg(secs) FROM t + + +statement ok +drop table t_source; + +statement ok +drop table t; + +# All supported time types + +# Columns are named: +# "nanos" --> Time64NanosecondArray +# "micros" --> Time64MicrosecondArray +# "millis" --> Time32MillisecondArray +# "secs" --> Time32SecondArray +# "names" --> StringArray + +statement ok +create table t_source +as values + ('18:06:30.243620451', 'Row 0'), + ('20:08:28.161121654', 'Row 1'), + ('19:11:04.156423842', 'Row 2'), + ('21:06:28.247821084', 'Row 3'); + + +statement ok +create table t as +select + arrow_cast(column1, 'Time64(Nanosecond)') as nanos, + arrow_cast(column1, 'Time64(Microsecond)') as micros, + arrow_cast(column1, 'Time32(Millisecond)') as millis, + arrow_cast(column1, 'Time32(Second)') as secs, + column2 as names +from t_source; + +# Demonstate the contents +query DDDDT +select * from t; +---- +18:06:30.243620451 18:06:30.243620 18:06:30.243 18:06:30 Row 0 +20:08:28.161121654 20:08:28.161121 20:08:28.161 20:08:28 Row 1 +19:11:04.156423842 19:11:04.156423 19:11:04.156 19:11:04 Row 2 +21:06:28.247821084 21:06:28.247821 21:06:28.247 21:06:28 Row 3 + +# aggregate_times_sum +statement error DataFusion error: Error during planning: The function Sum does not support inputs of type Time64\(Nanosecond\). +SELECT sum(nanos), sum(micros), sum(millis), sum(secs) FROM t + +# aggregate_times_count +query IIII +SELECT count(nanos), count(micros), count(millis), count(secs) FROM t +---- +4 4 4 4 + + +# aggregate_times_min +query DDDD +SELECT min(nanos), min(micros), min(millis), min(secs) FROM t +---- +18:06:30.243620451 18:06:30.243620 18:06:30.243 18:06:30 + +# aggregate_times_max +query DDDD +SELECT max(nanos), max(micros), max(millis), max(secs) FROM t +---- +21:06:28.247821084 21:06:28.247821 21:06:28.247 21:06:28 + + +# aggregate_times_avg +statement error DataFusion error: Error during planning: The function Avg does not support inputs of type Time64\(Nanosecond\). +SELECT avg(nanos), avg(micros), avg(millis), avg(secs) FROM t + +statement ok +drop table t_source; + +statement ok +drop table t; + +query I +select median(a) from (select 1 as a where 1=0); +---- +NULL + +query error DataFusion error: Execution error: aggregate function needs at least one non-null element +select approx_median(a) from (select 1 as a where 1=0); diff --git a/datafusion/core/tests/sqllogictests/test_files/ddl.slt b/datafusion/core/tests/sqllogictests/test_files/ddl.slt index 642093c364f3..a4c90df07253 100644 --- a/datafusion/core/tests/sqllogictests/test_files/ddl.slt +++ b/datafusion/core/tests/sqllogictests/test_files/ddl.slt @@ -63,7 +63,7 @@ statement error Table 'user' doesn't exist. DROP TABLE user; # Can not insert into a undefined table -statement error No table named 'user' +statement error DataFusion error: Error during planning: table 'datafusion.public.user' not found insert into user values(1, 20); ########## @@ -366,7 +366,7 @@ STORED AS CSV PARTITIONED BY (c_date) LOCATION 'tests/data/partitioned_table'; -query TP? +query TPD SELECT * from csv_with_timestamps where c_date='2018-11-13' ---- Jorge 2018-12-13T12:12:10.011 2018-11-13 @@ -421,9 +421,27 @@ statement ok DROP TABLE aggregate_simple +# sql_table_insert +statement ok +CREATE TABLE abc AS VALUES (1,2,3), (4,5,6); + +statement ok +CREATE TABLE xyz AS VALUES (1,3,3), (5,5,6); + +statement ok +INSERT INTO abc SELECT * FROM xyz; + +query III +SELECT * FROM abc +---- +1 2 3 +4 5 6 +1 3 3 +5 5 6 + # Should create an empty table statement ok -CREATE TABLE table_without_values(field1 BIGINT, field2 BIGINT); +CREATE TABLE table_without_values(field1 BIGINT NULL, field2 BIGINT NULL); # Should skip existing table @@ -444,8 +462,8 @@ CREATE OR REPLACE TABLE IF NOT EXISTS table_without_values(field1 BIGINT, field2 statement ok insert into table_without_values values (1, 2), (2, 3), (2, 4); -query II rowsort -select * from table_without_values; +query II +select * from table_without_values ---- 1 2 2 3 @@ -454,7 +472,7 @@ select * from table_without_values; # Should recreate existing table statement ok -CREATE OR REPLACE TABLE table_without_values(field1 BIGINT, field2 BIGINT); +CREATE OR REPLACE TABLE table_without_values(field1 BIGINT NULL, field2 BIGINT NULL); # Should insert into a recreated table diff --git a/datafusion/core/tests/sqllogictests/test_files/information_schema.slt b/datafusion/core/tests/sqllogictests/test_files/information_schema.slt index 81ce141a8fe0..25e4195fba6c 100644 --- a/datafusion/core/tests/sqllogictests/test_files/information_schema.slt +++ b/datafusion/core/tests/sqllogictests/test_files/information_schema.slt @@ -84,6 +84,30 @@ datafusion information_schema views VIEW datafusion public t BASE TABLE datafusion public t2 BASE TABLE +query TTTT rowsort +SELECT * from information_schema.tables WHERE tables.table_schema='information_schema'; +---- +datafusion information_schema columns VIEW +datafusion information_schema df_settings VIEW +datafusion information_schema tables VIEW +datafusion information_schema views VIEW + +query TTTT rowsort +SELECT * from information_schema.tables WHERE information_schema.tables.table_schema='information_schema'; +---- +datafusion information_schema columns VIEW +datafusion information_schema df_settings VIEW +datafusion information_schema tables VIEW +datafusion information_schema views VIEW + +query TTTT rowsort +SELECT * from information_schema.tables WHERE datafusion.information_schema.tables.table_schema='information_schema'; +---- +datafusion information_schema columns VIEW +datafusion information_schema df_settings VIEW +datafusion information_schema tables VIEW +datafusion information_schema views VIEW + # Cleanup statement ok drop table t diff --git a/datafusion/core/tests/sqllogictests/test_files/join.slt b/datafusion/core/tests/sqllogictests/test_files/join.slt index 8d5f573431bc..f7c291f909f6 100644 --- a/datafusion/core/tests/sqllogictests/test_files/join.slt +++ b/datafusion/core/tests/sqllogictests/test_files/join.slt @@ -19,8 +19,9 @@ ## Join Tests ########## +# Regression test: https://github.com/apache/arrow-datafusion/issues/4844 statement ok -CREATE TABLE students(name TEXT, mark INT) AS VALUES +CREATE TABLE IF NOT EXISTS students(name TEXT, mark INT) AS VALUES ('Stuart', 28), ('Amina', 89), ('Christen', 50), @@ -28,20 +29,13 @@ CREATE TABLE students(name TEXT, mark INT) AS VALUES ('Samantha', 21); statement ok -CREATE TABLE grades(grade INT, min INT, max INT) AS VALUES +CREATE TABLE IF NOT EXISTS grades(grade INT, min INT, max INT) AS VALUES (1, 0, 14), (2, 15, 35), (3, 36, 55), (4, 56, 79), (5, 80, 100); -statement ok -CREATE TABLE test1(a int, b int) as select 1 as a, 2 as b; - -statement ok -CREATE TABLE test2(a int, b int) as select 1 as a, 2 as b; - -# Regression test: https://github.com/apache/arrow-datafusion/issues/4844 query TII SELECT s.*, g.grade FROM students s join grades g on s.mark between g.min and g.max WHERE grade > 2 ORDER BY s.mark DESC ---- @@ -49,47 +43,499 @@ Amina 89 5 Salma 77 4 Christen 50 3 +statement ok +drop table IF EXISTS students; + +statement ok +drop table IF EXISTS grades; + +# issue: https://github.com/apache/arrow-datafusion/issues/5382 +statement ok +CREATE TABLE IF NOT EXISTS test1(a int, b int) as select 1 as a, 2 as b; + +statement ok +CREATE TABLE IF NOT EXISTS test2(a int, b int) as select 1 as a, 2 as b; + +query IIII rowsort +SELECT * FROM test2 FULL JOIN test1 ON true; +---- +1 2 1 2 + +statement ok +drop table IF EXISTS test1; + +statement ok +drop table IF EXISTS test2; + # two tables for join statement ok -CREATE TABLE t1(t1_id INT, t1_name TEXT, t1_int INT) AS VALUES +CREATE TABLE IF NOT EXISTS t1(t1_id INT, t1_name TEXT, t1_int INT) AS VALUES (11, 'a', 1), (22, 'b', 2), (33, 'c', 3), (44, 'd', 4); statement ok -CREATE TABLE t2(t2_id INT, t2_name TEXT, t2_int INT) AS VALUES +CREATE TABLE IF NOT EXISTS t2(t2_id INT, t2_name TEXT, t2_int INT) AS VALUES (11, 'z', 3), (22, 'y', 1), (44, 'x', 3), (55, 'w', 3); +# batch size +statement ok +set datafusion.execution.batch_size = 4096; + # left semi with wrong where clause -query error DataFusion error: Schema error: No field named 't2'.'t2_id'. Valid fields are 't1'.'t1_id', 't1'.'t1_name', 't1'.'t1_int'. -SELECT t1.t1_id, - t1.t1_name, - t1.t1_int -FROM t1 LEFT SEMI -JOIN t2 -ON ( - t1.t1_id = t2.t2_id) -WHERE t2.t2_id > 1 +query error DataFusion error: Schema error: No field named "t2"."t2_id". Valid fields are "t1"."t1_id", "t1"."t1_name", "t1"."t1_int". +SELECT t1.t1_id, t1.t1_name, t1.t1_int +FROM t1 + LEFT SEMI JOIN t2 ON t1.t1_id = t2.t2_id +WHERE t2.t2_id > 1 # left semi join with on-filter query ITI rowsort -SELECT t1.t1_id, - t1.t1_name, - t1.t1_int -FROM t1 LEFT SEMI -JOIN t2 -ON ( - t1.t1_id = t2.t2_id and t2.t2_int > 1) +SELECT t1.t1_id, t1.t1_name, t1.t1_int +FROM t1 + LEFT SEMI JOIN t2 + ON t1.t1_id = t2.t2_id + AND t2.t2_int > 1 ---- 11 a 1 44 d 4 -# issue: https://github.com/apache/arrow-datafusion/issues/5382 -query IIII rowsort -SELECT * FROM test2 FULL JOIN test1 ON true; +# equijoin +query ITT rowsort +SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t1_id = t2_id ---- -1 2 1 2 +11 a z +22 b y +44 d x + +query ITT rowsort +SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t2_id = t1_id +---- +11 a z +22 b y +44 d x + +# equijoin_multiple_condition_ordering +query ITT rowsort +SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t1_id = t2_id AND t1_name <> t2_name +---- +11 a z +22 b y +44 d x + +query ITT rowsort +SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t1_id = t2_id AND t2_name <> t1_name +---- +11 a z +22 b y +44 d x + +query ITT rowsort +SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t2_id = t1_id AND t1_name <> t2_name +---- +11 a z +22 b y +44 d x + +query ITT rowsort +SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t2_id = t1_id AND t2_name <> t1_name +---- +11 a z +22 b y +44 d x + +# equijoin_and_other_condition +query ITT rowsort +SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t1_id = t2_id AND t2_name >= 'y' +---- +11 a z +22 b y + +# equijoin_left_and_condition_from_right +query ITT rowsort +SELECT t1_id, t1_name, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id AND t2_name >= 'y' +---- +11 a z +22 b y +33 c NULL +44 d NULL + +# equijoin_left_and_not_null_condition_from_right +query ITT rowsort +SELECT t1_id, t1_name, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id AND t2_name is not null +---- +11 a z +22 b y +33 c NULL +44 d x + +# full_join_sub_query +query ITT rowsort +SELECT t1_id, t1_name, t2_name +FROM ( + SELECT * + FROM (t1) AS t1 +) + FULL JOIN ( + SELECT * + FROM (t2) AS t2 + ) + ON t1_id = t2_id AND t2_name >= 'y' +---- +11 a z +22 b y +33 c NULL +44 d NULL +NULL NULL w +NULL NULL x + +# equijoin_right_and_condition_from_left +query ITT rowsort +SELECT t1_id, t1_name, t2_name FROM t1 RIGHT JOIN t2 ON t1_id = t2_id AND t1_id >= 22 +---- +22 b y +44 d x +NULL NULL w +NULL NULL z + +# equijoin_left_and_condition_from_left +query ITT rowsort +SELECT t1_id, t1_name, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id AND t1_id >= 44 +---- +11 a NULL +22 b NULL +33 c NULL +44 d x + +# equijoin_left_and_condition_from_both +query III rowsort +SELECT t1_id, t1_int, t2_int FROM t1 LEFT JOIN t2 ON t1_id = t2_id AND t1_int >= t2_int +---- +11 1 NULL +22 2 1 +33 3 NULL +44 4 3 + +# equijoin_right_and_condition_from_right +query ITT rowsort +SELECT t1_id, t1_name, t2_name FROM t1 RIGHT JOIN t2 ON t1_id = t2_id AND t2_id >= 22 +---- +22 b y +44 d x +NULL NULL w +NULL NULL z + +# equijoin_right_and_condition_from_both +query III rowsort +SELECT t1_int, t2_int, t2_id FROM t1 RIGHT JOIN t2 ON t1_id = t2_id AND t2_int <= t1_int +---- +2 1 22 +4 3 44 +NULL 3 11 +NULL 3 55 + +# left_join +query ITT rowsort +SELECT t1_id, t1_name, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id +---- +11 a z +22 b y +33 c NULL +44 d x + +query ITT rowsort +SELECT t1_id, t1_name, t2_name FROM t1 LEFT JOIN t2 ON t2_id = t1_id +---- +11 a z +22 b y +33 c NULL +44 d x + +# right_join +query ITT rowsort +SELECT t1_id, t1_name, t2_name FROM t1 RIGHT JOIN t2 ON t1_id = t2_id +---- +11 a z +22 b y +44 d x +NULL NULL w + +query ITT rowsort +SELECT t1_id, t1_name, t2_name FROM t1 RIGHT JOIN t2 ON t2_id = t1_id +---- +11 a z +22 b y +44 d x +NULL NULL w + +# full_join +query ITT rowsort +SELECT t1_id, t1_name, t2_name FROM t1 FULL JOIN t2 ON t1_id = t2_id +---- +11 a z +22 b y +33 c NULL +44 d x +NULL NULL w + +query ITT rowsort +SELECT t1_id, t1_name, t2_name FROM t1 FULL JOIN t2 ON t2_id = t1_id +---- +11 a z +22 b y +33 c NULL +44 d x +NULL NULL w + +# equijoin_implicit_syntax +query ITT rowsort +SELECT t1_id, t1_name, t2_name FROM t1, t2 WHERE t1_id = t2_id +---- +11 a z +22 b y +44 d x + +query ITT rowsort +SELECT t1_id, t1_name, t2_name FROM t1, t2 WHERE t2_id = t1_id +---- +11 a z +22 b y +44 d x + +# equijoin_implicit_syntax_with_filter +query ITT rowsort +SELECT t1_id, t1_name, t2_name FROM t1, t2 WHERE t1_id > 0 AND t1_id = t2_id AND t2_id < 99 +---- +11 a z +22 b y +44 d x + +# equijoin_implicit_syntax_reversed +query ITT rowsort +SELECT t1_id, t1_name, t2_name FROM t1, t2 WHERE t2_id = t1_id +---- +11 a z +22 b y +44 d x + +# reduce_left_join_1 +query ITIITI rowsort +SELECT t1_id, t1_name, t1_int, t2_id, t2_name, t2_int +FROM t1 + LEFT JOIN t2 ON t1.t1_id = t2.t2_id +WHERE t2.t2_id < 100 +---- +11 a 1 11 z 3 +22 b 2 22 y 1 +44 d 4 44 x 3 + +# reduce_left_join_2 +# filter expr: `t2.t2_int < 10 or (t1.t1_int > 2 and t2.t2_name != 'w')` +# could be write to: `(t1.t1_int > 2 or t2.t2_int < 10) and (t2.t2_name != 'w' or t2.t2_int < 10)` +# the right part `(t2.t2_name != 'w' or t2.t2_int < 10)` could be push down left join side and remove in filter. +query ITIITI rowsort +SELECT t1_id, t1_name, t1_int, t2_id, t2_name, t2_int +FROM t1 + LEFT JOIN t2 ON t1.t1_id = t2.t2_id +WHERE t2.t2_int < 10 + OR (t1.t1_int > 2 + AND t2.t2_name != 'w') +---- +11 a 1 11 z 3 +22 b 2 22 y 1 +44 d 4 44 x 3 + +# reduce_left_join_3 +query ITIITI +SELECT * +FROM ( + SELECT t1.* + FROM t1 + LEFT JOIN t2 ON t1.t1_id = t2.t2_id + WHERE t2.t2_int < 3 +) t3 + LEFT JOIN t2 ON t3.t1_int = t2.t2_int +WHERE t3.t1_id < 100 +---- +22 b 2 NULL NULL NULL + +# reduce_right_join_1 +query ITIITI rowsort +SELECT t1_id, t1_name, t1_int, t2_id, t2_name, t2_int +FROM t1 + RIGHT JOIN t2 ON t1.t1_id = t2.t2_id +WHERE t1.t1_int IS NOT NULL +---- +11 a 1 11 z 3 +22 b 2 22 y 1 +44 d 4 44 x 3 + +# reduce_right_join_2 +query ITIITI rowsort +SELECT * +FROM t1 + RIGHT JOIN t2 ON t1.t1_id = t2.t2_id +WHERE NOT t1.t1_int = t2.t2_int +---- +11 a 1 11 z 3 +22 b 2 22 y 1 +44 d 4 44 x 3 + +# reduce_full_join_to_right_join +query ITIITI rowsort +SELECT * +FROM t1 + FULL JOIN t2 ON t1.t1_id = t2.t2_id +WHERE t2.t2_name IS NOT NULL +---- +11 a 1 11 z 3 +22 b 2 22 y 1 +44 d 4 44 x 3 +NULL NULL NULL 55 w 3 + +# reduce_full_join_to_left_join +query ITIITI rowsort +SELECT * +FROM t1 + FULL JOIN t2 ON t1.t1_id = t2.t2_id +WHERE t1.t1_name != 'b' +---- +11 a 1 11 z 3 +33 c 3 NULL NULL NULL +44 d 4 44 x 3 + +# reduce_full_join_to_inner_join +query ITIITI rowsort +SELECT * +FROM t1 + FULL JOIN t2 ON t1.t1_id = t2.t2_id +WHERE t1.t1_name != 'b' + AND t2.t2_name = 'x' +---- +44 d 4 44 x 3 + +# left_join_with_nonequal_condition +query ITT rowsort +SELECT t1_id, t1_name, t2_name +FROM t1 + LEFT JOIN t2 + ON t1_id != t2_id + AND t2_id >= 100 +---- +11 a NULL +22 b NULL +33 c NULL +44 d NULL + +# reduce_cross_join_with_expr_join_key_all +query ITIITI rowsort +SELECT * +FROM t1 + CROSS JOIN t2 +WHERE t1.t1_id + 12 = t2.t2_id + 1 +---- +11 a 1 22 y 1 +33 c 3 44 x 3 +44 d 4 55 w 3 + +# reduce_cross_join_with_cast_expr_join_key +query IIT rowsort +SELECT t1.t1_id, t2.t2_id, t1.t1_name +FROM t1 + CROSS JOIN t2 +WHERE t1.t1_id + 11 = CAST(t2.t2_id AS BIGINT) +---- +11 22 a +33 44 c +44 55 d + +# reduce_cross_join_with_wildcard_and_expr +query ITIITII rowsort +SELECT *, t1.t1_id + 11 +FROM t1, t2 +WHERE t1.t1_id + 11 = t2.t2_id +---- +11 a 1 22 y 1 22 +33 c 3 44 x 3 44 +44 d 4 55 w 3 55 + +# both_side_expr_key_inner_join +query IIT rowsort +SELECT t1.t1_id, t2.t2_id, t1.t1_name +FROM t1 + INNER JOIN t2 ON + t1.t1_id + cast(12 as INT UNSIGNED) = t2.t2_id + cast(1 as INT UNSIGNED) +---- +11 22 a +33 44 c +44 55 d + +# left_side_expr_key_inner_join +query IIT rowsort +SELECT t1_id, t2_id, t1_name +FROM t1 + INNER JOIN t2 ON + t1.t1_id + cast(11 as INT UNSIGNED) = t2.t2_id +---- +11 22 a +33 44 c +44 55 d + +# right_side_expr_key_inner_join +query IIT rowsort +SELECT t1.t1_id, t2.t2_id, t1.t1_name +FROM t1 + INNER JOIN t2 ON + t1.t1_id + cast(11 as INT UNSIGNED) = t2.t2_id +---- +11 22 a +33 44 c +44 55 d + +# select_wildcard_with_expr_key_inner_join +query ITIITI rowsort +SELECT * FROM t1 INNER JOIN t2 ON t1.t1_id = t2.t2_id - cast(11 as INT UNSIGNED) +---- +11 a 1 22 y 1 +33 c 3 44 x 3 +44 d 4 55 w 3 + +# join_with_type_coercion_for_equi_expr +query ITI rowsort +SELECT t1.t1_id, t1.t1_name, t2.t2_id +FROM t1 + INNER JOIN t2 ON t1.t1_id + 11 = t2.t2_id +---- +11 a 22 +33 c 44 +44 d 55 + +# join_only_with_filter +query ITI rowsort +select t1.t1_id, t1.t1_name, t2.t2_id from t1 inner join t2 on t1.t1_id * 4 < t2.t2_id +---- +11 a 55 + +# type_coercion_join_with_filter_and_equi_expr +query ITI rowsort +SELECT t1.t1_id, t1.t1_name, t2.t2_id +FROM t1 + INNER JOIN t2 + ON t1.t1_id * 5 = t2.t2_id + AND t1.t1_id * 4 < t2.t2_id +---- +11 a 55 + +statement ok +drop table IF EXISTS t1; + +statement ok +drop table IF EXISTS t2; + +# batch size +statement ok +set datafusion.execution.batch_size = 8192; diff --git a/datafusion/core/tests/sqllogictests/test_files/join_disable_repartition_joins.slt b/datafusion/core/tests/sqllogictests/test_files/join_disable_repartition_joins.slt new file mode 100644 index 000000000000..5f680fcae73f --- /dev/null +++ b/datafusion/core/tests/sqllogictests/test_files/join_disable_repartition_joins.slt @@ -0,0 +1,30 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +########## +## Join Tests +########## + +# turn off repartition_joins +statement ok +set datafusion.optimizer.repartition_joins = false; + +include ./join.slt + +# turn on repartition_joins +statement ok +set datafusion.optimizer.repartition_joins = true; diff --git a/datafusion/core/tests/sqllogictests/test_files/subquery.slt b/datafusion/core/tests/sqllogictests/test_files/subquery.slt index 70c3c3e13df5..0c108e2c810c 100644 --- a/datafusion/core/tests/sqllogictests/test_files/subquery.slt +++ b/datafusion/core/tests/sqllogictests/test_files/subquery.slt @@ -60,3 +60,42 @@ where t1.t1_id + 12 not in ( ) ---- 22 b 2 + +# in subquery with two parentheses, see #5529 +query ITI rowsort +select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id in (( + select t2.t2_id from t2 + )) +---- +11 a 1 +22 b 2 +44 d 4 + +query ITI rowsort +select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id in (( + select t2.t2_id from t2 + )) +and t1.t1_int < 3 +---- +11 a 1 +22 b 2 + +query ITI rowsort +select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id not in (( + select t2.t2_id from t2 where t2.t2_int = 3 + )) +---- +22 b 2 +33 c 3 diff --git a/datafusion/core/tests/sqllogictests/test_files/window.slt b/datafusion/core/tests/sqllogictests/test_files/window.slt index 148c7a4fd0b9..85a5bc18d4e2 100644 --- a/datafusion/core/tests/sqllogictests/test_files/window.slt +++ b/datafusion/core/tests/sqllogictests/test_files/window.slt @@ -2038,3 +2038,39 @@ SELECT statement ok set datafusion.execution.target_partitions = 2; + +# test_window_agg_with_bounded_group +query TT +EXPLAIN SELECT SUM(c12) OVER(ORDER BY c1, c2 GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as sum1, + SUM(c12) OVER(ORDER BY c1 GROUPS BETWEEN 5 PRECEDING AND 3 PRECEDING) as sum2 + FROM aggregate_test_100 ORDER BY c9 LIMIT 5 +---- +logical_plan +Projection: sum1, sum2 + Limit: skip=0, fetch=5 + Sort: aggregate_test_100.c9 ASC NULLS LAST, fetch=5 + Projection: SUM(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS sum1, SUM(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST] GROUPS BETWEEN 5 PRECEDING AND 3 PRECEDING AS sum2, aggregate_test_100.c9 + WindowAggr: windowExpr=[[SUM(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST] GROUPS BETWEEN 5 PRECEDING AND 3 PRECEDING]] + WindowAggr: windowExpr=[[SUM(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] + TableScan: aggregate_test_100 projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13] +physical_plan +ProjectionExec: expr=[sum1@0 as sum1, sum2@1 as sum2] + GlobalLimitExec: skip=0, fetch=5 + SortExec: fetch=5, expr=[c9@2 ASC NULLS LAST] + ProjectionExec: expr=[SUM(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING@13 as sum1, SUM(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST] GROUPS BETWEEN 5 PRECEDING AND 3 PRECEDING@14 as sum2, c9@8 as c9] + BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c12): Ok(Field { name: "SUM(aggregate_test_100.c12)", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Groups, start_bound: Preceding(UInt64(5)), end_bound: Preceding(UInt64(3)) }] + BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c12): Ok(Field { name: "SUM(aggregate_test_100.c12)", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Groups, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }] + SortExec: expr=[c1@0 ASC NULLS LAST,c2@1 ASC NULLS LAST] + CsvExec: files={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, has_header=true, limit=None, projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13] + +query RR +SELECT SUM(c12) OVER(ORDER BY c1, c2 GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as sum1, + SUM(c12) OVER(ORDER BY c1 GROUPS BETWEEN 5 PRECEDING AND 3 PRECEDING) as sum2 + FROM aggregate_test_100 ORDER BY c9 LIMIT 5 +---- +4.561269874379 18.036183428008 +6.808931568966 10.238448667883 +2.994840293343 NULL +9.674390599321 NULL +7.728066219895 NULL + diff --git a/datafusion/core/tests/window_fuzz.rs b/datafusion/core/tests/window_fuzz.rs index e03758600938..a71aab280ea1 100644 --- a/datafusion/core/tests/window_fuzz.rs +++ b/datafusion/core/tests/window_fuzz.rs @@ -227,9 +227,7 @@ fn get_random_window_frame(rng: &mut StdRng) -> WindowFrame { } else if rand_num < 2 { WindowFrameUnits::Rows } else { - // For now we do not support GROUPS in BoundedWindowAggExec implementation - // TODO: once GROUPS handling is available, use WindowFrameUnits::GROUPS in randomized tests also. - WindowFrameUnits::Range + WindowFrameUnits::Groups }; match units { // In range queries window frame boundaries should match column type @@ -256,8 +254,8 @@ fn get_random_window_frame(rng: &mut StdRng) -> WindowFrame { } window_frame } - // In window queries, window frame boundary should be Uint64 - WindowFrameUnits::Rows => { + // Window frame boundary should be UInt64 for both ROWS and GROUPS frames: + WindowFrameUnits::Rows | WindowFrameUnits::Groups => { let start_bound = if start_bound.is_preceding { WindowFrameBound::Preceding(ScalarValue::UInt64(Some( start_bound.val as u64, @@ -286,10 +284,10 @@ fn get_random_window_frame(rng: &mut StdRng) -> WindowFrame { window_frame.start_bound = WindowFrameBound::Preceding(ScalarValue::UInt64(None)); } + // We never use UNBOUNDED FOLLOWING in test. Because that case is not prunable and + // should work only with WindowAggExec window_frame } - // Once GROUPS support is added construct window frame for this case also - _ => todo!(), } } @@ -401,7 +399,7 @@ async fn run_window_test( assert_eq!( (i, usual_line), (i, running_line), - "Inconsistent result for window_fn: {window_fn:?}, args:{args:?}" + "Inconsistent result for window_frame: {window_frame:?}, window_fn: {window_fn:?}, args:{args:?}" ); } } diff --git a/datafusion/execution/src/config.rs b/datafusion/execution/src/config.rs new file mode 100644 index 000000000000..dbaaba304bf4 --- /dev/null +++ b/datafusion/execution/src/config.rs @@ -0,0 +1,375 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ + any::{Any, TypeId}, + collections::HashMap, + hash::{BuildHasherDefault, Hasher}, + sync::Arc, +}; + +use datafusion_common::{config::ConfigOptions, Result, ScalarValue}; + +/// Configuration options for Execution context +#[derive(Clone)] +pub struct SessionConfig { + /// Configuration options + options: ConfigOptions, + /// Opaque extensions. + extensions: AnyMap, +} + +impl Default for SessionConfig { + fn default() -> Self { + Self { + options: ConfigOptions::new(), + // Assume no extensions by default. + extensions: HashMap::with_capacity_and_hasher( + 0, + BuildHasherDefault::default(), + ), + } + } +} + +impl SessionConfig { + /// Create an execution config with default setting + pub fn new() -> Self { + Default::default() + } + + /// Create an execution config with config options read from the environment + pub fn from_env() -> Result { + Ok(ConfigOptions::from_env()?.into()) + } + + /// Set a configuration option + pub fn set(mut self, key: &str, value: ScalarValue) -> Self { + self.options.set(key, &value.to_string()).unwrap(); + self + } + + /// Set a boolean configuration option + pub fn set_bool(self, key: &str, value: bool) -> Self { + self.set(key, ScalarValue::Boolean(Some(value))) + } + + /// Set a generic `u64` configuration option + pub fn set_u64(self, key: &str, value: u64) -> Self { + self.set(key, ScalarValue::UInt64(Some(value))) + } + + /// Set a generic `usize` configuration option + pub fn set_usize(self, key: &str, value: usize) -> Self { + let value: u64 = value.try_into().expect("convert usize to u64"); + self.set(key, ScalarValue::UInt64(Some(value))) + } + + /// Set a generic `str` configuration option + pub fn set_str(self, key: &str, value: &str) -> Self { + self.set(key, ScalarValue::Utf8(Some(value.to_string()))) + } + + /// Customize batch size + pub fn with_batch_size(mut self, n: usize) -> Self { + // batch size must be greater than zero + assert!(n > 0); + self.options.execution.batch_size = n; + self + } + + /// Customize [`target_partitions`] + /// + /// [`target_partitions`]: crate::config::ExecutionOptions::target_partitions + pub fn with_target_partitions(mut self, n: usize) -> Self { + // partition count must be greater than zero + assert!(n > 0); + self.options.execution.target_partitions = n; + self + } + + /// Get [`target_partitions`] + /// + /// [`target_partitions`]: crate::config::ExecutionOptions::target_partitions + pub fn target_partitions(&self) -> usize { + self.options.execution.target_partitions + } + + /// Is the information schema enabled? + pub fn information_schema(&self) -> bool { + self.options.catalog.information_schema + } + + /// Should the context create the default catalog and schema? + pub fn create_default_catalog_and_schema(&self) -> bool { + self.options.catalog.create_default_catalog_and_schema + } + + /// Are joins repartitioned during execution? + pub fn repartition_joins(&self) -> bool { + self.options.optimizer.repartition_joins + } + + /// Are aggregates repartitioned during execution? + pub fn repartition_aggregations(&self) -> bool { + self.options.optimizer.repartition_aggregations + } + + /// Are window functions repartitioned during execution? + pub fn repartition_window_functions(&self) -> bool { + self.options.optimizer.repartition_windows + } + + /// Do we execute sorts in a per-partition fashion and merge afterwards, + /// or do we coalesce partitions first and sort globally? + pub fn repartition_sorts(&self) -> bool { + self.options.optimizer.repartition_sorts + } + + /// Are statistics collected during execution? + pub fn collect_statistics(&self) -> bool { + self.options.execution.collect_statistics + } + + /// Selects a name for the default catalog and schema + pub fn with_default_catalog_and_schema( + mut self, + catalog: impl Into, + schema: impl Into, + ) -> Self { + self.options.catalog.default_catalog = catalog.into(); + self.options.catalog.default_schema = schema.into(); + self + } + + /// Controls whether the default catalog and schema will be automatically created + pub fn with_create_default_catalog_and_schema(mut self, create: bool) -> Self { + self.options.catalog.create_default_catalog_and_schema = create; + self + } + + /// Enables or disables the inclusion of `information_schema` virtual tables + pub fn with_information_schema(mut self, enabled: bool) -> Self { + self.options.catalog.information_schema = enabled; + self + } + + /// Enables or disables the use of repartitioning for joins to improve parallelism + pub fn with_repartition_joins(mut self, enabled: bool) -> Self { + self.options.optimizer.repartition_joins = enabled; + self + } + + /// Enables or disables the use of repartitioning for aggregations to improve parallelism + pub fn with_repartition_aggregations(mut self, enabled: bool) -> Self { + self.options.optimizer.repartition_aggregations = enabled; + self + } + + /// Sets minimum file range size for repartitioning scans + pub fn with_repartition_file_min_size(mut self, size: usize) -> Self { + self.options.optimizer.repartition_file_min_size = size; + self + } + + /// Enables or disables the use of repartitioning for file scans + pub fn with_repartition_file_scans(mut self, enabled: bool) -> Self { + self.options.optimizer.repartition_file_scans = enabled; + self + } + + /// Enables or disables the use of repartitioning for window functions to improve parallelism + pub fn with_repartition_windows(mut self, enabled: bool) -> Self { + self.options.optimizer.repartition_windows = enabled; + self + } + + /// Enables or disables the use of per-partition sorting to improve parallelism + pub fn with_repartition_sorts(mut self, enabled: bool) -> Self { + self.options.optimizer.repartition_sorts = enabled; + self + } + + /// Enables or disables the use of pruning predicate for parquet readers to skip row groups + pub fn with_parquet_pruning(mut self, enabled: bool) -> Self { + self.options.execution.parquet.pruning = enabled; + self + } + + /// Returns true if pruning predicate should be used to skip parquet row groups + pub fn parquet_pruning(&self) -> bool { + self.options.execution.parquet.pruning + } + + /// Enables or disables the collection of statistics after listing files + pub fn with_collect_statistics(mut self, enabled: bool) -> Self { + self.options.execution.collect_statistics = enabled; + self + } + + /// Get the currently configured batch size + pub fn batch_size(&self) -> usize { + self.options.execution.batch_size + } + + /// Convert configuration options to name-value pairs with values + /// converted to strings. + /// + /// Note that this method will eventually be deprecated and + /// replaced by [`config_options`]. + /// + /// [`config_options`]: Self::config_options + pub fn to_props(&self) -> HashMap { + let mut map = HashMap::new(); + // copy configs from config_options + for entry in self.options.entries() { + map.insert(entry.key, entry.value.unwrap_or_default()); + } + + map + } + + /// Return a handle to the configuration options. + #[deprecated(since = "21.0.0", note = "use options() instead")] + pub fn config_options(&self) -> &ConfigOptions { + &self.options + } + + /// Return a mutable handle to the configuration options. + #[deprecated(since = "21.0.0", note = "use options_mut() instead")] + pub fn config_options_mut(&mut self) -> &mut ConfigOptions { + &mut self.options + } + + /// Return a handle to the configuration options. + pub fn options(&self) -> &ConfigOptions { + &self.options + } + + /// Return a mutable handle to the configuration options. + pub fn options_mut(&mut self) -> &mut ConfigOptions { + &mut self.options + } + + /// Add extensions. + /// + /// Extensions can be used to attach extra data to the session config -- e.g. tracing information or caches. + /// Extensions are opaque and the types are unknown to DataFusion itself, which makes them extremely flexible. [^1] + /// + /// Extensions are stored within an [`Arc`] so they do NOT require [`Clone`]. The are immutable. If you need to + /// modify their state over their lifetime -- e.g. for caches -- you need to establish some for of interior mutability. + /// + /// Extensions are indexed by their type `T`. If multiple values of the same type are provided, only the last one + /// will be kept. + /// + /// You may use [`get_extension`](Self::get_extension) to retrieve extensions. + /// + /// # Example + /// ``` + /// use std::sync::Arc; + /// use datafusion_execution::config::SessionConfig; + /// + /// // application-specific extension types + /// struct Ext1(u8); + /// struct Ext2(u8); + /// struct Ext3(u8); + /// + /// let ext1a = Arc::new(Ext1(10)); + /// let ext1b = Arc::new(Ext1(11)); + /// let ext2 = Arc::new(Ext2(2)); + /// + /// let cfg = SessionConfig::default() + /// // will only remember the last Ext1 + /// .with_extension(Arc::clone(&ext1a)) + /// .with_extension(Arc::clone(&ext1b)) + /// .with_extension(Arc::clone(&ext2)); + /// + /// let ext1_received = cfg.get_extension::().unwrap(); + /// assert!(!Arc::ptr_eq(&ext1_received, &ext1a)); + /// assert!(Arc::ptr_eq(&ext1_received, &ext1b)); + /// + /// let ext2_received = cfg.get_extension::().unwrap(); + /// assert!(Arc::ptr_eq(&ext2_received, &ext2)); + /// + /// assert!(cfg.get_extension::().is_none()); + /// ``` + /// + /// [^1]: Compare that to [`ConfigOptions`] which only supports [`ScalarValue`] payloads. + pub fn with_extension(mut self, ext: Arc) -> Self + where + T: Send + Sync + 'static, + { + let ext = ext as Arc; + let id = TypeId::of::(); + self.extensions.insert(id, ext); + self + } + + /// Get extension, if any for the specified type `T` exists. + /// + /// See [`with_extension`](Self::with_extension) on how to add attach extensions. + pub fn get_extension(&self) -> Option> + where + T: Send + Sync + 'static, + { + let id = TypeId::of::(); + self.extensions + .get(&id) + .cloned() + .map(|ext| Arc::downcast(ext).expect("TypeId unique")) + } +} + +impl From for SessionConfig { + fn from(options: ConfigOptions) -> Self { + Self { + options, + ..Default::default() + } + } +} + +/// Map that holds opaque objects indexed by their type. +/// +/// Data is wrapped into an [`Arc`] to enable [`Clone`] while still being [object safe]. +/// +/// [object safe]: https://doc.rust-lang.org/reference/items/traits.html#object-safety +type AnyMap = + HashMap, BuildHasherDefault>; + +/// Hasher for [`AnyMap`]. +/// +/// With [`TypeId`]s as keys, there's no need to hash them. They are already hashes themselves, coming from the compiler. +/// The [`IdHasher`] just holds the [`u64`] of the [`TypeId`], and then returns it, instead of doing any bit fiddling. +#[derive(Default)] +struct IdHasher(u64); + +impl Hasher for IdHasher { + fn write(&mut self, _: &[u8]) { + unreachable!("TypeId calls write_u64"); + } + + #[inline] + fn write_u64(&mut self, id: u64) { + self.0 = id; + } + + #[inline] + fn finish(&self) -> u64 { + self.0 + } +} diff --git a/datafusion/execution/src/lib.rs b/datafusion/execution/src/lib.rs index 55db55cf0a6a..8540e456c0ee 100644 --- a/datafusion/execution/src/lib.rs +++ b/datafusion/execution/src/lib.rs @@ -15,7 +15,11 @@ // specific language governing permissions and limitations // under the License. +//! DataFusion execution configuration and runtime structures + +pub mod config; pub mod disk_manager; pub mod memory_pool; pub mod object_store; pub mod registry; +pub mod runtime_env; diff --git a/datafusion/execution/src/memory_pool/mod.rs b/datafusion/execution/src/memory_pool/mod.rs index f68a2565004d..f1f745d4ebc5 100644 --- a/datafusion/execution/src/memory_pool/mod.rs +++ b/datafusion/execution/src/memory_pool/mod.rs @@ -17,7 +17,8 @@ //! Manages all available memory during query execution -use datafusion_common::Result; +use datafusion_common::{DataFusionError, Result}; +use parking_lot::Mutex; use std::sync::Arc; mod pool; @@ -163,6 +164,64 @@ impl Drop for MemoryReservation { } } +pub trait TryGrow: Send + Sync + std::fmt::Debug { + fn try_grow(&self, capacity: usize) -> Result<()>; +} + +/// Cloneable reference to [`MemoryReservation`] instance with interior mutability support +#[derive(Clone, Debug)] +pub struct SharedMemoryReservation(Arc>); + +impl From for SharedMemoryReservation { + /// Creates new [`SharedMemoryReservation`] from [`MemoryReservation`] + fn from(reservation: MemoryReservation) -> Self { + Self(Arc::new(Mutex::new(reservation))) + } +} + +impl TryGrow for SharedMemoryReservation { + /// Try to increase the size of this reservation by `capacity` bytes + fn try_grow(&self, capacity: usize) -> Result<()> { + self.0.lock().try_grow(capacity) + } +} + +/// Cloneable reference to [`MemoryReservation`] instance with interior mutability support. +/// Doesn't require [`MemoryReservation`] while creation, and can be initialized later. +#[derive(Clone, Debug)] +pub struct SharedOptionalMemoryReservation(Arc>>); + +impl SharedOptionalMemoryReservation { + /// Initialize inner [`MemoryReservation`] if `None`, otherwise -- do nothing + pub fn initialize(&self, name: impl Into, pool: &Arc) { + let mut locked = self.0.lock(); + if locked.is_none() { + *locked = Some(MemoryConsumer::new(name).register(pool)); + }; + } +} + +impl TryGrow for SharedOptionalMemoryReservation { + /// Try to increase the size of this reservation by `capacity` bytes + fn try_grow(&self, capacity: usize) -> Result<()> { + self.0 + .lock() + .as_mut() + .ok_or_else(|| { + DataFusionError::Internal( + "inner memory reservation not initialized".to_string(), + ) + })? + .try_grow(capacity) + } +} + +impl Default for SharedOptionalMemoryReservation { + fn default() -> Self { + Self(Arc::new(Mutex::new(None))) + } +} + const TB: u64 = 1 << 40; const GB: u64 = 1 << 30; const MB: u64 = 1 << 20; @@ -219,4 +278,63 @@ mod tests { a2.try_grow(25).unwrap(); assert_eq!(pool.reserved(), 25); } + + #[test] + fn test_shared_memory_reservation() { + let pool = Arc::new(GreedyMemoryPool::new(50)) as _; + let a1 = SharedMemoryReservation::from(MemoryConsumer::new("a1").register(&pool)); + let a2 = a1.clone(); + + // Reserve from a1 + a1.try_grow(10).unwrap(); + assert_eq!(pool.reserved(), 10); + + // Drop a1 - normally reservation calls `free` on drop. + // Ensure that reservation still alive in a2 + drop(a1); + assert_eq!(pool.reserved(), 10); + + // Ensure that after a2 dropped, memory gets back to the pool + drop(a2); + assert_eq!(pool.reserved(), 0); + } + + #[test] + fn test_optional_shared_memory_reservation() { + let pool = Arc::new(GreedyMemoryPool::new(50)) as _; + let a1 = SharedOptionalMemoryReservation::default(); + + // try_grow on empty inner reservation + let err = a1.try_grow(10).unwrap_err(); + assert_eq!( + err.to_string(), + "Internal error: inner memory reservation not initialized. \ + This was likely caused by a bug in DataFusion's code and we \ + would welcome that you file an bug report in our issue tracker" + ); + + // multiple initializations + a1.initialize("a1", &pool); + a1.initialize("a2", &pool); + { + let locked = a1.0.lock(); + let name = locked.as_ref().unwrap().consumer.name(); + assert_eq!(name, "a1"); + } + + let a2 = a1.clone(); + + // Reserve from a1 + a1.try_grow(10).unwrap(); + assert_eq!(pool.reserved(), 10); + + // Drop a1 - normally reservation calls `free` on drop. + // Ensure that reservation still alive in a2 + drop(a1); + assert_eq!(pool.reserved(), 10); + + // Ensure that after a2 dropped, memory gets back to the pool + drop(a2); + assert_eq!(pool.reserved(), 0); + } } diff --git a/datafusion/execution/src/object_store.rs b/datafusion/execution/src/object_store.rs index 9c5cca84799c..9b958e3023f3 100644 --- a/datafusion/execution/src/object_store.rs +++ b/datafusion/execution/src/object_store.rs @@ -79,7 +79,12 @@ impl std::fmt::Display for ObjectStoreUrl { } } -/// Provides a mechanism for lazy, on-demand creation of [`ObjectStore`] +/// Provides a mechanism for lazy, on-demand creation of an [`ObjectStore`] +/// +/// For example, to support reading arbitrary buckets from AWS S3 +/// without instantiating an [`ObjectStore`] for each possible bucket +/// up front, an [`ObjectStoreProvider`] can be used to create the +/// appropriate [`ObjectStore`] instance on demand. /// /// See [`ObjectStoreRegistry::new_with_provider`] pub trait ObjectStoreProvider: Send + Sync + 'static { @@ -89,21 +94,29 @@ pub trait ObjectStoreProvider: Send + Sync + 'static { fn get_by_url(&self, url: &Url) -> Result>; } -/// [`ObjectStoreRegistry`] stores [`ObjectStore`] keyed by url scheme and authority, that is -/// the part of a URL preceding the path +/// [`ObjectStoreRegistry`] maps a URL to an [`ObjectStore`] instance, +/// and allows DataFusion to read from different [`ObjectStore`] +/// instances. For example DataFusion might be configured so that +/// +/// 1. `s3://my_bucket/lineitem/` mapped to the `/lineitem` path on an +/// AWS S3 object store bound to `my_bucket` /// -/// This is used by DataFusion to find an appropriate [`ObjectStore`] for a [`ListingTableUrl`] -/// provided in a query such as +/// 2. `s3://my_other_bucket/lineitem/` mapped to the (same) +/// `/lineitem` path on a *different* AWS S3 object store bound to +/// `my_other_bucket` +/// +/// When given a [`ListingTableUrl`], DataFusion tries to find an +/// appropriate [`ObjectStore`]. For example /// /// ```sql /// create external table unicorns stored as parquet location 's3://my_bucket/lineitem/'; /// ``` /// -/// In this particular case the url `s3://my_bucket/lineitem/` will be provided to +/// In this particular case, the url `s3://my_bucket/lineitem/` will be provided to /// [`ObjectStoreRegistry::get_by_url`] and one of three things will happen: /// /// - If an [`ObjectStore`] has been registered with [`ObjectStoreRegistry::register_store`] with -/// scheme `s3` and host `my_bucket`, this [`ObjectStore`] will be returned +/// scheme `s3` and host `my_bucket`, that [`ObjectStore`] will be returned /// /// - If an [`ObjectStoreProvider`] has been associated with this [`ObjectStoreRegistry`] using /// [`ObjectStoreRegistry::new_with_provider`], [`ObjectStoreProvider::get_by_url`] will be invoked, @@ -115,9 +128,10 @@ pub trait ObjectStoreProvider: Send + Sync + 'static { /// /// This allows for two different use-cases: /// -/// * DBMS systems where object store buckets are explicitly created using DDL, can register these +/// 1. Systems where object store buckets are explicitly created using DDL, can register these /// buckets using [`ObjectStoreRegistry::register_store`] -/// * DMBS systems relying on ad-hoc discovery, without corresponding DDL, can create [`ObjectStore`] +/// +/// 2. Systems relying on ad-hoc discovery, without corresponding DDL, can create [`ObjectStore`] /// lazily, on-demand using [`ObjectStoreProvider`] /// /// [`ListingTableUrl`]: crate::datasource::listing::ListingTableUrl diff --git a/datafusion/core/src/execution/runtime_env.rs b/datafusion/execution/src/runtime_env.rs similarity index 97% rename from datafusion/core/src/execution/runtime_env.rs rename to datafusion/execution/src/runtime_env.rs index a38d563fac6d..3163e0e03edf 100644 --- a/datafusion/core/src/execution/runtime_env.rs +++ b/datafusion/execution/src/runtime_env.rs @@ -19,15 +19,12 @@ //! and various system level components that are used during physical plan execution. use crate::{ - error::Result, - execution::disk_manager::{DiskManager, DiskManagerConfig}, -}; - -use datafusion_common::DataFusionError; -use datafusion_execution::{ + disk_manager::{DiskManager, DiskManagerConfig}, memory_pool::{GreedyMemoryPool, MemoryPool, UnboundedMemoryPool}, object_store::ObjectStoreRegistry, }; + +use datafusion_common::{DataFusionError, Result}; use object_store::ObjectStore; use std::fmt::{Debug, Formatter}; use std::path::PathBuf; diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 1530513e9637..01860ae7d411 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -21,7 +21,7 @@ use crate::aggregate_function; use crate::built_in_function; use crate::expr_fn::binary_expr; use crate::logical_plan::Subquery; -use crate::utils::expr_to_columns; +use crate::utils::{expr_to_columns, find_out_reference_exprs}; use crate::window_frame; use crate::window_function; use crate::AggregateUDF; @@ -220,6 +220,9 @@ pub enum Expr { /// The type the parameter will be filled in with data_type: Option, }, + /// A place holder which hold a reference to a qualified field + /// in the outer query, used for correlated sub queries. + OuterReferenceColumn(DataType, Column), } /// Binary expression @@ -567,6 +570,7 @@ impl Expr { Expr::Case { .. } => "Case", Expr::Cast { .. } => "Cast", Expr::Column(..) => "Column", + Expr::OuterReferenceColumn(_, _) => "Outer", Expr::Exists { .. } => "Exists", Expr::GetIndexedField { .. } => "GetIndexedField", Expr::GroupingSet(..) => "GroupingSet", @@ -785,6 +789,11 @@ impl Expr { Ok(using_columns) } + + /// Return true when the expression contains out reference(correlated) expressions. + pub fn contains_outer(&self) -> bool { + !find_out_reference_exprs(self).is_empty() + } } impl Not for Expr { @@ -830,6 +839,7 @@ impl fmt::Debug for Expr { match self { Expr::Alias(expr, alias) => write!(f, "{expr:?} AS {alias}"), Expr::Column(c) => write!(f, "{c}"), + Expr::OuterReferenceColumn(_, c) => write!(f, "outer_ref({})", c), Expr::ScalarVariable(_, var_names) => write!(f, "{}", var_names.join(".")), Expr::Literal(v) => write!(f, "{v:?}"), Expr::Case(case) => { @@ -1110,6 +1120,7 @@ fn create_name(e: &Expr) -> Result { match e { Expr::Alias(_, name) => Ok(name.clone()), Expr::Column(c) => Ok(c.flat_name()), + Expr::OuterReferenceColumn(_, c) => Ok(format!("outer_ref({})", c.flat_name())), Expr::ScalarVariable(_, variable_names) => Ok(variable_names.join(".")), Expr::Literal(value) => Ok(format!("{value:?}")), Expr::BinaryExpr(binary_expr) => { diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 6465ca80b867..b20629946b01 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -28,17 +28,47 @@ use arrow::datatypes::DataType; use datafusion_common::{Column, Result}; use std::sync::Arc; -/// Create a column expression based on a qualified or unqualified column name +/// Create a column expression based on a qualified or unqualified column name. Will +/// normalize unquoted identifiers according to SQL rules (identifiers will become lowercase). /// -/// example: -/// ``` +/// For example: +/// +/// ```rust /// # use datafusion_expr::col; -/// let c = col("my_column"); +/// let c1 = col("a"); +/// let c2 = col("A"); +/// assert_eq!(c1, c2); +/// +/// // note how quoting with double quotes preserves the case +/// let c3 = col(r#""A""#); +/// assert_ne!(c1, c3); /// ``` pub fn col(ident: impl Into) -> Expr { Expr::Column(ident.into()) } +/// Create an unqualified column expression from the provided name, without normalizing +/// the column. +/// +/// For example: +/// +/// ```rust +/// # use datafusion_expr::{col, ident}; +/// let c1 = ident("A"); // not normalized staying as column 'A' +/// let c2 = col("A"); // normalized via SQL rules becoming column 'a' +/// assert_ne!(c1, c2); +/// +/// let c3 = col(r#""A""#); +/// assert_eq!(c1, c3); +/// +/// let c4 = col("t1.a"); // parses as relation 't1' column 'a' +/// let c5 = ident("t1.a"); // parses as column 't1.a' +/// assert_ne!(c4, c5); +/// ``` +pub fn ident(name: impl Into) -> Expr { + Expr::Column(Column::from_name(name)) +} + /// Return a new expression `left right` pub fn binary_expr(left: Expr, op: Operator, right: Expr) -> Expr { Expr::BinaryExpr(BinaryExpr::new(Box::new(left), op, Box::new(right))) @@ -267,7 +297,10 @@ pub fn approx_percentile_cont_with_weight( /// Create an EXISTS subquery expression pub fn exists(subquery: Arc) -> Expr { Expr::Exists { - subquery: Subquery { subquery }, + subquery: Subquery { + subquery, + outer_ref_columns: vec![], + }, negated: false, } } @@ -275,7 +308,10 @@ pub fn exists(subquery: Arc) -> Expr { /// Create a NOT EXISTS subquery expression pub fn not_exists(subquery: Arc) -> Expr { Expr::Exists { - subquery: Subquery { subquery }, + subquery: Subquery { + subquery, + outer_ref_columns: vec![], + }, negated: true, } } @@ -284,7 +320,10 @@ pub fn not_exists(subquery: Arc) -> Expr { pub fn in_subquery(expr: Expr, subquery: Arc) -> Expr { Expr::InSubquery { expr: Box::new(expr), - subquery: Subquery { subquery }, + subquery: Subquery { + subquery, + outer_ref_columns: vec![], + }, negated: false, } } @@ -293,14 +332,20 @@ pub fn in_subquery(expr: Expr, subquery: Arc) -> Expr { pub fn not_in_subquery(expr: Expr, subquery: Arc) -> Expr { Expr::InSubquery { expr: Box::new(expr), - subquery: Subquery { subquery }, + subquery: Subquery { + subquery, + outer_ref_columns: vec![], + }, negated: true, } } /// Create a scalar subquery expression pub fn scalar_subquery(subquery: Arc) -> Expr { - Expr::ScalarSubquery(Subquery { subquery }) + Expr::ScalarSubquery(Subquery { + subquery, + outer_ref_columns: vec![], + }) } /// Create an expression to represent the stddev() aggregate function @@ -652,7 +697,7 @@ mod test { #[test] fn filter_is_null_and_is_not_null() { let col_null = col("col1"); - let col_not_null = col("col2"); + let col_not_null = ident("col2"); assert_eq!(format!("{:?}", col_null.is_null()), "col1 IS NULL"); assert_eq!( format!("{:?}", col_not_null.is_not_null()), diff --git a/datafusion/expr/src/expr_rewriter.rs b/datafusion/expr/src/expr_rewriter.rs index 3ae9e43e0f25..b4e82be5781f 100644 --- a/datafusion/expr/src/expr_rewriter.rs +++ b/datafusion/expr/src/expr_rewriter.rs @@ -121,6 +121,7 @@ impl ExprRewritable for Expr { let expr = match self { Expr::Alias(expr, name) => Expr::Alias(rewrite_boxed(expr, rewriter)?, name), Expr::Column(_) => self.clone(), + Expr::OuterReferenceColumn(_, _) => self.clone(), Expr::Exists { .. } => self.clone(), Expr::InSubquery { expr, @@ -446,6 +447,19 @@ pub fn unnormalize_cols(exprs: impl IntoIterator) -> Vec { exprs.into_iter().map(unnormalize_col).collect() } +/// Recursively remove all the ['OuterReferenceColumn'] and return the inside Column +/// in the expression tree. +pub fn strip_outer_reference(expr: Expr) -> Expr { + rewrite_expr(expr, |expr| { + if let Expr::OuterReferenceColumn(_, col) = expr { + Ok(Expr::Column(col)) + } else { + Ok(expr) + } + }) + .expect("strip_outer_reference is infallable") +} + /// Implementation of [`ExprRewriter`] that calls a function, for use /// with [`rewrite_expr`] struct RewriterAdapter { @@ -664,7 +678,8 @@ mod test { fn normalize_cols_non_exist() { // test normalizing columns when the name doesn't exist let expr = col("a") + col("b"); - let schema_a = make_schema_with_empty_metadata(vec![make_field("tableA", "a")]); + let schema_a = + make_schema_with_empty_metadata(vec![make_field("\"tableA\"", "a")]); let schemas = vec![schema_a]; let schemas = schemas.iter().collect::>(); @@ -674,7 +689,7 @@ mod test { .to_string(); assert_eq!( error, - "Schema error: No field named 'b'. Valid fields are 'tableA'.'a'." + r#"Schema error: No field named "b". Valid fields are "tableA"."a"."# ); } @@ -690,7 +705,7 @@ mod test { } fn make_field(relation: &str, column: &str) -> DFField { - DFField::new(Some(relation), column, DataType::Int8, false) + DFField::new(Some(relation.to_string()), column, DataType::Int8, false) } #[test] diff --git a/datafusion/expr/src/expr_rewriter/order_by.rs b/datafusion/expr/src/expr_rewriter/order_by.rs index 84b7193e22ab..a26b3c6741cb 100644 --- a/datafusion/expr/src/expr_rewriter/order_by.rs +++ b/datafusion/expr/src/expr_rewriter/order_by.rs @@ -19,7 +19,7 @@ use crate::expr::Sort; use crate::expr_rewriter::{normalize_col, rewrite_expr}; -use crate::{Expr, ExprSchemable, LogicalPlan}; +use crate::{Cast, Expr, ExprSchemable, LogicalPlan, TryCast}; use datafusion_common::{Column, Result}; /// Rewrite sort on aggregate expressions to sort on the column of aggregate output @@ -116,7 +116,19 @@ fn rewrite_in_terms_of_projection( // look for the column named the same as this expr if let Some(found) = proj_exprs.iter().find(|a| expr_match(&search_col, a)) { - return Ok((*found).clone()); + let found = found.clone(); + let expr = match normalized_expr { + Expr::Cast(Cast { expr: _, data_type }) => Expr::Cast(Cast { + expr: Box::new(found), + data_type, + }), + Expr::TryCast(TryCast { expr: _, data_type }) => Expr::TryCast(TryCast { + expr: Box::new(found), + data_type, + }), + _ => found, + }; + return Ok(expr); } Ok(expr) }) @@ -141,7 +153,8 @@ mod test { use arrow::datatypes::{DataType, Field, Schema}; use crate::{ - avg, col, lit, logical_plan::builder::LogicalTableSource, min, LogicalPlanBuilder, + avg, cast, col, lit, logical_plan::builder::LogicalTableSource, min, try_cast, + LogicalPlanBuilder, }; use super::*; @@ -241,6 +254,34 @@ mod test { } } + #[test] + fn preserve_cast() { + let plan = make_input() + .project(vec![col("c2").alias("c2")]) + .unwrap() + .project(vec![col("c2").alias("c2")]) + .unwrap() + .build() + .unwrap(); + + let cases = vec![ + TestCase { + desc: "Cast is preserved by rewrite_sort_cols_by_aggs", + input: sort(cast(col("c2"), DataType::Int64)), + expected: sort(cast(col("c2").alias("c2"), DataType::Int64)), + }, + TestCase { + desc: "TryCast is preserved by rewrite_sort_cols_by_aggs", + input: sort(try_cast(col("c2"), DataType::Int64)), + expected: sort(try_cast(col("c2").alias("c2"), DataType::Int64)), + }, + ]; + + for case in cases { + case.run(&plan) + } + } + struct TestCase { desc: &'static str, input: Expr, diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 493c425d7888..fafda79a6f61 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -65,6 +65,7 @@ impl ExprSchemable for Expr { }, Expr::Sort(Sort { expr, .. }) | Expr::Negative(expr) => expr.get_type(schema), Expr::Column(c) => Ok(schema.data_type(c)?.clone()), + Expr::OuterReferenceColumn(ty, _) => Ok(ty.clone()), Expr::ScalarVariable(ty, _) => Ok(ty.clone()), Expr::Literal(l) => Ok(l.get_datatype()), Expr::Case(case) => case.when_then_expr[0].1.get_type(schema), @@ -136,9 +137,10 @@ impl ExprSchemable for Expr { Expr::Placeholder { data_type, .. } => data_type.clone().ok_or_else(|| { DataFusionError::Plan("Placeholder type could not be resolved".to_owned()) }), - Expr::Wildcard => Err(DataFusionError::Internal( - "Wildcard expressions are not valid in a logical query plan".to_owned(), - )), + Expr::Wildcard => { + // Wildcard do not really have a type and do not appear in projections + Ok(DataType::Null) + } Expr::QualifiedWildcard { .. } => Err(DataFusionError::Internal( "QualifiedWildcard expressions are not valid in a logical query plan" .to_owned(), @@ -173,6 +175,7 @@ impl ExprSchemable for Expr { | Expr::InList { expr, .. } => expr.nullable(input_schema), Expr::Between(Between { expr, .. }) => expr.nullable(input_schema), Expr::Column(c) => input_schema.nullable(c), + Expr::OuterReferenceColumn(_, _) => Ok(true), Expr::Literal(value) => Ok(value.is_null()), Expr::Case(case) => { // this expression is nullable if any of the input expressions are nullable @@ -247,13 +250,12 @@ impl ExprSchemable for Expr { fn to_field(&self, input_schema: &DFSchema) -> Result { match self { Expr::Column(c) => Ok(DFField::new( - c.relation.as_deref(), + c.relation.clone(), &c.name, self.get_type(input_schema)?, self.nullable(input_schema)?, )), - _ => Ok(DFField::new( - None, + _ => Ok(DFField::new_unqualified( &self.display_name()?, self.get_type(input_schema)?, self.nullable(input_schema)?, diff --git a/datafusion/expr/src/expr_visitor.rs b/datafusion/expr/src/expr_visitor.rs index e3336a8c46d3..84ca6f7ed9df 100644 --- a/datafusion/expr/src/expr_visitor.rs +++ b/datafusion/expr/src/expr_visitor.rs @@ -134,6 +134,8 @@ impl ExprVisitable for Expr { }) } Expr::Column(_) + // Treat OuterReferenceColumn as a leaf expression + | Expr::OuterReferenceColumn(_, _) | Expr::ScalarVariable(_, _) | Expr::Literal(_) | Expr::Exists { .. } diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index a1fa82cda24e..c5e623785450 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -24,7 +24,7 @@ use crate::expr_rewriter::{ }; use crate::type_coercion::binary::comparison_coercion; use crate::utils::{columnize_expr, compare_sort_expr, exprlist_to_fields, from_plan}; -use crate::{and, binary_expr, Operator}; +use crate::{and, binary_expr, DmlStatement, Operator, WriteOp}; use crate::{ logical_plan::{ Aggregate, Analyze, CrossJoin, Distinct, EmptyRelation, Explain, Filter, Join, @@ -40,8 +40,8 @@ use crate::{ }; use arrow::datatypes::{DataType, Schema, SchemaRef}; use datafusion_common::{ - Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, - ToDFSchema, + Column, DFField, DFSchema, DFSchemaRef, DataFusionError, OwnedTableReference, Result, + ScalarValue, TableReference, ToDFSchema, }; use std::any::Any; use std::cmp::Ordering; @@ -176,8 +176,7 @@ impl LogicalPlanBuilder { .map(|(j, data_type)| { // naming is following convention https://www.postgresql.org/docs/current/queries-values.html let name = &format!("column{}", j + 1); - DFField::new( - None, + DFField::new_unqualified( name, data_type.clone().unwrap_or(DataType::Utf8), true, @@ -193,24 +192,70 @@ impl LogicalPlanBuilder { } /// Convert a table provider into a builder with a TableScan + /// + /// Note that if you pass a string as `table_name`, it is treated + /// as a SQL identifier, as described on [`TableReference`] and + /// thus is normalized + /// + /// # Example: + /// ``` + /// # use datafusion_expr::{lit, col, LogicalPlanBuilder, + /// # logical_plan::builder::LogicalTableSource, logical_plan::table_scan + /// # }; + /// # use std::sync::Arc; + /// # use arrow::datatypes::{Schema, DataType, Field}; + /// # use datafusion_common::TableReference; + /// # + /// # let employee_schema = Arc::new(Schema::new(vec![ + /// # Field::new("id", DataType::Int32, false), + /// # ])) as _; + /// # let table_source = Arc::new(LogicalTableSource::new(employee_schema)); + /// // Scan table_source with the name "mytable" (after normalization) + /// # let table = table_source.clone(); + /// let scan = LogicalPlanBuilder::scan("MyTable", table, None); + /// + /// // Scan table_source with the name "MyTable" by enclosing in quotes + /// # let table = table_source.clone(); + /// let scan = LogicalPlanBuilder::scan(r#""MyTable""#, table, None); + /// + /// // Scan table_source with the name "MyTable" by forming the table reference + /// # let table = table_source.clone(); + /// let table_reference = TableReference::bare("MyTable"); + /// let scan = LogicalPlanBuilder::scan(table_reference, table, None); + /// ``` pub fn scan( - table_name: impl Into, + table_name: impl Into, table_source: Arc, projection: Option>, ) -> Result { Self::scan_with_filters(table_name, table_source, projection, vec![]) } + /// Create a [DmlStatement] for inserting the contents of this builder into the named table + pub fn insert_into( + input: LogicalPlan, + table_name: impl Into, + table_schema: &Schema, + ) -> Result { + let table_schema = table_schema.clone().to_dfschema_ref()?; + Ok(Self::from(LogicalPlan::Dml(DmlStatement { + table_name: table_name.into(), + table_schema, + op: WriteOp::Insert, + input: Arc::new(input), + }))) + } + /// Convert a table provider into a builder with a TableScan pub fn scan_with_filters( - table_name: impl Into, + table_name: impl Into, table_source: Arc, projection: Option>, filters: Vec, ) -> Result { let table_name = table_name.into(); - if table_name.is_empty() { + if table_name.table().is_empty() { return Err(DataFusionError::Plan( "table_name cannot be empty".to_string(), )); @@ -224,14 +269,17 @@ impl LogicalPlanBuilder { DFSchema::new_with_metadata( p.iter() .map(|i| { - DFField::from_qualified(&table_name, schema.field(*i).clone()) + DFField::from_qualified( + table_name.clone(), + schema.field(*i).clone(), + ) }) .collect(), schema.metadata().clone(), ) }) .unwrap_or_else(|| { - DFSchema::try_from_qualified_schema(&table_name, &schema) + DFSchema::try_from_qualified_schema(table_name.clone(), &schema) })?; let table_scan = LogicalPlan::TableScan(TableScan { @@ -1090,7 +1138,7 @@ pub fn union(left_plan: LogicalPlan, right_plan: LogicalPlan) -> Result, +pub fn table_scan<'a>( + name: Option>>, table_schema: &Schema, projection: Option>, ) -> Result { + let table_source = table_source(table_schema); + let name = name + .map(|n| n.into()) + .unwrap_or_else(|| OwnedTableReference::bare(UNNAMED_TABLE)) + .to_owned_reference(); + LogicalPlanBuilder::scan(name, table_source, projection) +} + +fn table_source(table_schema: &Schema) -> Arc { let table_schema = Arc::new(table_schema.clone()); - let table_source = Arc::new(LogicalTableSource { table_schema }); - LogicalPlanBuilder::scan(name.unwrap_or(UNNAMED_TABLE), table_source, projection) + Arc::new(LogicalTableSource { table_schema }) } /// Wrap projection for a plan, if the join keys contains normal expression. @@ -1276,7 +1332,7 @@ pub fn unnest(input: LogicalPlan, column: Column) -> Result { DataType::List(field) | DataType::FixedSizeList(field, _) | DataType::LargeList(field) => DFField::new( - unnest_field.qualifier().map(String::as_str), + unnest_field.qualifier().cloned(), unnest_field.name(), field.data_type().clone(), unnest_field.is_nullable(), @@ -1317,7 +1373,7 @@ pub fn unnest(input: LogicalPlan, column: Column) -> Result { mod tests { use crate::{expr, expr_fn::exists}; use arrow::datatypes::{DataType, Field}; - use datafusion_common::SchemaError; + use datafusion_common::{OwnedTableReference, SchemaError, TableReference}; use crate::logical_plan::StringifiedPlan; @@ -1344,12 +1400,36 @@ mod tests { #[test] fn plan_builder_schema() { let schema = employee_schema(); - let plan = table_scan(Some("employee_csv"), &schema, None).unwrap(); + let projection = None; + let plan = + LogicalPlanBuilder::scan("employee_csv", table_source(&schema), projection) + .unwrap(); + let expected = DFSchema::try_from_qualified_schema( + TableReference::bare("employee_csv"), + &schema, + ) + .unwrap(); + assert_eq!(&expected, plan.schema().as_ref()); - let expected = - DFSchema::try_from_qualified_schema("employee_csv", &schema).unwrap(); + // Note scan of "EMPLOYEE_CSV" is treated as a SQL identifer + // (and thus normalized to "employee"csv") as well + let projection = None; + let plan = + LogicalPlanBuilder::scan("EMPLOYEE_CSV", table_source(&schema), projection) + .unwrap(); + assert_eq!(&expected, plan.schema().as_ref()); + } - assert_eq!(&expected, plan.schema().as_ref()) + #[test] + fn plan_builder_empty_name() { + let schema = employee_schema(); + let projection = None; + let err = + LogicalPlanBuilder::scan("", table_source(&schema), projection).unwrap_err(); + assert_eq!( + err.to_string(), + "Error during planning: table_name cannot be empty" + ); } #[test] @@ -1464,9 +1544,10 @@ mod tests { #[test] fn plan_builder_union_different_num_columns_error() -> Result<()> { - let plan1 = table_scan(None, &employee_schema(), Some(vec![3]))?; - - let plan2 = table_scan(None, &employee_schema(), Some(vec![3, 4]))?; + let plan1 = + table_scan(TableReference::none(), &employee_schema(), Some(vec![3]))?; + let plan2 = + table_scan(TableReference::none(), &employee_schema(), Some(vec![3, 4]))?; let expected = "Error during planning: Union queries must have the same number of columns, (left is 1, right is 2)"; let err_msg1 = plan1.clone().union(plan2.clone().build()?).unwrap_err(); @@ -1592,10 +1673,13 @@ mod tests { match plan { Err(DataFusionError::SchemaError(SchemaError::AmbiguousReference { - qualifier, - name, + field: + Column { + relation: Some(OwnedTableReference::Bare { table }), + name, + }, })) => { - assert_eq!("employee_csv", qualifier.unwrap().as_str()); + assert_eq!("employee_csv", table); assert_eq!("id", &name); Ok(()) } @@ -1618,10 +1702,13 @@ mod tests { match plan { Err(DataFusionError::SchemaError(SchemaError::AmbiguousReference { - qualifier, - name, + field: + Column { + relation: Some(OwnedTableReference::Bare { table }), + name, + }, })) => { - assert_eq!("employee_csv", qualifier.unwrap().as_str()); + assert_eq!("employee_csv", table); assert_eq!("state", &name); Ok(()) } @@ -1684,9 +1771,10 @@ mod tests { #[test] fn plan_builder_intersect_different_num_columns_error() -> Result<()> { - let plan1 = table_scan(None, &employee_schema(), Some(vec![3]))?; - - let plan2 = table_scan(None, &employee_schema(), Some(vec![3, 4]))?; + let plan1 = + table_scan(TableReference::none(), &employee_schema(), Some(vec![3]))?; + let plan2 = + table_scan(TableReference::none(), &employee_schema(), Some(vec![3, 4]))?; let expected = "Error during planning: INTERSECT/EXCEPT query must have the same number of columns. \ Left is 1 and right is 2."; @@ -1722,7 +1810,7 @@ mod tests { // Check unnested field is a scalar let field = plan .schema() - .field_with_name(Some("test_table"), "strings") + .field_with_name(Some(&TableReference::bare("test_table")), "strings") .unwrap(); assert_eq!(&DataType::Utf8, field.data_type()); @@ -1741,7 +1829,7 @@ mod tests { // Check unnested struct list field should be a struct. let field = plan .schema() - .field_with_name(Some("test_table"), "structs") + .field_with_name(Some(&TableReference::bare("test_table")), "structs") .unwrap(); assert!(matches!(field.data_type(), DataType::Struct(_))); diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 83bda06122b1..0e0cfca92025 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -24,7 +24,7 @@ use crate::logical_plan::display::{GraphvizVisitor, IndentVisitor}; use crate::logical_plan::extension::UserDefinedLogicalNode; use crate::logical_plan::plan; use crate::utils::{ - self, exprlist_to_fields, from_plan, grouping_set_expr_count, + exprlist_to_fields, find_out_reference_exprs, from_plan, grouping_set_expr_count, grouping_set_to_exprlist, }; use crate::{ @@ -34,7 +34,7 @@ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::{ plan_err, Column, DFSchema, DFSchemaRef, DataFusionError, OwnedTableReference, - ScalarValue, + ScalarValue, TableReference, }; use std::collections::{HashMap, HashSet}; use std::fmt::{self, Debug, Display, Formatter}; @@ -266,6 +266,31 @@ impl LogicalPlan { exprs } + /// Returns all the out reference(correlated) expressions (recursively) in the current + /// logical plan nodes and all its descendant nodes. + pub fn all_out_ref_exprs(self: &LogicalPlan) -> Vec { + let mut exprs = vec![]; + self.inspect_expressions(|e| { + find_out_reference_exprs(e).into_iter().for_each(|e| { + if !exprs.contains(&e) { + exprs.push(e) + } + }); + Ok(()) as Result<(), DataFusionError> + }) + // closure always returns OK + .unwrap(); + self.inputs() + .into_iter() + .flat_map(|child| child.all_out_ref_exprs()) + .for_each(|e| { + if !exprs.contains(&e) { + exprs.push(e) + } + }); + exprs + } + /// Calls `f` on all expressions (non-recursively) in the current /// logical plan node. This does not include expressions in any /// children. @@ -632,22 +657,21 @@ impl LogicalPlan { /// params_values pub fn replace_params_with_values( &self, - param_values: &Vec, + param_values: &[ScalarValue], ) -> Result { - let exprs = self.expressions(); - let mut new_exprs = vec![]; - for expr in exprs { - new_exprs.push(Self::replace_placeholders_with_values(expr, param_values)?); - } + let new_exprs = self + .expressions() + .into_iter() + .map(|e| Self::replace_placeholders_with_values(e, param_values)) + .collect::, DataFusionError>>()?; - let new_inputs = self.inputs(); - let mut new_inputs_with_values = vec![]; - for input in new_inputs { - new_inputs_with_values.push(input.replace_params_with_values(param_values)?); - } + let new_inputs_with_values = self + .inputs() + .into_iter() + .map(|inp| inp.replace_params_with_values(param_values)) + .collect::, DataFusionError>>()?; - let new_plan = utils::from_plan(self, &new_exprs, &new_inputs_with_values)?; - Ok(new_plan) + from_plan(self, &new_exprs, &new_inputs_with_values) } /// Walk the logical plan, find any `PlaceHolder` tokens, and return a map of their IDs and DataTypes @@ -748,11 +772,12 @@ impl LogicalPlan { Ok(Expr::Literal(value.clone())) } Expr::ScalarSubquery(qry) => { - let subquery = Arc::new( - qry.subquery - .replace_params_with_values(¶m_values.to_vec())?, - ); - Ok(Expr::ScalarSubquery(plan::Subquery { subquery })) + let subquery = + Arc::new(qry.subquery.replace_params_with_values(param_values)?); + Ok(Expr::ScalarSubquery(plan::Subquery { + subquery, + outer_ref_columns: qry.outer_ref_columns.clone(), + })) } _ => Ok(expr), } @@ -1437,9 +1462,12 @@ impl SubqueryAlias { alias: impl Into, ) -> datafusion_common::Result { let alias = alias.into(); + let table_ref = TableReference::bare(&alias); let schema: Schema = plan.schema().as_ref().clone().into(); - let schema = - DFSchemaRef::new(DFSchema::try_from_qualified_schema(&alias, &schema)?); + let schema = DFSchemaRef::new(DFSchema::try_from_qualified_schema( + table_ref.to_owned_reference(), + &schema, + )?); Ok(SubqueryAlias { input: Arc::new(plan), alias, @@ -1521,7 +1549,7 @@ pub struct Window { #[derive(Clone)] pub struct TableScan { /// The name of the table - pub table_name: String, + pub table_name: OwnedTableReference, /// The source of the table pub source: Arc, /// Optional column indices to use as a projection @@ -1632,6 +1660,8 @@ pub struct CreateExternalTable { pub if_not_exists: bool, /// SQL used to create the table, if available pub definition: Option, + /// Order expressions supplied by user + pub order_exprs: Vec, /// File compression type (GZIP, BZIP2, XZ, ZSTD) pub file_compression_type: CompressionTypeVariant, /// Table(provider) specific options @@ -1652,6 +1682,7 @@ impl Hash for CreateExternalTable { self.if_not_exists.hash(state); self.definition.hash(state); self.file_compression_type.hash(state); + self.order_exprs.hash(state); self.options.len().hash(state); // HashMap is not hashable } } @@ -1919,14 +1950,11 @@ impl Join { pub struct Subquery { /// The subquery pub subquery: Arc, + /// The outer references used in the subquery + pub outer_ref_columns: Vec, } impl Subquery { - pub fn new(plan: LogicalPlan) -> Self { - Subquery { - subquery: Arc::new(plan), - } - } pub fn try_from_expr(plan: &Expr) -> datafusion_common::Result<&Subquery> { match plan { Expr::ScalarSubquery(it) => Ok(it), @@ -2391,7 +2419,7 @@ mod tests { Field::new("state", DataType::Utf8, false), ]); - table_scan(None, &schema, Some(vec![0, 1])) + table_scan(TableReference::none(), &schema, Some(vec![0, 1])) .unwrap() .filter(col("state").eq(lit("CO"))) .unwrap() diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 95b5f881b37a..d55338adc5ed 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -35,6 +35,7 @@ use crate::{ use arrow::datatypes::{DataType, TimeUnit}; use datafusion_common::{ Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, + TableReference, }; use std::cmp::Ordering; use std::collections::HashSet; @@ -96,7 +97,45 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { Expr::ScalarVariable(_, var_names) => { accum.insert(Column::from_name(var_names.join("."))); } - _ => {} + // Use explicit pattern match instead of a default + // implementation, so that in the future if someone adds + // new Expr types, they will check here as well + Expr::Alias(_, _) + | Expr::Literal(_) + | Expr::BinaryExpr { .. } + | Expr::Like { .. } + | Expr::ILike { .. } + | Expr::SimilarTo { .. } + | Expr::Not(_) + | Expr::IsNotNull(_) + | Expr::IsNull(_) + | Expr::IsTrue(_) + | Expr::IsFalse(_) + | Expr::IsUnknown(_) + | Expr::IsNotTrue(_) + | Expr::IsNotFalse(_) + | Expr::IsNotUnknown(_) + | Expr::Negative(_) + | Expr::Between { .. } + | Expr::Case { .. } + | Expr::Cast { .. } + | Expr::TryCast { .. } + | Expr::Sort { .. } + | Expr::ScalarFunction { .. } + | Expr::ScalarUDF { .. } + | Expr::WindowFunction { .. } + | Expr::AggregateFunction { .. } + | Expr::GroupingSet(_) + | Expr::AggregateUDF { .. } + | Expr::InList { .. } + | Expr::Exists { .. } + | Expr::InSubquery { .. } + | Expr::ScalarSubquery(_) + | Expr::Wildcard + | Expr::QualifiedWildcard { .. } + | Expr::GetIndexedField { .. } + | Expr::Placeholder { .. } + | Expr::OuterReferenceColumn { .. } => {} } Ok(()) }) @@ -154,8 +193,9 @@ pub fn expand_qualified_wildcard( qualifier: &str, schema: &DFSchema, ) -> Result> { + let qualifier = TableReference::from(qualifier); let qualified_fields: Vec = schema - .fields_with_qualified(qualifier) + .fields_with_qualified(&qualifier) .into_iter() .cloned() .collect(); @@ -343,6 +383,14 @@ pub fn find_window_exprs(exprs: &[Expr]) -> Vec { }) } +/// Collect all deeply nested `Expr::OuterReferenceColumn`. They are returned in order of occurrence +/// (depth first), with duplicates omitted. +pub fn find_out_reference_exprs(expr: &Expr) -> Vec { + find_exprs_in_expr(expr, &|nested_expr| { + matches!(nested_expr, Expr::OuterReferenceColumn { .. }) + }) +} + /// Search the provided `Expr`'s, and all of their nested `Expr`, for any that /// pass the provided test. The returned `Expr`'s are deduplicated and returned /// in order of appearance (depth first). @@ -598,21 +646,20 @@ pub fn from_plan( let right = inputs[1].clone(); LogicalPlanBuilder::from(left).cross_join(right)?.build() } - LogicalPlan::Subquery(_) => { + LogicalPlan::Subquery(Subquery { + outer_ref_columns, .. + }) => { let subquery = LogicalPlanBuilder::from(inputs[0].clone()).build()?; Ok(LogicalPlan::Subquery(Subquery { subquery: Arc::new(subquery), + outer_ref_columns: outer_ref_columns.clone(), })) } LogicalPlan::SubqueryAlias(SubqueryAlias { alias, .. }) => { - let schema = inputs[0].schema().as_ref().clone().into(); - let schema = - DFSchemaRef::new(DFSchema::try_from_qualified_schema(alias, &schema)?); - Ok(LogicalPlan::SubqueryAlias(SubqueryAlias { - alias: alias.clone(), - input: Arc::new(inputs[0].clone()), - schema, - })) + Ok(LogicalPlan::SubqueryAlias(SubqueryAlias::try_new( + inputs[0].clone(), + alias.clone(), + )?)) } LogicalPlan::Limit(Limit { skip, fetch, .. }) => Ok(LogicalPlan::Limit(Limit { skip: *skip, @@ -815,6 +862,7 @@ pub fn exprlist_to_fields<'a>( pub fn columnize_expr(e: Expr, input_schema: &DFSchema) -> Expr { match e { Expr::Column(_) => e, + Expr::OuterReferenceColumn(_, _) => e, Expr::Alias(inner_expr, name) => { columnize_expr(*inner_expr, input_schema).alias(name) } diff --git a/datafusion/optimizer/src/analyzer.rs b/datafusion/optimizer/src/analyzer.rs new file mode 100644 index 000000000000..e999eb2419d0 --- /dev/null +++ b/datafusion/optimizer/src/analyzer.rs @@ -0,0 +1,204 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::count_wildcard_rule::CountWildcardRule; +use crate::rewrite::TreeNodeRewritable; +use datafusion_common::config::ConfigOptions; +use datafusion_common::{DataFusionError, Result}; +use datafusion_expr::expr_visitor::inspect_expr_pre; +use datafusion_expr::{Expr, LogicalPlan}; +use log::{debug, trace}; +use std::sync::Arc; +use std::time::Instant; + +/// `AnalyzerRule` transforms the unresolved ['LogicalPlan']s and unresolved ['Expr']s into +/// the resolved form. +pub trait AnalyzerRule { + /// Rewrite `plan` + fn analyze(&self, plan: &LogicalPlan, config: &ConfigOptions) -> Result; + + /// A human readable name for this analyzer rule + fn name(&self) -> &str; +} +/// A rule-based Analyzer. +#[derive(Clone)] +pub struct Analyzer { + /// All rules to apply + pub rules: Vec>, +} + +impl Default for Analyzer { + fn default() -> Self { + Self::new() + } +} + +impl Analyzer { + /// Create a new analyzer using the recommended list of rules + pub fn new() -> Self { + let rules: Vec> = + vec![Arc::new(CountWildcardRule::new())]; + Self::with_rules(rules) + } + + /// Create a new analyzer with the given rules + pub fn with_rules(rules: Vec>) -> Self { + Self { rules } + } + + /// Analyze the logical plan by applying analyzer rules, and + /// do necessary check and fail the invalid plans + pub fn execute_and_check( + &self, + plan: &LogicalPlan, + config: &ConfigOptions, + ) -> Result { + let start_time = Instant::now(); + let mut new_plan = plan.clone(); + + // TODO add common rule executor for Analyzer and Optimizer + for rule in &self.rules { + new_plan = rule.analyze(&new_plan, config)?; + } + check_plan(&new_plan)?; + log_plan("Final analyzed plan", &new_plan); + debug!("Analyzer took {} ms", start_time.elapsed().as_millis()); + Ok(new_plan) + } +} + +/// Log the plan in debug/tracing mode after some part of the optimizer runs +fn log_plan(description: &str, plan: &LogicalPlan) { + debug!("{description}:\n{}\n", plan.display_indent()); + trace!("{description}::\n{}\n", plan.display_indent_schema()); +} + +/// Do necessary check and fail the invalid plan +fn check_plan(plan: &LogicalPlan) -> Result<()> { + plan.for_each_up(&|plan: &LogicalPlan| { + plan.expressions().into_iter().try_for_each(|expr| { + // recursively look for subqueries + inspect_expr_pre(&expr, |expr| match expr { + Expr::Exists { subquery, .. } + | Expr::InSubquery { subquery, .. } + | Expr::ScalarSubquery(subquery) => { + check_subquery_expr(plan, &subquery.subquery, expr) + } + _ => Ok(()), + }) + }) + }) +} + +/// Do necessary check on subquery expressions and fail the invalid plan +/// 1) Check whether the outer plan is in the allowed outer plans list to use subquery expressions, +/// the allowed while list: [Projection, Filter, Window, Aggregate, Sort, Join]. +/// 2) Check whether the inner plan is in the allowed inner plans list to use correlated(outer) expressions. +/// 3) Check and validate unsupported cases to use the correlated(outer) expressions inside the subquery(inner) plans/inner expressions. +/// For example, we do not want to support to use correlated expressions as the Join conditions in the subquery plan when the Join +/// is a Full Out Join +fn check_subquery_expr( + outer_plan: &LogicalPlan, + inner_plan: &LogicalPlan, + expr: &Expr, +) -> Result<()> { + check_plan(inner_plan)?; + + // Scalar subquery should only return one column + if matches!(expr, Expr::ScalarSubquery(subquery) if subquery.subquery.schema().fields().len() > 1) + { + return Err(DataFusionError::Plan( + "Scalar subquery should only return one column".to_string(), + )); + } + + match outer_plan { + LogicalPlan::Projection(_) + | LogicalPlan::Filter(_) + | LogicalPlan::Window(_) + | LogicalPlan::Aggregate(_) + | LogicalPlan::Join(_) => Ok(()), + LogicalPlan::Sort(_) => match expr { + Expr::InSubquery { .. } | Expr::Exists { .. } => Err(DataFusionError::Plan( + "In/Exist subquery can not be used in Sort plan nodes".to_string(), + )), + Expr::ScalarSubquery(_) => Ok(()), + _ => Ok(()), + }, + _ => Err(DataFusionError::Plan( + "Subquery can only be used in Projection, Filter, \ + Window functions, Aggregate, Sort and Join plan nodes" + .to_string(), + )), + }?; + check_correlations_in_subquery(outer_plan, inner_plan, expr, true) +} + +// Recursively check the unsupported outer references in the sub query plan. +fn check_correlations_in_subquery( + outer_plan: &LogicalPlan, + inner_plan: &LogicalPlan, + expr: &Expr, + can_contain_outer_ref: bool, +) -> Result<()> { + // We want to support as many operators as possible inside the correlated subquery + if !can_contain_outer_ref && contains_outer_reference(outer_plan, inner_plan, expr) { + return Err(DataFusionError::Plan( + "Accessing outer reference column is not allowed in the plan".to_string(), + )); + } + match inner_plan { + LogicalPlan::Projection(_) + | LogicalPlan::Filter(_) + | LogicalPlan::Window(_) + | LogicalPlan::Aggregate(_) + | LogicalPlan::Distinct(_) + | LogicalPlan::Sort(_) + | LogicalPlan::CrossJoin(_) + | LogicalPlan::Union(_) + | LogicalPlan::TableScan(_) + | LogicalPlan::EmptyRelation(_) + | LogicalPlan::Limit(_) + | LogicalPlan::Subquery(_) + | LogicalPlan::SubqueryAlias(_) => inner_plan.apply_children(|plan| { + check_correlations_in_subquery(outer_plan, plan, expr, can_contain_outer_ref) + }), + LogicalPlan::Join(_) => { + // TODO support correlation columns in the subquery join + inner_plan.apply_children(|plan| { + check_correlations_in_subquery( + outer_plan, + plan, + expr, + can_contain_outer_ref, + ) + }) + } + _ => Err(DataFusionError::Plan( + "Unsupported operator in the subquery plan.".to_string(), + )), + } +} + +fn contains_outer_reference( + _outer_plan: &LogicalPlan, + _inner_plan: &LogicalPlan, + _expr: &Expr, +) -> bool { + // TODO check outer references + false +} diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index e1830390dee9..33bf676db128 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -312,7 +312,7 @@ fn build_common_expr_project_plan( match expr_set.get(&id) { Some((expr, _, data_type)) => { // todo: check `nullable` - let field = DFField::new(None, &id, data_type.clone(), true); + let field = DFField::new_unqualified(&id, data_type.clone(), true); fields_set.insert(field.name().to_owned()); project_exprs.push(expr.clone().alias(&id)); } @@ -624,8 +624,8 @@ mod test { let schema = Arc::new(DFSchema::new_with_metadata( vec![ - DFField::new(None, "a", DataType::Int64, false), - DFField::new(None, "c", DataType::Int64, false), + DFField::new_unqualified("a", DataType::Int64, false), + DFField::new_unqualified("c", DataType::Int64, false), ], Default::default(), )?); diff --git a/datafusion/optimizer/src/count_wildcard_rule.rs b/datafusion/optimizer/src/count_wildcard_rule.rs new file mode 100644 index 000000000000..416bd0337a4d --- /dev/null +++ b/datafusion/optimizer/src/count_wildcard_rule.rs @@ -0,0 +1,101 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::analyzer::AnalyzerRule; +use datafusion_common::config::ConfigOptions; +use datafusion_common::Result; +use datafusion_expr::expr::AggregateFunction; +use datafusion_expr::utils::COUNT_STAR_EXPANSION; +use datafusion_expr::{aggregate_function, lit, Aggregate, Expr, LogicalPlan, Window}; +use std::ops::Deref; +use std::sync::Arc; + +pub struct CountWildcardRule {} + +impl Default for CountWildcardRule { + fn default() -> Self { + CountWildcardRule::new() + } +} + +impl CountWildcardRule { + pub fn new() -> Self { + CountWildcardRule {} + } +} +impl AnalyzerRule for CountWildcardRule { + fn analyze(&self, plan: &LogicalPlan, _: &ConfigOptions) -> Result { + let new_plan = match plan { + LogicalPlan::Window(window) => { + let inputs = plan.inputs(); + let window_expr = window.clone().window_expr; + let window_expr = handle_wildcard(window_expr).unwrap(); + LogicalPlan::Window(Window { + input: Arc::new(inputs.get(0).unwrap().deref().clone()), + window_expr, + schema: plan.schema().clone(), + }) + } + + LogicalPlan::Aggregate(aggregate) => { + let inputs = plan.inputs(); + let aggr_expr = aggregate.clone().aggr_expr; + let aggr_expr = handle_wildcard(aggr_expr).unwrap(); + LogicalPlan::Aggregate( + Aggregate::try_new_with_schema( + Arc::new(inputs.get(0).unwrap().deref().clone()), + aggregate.clone().group_expr, + aggr_expr, + plan.schema().clone(), + ) + .unwrap(), + ) + } + _ => plan.clone(), + }; + Ok(new_plan) + } + + fn name(&self) -> &str { + "count_wildcard_rule" + } +} + +//handle Count(Expr:Wildcard) with DataFrame API +pub fn handle_wildcard(exprs: Vec) -> Result> { + let exprs: Vec = exprs + .iter() + .map(|expr| match expr { + Expr::AggregateFunction(AggregateFunction { + fun: aggregate_function::AggregateFunction::Count, + args, + distinct, + filter, + }) if args.len() == 1 => match args[0] { + Expr::Wildcard => Expr::AggregateFunction(AggregateFunction { + fun: aggregate_function::AggregateFunction::Count, + args: vec![lit(COUNT_STAR_EXPANSION)], + distinct: *distinct, + filter: filter.clone(), + }), + _ => expr.clone(), + }, + _ => expr.clone(), + }) + .collect(); + Ok(exprs) +} diff --git a/datafusion/optimizer/src/decorrelate_where_exists.rs b/datafusion/optimizer/src/decorrelate_where_exists.rs index ffc12c158235..35d7858ba3fa 100644 --- a/datafusion/optimizer/src/decorrelate_where_exists.rs +++ b/datafusion/optimizer/src/decorrelate_where_exists.rs @@ -58,11 +58,14 @@ impl DecorrelateWhereExists { for it in filters.iter() { match it { Expr::Exists { subquery, negated } => { - let subquery = self + let subquery_plan = self .try_optimize(&subquery.subquery, config)? .map(Arc::new) .unwrap_or_else(|| subquery.subquery.clone()); - let subquery = Subquery { subquery }; + let subquery = Subquery { + subquery: subquery_plan, + outer_ref_columns: subquery.outer_ref_columns.clone(), + }; let subquery = SubqueryInfo::new(subquery.clone(), *negated); subqueries.push(subquery); } diff --git a/datafusion/optimizer/src/decorrelate_where_in.rs b/datafusion/optimizer/src/decorrelate_where_in.rs index 5eb7e99e7980..3439757f08d8 100644 --- a/datafusion/optimizer/src/decorrelate_where_in.rs +++ b/datafusion/optimizer/src/decorrelate_where_in.rs @@ -65,11 +65,14 @@ impl DecorrelateWhereIn { subquery, negated, } => { - let subquery = self + let subquery_plan = self .try_optimize(&subquery.subquery, config)? .map(Arc::new) .unwrap_or_else(|| subquery.subquery.clone()); - let subquery = Subquery { subquery }; + let subquery = Subquery { + subquery: subquery_plan, + outer_ref_columns: subquery.outer_ref_columns.clone(), + }; let subquery = SubqueryInfo::new(subquery.clone(), (**expr).clone(), *negated); subqueries.push(subquery); @@ -149,6 +152,7 @@ fn optimize_where_in( let projection = Projection::try_from_plan(&query_info.query.subquery) .map_err(|e| context!("a projection is required", e))?; let subquery_input = projection.input.clone(); + // TODO add the validate logic to Analyzer let subquery_expr = only_or_err(projection.expr.as_slice()) .map_err(|e| context!("single expression projection required", e))?; diff --git a/datafusion/optimizer/src/eliminate_duplicated_expr.rs b/datafusion/optimizer/src/eliminate_duplicated_expr.rs index 5a882108e0fd..15f3d8e1d851 100644 --- a/datafusion/optimizer/src/eliminate_duplicated_expr.rs +++ b/datafusion/optimizer/src/eliminate_duplicated_expr.rs @@ -18,8 +18,9 @@ use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::Result; +use datafusion_expr::expr::Sort as ExprSort; use datafusion_expr::logical_plan::LogicalPlan; -use datafusion_expr::Sort; +use datafusion_expr::{Expr, Sort}; use hashbrown::HashSet; /// Optimization rule that eliminate duplicated expr. @@ -41,15 +42,28 @@ impl OptimizerRule for EliminateDuplicatedExpr { ) -> Result> { match plan { LogicalPlan::Sort(sort) => { + let normalized_sort_keys = sort + .expr + .iter() + .map(|e| match e { + Expr::Sort(ExprSort { expr, .. }) => { + Expr::Sort(ExprSort::new(expr.clone(), true, false)) + } + _ => e.clone(), + }) + .collect::>(); + // dedup sort.expr and keep order let mut dedup_expr = Vec::new(); let mut dedup_set = HashSet::new(); - for expr in &sort.expr { - if !dedup_set.contains(expr) { - dedup_expr.push(expr); - dedup_set.insert(expr.clone()); - } - } + sort.expr.iter().zip(normalized_sort_keys.iter()).for_each( + |(expr, normalized_expr)| { + if !dedup_set.contains(normalized_expr) { + dedup_expr.push(expr); + dedup_set.insert(normalized_expr); + } + }, + ); if dedup_expr.len() == sort.expr.len() { Ok(None) } else { @@ -100,4 +114,23 @@ mod tests { \n TableScan: test"; assert_optimized_plan_eq(&plan, expected) } + + #[test] + fn eliminate_sort_exprs_with_options() -> Result<()> { + let table_scan = test_table_scan().unwrap(); + let sort_exprs = vec![ + col("a").sort(true, true), + col("b").sort(true, false), + col("a").sort(false, false), + col("b").sort(false, true), + ]; + let plan = LogicalPlanBuilder::from(table_scan) + .sort(sort_exprs)? + .limit(5, Some(10))? + .build()?; + let expected = "Limit: skip=5, fetch=10\ + \n Sort: test.a ASC NULLS FIRST, test.b ASC NULLS LAST\ + \n TableScan: test"; + assert_optimized_plan_eq(&plan, expected) + } } diff --git a/datafusion/optimizer/src/inline_table_scan.rs b/datafusion/optimizer/src/inline_table_scan.rs index 6b58399192ee..722a70cb380d 100644 --- a/datafusion/optimizer/src/inline_table_scan.rs +++ b/datafusion/optimizer/src/inline_table_scan.rs @@ -57,7 +57,16 @@ impl OptimizerRule for InlineTableScan { generate_projection_expr(projection, sub_plan)?; let plan = LogicalPlanBuilder::from(sub_plan.clone()) .project(projection_exprs)? - .alias(table_name)?; + // Since this This is creating a subquery like: + //```sql + // ... + // FROM as "table_name" + // ``` + // + // it doesn't make sense to have a qualified + // reference (e.g. "foo"."bar") -- this convert to + // string + .alias(table_name.to_string())?; Ok(Some(plan.build()?)) } else { Ok(None) diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index 4bbbb4645af3..3fa1995271dc 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -16,6 +16,7 @@ // under the License. pub mod alias; +pub mod analyzer; pub mod common_subexpr_eliminate; pub mod decorrelate_where_exists; pub mod decorrelate_where_in; @@ -35,6 +36,7 @@ pub mod push_down_filter; pub mod push_down_limit; pub mod push_down_projection; pub mod replace_distinct_aggregate; +pub mod rewrite; pub mod rewrite_disjunctive_predicate; pub mod scalar_subquery_to_join; pub mod simplify_expressions; @@ -43,6 +45,7 @@ pub mod type_coercion; pub mod unwrap_cast_in_comparison; pub mod utils; +pub mod count_wildcard_rule; #[cfg(test)] pub mod test; diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index c1baa25d4363..01e945119c46 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -17,6 +17,7 @@ //! Query optimizer traits +use crate::analyzer::Analyzer; use crate::common_subexpr_eliminate::CommonSubexprEliminate; use crate::decorrelate_where_exists::DecorrelateWhereExists; use crate::decorrelate_where_in::DecorrelateWhereIn; @@ -266,9 +267,10 @@ impl Optimizer { F: FnMut(&LogicalPlan, &dyn OptimizerRule), { let options = config.options(); + let analyzed_plan = Analyzer::default().execute_and_check(plan, options)?; let start_time = Instant::now(); - let mut old_plan = Cow::Borrowed(plan); - let mut new_plan = plan.clone(); + let mut old_plan = Cow::Borrowed(&analyzed_plan); + let mut new_plan = analyzed_plan.clone(); let mut i = 0; while i < options.optimizer.max_passes { log_plan(&format!("Optimizer input (pass {i})"), &new_plan); @@ -476,9 +478,9 @@ mod tests { Internal error: Optimizer rule 'get table_scan rule' failed, due to generate a different schema, \ original schema: DFSchema { fields: [], metadata: {} }, \ new schema: DFSchema { fields: [\ - DFField { qualifier: Some(\"test\"), field: Field { name: \"a\", data_type: UInt32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, \ - DFField { qualifier: Some(\"test\"), field: Field { name: \"b\", data_type: UInt32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, \ - DFField { qualifier: Some(\"test\"), field: Field { name: \"c\", data_type: UInt32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }], \ + DFField { qualifier: Some(Bare { table: \"test\" }), field: Field { name: \"a\", data_type: UInt32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, \ + DFField { qualifier: Some(Bare { table: \"test\" }), field: Field { name: \"b\", data_type: UInt32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, \ + DFField { qualifier: Some(Bare { table: \"test\" }), field: Field { name: \"c\", data_type: UInt32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }], \ metadata: {} }. \ This was likely caused by a bug in DataFusion's code \ and we would welcome that you file an bug report in our issue tracker", @@ -521,7 +523,7 @@ mod tests { let new_arrow_field = f.field().clone().with_metadata(metadata); if let Some(qualifier) = f.qualifier() { - DFField::from_qualified(qualifier, new_arrow_field) + DFField::from_qualified(qualifier.clone(), new_arrow_field) } else { DFField::from(new_arrow_field) } diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index f910268dfe86..55c77e51e2d3 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -2082,7 +2082,7 @@ mod tests { let test_provider = PushDownProvider { filter_support }; let table_scan = LogicalPlan::TableScan(TableScan { - table_name: "test".to_string(), + table_name: "test".into(), filters: vec![], projected_schema: Arc::new(DFSchema::try_from( (*test_provider.schema()).clone(), @@ -2154,7 +2154,7 @@ mod tests { }; let table_scan = LogicalPlan::TableScan(TableScan { - table_name: "test".to_string(), + table_name: "test".into(), filters: vec![col("a").eq(lit(10i64)), col("b").gt(lit(11i64))], projected_schema: Arc::new(DFSchema::try_from( (*test_provider.schema()).clone(), @@ -2183,7 +2183,7 @@ mod tests { }; let table_scan = LogicalPlan::TableScan(TableScan { - table_name: "test".to_string(), + table_name: "test".into(), filters: vec![], projected_schema: Arc::new(DFSchema::try_from( (*test_provider.schema()).clone(), diff --git a/datafusion/optimizer/src/push_down_projection.rs b/datafusion/optimizer/src/push_down_projection.rs index 6d7ab481b83a..767077aa0c02 100644 --- a/datafusion/optimizer/src/push_down_projection.rs +++ b/datafusion/optimizer/src/push_down_projection.rs @@ -536,7 +536,9 @@ fn push_down_scan( // create the projected schema let projected_fields: Vec = projection .iter() - .map(|i| DFField::from_qualified(&scan.table_name, schema.fields()[*i].clone())) + .map(|i| { + DFField::from_qualified(scan.table_name.clone(), schema.fields()[*i].clone()) + }) .collect(); let projected_schema = projected_fields.to_dfschema_ref()?; diff --git a/datafusion/core/src/physical_plan/rewrite.rs b/datafusion/optimizer/src/rewrite.rs similarity index 81% rename from datafusion/core/src/physical_plan/rewrite.rs rename to datafusion/optimizer/src/rewrite.rs index 2972b546bb0e..4a2d35de0086 100644 --- a/datafusion/core/src/physical_plan/rewrite.rs +++ b/datafusion/optimizer/src/rewrite.rs @@ -15,13 +15,11 @@ // specific language governing permissions and limitations // under the License. -//! Trait to make Executionplan rewritable +//! Trait to make LogicalPlan rewritable -use crate::physical_plan::with_new_children_if_necessary; -use crate::physical_plan::ExecutionPlan; use datafusion_common::Result; -use std::sync::Arc; +use datafusion_expr::LogicalPlan; /// a Trait for marking tree node types that are rewritable pub trait TreeNodeRewritable: Clone { @@ -119,6 +117,29 @@ pub trait TreeNodeRewritable: Clone { fn map_children(self, transform: F) -> Result where F: FnMut(Self) -> Result; + + /// Apply the given function `func` to this node and recursively apply to the node's children + fn for_each(&self, func: &F) -> Result<()> + where + F: Fn(&Self) -> Result<()>, + { + func(self)?; + self.apply_children(|node| node.for_each(func)) + } + + /// Recursively apply the given function `func` to the node's children and to this node + fn for_each_up(&self, func: &F) -> Result<()> + where + F: Fn(&Self) -> Result<()>, + { + self.apply_children(|node| node.for_each_up(func))?; + func(self) + } + + /// Apply the given function `func` to the node's children + fn apply_children(&self, func: F) -> Result<()> + where + F: Fn(&Self) -> Result<()>; } /// Trait for potentially recursively transform an [`TreeNodeRewritable`] node @@ -149,18 +170,30 @@ pub enum RewriteRecursion { Skip, } -impl TreeNodeRewritable for Arc { +impl TreeNodeRewritable for LogicalPlan { fn map_children(self, transform: F) -> Result where F: FnMut(Self) -> Result, { - let children = self.children(); + let children = self.inputs().into_iter().cloned().collect::>(); if !children.is_empty() { let new_children: Result> = children.into_iter().map(transform).collect(); - with_new_children_if_necessary(self, new_children?) + self.with_new_inputs(new_children?.as_slice()) } else { Ok(self) } } + + fn apply_children(&self, func: F) -> Result<()> + where + F: Fn(&Self) -> Result<()>, + { + let children = self.inputs(); + if !children.is_empty() { + children.into_iter().try_for_each(func) + } else { + Ok(()) + } + } } diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index cfec3cb741bd..303d97bb4f94 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -72,11 +72,14 @@ impl ScalarSubqueryToJoin { Ok(subquery) => subquery, _ => return Ok(()), }; - let subquery = self + let subquery_plan = self .try_optimize(&subquery.subquery, config)? .map(Arc::new) .unwrap_or_else(|| subquery.subquery.clone()); - let subquery = Subquery { subquery }; + let subquery = Subquery { + subquery: subquery_plan, + outer_ref_columns: subquery.outer_ref_columns.clone(), + }; let res = SubqueryInfo::new(subquery, expr, *op, lhs); subqueries.push(res); Ok(()) @@ -272,16 +275,10 @@ fn optimize_scalar( // qualify the join columns for outside the subquery let mut subqry_cols: Vec<_> = subqry_cols .iter() - .map(|it| Column { - relation: Some(subqry_alias.clone()), - name: it.name.clone(), - }) + .map(|it| Column::new(Some(subqry_alias.clone()), it.name.clone())) .collect(); - let qry_expr = Expr::Column(Column { - relation: Some(subqry_alias), - name: "__value".to_string(), - }); + let qry_expr = Expr::Column(Column::new(Some(subqry_alias), "__value".to_string())); // if correlated subquery's operation is column equality, put the clause into join on clause. let mut restore_where_clause = true; @@ -330,7 +327,6 @@ fn optimize_scalar( new_plan = new_plan.filter(expr)? } let new_plan = new_plan.build()?; - Ok(Some(new_plan)) } diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 033b554402c9..bfac8da643db 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -254,6 +254,7 @@ impl<'a> ConstEvaluator<'a> { | Expr::AggregateUDF { .. } | Expr::ScalarVariable(_, _) | Expr::Column(_) + | Expr::OuterReferenceColumn(_, _) | Expr::Exists { .. } | Expr::InSubquery { .. } | Expr::ScalarSubquery(_) @@ -390,6 +391,22 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { lit(negated) } + // expr IN ((subquery)) -> expr IN (subquery), see ##5529 + Expr::InList { + expr, + mut list, + negated, + } if list.len() == 1 + && matches!(list.first(), Some(Expr::ScalarSubquery { .. })) => + { + let Expr::ScalarSubquery(subquery) = list.remove(0) else { unreachable!() }; + Expr::InSubquery { + expr, + subquery, + negated, + } + } + // if expr is a single column reference: // expr IN (A, B, ...) --> (expr = A) OR (expr = B) OR (expr = C) Expr::InList { @@ -1002,6 +1019,11 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { // Expr::Not(inner) => negate_clause(*inner), + // + // Rules for Negative + // + Expr::Negative(inner) => distribute_negation(*inner), + // // Rules for Case // @@ -1110,6 +1132,7 @@ mod tests { }; use super::*; + use crate::test::test_table_scan_with_name; use arrow::{ array::{ArrayRef, Int32Array}, datatypes::{DataType, Field, Schema}, @@ -2122,6 +2145,37 @@ mod tests { assert_eq!(simplify(expr), expected); } + #[test] + fn test_simplify_by_de_morgan_laws() { + // Laws with logical operations + // !(c3 AND c4) --> !c3 OR !c4 + let expr = and(col("c3"), col("c4")).not(); + let expected = or(col("c3").not(), col("c4").not()); + assert_eq!(simplify(expr), expected); + // !(c3 OR c4) --> !c3 AND !c4 + let expr = or(col("c3"), col("c4")).not(); + let expected = and(col("c3").not(), col("c4").not()); + assert_eq!(simplify(expr), expected); + // !(!c3) --> c3 + let expr = col("c3").not().not(); + let expected = col("c3"); + assert_eq!(simplify(expr), expected); + + // Laws with bitwise operations + // !(c3 & c4) --> !c3 | !c4 + let expr = -bitwise_and(col("c3"), col("c4")); + let expected = bitwise_or(-col("c3"), -col("c4")); + assert_eq!(simplify(expr), expected); + // !(c3 | c4) --> !c3 & !c4 + let expr = -bitwise_or(col("c3"), col("c4")); + let expected = bitwise_and(-col("c3"), -col("c4")); + assert_eq!(simplify(expr), expected); + // !(!c3) --> c3 + let expr = -(-col("c3")); + let expected = col("c3"); + assert_eq!(simplify(expr), expected); + } + #[test] fn test_simplify_null_and_false() { let expr = and(lit_bool_null(), lit(false)); @@ -2410,14 +2464,14 @@ mod tests { Arc::new( DFSchema::new_with_metadata( vec![ - DFField::new(None, "c1", DataType::Utf8, true), - DFField::new(None, "c2", DataType::Boolean, true), - DFField::new(None, "c3", DataType::Int64, true), - DFField::new(None, "c4", DataType::UInt32, true), - DFField::new(None, "c1_non_null", DataType::Utf8, false), - DFField::new(None, "c2_non_null", DataType::Boolean, false), - DFField::new(None, "c3_non_null", DataType::Int64, false), - DFField::new(None, "c4_non_null", DataType::UInt32, false), + DFField::new_unqualified("c1", DataType::Utf8, true), + DFField::new_unqualified("c2", DataType::Boolean, true), + DFField::new_unqualified("c3", DataType::Int64, true), + DFField::new_unqualified("c4", DataType::UInt32, true), + DFField::new_unqualified("c1_non_null", DataType::Utf8, false), + DFField::new_unqualified("c2_non_null", DataType::Boolean, false), + DFField::new_unqualified("c3_non_null", DataType::Int64, false), + DFField::new_unqualified("c4_non_null", DataType::UInt32, false), ], HashMap::new(), ) @@ -2425,11 +2479,6 @@ mod tests { ) } - #[test] - fn simplify_expr_not_not() { - assert_eq!(simplify(col("c2").not().not().not()), col("c2").not(),); - } - #[test] fn simplify_expr_null_comparison() { // x = null is always null @@ -2688,6 +2737,51 @@ mod tests { simplify(in_list(col("c1"), vec![lit(1), lit(2)], true)), col("c1").not_eq(lit(2)).and(col("c1").not_eq(lit(1))) ); + + let subquery = Arc::new(test_table_scan_with_name("test").unwrap()); + assert_eq!( + simplify(in_list( + col("c1"), + vec![scalar_subquery(subquery.clone())], + false + )), + in_subquery(col("c1"), subquery.clone()) + ); + assert_eq!( + simplify(in_list( + col("c1"), + vec![scalar_subquery(subquery.clone())], + true + )), + not_in_subquery(col("c1"), subquery) + ); + + let subquery1 = + scalar_subquery(Arc::new(test_table_scan_with_name("test1").unwrap())); + let subquery2 = + scalar_subquery(Arc::new(test_table_scan_with_name("test2").unwrap())); + + // c1 NOT IN (, ) -> c1 != AND c1 != + assert_eq!( + simplify(in_list( + col("c1"), + vec![subquery1.clone(), subquery2.clone()], + true + )), + col("c1") + .not_eq(subquery2.clone()) + .and(col("c1").not_eq(subquery1.clone())) + ); + + // c1 IN (, ) -> c1 == OR c1 == + assert_eq!( + simplify(in_list( + col("c1"), + vec![subquery1.clone(), subquery2.clone()], + false + )), + col("c1").eq(subquery2).or(col("c1").eq(subquery1)) + ); } #[test] diff --git a/datafusion/optimizer/src/simplify_expressions/utils.rs b/datafusion/optimizer/src/simplify_expressions/utils.rs index 352674c3a68e..8b3f437dc233 100644 --- a/datafusion/optimizer/src/simplify_expressions/utils.rs +++ b/datafusion/optimizer/src/simplify_expressions/utils.rs @@ -20,7 +20,7 @@ use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_expr::{ expr::{Between, BinaryExpr}, - expr_fn::{and, concat_ws, or}, + expr_fn::{and, bitwise_and, bitwise_or, concat_ws, or}, lit, BuiltinScalarFunction, Expr, Like, Operator, }; @@ -311,6 +311,45 @@ pub fn negate_clause(expr: Expr) -> Expr { } } +/// bitwise negate a Negative clause +/// input is the clause to be bitwise negated.(args for Negative clause) +/// For BinaryExpr: +/// ~(A & B) ===> ~A | ~B +/// ~(A | B) ===> ~A & ~B +/// For Negative: +/// ~(~A) ===> A +/// For others, use Negative clause +pub fn distribute_negation(expr: Expr) -> Expr { + match expr { + Expr::BinaryExpr(BinaryExpr { left, op, right }) => { + match op { + // ~(A & B) ===> ~A | ~B + Operator::BitwiseAnd => { + let left = distribute_negation(*left); + let right = distribute_negation(*right); + + bitwise_or(left, right) + } + // ~(A | B) ===> ~A & ~B + Operator::BitwiseOr => { + let left = distribute_negation(*left); + let right = distribute_negation(*right); + + bitwise_and(left, right) + } + // use negative clause + _ => Expr::Negative(Box::new(Expr::BinaryExpr(BinaryExpr::new( + left, op, right, + )))), + } + } + // ~(~A) ===> A + Expr::Negative(expr) => *expr, + // use negative clause + _ => Expr::Negative(Box::new(expr)), + } +} + /// Simplify the `concat` function by /// 1. filtering out all `null` literals /// 2. concatenating contiguous literal arguments diff --git a/datafusion/optimizer/src/type_coercion.rs b/datafusion/optimizer/src/type_coercion.rs index c4c5c29d37a1..437d9cd47d0a 100644 --- a/datafusion/optimizer/src/type_coercion.rs +++ b/datafusion/optimizer/src/type_coercion.rs @@ -125,14 +125,23 @@ impl ExprRewriter for TypeCoercionRewriter { fn mutate(&mut self, expr: Expr) -> Result { match expr { - Expr::ScalarSubquery(Subquery { subquery }) => { + Expr::ScalarSubquery(Subquery { + subquery, + outer_ref_columns, + }) => { let new_plan = optimize_internal(&self.schema, &subquery)?; - Ok(Expr::ScalarSubquery(Subquery::new(new_plan))) + Ok(Expr::ScalarSubquery(Subquery { + subquery: Arc::new(new_plan), + outer_ref_columns, + })) } Expr::Exists { subquery, negated } => { let new_plan = optimize_internal(&self.schema, &subquery.subquery)?; Ok(Expr::Exists { - subquery: Subquery::new(new_plan), + subquery: Subquery { + subquery: Arc::new(new_plan), + outer_ref_columns: subquery.outer_ref_columns, + }, negated, }) } @@ -144,7 +153,10 @@ impl ExprRewriter for TypeCoercionRewriter { let new_plan = optimize_internal(&self.schema, &subquery.subquery)?; Ok(Expr::InSubquery { expr, - subquery: Subquery::new(new_plan), + subquery: Subquery { + subquery: Arc::new(new_plan), + outer_ref_columns: subquery.outer_ref_columns, + }, negated, }) } @@ -664,7 +676,7 @@ mod test { produce_one_row: false, schema: Arc::new( DFSchema::new_with_metadata( - vec![DFField::new(None, "a", DataType::Float64, true)], + vec![DFField::new_unqualified("a", DataType::Float64, true)], std::collections::HashMap::new(), ) .unwrap(), @@ -682,7 +694,7 @@ mod test { produce_one_row: false, schema: Arc::new( DFSchema::new_with_metadata( - vec![DFField::new(None, "a", DataType::Float64, true)], + vec![DFField::new_unqualified("a", DataType::Float64, true)], std::collections::HashMap::new(), ) .unwrap(), @@ -881,7 +893,7 @@ mod test { produce_one_row: false, schema: Arc::new( DFSchema::new_with_metadata( - vec![DFField::new(None, "a", DataType::Int64, true)], + vec![DFField::new_unqualified("a", DataType::Int64, true)], std::collections::HashMap::new(), ) .unwrap(), @@ -899,7 +911,11 @@ mod test { produce_one_row: false, schema: Arc::new( DFSchema::new_with_metadata( - vec![DFField::new(None, "a", DataType::Decimal128(12, 4), true)], + vec![DFField::new_unqualified( + "a", + DataType::Decimal128(12, 4), + true, + )], std::collections::HashMap::new(), ) .unwrap(), @@ -1082,7 +1098,7 @@ mod test { produce_one_row: false, schema: Arc::new( DFSchema::new_with_metadata( - vec![DFField::new(None, "a", data_type, true)], + vec![DFField::new_unqualified("a", data_type, true)], std::collections::HashMap::new(), ) .unwrap(), @@ -1095,7 +1111,7 @@ mod test { // gt let schema = Arc::new( DFSchema::new_with_metadata( - vec![DFField::new(None, "a", DataType::Int64, true)], + vec![DFField::new_unqualified("a", DataType::Int64, true)], std::collections::HashMap::new(), ) .unwrap(), @@ -1109,7 +1125,7 @@ mod test { // eq let schema = Arc::new( DFSchema::new_with_metadata( - vec![DFField::new(None, "a", DataType::Int64, true)], + vec![DFField::new_unqualified("a", DataType::Int64, true)], std::collections::HashMap::new(), ) .unwrap(), @@ -1123,7 +1139,7 @@ mod test { // lt let schema = Arc::new( DFSchema::new_with_metadata( - vec![DFField::new(None, "a", DataType::Int64, true)], + vec![DFField::new_unqualified("a", DataType::Int64, true)], std::collections::HashMap::new(), ) .unwrap(), diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index 46c4d35227ef..4c2a24f05515 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -24,6 +24,7 @@ use crate::{OptimizerConfig, OptimizerRule}; use arrow::datatypes::{ DataType, TimeUnit, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION, }; +use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS}; use datafusion_common::{DFSchemaRef, DataFusionError, Result, ScalarValue}; use datafusion_expr::expr::{BinaryExpr, Cast, TryCast}; use datafusion_expr::expr_rewriter::{ExprRewriter, RewriteRecursion}; @@ -31,6 +32,7 @@ use datafusion_expr::utils::from_plan; use datafusion_expr::{ binary_expr, in_list, lit, Expr, ExprSchemable, LogicalPlan, Operator, }; +use std::cmp::Ordering; use std::sync::Arc; /// [`UnwrapCastInComparison`] attempts to remove casts from @@ -400,16 +402,36 @@ fn try_cast_literal_to_type( DataType::UInt32 => ScalarValue::UInt32(Some(value as u32)), DataType::UInt64 => ScalarValue::UInt64(Some(value as u64)), DataType::Timestamp(TimeUnit::Second, tz) => { - ScalarValue::TimestampSecond(Some(value as i64), tz.clone()) + let value = cast_between_timestamp( + lit_data_type, + DataType::Timestamp(TimeUnit::Second, tz.clone()), + value, + ); + ScalarValue::TimestampSecond(value, tz.clone()) } DataType::Timestamp(TimeUnit::Millisecond, tz) => { - ScalarValue::TimestampMillisecond(Some(value as i64), tz.clone()) + let value = cast_between_timestamp( + lit_data_type, + DataType::Timestamp(TimeUnit::Millisecond, tz.clone()), + value, + ); + ScalarValue::TimestampMillisecond(value, tz.clone()) } DataType::Timestamp(TimeUnit::Microsecond, tz) => { - ScalarValue::TimestampMicrosecond(Some(value as i64), tz.clone()) + let value = cast_between_timestamp( + lit_data_type, + DataType::Timestamp(TimeUnit::Microsecond, tz.clone()), + value, + ); + ScalarValue::TimestampMicrosecond(value, tz.clone()) } DataType::Timestamp(TimeUnit::Nanosecond, tz) => { - ScalarValue::TimestampNanosecond(Some(value as i64), tz.clone()) + let value = cast_between_timestamp( + lit_data_type, + DataType::Timestamp(TimeUnit::Nanosecond, tz.clone()), + value, + ); + ScalarValue::TimestampNanosecond(value, tz.clone()) } DataType::Decimal128(p, s) => { ScalarValue::Decimal128(Some(value), *p, *s) @@ -428,6 +450,32 @@ fn try_cast_literal_to_type( } } +/// Cast a timestamp value from one unit to another +fn cast_between_timestamp(from: DataType, to: DataType, value: i128) -> Option { + let value = value as i64; + let from_scale = match from { + DataType::Timestamp(TimeUnit::Second, _) => 1, + DataType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS, + DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS, + DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS, + _ => return Some(value), + }; + + let to_scale = match to { + DataType::Timestamp(TimeUnit::Second, _) => 1, + DataType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS, + DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS, + DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS, + _ => return Some(value), + }; + + match from_scale.cmp(&to_scale) { + Ordering::Less => value.checked_mul(to_scale / from_scale), + Ordering::Greater => Some(value / (from_scale / to_scale)), + Ordering::Equal => Some(value), + } +} + #[cfg(test)] mod tests { use super::*; @@ -695,14 +743,22 @@ mod tests { Arc::new( DFSchema::new_with_metadata( vec![ - DFField::new(None, "c1", DataType::Int32, false), - DFField::new(None, "c2", DataType::Int64, false), - DFField::new(None, "c3", DataType::Decimal128(18, 2), false), - DFField::new(None, "c4", DataType::Decimal128(38, 37), false), - DFField::new(None, "c5", DataType::Float32, false), - DFField::new(None, "c6", DataType::UInt32, false), - DFField::new(None, "ts_nano_none", timestamp_nano_none_type(), false), - DFField::new(None, "ts_nano_utf", timestamp_nano_utc_type(), false), + DFField::new_unqualified("c1", DataType::Int32, false), + DFField::new_unqualified("c2", DataType::Int64, false), + DFField::new_unqualified("c3", DataType::Decimal128(18, 2), false), + DFField::new_unqualified("c4", DataType::Decimal128(38, 37), false), + DFField::new_unqualified("c5", DataType::Float32, false), + DFField::new_unqualified("c6", DataType::UInt32, false), + DFField::new_unqualified( + "ts_nano_none", + timestamp_nano_none_type(), + false, + ), + DFField::new_unqualified( + "ts_nano_utf", + timestamp_nano_utc_type(), + false, + ), ], HashMap::new(), ) @@ -1070,4 +1126,162 @@ mod tests { } } } + + #[test] + fn test_try_cast_literal_to_timestamp() { + // same timestamp + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampNanosecond(Some(123456), None), + &DataType::Timestamp(TimeUnit::Nanosecond, None), + ) + .unwrap() + .unwrap(); + + assert_eq!( + new_scalar, + ScalarValue::TimestampNanosecond(Some(123456), None) + ); + + // TimestampNanosecond to TimestampMicrosecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampNanosecond(Some(123456), None), + &DataType::Timestamp(TimeUnit::Microsecond, None), + ) + .unwrap() + .unwrap(); + + assert_eq!( + new_scalar, + ScalarValue::TimestampMicrosecond(Some(123), None) + ); + + // TimestampNanosecond to TimestampMillisecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampNanosecond(Some(123456), None), + &DataType::Timestamp(TimeUnit::Millisecond, None), + ) + .unwrap() + .unwrap(); + + assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(Some(0), None)); + + // TimestampNanosecond to TimestampSecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampNanosecond(Some(123456), None), + &DataType::Timestamp(TimeUnit::Second, None), + ) + .unwrap() + .unwrap(); + + assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(0), None)); + + // TimestampMicrosecond to TimestampNanosecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampMicrosecond(Some(123), None), + &DataType::Timestamp(TimeUnit::Nanosecond, None), + ) + .unwrap() + .unwrap(); + + assert_eq!( + new_scalar, + ScalarValue::TimestampNanosecond(Some(123000), None) + ); + + // TimestampMicrosecond to TimestampMillisecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampMicrosecond(Some(123), None), + &DataType::Timestamp(TimeUnit::Millisecond, None), + ) + .unwrap() + .unwrap(); + + assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(Some(0), None)); + + // TimestampMicrosecond to TimestampSecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampMicrosecond(Some(123456789), None), + &DataType::Timestamp(TimeUnit::Second, None), + ) + .unwrap() + .unwrap(); + assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(123), None)); + + // TimestampMillisecond to TimestampNanosecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampMillisecond(Some(123), None), + &DataType::Timestamp(TimeUnit::Nanosecond, None), + ) + .unwrap() + .unwrap(); + assert_eq!( + new_scalar, + ScalarValue::TimestampNanosecond(Some(123000000), None) + ); + + // TimestampMillisecond to TimestampMicrosecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampMillisecond(Some(123), None), + &DataType::Timestamp(TimeUnit::Microsecond, None), + ) + .unwrap() + .unwrap(); + assert_eq!( + new_scalar, + ScalarValue::TimestampMicrosecond(Some(123000), None) + ); + // TimestampMillisecond to TimestampSecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampMillisecond(Some(123456789), None), + &DataType::Timestamp(TimeUnit::Second, None), + ) + .unwrap() + .unwrap(); + assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(123456), None)); + + // TimestampSecond to TimestampNanosecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampSecond(Some(123), None), + &DataType::Timestamp(TimeUnit::Nanosecond, None), + ) + .unwrap() + .unwrap(); + assert_eq!( + new_scalar, + ScalarValue::TimestampNanosecond(Some(123000000000), None) + ); + + // TimestampSecond to TimestampMicrosecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampSecond(Some(123), None), + &DataType::Timestamp(TimeUnit::Microsecond, None), + ) + .unwrap() + .unwrap(); + assert_eq!( + new_scalar, + ScalarValue::TimestampMicrosecond(Some(123000000), None) + ); + + // TimestampSecond to TimestampMillisecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampSecond(Some(123), None), + &DataType::Timestamp(TimeUnit::Millisecond, None), + ) + .unwrap() + .unwrap(); + assert_eq!( + new_scalar, + ScalarValue::TimestampMillisecond(Some(123000), None) + ); + + // overflow + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampSecond(Some(i64::MAX), None), + &DataType::Timestamp(TimeUnit::Millisecond, None), + ) + .unwrap() + .unwrap(); + assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(None, None)); + } } diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 617c123d837c..235b07d9e86c 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -21,7 +21,9 @@ use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::{plan_err, Column, DFSchemaRef, DataFusionError}; use datafusion_common::{DFSchema, Result}; use datafusion_expr::expr::{BinaryExpr, Sort}; -use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter}; +use datafusion_expr::expr_rewriter::{ + strip_outer_reference, ExprRewritable, ExprRewriter, +}; use datafusion_expr::expr_visitor::inspect_expr_pre; use datafusion_expr::logical_plan::LogicalPlanBuilder; use datafusion_expr::utils::{check_all_columns_from_schema, from_plan}; @@ -291,46 +293,52 @@ pub fn find_join_exprs( let mut joins = vec![]; let mut others = vec![]; for filter in exprs.iter() { - let (left, op, right) = match filter { - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - (*left.clone(), *op, *right.clone()) - } - _ => { - others.push((*filter).clone()); - continue; - } - }; - let left = match left { - Expr::Column(c) => c, - _ => { + // If the expression contains correlated predicates, add it to join filters + if filter.contains_outer() { + joins.push(strip_outer_reference((*filter).clone())); + } else { + // TODO remove the logic + let (left, op, right) = match filter { + Expr::BinaryExpr(BinaryExpr { left, op, right }) => { + (*left.clone(), *op, *right.clone()) + } + _ => { + others.push((*filter).clone()); + continue; + } + }; + let left = match left { + Expr::Column(c) => c, + _ => { + others.push((*filter).clone()); + continue; + } + }; + let right = match right { + Expr::Column(c) => c, + _ => { + others.push((*filter).clone()); + continue; + } + }; + if fields.contains(&left.flat_name()) && fields.contains(&right.flat_name()) { others.push((*filter).clone()); - continue; + continue; // both columns present (none closed-upon) } - }; - let right = match right { - Expr::Column(c) => c, - _ => { + if !fields.contains(&left.flat_name()) && !fields.contains(&right.flat_name()) + { others.push((*filter).clone()); - continue; + continue; // neither column present (syntax error?) } - }; - if fields.contains(&left.flat_name()) && fields.contains(&right.flat_name()) { - others.push((*filter).clone()); - continue; // both columns present (none closed-upon) - } - if !fields.contains(&left.flat_name()) && !fields.contains(&right.flat_name()) { - others.push((*filter).clone()); - continue; // neither column present (syntax error?) - } - match op { - Operator::Eq => {} - Operator::NotEq => {} - _ => { - plan_err!(format!("can't optimize {op} column comparison"))?; + match op { + Operator::Eq => {} + Operator::NotEq => {} + _ => { + plan_err!(format!("can't optimize {op} column comparison"))?; + } } + joins.push((*filter).clone()) } - - joins.push((*filter).clone()) } Ok((joins, others)) @@ -460,13 +468,17 @@ fn add_alias_if_changed(original_name: String, expr: Expr) -> Result { /// merge inputs schema into a single schema. pub fn merge_schema(inputs: Vec<&LogicalPlan>) -> DFSchema { - inputs - .iter() - .map(|input| input.schema()) - .fold(DFSchema::empty(), |mut lhs, rhs| { - lhs.merge(rhs); - lhs - }) + if inputs.len() == 1 { + inputs[0].schema().clone().as_ref().clone() + } else { + inputs.iter().map(|input| input.schema()).fold( + DFSchema::empty(), + |mut lhs, rhs| { + lhs.merge(rhs); + lhs + }, + ) + } } /// Extract join predicates from the correclated subquery. @@ -485,11 +497,16 @@ pub(crate) fn extract_join_filters( let mut join_filters: Vec = vec![]; let mut subquery_filters: Vec = vec![]; for expr in subquery_filter_exprs { - let cols = expr.to_columns()?; - if check_all_columns_from_schema(&cols, input_schema.clone())? { - subquery_filters.push(expr.clone()); + // If the expression contains correlated predicates, add it to join filters + if expr.contains_outer() { + join_filters.push(strip_outer_reference(expr.clone())) } else { - join_filters.push(expr.clone()) + let cols = expr.to_columns()?; + if check_all_columns_from_schema(&cols, input_schema.clone())? { + subquery_filters.push(expr.clone()); + } else { + join_filters.push(expr.clone()) + } } } diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index cd1201eb281f..e21569ce83f1 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -69,7 +69,7 @@ uuid = { version = "^1.2", features = ["v4"] } [dev-dependencies] criterion = "0.4" rand = "0.8" -rstest = "0.16.0" +rstest = "0.17.0" [[bench]] harness = false diff --git a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs b/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs index 670030b09e66..083671c6b55d 100644 --- a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs +++ b/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs @@ -367,6 +367,11 @@ impl Accumulator for ApproxPercentileAccumulator { } fn evaluate(&self) -> Result { + if self.digest.count() == 0.0 { + return Err(DataFusionError::Execution( + "aggregate function needs at least one non-null element".to_string(), + )); + } let q = self.digest.estimate_quantile(self.percentile); // These acceptable return types MUST match the validation in diff --git a/datafusion/physical-expr/src/aggregate/count.rs b/datafusion/physical-expr/src/aggregate/count.rs index 38b193ebf270..dc77b794a283 100644 --- a/datafusion/physical-expr/src/aggregate/count.rs +++ b/datafusion/physical-expr/src/aggregate/count.rs @@ -137,13 +137,13 @@ impl Accumulator for CountAccumulator { fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { let array = &values[0]; - self.count += (array.len() - array.data().null_count()) as i64; + self.count += (array.len() - array.null_count()) as i64; Ok(()) } fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { let array = &values[0]; - self.count -= (array.len() - array.data().null_count()) as i64; + self.count -= (array.len() - array.null_count()) as i64; Ok(()) } @@ -183,7 +183,7 @@ impl RowAccumulator for CountRowAccumulator { accessor: &mut RowAccessor, ) -> Result<()> { let array = &values[0]; - let delta = (array.len() - array.data().null_count()) as u64; + let delta = (array.len() - array.null_count()) as u64; accessor.add_u64(self.state_index, delta); Ok(()) } diff --git a/datafusion/physical-expr/src/aggregate/median.rs b/datafusion/physical-expr/src/aggregate/median.rs index 9daabc01d8be..c933f1f75b59 100644 --- a/datafusion/physical-expr/src/aggregate/median.rs +++ b/datafusion/physical-expr/src/aggregate/median.rs @@ -143,6 +143,10 @@ impl Accumulator for MedianAccumulator { } fn evaluate(&self) -> Result { + if !self.all_values.iter().any(|v| !v.is_null()) { + return ScalarValue::try_from(&self.data_type); + } + // Create an array of all the non null values and find the // sorted indexes let array = ScalarValue::iter_to_array( diff --git a/datafusion/physical-expr/src/aggregate/sum.rs b/datafusion/physical-expr/src/aggregate/sum.rs index 0a21de22d32d..a815a33c8c7f 100644 --- a/datafusion/physical-expr/src/aggregate/sum.rs +++ b/datafusion/physical-expr/src/aggregate/sum.rs @@ -236,7 +236,7 @@ impl Accumulator for SumAccumulator { fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { let values = &values[0]; - self.count += (values.len() - values.data().null_count()) as u64; + self.count += (values.len() - values.null_count()) as u64; let delta = sum_batch(values, &self.sum.get_datatype())?; self.sum = self.sum.add(&delta)?; Ok(()) @@ -244,7 +244,7 @@ impl Accumulator for SumAccumulator { fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { let values = &values[0]; - self.count -= (values.len() - values.data().null_count()) as u64; + self.count -= (values.len() - values.null_count()) as u64; let delta = sum_batch(values, &self.sum.get_datatype())?; self.sum = self.sum.sub(&delta)?; Ok(()) diff --git a/datafusion/physical-expr/src/datetime_expressions.rs b/datafusion/physical-expr/src/datetime_expressions.rs index 4b3eeb4424d9..4ce1f18e8b83 100644 --- a/datafusion/physical-expr/src/datetime_expressions.rs +++ b/datafusion/physical-expr/src/datetime_expressions.rs @@ -17,6 +17,7 @@ //! DateTime expressions +use arrow::array::Float64Builder; use arrow::compute::cast; use arrow::{ array::TimestampNanosecondArray, compute::kernels::temporal, datatypes::TimeUnit, @@ -478,6 +479,7 @@ pub fn date_part(args: &[ColumnarValue]) -> Result { "millisecond" => extract_date_part!(&array, millis), "microsecond" => extract_date_part!(&array, micros), "nanosecond" => extract_date_part!(&array, nanos), + "epoch" => extract_date_part!(&array, epoch), _ => Err(DataFusionError::Execution(format!( "Date part '{date_part}' not supported" ))), @@ -537,6 +539,40 @@ where to_ticks(array, 1_000_000_000) } +fn epoch(array: &PrimitiveArray) -> Result +where + T: ArrowTemporalType + ArrowNumericType, + i64: From, +{ + let mut b = Float64Builder::with_capacity(array.len()); + match array.data_type() { + DataType::Timestamp(tu, _) => { + for i in 0..array.len() { + if array.is_null(i) { + b.append_null(); + } else { + let scale = match tu { + TimeUnit::Second => 1, + TimeUnit::Millisecond => 1_000, + TimeUnit::Microsecond => 1_000_000, + TimeUnit::Nanosecond => 1_000_000_000, + }; + + let n: i64 = array.value(i).into(); + b.append_value(n as f64 / scale as f64); + } + } + } + _ => { + return Err(DataFusionError::Internal(format!( + "Can not convert {:?} to epoch", + array.data_type() + ))) + } + } + Ok(b.finish()) +} + #[cfg(test)] mod tests { use std::sync::Arc; diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index fe268c5a34ba..a67c3748ca4e 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -19,13 +19,13 @@ mod adapter; mod kernels; mod kernels_arrow; -use std::convert::TryInto; use std::{any::Any, sync::Arc}; use arrow::array::*; use arrow::compute::kernels::arithmetic::{ add_dyn, add_scalar_dyn as add_dyn_scalar, divide_dyn_opt, - divide_scalar_dyn as divide_dyn_scalar, modulus, modulus_scalar, multiply_dyn, + divide_scalar_dyn as divide_dyn_scalar, modulus_dyn, + modulus_scalar_dyn as modulus_dyn_scalar, multiply_dyn, multiply_scalar_dyn as multiply_dyn_scalar, subtract_dyn, subtract_scalar_dyn as subtract_dyn_scalar, }; @@ -64,7 +64,7 @@ use kernels_arrow::{ is_distinct_from_null, is_distinct_from_utf8, is_not_distinct_from, is_not_distinct_from_bool, is_not_distinct_from_decimal, is_not_distinct_from_f32, is_not_distinct_from_f64, is_not_distinct_from_null, is_not_distinct_from_utf8, - modulus_decimal, modulus_decimal_scalar, multiply_decimal_dyn_scalar, + modulus_decimal_dyn_scalar, modulus_dyn_decimal, multiply_decimal_dyn_scalar, multiply_dyn_decimal, subtract_decimal_dyn_scalar, subtract_dyn_decimal, }; @@ -76,7 +76,7 @@ use crate::intervals::cp_solver::{propagate_arithmetic, propagate_comparison}; use crate::intervals::{apply_operator, Interval}; use crate::physical_expr::down_cast_any_ref; use crate::{analysis_expect, AnalysisContext, ExprBoundaries, PhysicalExpr}; -use datafusion_common::cast::{as_boolean_array, as_decimal128_array}; +use datafusion_common::cast::as_boolean_array; use datafusion_common::ScalarValue; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::type_coercion::binary::binary_operator_data_type; @@ -160,21 +160,6 @@ macro_rules! compute_decimal_op_dyn_scalar { }}; } -macro_rules! compute_decimal_op_scalar { - ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{ - let ll = as_decimal128_array($LEFT).unwrap(); - if let ScalarValue::Decimal128(Some(_), _, _) = $RIGHT { - Ok(Arc::new(paste::expr! {[<$OP _decimal_scalar>]}( - ll, - $RIGHT.try_into()?, - )?)) - } else { - // when the $RIGHT is a NULL, generate a NULL array of LEFT's datatype - Ok(Arc::new(new_null_array($LEFT.data_type(), $LEFT.len()))) - } - }}; -} - macro_rules! compute_decimal_op { ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{ let ll = $LEFT.as_any().downcast_ref::<$DT>().unwrap(); @@ -335,25 +320,6 @@ macro_rules! compute_bool_op { }}; } -/// Invoke a compute kernel on a data array and a scalar value -/// LEFT is array, RIGHT is scalar value -macro_rules! compute_op_scalar { - ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{ - if $RIGHT.is_null() { - Ok(Arc::new(new_null_array($LEFT.data_type(), $LEFT.len()))) - } else { - let ll = $LEFT - .as_any() - .downcast_ref::<$DT>() - .expect("compute_op failed to downcast left side array"); - Ok(Arc::new(paste::expr! {[<$OP _scalar>]}( - &ll, - $RIGHT.try_into()?, - )?)) - } - }}; -} - /// Invoke a dyn compute kernel on a data array and a scalar value /// LEFT is Primitive or Dictionary array of numeric values, RIGHT is scalar value /// OP_TYPE is the return type of scalar function @@ -448,31 +414,6 @@ macro_rules! binary_string_array_op { }}; } -/// Invoke a compute kernel on a pair of arrays -/// The binary_primitive_array_op macro only evaluates for primitive types -/// like integers and floats. -macro_rules! binary_primitive_array_op { - ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{ - match $LEFT.data_type() { - DataType::Decimal128(_,_) => compute_decimal_op!($LEFT, $RIGHT, $OP, Decimal128Array), - DataType::Int8 => compute_op!($LEFT, $RIGHT, $OP, Int8Array), - DataType::Int16 => compute_op!($LEFT, $RIGHT, $OP, Int16Array), - DataType::Int32 => compute_op!($LEFT, $RIGHT, $OP, Int32Array), - DataType::Int64 => compute_op!($LEFT, $RIGHT, $OP, Int64Array), - DataType::UInt8 => compute_op!($LEFT, $RIGHT, $OP, UInt8Array), - DataType::UInt16 => compute_op!($LEFT, $RIGHT, $OP, UInt16Array), - DataType::UInt32 => compute_op!($LEFT, $RIGHT, $OP, UInt32Array), - DataType::UInt64 => compute_op!($LEFT, $RIGHT, $OP, UInt64Array), - DataType::Float32 => compute_op!($LEFT, $RIGHT, $OP, Float32Array), - DataType::Float64 => compute_op!($LEFT, $RIGHT, $OP, Float64Array), - other => Err(DataFusionError::Internal(format!( - "Data type {:?} not supported for binary operation '{}' on primitive arrays", - other, stringify!($OP) - ))), - } - }}; -} - /// Invoke a compute kernel on a pair of arrays /// The binary_primitive_array_op macro only evaluates for primitive types /// like integers and floats. @@ -525,32 +466,6 @@ macro_rules! binary_primitive_array_op_dyn_scalar { }} } -/// Invoke a compute kernel on an array and a scalar -/// The binary_primitive_array_op_scalar macro only evaluates for primitive -/// types like integers and floats. -macro_rules! binary_primitive_array_op_scalar { - ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{ - let result: Result> = match $LEFT.data_type() { - DataType::Decimal128(_,_) => compute_decimal_op_scalar!($LEFT, $RIGHT, $OP, Decimal128Array), - DataType::Int8 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int8Array), - DataType::Int16 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int16Array), - DataType::Int32 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int32Array), - DataType::Int64 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int64Array), - DataType::UInt8 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt8Array), - DataType::UInt16 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt16Array), - DataType::UInt32 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt32Array), - DataType::UInt64 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt64Array), - DataType::Float32 => compute_op_scalar!($LEFT, $RIGHT, $OP, Float32Array), - DataType::Float64 => compute_op_scalar!($LEFT, $RIGHT, $OP, Float64Array), - other => Err(DataFusionError::Internal(format!( - "Data type {:?} not supported for scalar operation '{}' on primitive array", - other, stringify!($OP) - ))), - }; - Some(result) - }}; -} - /// The binary_array_op macro includes types that extend beyond the primitive, /// such as Utf8 strings. #[macro_export] @@ -1128,8 +1043,7 @@ impl BinaryExpr { binary_primitive_array_op_dyn_scalar!(array, scalar, divide) } Operator::Modulo => { - // todo: change to binary_primitive_array_op_dyn_scalar! once modulo is implemented - binary_primitive_array_op_scalar!(array, scalar, modulus) + binary_primitive_array_op_dyn_scalar!(array, scalar, modulus) } Operator::RegexMatch => binary_string_array_flag_op_scalar!( array, @@ -1239,7 +1153,9 @@ impl BinaryExpr { Operator::Divide => { binary_primitive_array_op_dyn!(left, right, divide_dyn_opt) } - Operator::Modulo => binary_primitive_array_op!(left, right, modulus), + Operator::Modulo => { + binary_primitive_array_op_dyn!(left, right, modulus_dyn) + } Operator::And => { if left_data_type == &DataType::Boolean { boolean_op!(&left, &right, and_kleene) @@ -2638,6 +2554,201 @@ mod tests { Ok(()) } + #[test] + #[cfg(feature = "dictionary_expressions")] + fn modulus_op_dict() -> Result<()> { + let schema = Schema::new(vec![ + Field::new( + "a", + DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)), + true, + ), + Field::new( + "b", + DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)), + true, + ), + ]); + + let mut dict_builder = PrimitiveDictionaryBuilder::::new(); + + dict_builder.append(1)?; + dict_builder.append_null(); + dict_builder.append(2)?; + dict_builder.append(5)?; + dict_builder.append(0)?; + + let a = dict_builder.finish(); + + let b = Int32Array::from(vec![1, 2, 4, 8, 16]); + let keys = Int8Array::from(vec![0, 1, 1, 2, 1]); + let b = DictionaryArray::try_new(&keys, &b)?; + + apply_arithmetic::( + Arc::new(schema), + vec![Arc::new(a), Arc::new(b)], + Operator::Modulo, + Int32Array::from(vec![Some(0), None, Some(0), Some(1), Some(0)]), + )?; + + Ok(()) + } + + #[test] + #[cfg(feature = "dictionary_expressions")] + fn modulus_op_dict_decimal() -> Result<()> { + let schema = Schema::new(vec![ + Field::new( + "a", + DataType::Dictionary( + Box::new(DataType::Int8), + Box::new(DataType::Decimal128(10, 0)), + ), + true, + ), + Field::new( + "b", + DataType::Dictionary( + Box::new(DataType::Int8), + Box::new(DataType::Decimal128(10, 0)), + ), + true, + ), + ]); + + let value = 123; + let decimal_array = Arc::new(create_decimal_array( + &[ + Some(value), + Some(value + 2), + Some(value - 1), + Some(value + 1), + ], + 10, + 0, + )) as ArrayRef; + + let keys = Int8Array::from(vec![Some(0), Some(2), None, Some(3), Some(0)]); + let a = DictionaryArray::try_new(&keys, &decimal_array)?; + + let keys = Int8Array::from(vec![Some(0), None, Some(3), Some(2), Some(2)]); + let decimal_array = create_decimal_array( + &[ + Some(value + 1), + Some(value + 3), + Some(value), + Some(value + 2), + ], + 10, + 0, + ); + let b = DictionaryArray::try_new(&keys, &decimal_array)?; + + apply_arithmetic( + Arc::new(schema), + vec![Arc::new(a), Arc::new(b)], + Operator::Modulo, + create_decimal_array(&[Some(123), None, None, Some(1), Some(0)], 10, 0), + )?; + + Ok(()) + } + + #[test] + fn modulus_op_scalar() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let a = Int32Array::from(vec![1, 2, 3, 4, 5]); + + apply_arithmetic_scalar( + Arc::new(schema), + vec![Arc::new(a)], + Operator::Modulo, + ScalarValue::Int32(Some(2)), + Arc::new(Int32Array::from(vec![1, 0, 1, 0, 1])), + )?; + + Ok(()) + } + + #[test] + fn modules_op_dict_scalar() -> Result<()> { + let schema = Schema::new(vec![Field::new( + "a", + DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)), + true, + )]); + + let mut dict_builder = PrimitiveDictionaryBuilder::::new(); + + dict_builder.append(1)?; + dict_builder.append_null(); + dict_builder.append(2)?; + dict_builder.append(5)?; + + let a = dict_builder.finish(); + + let mut dict_builder = PrimitiveDictionaryBuilder::::new(); + + dict_builder.append(1)?; + dict_builder.append_null(); + dict_builder.append(0)?; + dict_builder.append(1)?; + let expected = dict_builder.finish(); + + apply_arithmetic_scalar( + Arc::new(schema), + vec![Arc::new(a)], + Operator::Modulo, + ScalarValue::Dictionary( + Box::new(DataType::Int8), + Box::new(ScalarValue::Int32(Some(2))), + ), + Arc::new(expected), + )?; + + Ok(()) + } + + #[test] + fn modulus_op_dict_scalar_decimal() -> Result<()> { + let schema = Schema::new(vec![Field::new( + "a", + DataType::Dictionary( + Box::new(DataType::Int8), + Box::new(DataType::Decimal128(10, 0)), + ), + true, + )]); + + let value = 123; + let decimal_array = Arc::new(create_decimal_array( + &[Some(value), None, Some(value - 1), Some(value + 1)], + 10, + 0, + )) as ArrayRef; + + let keys = Int8Array::from(vec![0, 2, 1, 3, 0]); + let a = DictionaryArray::try_new(&keys, &decimal_array)?; + + let keys = Int8Array::from(vec![0, 2, 1, 3, 0]); + let decimal_array = + create_decimal_array(&[Some(1), None, Some(0), Some(0)], 10, 0); + let expected = DictionaryArray::try_new(&keys, &decimal_array)?; + + apply_arithmetic_scalar( + Arc::new(schema), + vec![Arc::new(a)], + Operator::Modulo, + ScalarValue::Dictionary( + Box::new(DataType::Int8), + Box::new(ScalarValue::Decimal128(Some(2), 10, 0)), + ), + Arc::new(expected), + )?; + + Ok(()) + } + fn apply_arithmetic( schema: SchemaRef, data: Vec, diff --git a/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs b/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs index 6ef515b33c54..57cf6a1cf80d 100644 --- a/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs +++ b/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs @@ -19,8 +19,9 @@ //! destined for arrow-rs but are in datafusion until they are ported. use arrow::compute::{ - add_dyn, add_scalar_dyn, divide_dyn_opt, divide_scalar_dyn, modulus, modulus_scalar, - multiply_dyn, multiply_scalar_dyn, subtract_dyn, subtract_scalar_dyn, + add_dyn, add_scalar_dyn, divide_dyn_opt, divide_scalar_dyn, modulus_dyn, + modulus_scalar_dyn, multiply_dyn, multiply_scalar_dyn, subtract_dyn, + subtract_scalar_dyn, }; use arrow::datatypes::Decimal128Type; use arrow::{array::*, datatypes::ArrowNumericType, downcast_dictionary_array}; @@ -390,23 +391,23 @@ pub(crate) fn divide_dyn_opt_decimal( decimal_array_with_precision_scale(array, precision, scale) } -pub(crate) fn modulus_decimal( - left: &Decimal128Array, - right: &Decimal128Array, -) -> Result { - let array = - modulus(left, right)?.with_precision_and_scale(left.precision(), left.scale())?; - Ok(array) +pub(crate) fn modulus_dyn_decimal( + left: &dyn Array, + right: &dyn Array, +) -> Result { + let (precision, scale) = get_precision_scale(left)?; + let array = modulus_dyn(left, right)?; + decimal_array_with_precision_scale(array, precision, scale) } -pub(crate) fn modulus_decimal_scalar( - left: &Decimal128Array, +pub(crate) fn modulus_decimal_dyn_scalar( + left: &dyn Array, right: i128, -) -> Result { - // `0` for right will be checked in `modulus_scalar` - let array = modulus_scalar(left, right)? - .with_precision_and_scale(left.precision(), left.scale())?; - Ok(array) +) -> Result { + let (precision, scale) = get_precision_scale(left)?; + + let array = modulus_scalar_dyn::(left, right)?; + decimal_array_with_precision_scale(array, precision, scale) } #[cfg(test)] @@ -570,14 +571,16 @@ mod tests { 3, ); assert_eq!(&expect, result); - let result = modulus_decimal(&left_decimal_array, &right_decimal_array)?; + let result = modulus_dyn_decimal(&left_decimal_array, &right_decimal_array)?; + let result = as_decimal128_array(&result)?; let expect = create_decimal_array(&[Some(7), None, Some(37), Some(16), None], 25, 3); - assert_eq!(expect, result); - let result = modulus_decimal_scalar(&left_decimal_array, 10)?; + assert_eq!(&expect, result); + let result = modulus_decimal_dyn_scalar(&left_decimal_array, 10)?; + let result = as_decimal128_array(&result)?; let expect = create_decimal_array(&[Some(7), None, Some(7), Some(7), Some(7)], 25, 3); - assert_eq!(expect, result); + assert_eq!(&expect, result); Ok(()) } @@ -589,9 +592,10 @@ mod tests { let err = divide_decimal_dyn_scalar(&left_decimal_array, 0).unwrap_err(); assert_eq!("Arrow error: Divide by zero error", err.to_string()); - let err = modulus_decimal(&left_decimal_array, &right_decimal_array).unwrap_err(); + let err = + modulus_dyn_decimal(&left_decimal_array, &right_decimal_array).unwrap_err(); assert_eq!("Arrow error: Divide by zero error", err.to_string()); - let err = modulus_decimal_scalar(&left_decimal_array, 0).unwrap_err(); + let err = modulus_decimal_dyn_scalar(&left_decimal_array, 0).unwrap_err(); assert_eq!("Arrow error: Divide by zero error", err.to_string()); } diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 3dff4ae9745a..d649474455db 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -716,10 +716,10 @@ mod tests { //let valid_array = vec![true, false, false, true, false, tru let null_buffer = Buffer::from([0b00101001u8]); - let load4 = ArrayDataBuilder::new(load4.data_type().clone()) - .len(load4.len()) + let load4 = load4 + .into_data() + .into_builder() .null_bit_buffer(Some(null_buffer)) - .buffers(load4.data().buffers().to_vec()) .build() .unwrap(); let load4: Float64Array = load4.into(); diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 3a5a25ff66cf..e37d5cd7498a 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -86,7 +86,7 @@ where { fn new(array: &T, hash_set: ArrayHashSet) -> Self { Self { - array: T::from(array.data().clone()), + array: downcast_array(array), hash_set, } } @@ -103,15 +103,14 @@ where v => { let values_contains = self.contains(v.values().as_ref(), negated)?; let result = take(&values_contains, v.keys(), None)?; - return Ok(BooleanArray::from(result.data().clone())) + return Ok(downcast_array(result.as_ref())) } _ => {} } let v = v.as_any().downcast_ref::().unwrap(); - let in_data = self.array.data(); let in_array = &self.array; - let has_nulls = in_data.null_count() != 0; + let has_nulls = in_array.null_count() != 0; Ok(ArrayIter::new(v) .map(|v| { diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs b/datafusion/physical-expr/src/intervals/cp_solver.rs index 302a86cdc927..66367001c642 100644 --- a/datafusion/physical-expr/src/intervals/cp_solver.rs +++ b/datafusion/physical-expr/src/intervals/cp_solver.rs @@ -326,7 +326,7 @@ impl ExprIntervalGraph { // ``` /// This function associates stable node indices with [PhysicalExpr]s so - /// that we can match Arc and NodeIndex objects during + /// that we can match `Arc` and NodeIndex objects during /// membership tests. pub fn gather_node_indices( &mut self, diff --git a/datafusion/physical-expr/src/utils.rs b/datafusion/physical-expr/src/utils.rs index a80a92bc5a5a..5b357d931821 100644 --- a/datafusion/physical-expr/src/utils.rs +++ b/datafusion/physical-expr/src/utils.rs @@ -377,35 +377,18 @@ pub fn reassign_predicate_columns( schema: &SchemaRef, ignore_not_found: bool, ) -> Result, DataFusionError> { - let mut rewriter = ColumnAssigner { - schema, - ignore_not_found, - }; - pred.transform_using(&mut rewriter) -} - -#[derive(Debug)] -struct ColumnAssigner<'a> { - schema: &'a SchemaRef, - ignore_not_found: bool, -} - -impl<'a> TreeNodeRewriter> for ColumnAssigner<'a> { - fn mutate( - &mut self, - expr: Arc, - ) -> Result, DataFusionError> { + pred.transform(&|expr| { if let Some(column) = expr.as_any().downcast_ref::() { - let index = match self.schema.index_of(column.name()) { + let index = match schema.index_of(column.name()) { Ok(idx) => idx, - Err(_) if self.ignore_not_found => usize::MAX, + Err(_) if ignore_not_found => usize::MAX, Err(e) => return Err(e.into()), }; - return Ok(Arc::new(Column::new(column.name(), index))); + return Ok(Some(Arc::new(Column::new(column.name(), index)))); } - Ok(expr) - } + Ok(None) + }) } #[cfg(test)] diff --git a/datafusion/physical-expr/src/window/aggregate.rs b/datafusion/physical-expr/src/window/aggregate.rs index e9b43cd070cd..95fd86148ac2 100644 --- a/datafusion/physical-expr/src/window/aggregate.rs +++ b/datafusion/physical-expr/src/window/aggregate.rs @@ -27,7 +27,7 @@ use arrow::{array::ArrayRef, datatypes::Field}; use datafusion_common::ScalarValue; use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::{Accumulator, WindowFrame, WindowFrameUnits}; +use datafusion_expr::{Accumulator, WindowFrame}; use crate::window::window_expr::{reverse_order_bys, AggregateWindowExpr}; use crate::window::{ @@ -115,11 +115,8 @@ impl WindowExpr for PlainAggregateWindowExpr { })?; let mut state = &mut window_state.state; if self.window_frame.start_bound.is_unbounded() { - state.window_frame_range.start = if state.window_frame_range.end >= 1 { - state.window_frame_range.end - 1 - } else { - 0 - }; + state.window_frame_range.start = + state.window_frame_range.end.saturating_sub(1); } } Ok(()) @@ -159,10 +156,8 @@ impl WindowExpr for PlainAggregateWindowExpr { } fn uses_bounded_memory(&self) -> bool { - // NOTE: Currently, groups queries do not support the bounded memory variant. self.aggregate.supports_bounded_execution() && !self.window_frame.end_bound.is_unbounded() - && !matches!(self.window_frame.units, WindowFrameUnits::Groups) } } diff --git a/datafusion/physical-expr/src/window/built_in.rs b/datafusion/physical-expr/src/window/built_in.rs index 70ddb2c7671a..329eac333460 100644 --- a/datafusion/physical-expr/src/window/built_in.rs +++ b/datafusion/physical-expr/src/window/built_in.rs @@ -32,11 +32,11 @@ use crate::window::{ }; use crate::{expressions::PhysicalSortExpr, PhysicalExpr}; use arrow::array::{new_empty_array, Array, ArrayRef}; -use arrow::compute::{concat, SortOptions}; +use arrow::compute::SortOptions; use arrow::datatypes::Field; use arrow::record_batch::RecordBatch; -use datafusion_common::{DataFusionError, Result, ScalarValue}; -use datafusion_expr::{WindowFrame, WindowFrameUnits}; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::WindowFrame; /// A window expr that takes the form of a built in window function #[derive(Debug)] @@ -104,17 +104,20 @@ impl WindowExpr for BuiltInWindowExpr { let mut row_wise_results = vec![]; let (values, order_bys) = self.get_values_orderbys(batch)?; - let mut window_frame_ctx = WindowFrameContext::new( - &self.window_frame, - sort_options, - Range { start: 0, end: 0 }, - ); + let mut window_frame_ctx = + WindowFrameContext::new(self.window_frame.clone(), sort_options); + let mut last_range = Range { start: 0, end: 0 }; // We iterate on each row to calculate window frame range and and window function result for idx in 0..num_rows { - let range = - window_frame_ctx.calculate_range(&order_bys, num_rows, idx)?; + let range = window_frame_ctx.calculate_range( + &order_bys, + &last_range, + num_rows, + idx, + )?; let value = evaluator.evaluate_inside_range(&values, &range)?; row_wise_results.push(value); + last_range = range; } ScalarValue::iter_to_array(row_wise_results.into_iter()) } else if evaluator.include_rank() { @@ -139,26 +142,23 @@ impl WindowExpr for BuiltInWindowExpr { let out_type = field.data_type(); let sort_options = self.order_by.iter().map(|o| o.options).collect::>(); for (partition_row, partition_batch_state) in partition_batches.iter() { - if !window_agg_state.contains_key(partition_row) { - let evaluator = self.expr.create_evaluator()?; - window_agg_state.insert( - partition_row.clone(), - WindowState { - state: WindowAggState::new(out_type)?, - window_fn: WindowFn::Builtin(evaluator), - }, - ); - }; let window_state = - window_agg_state.get_mut(partition_row).ok_or_else(|| { - DataFusionError::Execution("Cannot find state".to_string()) - })?; + if let Some(window_state) = window_agg_state.get_mut(partition_row) { + window_state + } else { + let evaluator = self.expr.create_evaluator()?; + window_agg_state + .entry(partition_row.clone()) + .or_insert(WindowState { + state: WindowAggState::new(out_type)?, + window_fn: WindowFn::Builtin(evaluator), + }) + }; let evaluator = match &mut window_state.window_fn { WindowFn::Builtin(evaluator) => evaluator, _ => unreachable!(), }; let mut state = &mut window_state.state; - state.is_end = partition_batch_state.is_end; let (values, order_bys) = self.get_values_orderbys(&partition_batch_state.record_batch)?; @@ -166,13 +166,6 @@ impl WindowExpr for BuiltInWindowExpr { // We iterate on each row to perform a running calculation. let record_batch = &partition_batch_state.record_batch; let num_rows = record_batch.num_rows(); - let last_range = state.window_frame_range.clone(); - let mut window_frame_ctx = WindowFrameContext::new( - &self.window_frame, - sort_options.clone(), - // Start search from the last range - last_range, - ); let sort_partition_points = if evaluator.include_rank() { let columns = self.sort_columns(record_batch)?; self.evaluate_partition_points(num_rows, &columns)? @@ -180,33 +173,43 @@ impl WindowExpr for BuiltInWindowExpr { vec![] }; let mut row_wise_results: Vec = vec![]; - let mut last_range = state.window_frame_range.clone(); for idx in state.last_calculated_index..num_rows { - state.window_frame_range = if self.expr.uses_window_frame() { - window_frame_ctx.calculate_range(&order_bys, num_rows, idx) + let frame_range = if self.expr.uses_window_frame() { + state + .window_frame_ctx + .get_or_insert_with(|| { + WindowFrameContext::new( + self.window_frame.clone(), + sort_options.clone(), + ) + }) + .calculate_range( + &order_bys, + // Start search from the last range + &state.window_frame_range, + num_rows, + idx, + ) } else { - evaluator.get_range(state, num_rows) + evaluator.get_range(idx, num_rows) }?; - evaluator.update_state(state, &order_bys, &sort_partition_points)?; - let frame_range = &state.window_frame_range; // Exit if the range extends all the way: - if frame_range.end == num_rows && !state.is_end { + if frame_range.end == num_rows && !partition_batch_state.is_end { break; } + // Update last range + state.window_frame_range = frame_range; + evaluator.update_state(state, idx, &order_bys, &sort_partition_points)?; row_wise_results.push(evaluator.evaluate_stateful(&values)?); - last_range.clone_from(frame_range); - state.last_calculated_index += 1; } - state.window_frame_range = last_range; let out_col = if row_wise_results.is_empty() { new_empty_array(out_type) } else { ScalarValue::iter_to_array(row_wise_results.into_iter())? }; - state.out_col = concat(&[&state.out_col, &out_col])?; - state.n_row_result_missing = num_rows - state.last_calculated_index; + state.update(&out_col, partition_batch_state)?; if self.window_frame.start_bound.is_unbounded() { let mut evaluator_state = evaluator.state()?; if let BuiltinWindowState::NthValue(nth_value_state) = @@ -236,11 +239,9 @@ impl WindowExpr for BuiltInWindowExpr { } fn uses_bounded_memory(&self) -> bool { - // NOTE: Currently, groups queries do not support the bounded memory variant. self.expr.supports_bounded_execution() && (!self.expr.uses_window_frame() - || !(self.window_frame.end_bound.is_unbounded() - || matches!(self.window_frame.units, WindowFrameUnits::Groups))) + || !self.window_frame.end_bound.is_unbounded()) } } @@ -271,9 +272,7 @@ fn memoize_nth_value( let result = ScalarValue::try_from_array(out, size - 1)?; nth_value_state.finalized_result = Some(result); } - if state.window_frame_range.end > 0 { - state.window_frame_range.start = state.window_frame_range.end - 1; - } + state.window_frame_range.start = state.window_frame_range.end.saturating_sub(1); } Ok(()) } diff --git a/datafusion/physical-expr/src/window/lead_lag.rs b/datafusion/physical-expr/src/window/lead_lag.rs index 794c7bd39249..e2dfd52daf71 100644 --- a/datafusion/physical-expr/src/window/lead_lag.rs +++ b/datafusion/physical-expr/src/window/lead_lag.rs @@ -189,33 +189,25 @@ impl PartitionEvaluator for WindowShiftEvaluator { fn update_state( &mut self, - state: &WindowAggState, + _state: &WindowAggState, + idx: usize, _range_columns: &[ArrayRef], _sort_partition_points: &[Range], ) -> Result<()> { - self.state.idx = state.last_calculated_index; + self.state.idx = idx; Ok(()) } - fn get_range(&self, state: &WindowAggState, n_rows: usize) -> Result> { + fn get_range(&self, idx: usize, n_rows: usize) -> Result> { if self.shift_offset > 0 { let offset = self.shift_offset as usize; - let start = if state.last_calculated_index > offset { - state.last_calculated_index - offset - } else { - 0 - }; - Ok(Range { - start, - end: state.last_calculated_index + 1, - }) + let start = idx.saturating_sub(offset); + let end = idx + 1; + Ok(Range { start, end }) } else { let offset = (-self.shift_offset) as usize; - let end = min(state.last_calculated_index + offset, n_rows); - Ok(Range { - start: state.last_calculated_index, - end, - }) + let end = min(idx + offset, n_rows); + Ok(Range { start: idx, end }) } } diff --git a/datafusion/physical-expr/src/window/nth_value.rs b/datafusion/physical-expr/src/window/nth_value.rs index ef6e3c6d016d..4da91e75ef20 100644 --- a/datafusion/physical-expr/src/window/nth_value.rs +++ b/datafusion/physical-expr/src/window/nth_value.rs @@ -160,6 +160,7 @@ impl PartitionEvaluator for NthValueEvaluator { fn update_state( &mut self, state: &WindowAggState, + _idx: usize, _range_columns: &[ArrayRef], _sort_partition_points: &[Range], ) -> Result<()> { diff --git a/datafusion/physical-expr/src/window/partition_evaluator.rs b/datafusion/physical-expr/src/window/partition_evaluator.rs index 7887d1412b98..758f7c3b1b23 100644 --- a/datafusion/physical-expr/src/window/partition_evaluator.rs +++ b/datafusion/physical-expr/src/window/partition_evaluator.rs @@ -38,9 +38,15 @@ pub trait PartitionEvaluator: Debug + Send { Ok(BuiltinWindowState::Default) } + /// Updates the internal state for Built-in window function + // state is useful to update internal state for Built-in window function. + // idx is the index of last row for which result is calculated. + // range_columns is the result of order by column values. It is used to calculate rank boundaries + // sort_partition_points is the boundaries of each rank in the range_column. It is used to update rank. fn update_state( &mut self, _state: &WindowAggState, + _idx: usize, _range_columns: &[ArrayRef], _sort_partition_points: &[Range], ) -> Result<()> { @@ -54,20 +60,23 @@ pub trait PartitionEvaluator: Debug + Send { )) } - fn get_range(&self, _state: &WindowAggState, _n_rows: usize) -> Result> { + /// Gets the range where Built-in window function result is calculated. + // idx is the index of last row for which result is calculated. + // n_rows is the number of rows of the input record batch (Used during bound check) + fn get_range(&self, _idx: usize, _n_rows: usize) -> Result> { Err(DataFusionError::NotImplemented( "get_range is not implemented for this window function".to_string(), )) } - /// evaluate the partition evaluator against the partition + /// Evaluate the partition evaluator against the partition fn evaluate(&self, _values: &[ArrayRef], _num_rows: usize) -> Result { Err(DataFusionError::NotImplemented( "evaluate is not implemented by default".into(), )) } - /// evaluate window function result inside given range + /// Evaluate window function result inside given range fn evaluate_stateful(&mut self, _values: &[ArrayRef]) -> Result { Err(DataFusionError::NotImplemented( "evaluate_stateful is not implemented by default".into(), diff --git a/datafusion/physical-expr/src/window/rank.rs b/datafusion/physical-expr/src/window/rank.rs index ead9d44535ba..5f016739cfa0 100644 --- a/datafusion/physical-expr/src/window/rank.rs +++ b/datafusion/physical-expr/src/window/rank.rs @@ -25,6 +25,7 @@ use crate::PhysicalExpr; use arrow::array::ArrayRef; use arrow::array::{Float64Array, UInt64Array}; use arrow::datatypes::{DataType, Field}; +use datafusion_common::utils::get_row_at_idx; use datafusion_common::{DataFusionError, Result, ScalarValue}; use std::any::Any; use std::iter; @@ -118,11 +119,10 @@ pub(crate) struct RankEvaluator { } impl PartitionEvaluator for RankEvaluator { - fn get_range(&self, state: &WindowAggState, _n_rows: usize) -> Result> { - Ok(Range { - start: state.last_calculated_index, - end: state.last_calculated_index + 1, - }) + fn get_range(&self, idx: usize, _n_rows: usize) -> Result> { + let start = idx; + let end = idx + 1; + Ok(Range { start, end }) } fn state(&self) -> Result { @@ -132,22 +132,21 @@ impl PartitionEvaluator for RankEvaluator { fn update_state( &mut self, state: &WindowAggState, + idx: usize, range_columns: &[ArrayRef], sort_partition_points: &[Range], ) -> Result<()> { - // find range inside `sort_partition_points` containing `state.last_calculated_index` + // find range inside `sort_partition_points` containing `idx` let chunk_idx = sort_partition_points .iter() - .position(|elem| { - elem.start <= state.last_calculated_index - && state.last_calculated_index < elem.end - }) - .ok_or_else(|| DataFusionError::Execution("Expects sort_partition_points to contain state.last_calculated_index".to_string()))?; + .position(|elem| elem.start <= idx && idx < elem.end) + .ok_or_else(|| { + DataFusionError::Execution( + "Expects sort_partition_points to contain idx".to_string(), + ) + })?; let chunk = &sort_partition_points[chunk_idx]; - let last_rank_data = range_columns - .iter() - .map(|c| ScalarValue::try_from_array(c, chunk.end - 1)) - .collect::>>()?; + let last_rank_data = get_row_at_idx(range_columns, chunk.end - 1)?; let empty = self.state.last_rank_data.is_empty(); if empty || self.state.last_rank_data != last_rank_data { self.state.last_rank_data = last_rank_data; diff --git a/datafusion/physical-expr/src/window/row_number.rs b/datafusion/physical-expr/src/window/row_number.rs index c858a5724a20..8961b277e7de 100644 --- a/datafusion/physical-expr/src/window/row_number.rs +++ b/datafusion/physical-expr/src/window/row_number.rs @@ -19,7 +19,7 @@ use crate::window::partition_evaluator::PartitionEvaluator; use crate::window::window_expr::{BuiltinWindowState, NumRowsState}; -use crate::window::{BuiltInWindowFunctionExpr, WindowAggState}; +use crate::window::BuiltInWindowFunctionExpr; use crate::PhysicalExpr; use arrow::array::{ArrayRef, UInt64Array}; use arrow::datatypes::{DataType, Field}; @@ -81,11 +81,10 @@ impl PartitionEvaluator for NumRowsEvaluator { Ok(BuiltinWindowState::NumRows(self.state.clone())) } - fn get_range(&self, state: &WindowAggState, _n_rows: usize) -> Result> { - Ok(Range { - start: state.last_calculated_index, - end: state.last_calculated_index + 1, - }) + fn get_range(&self, idx: usize, _n_rows: usize) -> Result> { + let start = idx; + let end = idx + 1; + Ok(Range { start, end }) } /// evaluate window function result inside given range diff --git a/datafusion/physical-expr/src/window/sliding_aggregate.rs b/datafusion/physical-expr/src/window/sliding_aggregate.rs index 0723f05c598f..7fa33d71ca44 100644 --- a/datafusion/physical-expr/src/window/sliding_aggregate.rs +++ b/datafusion/physical-expr/src/window/sliding_aggregate.rs @@ -26,7 +26,7 @@ use arrow::datatypes::Field; use arrow::record_batch::RecordBatch; use datafusion_common::{Result, ScalarValue}; -use datafusion_expr::{Accumulator, WindowFrame, WindowFrameUnits}; +use datafusion_expr::{Accumulator, WindowFrame}; use crate::window::window_expr::{reverse_order_bys, AggregateWindowExpr}; use crate::window::{ @@ -138,10 +138,8 @@ impl WindowExpr for SlidingAggregateWindowExpr { } fn uses_bounded_memory(&self) -> bool { - // NOTE: Currently, groups queries do not support the bounded memory variant. self.aggregate.supports_bounded_execution() && !self.window_frame.end_bound.is_unbounded() - && !matches!(self.window_frame.units, WindowFrameUnits::Groups) } } diff --git a/datafusion/physical-expr/src/window/window_expr.rs b/datafusion/physical-expr/src/window/window_expr.rs index 96e22976b3d8..7568fa3b2b58 100644 --- a/datafusion/physical-expr/src/window/window_expr.rs +++ b/datafusion/physical-expr/src/window/window_expr.rs @@ -18,7 +18,7 @@ use crate::window::partition_evaluator::PartitionEvaluator; use crate::window::window_frame_state::WindowFrameContext; use crate::{PhysicalExpr, PhysicalSortExpr}; -use arrow::array::{new_empty_array, ArrayRef}; +use arrow::array::{new_empty_array, Array, ArrayRef}; use arrow::compute::kernels::partition::lexicographical_partition_ranges; use arrow::compute::kernels::sort::SortColumn; use arrow::compute::{concat, SortOptions}; @@ -164,8 +164,18 @@ pub trait AggregateWindowExpr: WindowExpr { fn aggregate_evaluate(&self, batch: &RecordBatch) -> Result { let mut accumulator = self.get_accumulator()?; let mut last_range = Range { start: 0, end: 0 }; - let mut idx = 0; - self.get_result_column(&mut accumulator, batch, &mut last_range, &mut idx, false) + let sort_options: Vec = + self.order_by().iter().map(|o| o.options).collect(); + let mut window_frame_ctx = + WindowFrameContext::new(self.get_window_frame().clone(), sort_options); + self.get_result_column( + &mut accumulator, + batch, + &mut last_range, + &mut window_frame_ctx, + 0, + false, + ) } /// Statefully evaluates the window function against the batch. Maintains @@ -196,20 +206,25 @@ pub trait AggregateWindowExpr: WindowExpr { WindowFn::Aggregate(accumulator) => accumulator, _ => unreachable!(), }; - let mut state = &mut window_state.state; - + let state = &mut window_state.state; let record_batch = &partition_batch_state.record_batch; + + // If there is no window state context, initialize it. + let window_frame_ctx = state.window_frame_ctx.get_or_insert_with(|| { + let sort_options: Vec = + self.order_by().iter().map(|o| o.options).collect(); + WindowFrameContext::new(self.get_window_frame().clone(), sort_options) + }); let out_col = self.get_result_column( accumulator, record_batch, + // Start search from the last range &mut state.window_frame_range, - &mut state.last_calculated_index, + window_frame_ctx, + state.last_calculated_index, !partition_batch_state.is_end, )?; - state.is_end = partition_batch_state.is_end; - state.out_col = concat(&[&state.out_col, &out_col])?; - state.n_row_result_missing = - record_batch.num_rows() - state.last_calculated_index; + state.update(&out_col, partition_batch_state)?; } Ok(()) } @@ -221,23 +236,18 @@ pub trait AggregateWindowExpr: WindowExpr { accumulator: &mut Box, record_batch: &RecordBatch, last_range: &mut Range, - idx: &mut usize, + window_frame_ctx: &mut WindowFrameContext, + mut idx: usize, not_end: bool, ) -> Result { let (values, order_bys) = self.get_values_orderbys(record_batch)?; // We iterate on each row to perform a running calculation. let length = values[0].len(); - let sort_options: Vec = - self.order_by().iter().map(|o| o.options).collect(); - let mut window_frame_ctx = WindowFrameContext::new( - self.get_window_frame(), - sort_options, - // Start search from the last range - last_range.clone(), - ); let mut row_wise_results: Vec = vec![]; - while *idx < length { - let cur_range = window_frame_ctx.calculate_range(&order_bys, length, *idx)?; + while idx < length { + // Start search from the last_range. This squeezes searched range. + let cur_range = + window_frame_ctx.calculate_range(&order_bys, last_range, length, idx)?; // Exit if the range extends all the way: if cur_range.end == length && not_end { break; @@ -248,9 +258,10 @@ pub trait AggregateWindowExpr: WindowExpr { &values, accumulator, )?; - last_range.clone_from(&cur_range); + // Update last range + *last_range = cur_range; row_wise_results.push(value); - *idx += 1; + idx += 1; } if row_wise_results.is_empty() { let field = self.field()?; @@ -340,6 +351,7 @@ pub enum BuiltinWindowState { pub struct WindowAggState { /// The range that we calculate the window function pub window_frame_range: Range, + pub window_frame_ctx: Option, /// The index of the last row that its result is calculated inside the partition record batch buffer. pub last_calculated_index: usize, /// The offset of the deleted row number @@ -353,6 +365,54 @@ pub struct WindowAggState { pub is_end: bool, } +impl WindowAggState { + pub fn prune_state(&mut self, n_prune: usize) { + self.window_frame_range = Range { + start: self.window_frame_range.start - n_prune, + end: self.window_frame_range.end - n_prune, + }; + self.last_calculated_index -= n_prune; + self.offset_pruned_rows += n_prune; + + match self.window_frame_ctx.as_mut() { + // Rows have no state do nothing + Some(WindowFrameContext::Rows(_)) => {} + Some(WindowFrameContext::Range { .. }) => {} + Some(WindowFrameContext::Groups { state, .. }) => { + let mut n_group_to_del = 0; + for (_, end_idx) in &state.group_end_indices { + if n_prune < *end_idx { + break; + } + n_group_to_del += 1; + } + state.group_end_indices.drain(0..n_group_to_del); + state + .group_end_indices + .iter_mut() + .for_each(|(_, start_idx)| *start_idx -= n_prune); + state.current_group_idx -= n_group_to_del; + } + None => {} + }; + } +} + +impl WindowAggState { + pub fn update( + &mut self, + out_col: &ArrayRef, + partition_batch_state: &PartitionBatchState, + ) -> Result<()> { + self.last_calculated_index += out_col.len(); + self.out_col = concat(&[&self.out_col, &out_col])?; + self.n_row_result_missing = + partition_batch_state.record_batch.num_rows() - self.last_calculated_index; + self.is_end = partition_batch_state.is_end; + Ok(()) + } +} + /// State for each unique partition determined according to PARTITION BY column(s) #[derive(Debug)] pub struct PartitionBatchState { @@ -383,6 +443,7 @@ impl WindowAggState { let empty_out_col = ScalarValue::try_from(out_type)?.to_array_of_size(0); Ok(Self { window_frame_range: Range { start: 0, end: 0 }, + window_frame_ctx: None, last_calculated_index: 0, offset_pruned_rows: 0, out_col: empty_out_col, diff --git a/datafusion/physical-expr/src/window/window_frame_state.rs b/datafusion/physical-expr/src/window/window_frame_state.rs index 64abacde49c1..01a4f9ad71a8 100644 --- a/datafusion/physical-expr/src/window/window_frame_state.rs +++ b/datafusion/physical-expr/src/window/window_frame_state.rs @@ -31,37 +31,33 @@ use std::sync::Arc; /// This object stores the window frame state for use in incremental calculations. #[derive(Debug)] -pub enum WindowFrameContext<'a> { +pub enum WindowFrameContext { /// ROWS frames are inherently stateless. - Rows(&'a Arc), + Rows(Arc), /// RANGE frames are stateful, they store indices specifying where the /// previous search left off. This amortizes the overall cost to O(n) /// where n denotes the row count. Range { - window_frame: &'a Arc, + window_frame: Arc, state: WindowFrameStateRange, }, /// GROUPS frames are stateful, they store group boundaries and indices /// specifying where the previous search left off. This amortizes the /// overall cost to O(n) where n denotes the row count. Groups { - window_frame: &'a Arc, + window_frame: Arc, state: WindowFrameStateGroups, }, } -impl<'a> WindowFrameContext<'a> { +impl WindowFrameContext { /// Create a new state object for the given window frame. - pub fn new( - window_frame: &'a Arc, - sort_options: Vec, - last_range: Range, - ) -> Self { + pub fn new(window_frame: Arc, sort_options: Vec) -> Self { match window_frame.units { WindowFrameUnits::Rows => WindowFrameContext::Rows(window_frame), WindowFrameUnits::Range => WindowFrameContext::Range { window_frame, - state: WindowFrameStateRange::new(sort_options, last_range), + state: WindowFrameStateRange::new(sort_options), }, WindowFrameUnits::Groups => WindowFrameContext::Groups { window_frame, @@ -74,10 +70,11 @@ impl<'a> WindowFrameContext<'a> { pub fn calculate_range( &mut self, range_columns: &[ArrayRef], + last_range: &Range, length: usize, idx: usize, ) -> Result> { - match *self { + match self { WindowFrameContext::Rows(window_frame) => { Self::calculate_range_rows(window_frame, length, idx) } @@ -87,7 +84,13 @@ impl<'a> WindowFrameContext<'a> { WindowFrameContext::Range { window_frame, ref mut state, - } => state.calculate_range(window_frame, range_columns, length, idx), + } => state.calculate_range( + window_frame, + last_range, + range_columns, + length, + idx, + ), // Sort options is not used in GROUPS mode calculations as the // inequality of two rows indicates a group change, and ordering // or position of NULLs do not impact inequality. @@ -159,33 +162,29 @@ impl<'a> WindowFrameContext<'a> { } /// This structure encapsulates all the state information we require as we scan -/// ranges of data while processing RANGE frames. Attribute `last_range` stores -/// the resulting indices from the previous search. Since the indices only -/// advance forward, we start from `last_range` subsequently. Thus, the overall -/// time complexity of linear search amortizes to O(n) where n denotes the total -/// row count. +/// ranges of data while processing RANGE frames. /// Attribute `sort_options` stores the column ordering specified by the ORDER /// BY clause. This information is used to calculate the range. #[derive(Debug, Default)] pub struct WindowFrameStateRange { - last_range: Range, sort_options: Vec, } impl WindowFrameStateRange { /// Create a new object to store the search state. - fn new(sort_options: Vec, last_range: Range) -> Self { - Self { - // Stores the search range we calculate for future use. - last_range, - sort_options, - } + fn new(sort_options: Vec) -> Self { + Self { sort_options } } /// This function calculates beginning/ending indices for the frame of the current row. + // Argument `last_range` stores the resulting indices from the previous search. Since the indices only + // advance forward, we start from `last_range` subsequently. Thus, the overall + // time complexity of linear search amortizes to O(n) where n denotes the total + // row count. fn calculate_range( &mut self, window_frame: &Arc, + last_range: &Range, range_columns: &[ArrayRef], length: usize, idx: usize, @@ -198,6 +197,7 @@ impl WindowFrameStateRange { } else { self.calculate_index_of_row::( range_columns, + last_range, idx, Some(n), length, @@ -206,6 +206,7 @@ impl WindowFrameStateRange { } WindowFrameBound::CurrentRow => self.calculate_index_of_row::( range_columns, + last_range, idx, None, length, @@ -213,6 +214,7 @@ impl WindowFrameStateRange { WindowFrameBound::Following(ref n) => self .calculate_index_of_row::( range_columns, + last_range, idx, Some(n), length, @@ -222,12 +224,14 @@ impl WindowFrameStateRange { WindowFrameBound::Preceding(ref n) => self .calculate_index_of_row::( range_columns, + last_range, idx, Some(n), length, )?, WindowFrameBound::CurrentRow => self.calculate_index_of_row::( range_columns, + last_range, idx, None, length, @@ -239,6 +243,7 @@ impl WindowFrameStateRange { } else { self.calculate_index_of_row::( range_columns, + last_range, idx, Some(n), length, @@ -246,9 +251,6 @@ impl WindowFrameStateRange { } } }; - // Store the resulting range so we can start from here subsequently: - self.last_range.start = start; - self.last_range.end = end; Ok(Range { start, end }) } @@ -258,6 +260,7 @@ impl WindowFrameStateRange { fn calculate_index_of_row( &mut self, range_columns: &[ArrayRef], + last_range: &Range, idx: usize, delta: Option<&ScalarValue>, length: usize, @@ -298,9 +301,9 @@ impl WindowFrameStateRange { current_row_values }; let search_start = if SIDE { - self.last_range.start + last_range.start } else { - self.last_range.end + last_range.end }; let compare_fn = |current: &[ScalarValue], target: &[ScalarValue]| { let cmp = compare_rows(current, target, &self.sort_options)?; @@ -332,16 +335,16 @@ impl WindowFrameStateRange { // last row of the group that comes "offset" groups after the current group. // - UNBOUNDED FOLLOWING: End with the last row of the partition. Possible only in frame_end. -// This structure encapsulates all the state information we require as we -// scan groups of data while processing window frames. +/// This structure encapsulates all the state information we require as we +/// scan groups of data while processing window frames. #[derive(Debug, Default)] pub struct WindowFrameStateGroups { /// A tuple containing group values and the row index where the group ends. /// Example: [[1, 1], [1, 1], [2, 1], [2, 1], ...] would correspond to /// [([1, 1], 2), ([2, 1], 4), ...]. - group_start_indices: VecDeque<(Vec, usize)>, + pub group_end_indices: VecDeque<(Vec, usize)>, /// The group index to which the row index belongs. - current_group_idx: usize, + pub current_group_idx: usize, } impl WindowFrameStateGroups { @@ -435,14 +438,28 @@ impl WindowFrameStateGroups { 0 }; let mut group_start = 0; - let last_group = self.group_start_indices.back(); - if let Some((_, group_end)) = last_group { + let last_group = self.group_end_indices.back_mut(); + if let Some((group_row, group_end)) = last_group { + if *group_end < length { + let new_group_row = get_row_at_idx(range_columns, *group_end)?; + // If last/current group keys are the same, we extend the last group: + if new_group_row.eq(group_row) { + // Update the end boundary of the group (search right boundary): + *group_end = search_in_slice( + range_columns, + group_row, + check_equality, + *group_end, + length, + )?; + } + } // Start searching from the last group boundary: group_start = *group_end; } // Advance groups until `idx` is inside a group: - while idx > group_start { + while idx >= group_start { let group_row = get_row_at_idx(range_columns, group_start)?; // Find end boundary of the group (search right boundary): let group_end = search_in_slice( @@ -452,13 +469,13 @@ impl WindowFrameStateGroups { group_start, length, )?; - self.group_start_indices.push_back((group_row, group_end)); + self.group_end_indices.push_back((group_row, group_end)); group_start = group_end; } // Update the group index `idx` belongs to: - while self.current_group_idx < self.group_start_indices.len() - && idx >= self.group_start_indices[self.current_group_idx].1 + while self.current_group_idx < self.group_end_indices.len() + && idx >= self.group_end_indices[self.current_group_idx].1 { self.current_group_idx += 1; } @@ -475,7 +492,7 @@ impl WindowFrameStateGroups { }; // Extend `group_start_indices` until it includes at least `group_idx`: - while self.group_start_indices.len() <= group_idx && group_start < length { + while self.group_end_indices.len() <= group_idx && group_start < length { let group_row = get_row_at_idx(range_columns, group_start)?; // Find end boundary of the group (search right boundary): let group_end = search_in_slice( @@ -485,7 +502,7 @@ impl WindowFrameStateGroups { group_start, length, )?; - self.group_start_indices.push_back((group_row, group_end)); + self.group_end_indices.push_back((group_row, group_end)); group_start = group_end; } @@ -493,10 +510,10 @@ impl WindowFrameStateGroups { Ok(match (SIDE, SEARCH_SIDE) { // Window frame start: (true, _) => { - let group_idx = min(group_idx, self.group_start_indices.len()); + let group_idx = min(group_idx, self.group_end_indices.len()); if group_idx > 0 { // Normally, start at the boundary of the previous group. - self.group_start_indices[group_idx - 1].1 + self.group_end_indices[group_idx - 1].1 } else { // If previous group is out of the table, start at zero. 0 @@ -506,7 +523,7 @@ impl WindowFrameStateGroups { (false, true) => { if self.current_group_idx >= delta { let group_idx = self.current_group_idx - delta; - self.group_start_indices[group_idx].1 + self.group_end_indices[group_idx].1 } else { // Group is out of the table, therefore end at zero. 0 @@ -516,9 +533,9 @@ impl WindowFrameStateGroups { (false, false) => { let group_idx = min( self.current_group_idx + delta, - self.group_start_indices.len() - 1, + self.group_end_indices.len() - 1, ); - self.group_start_indices[group_idx].1 + self.group_end_indices[group_idx].1 } }) } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index d0deb567fad4..76ec5b001708 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -97,7 +97,8 @@ message ParquetFormat { message AvroFormat {} message ListingTableScanNode { - string table_name = 1; + reserved 1; // was string table_name + OwnedTableReference table_name = 14; repeated string paths = 2; string file_extension = 3; ProjectionColumns projection = 4; @@ -115,7 +116,8 @@ message ListingTableScanNode { } message ViewTableScanNode { - string table_name = 1; + reserved 1; // was string table_name + OwnedTableReference table_name = 6; LogicalPlanNode input = 2; Schema schema = 3; ProjectionColumns projection = 4; @@ -124,7 +126,8 @@ message ViewTableScanNode { // Logical Plan to Scan a CustomTableProvider registered at runtime message CustomTableScanNode { - string table_name = 1; + reserved 1; // was string table_name + OwnedTableReference table_name = 6; ProjectionColumns projection = 2; Schema schema = 3; repeated LogicalExprNode filters = 4; @@ -180,6 +183,7 @@ message CreateExternalTableNode { string delimiter = 8; string definition = 9; string file_compression_type = 10; + repeated LogicalExprNode order_exprs = 13; map options = 11; } diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 7246b5f2b2cf..406d9ee27aa5 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -3362,6 +3362,9 @@ impl serde::Serialize for CreateExternalTableNode { if !self.file_compression_type.is_empty() { len += 1; } + if !self.order_exprs.is_empty() { + len += 1; + } if !self.options.is_empty() { len += 1; } @@ -3396,6 +3399,9 @@ impl serde::Serialize for CreateExternalTableNode { if !self.file_compression_type.is_empty() { struct_ser.serialize_field("fileCompressionType", &self.file_compression_type)?; } + if !self.order_exprs.is_empty() { + struct_ser.serialize_field("orderExprs", &self.order_exprs)?; + } if !self.options.is_empty() { struct_ser.serialize_field("options", &self.options)?; } @@ -3424,6 +3430,8 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { "definition", "file_compression_type", "fileCompressionType", + "order_exprs", + "orderExprs", "options", ]; @@ -3439,6 +3447,7 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { Delimiter, Definition, FileCompressionType, + OrderExprs, Options, } impl<'de> serde::Deserialize<'de> for GeneratedField { @@ -3471,6 +3480,7 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { "delimiter" => Ok(GeneratedField::Delimiter), "definition" => Ok(GeneratedField::Definition), "fileCompressionType" | "file_compression_type" => Ok(GeneratedField::FileCompressionType), + "orderExprs" | "order_exprs" => Ok(GeneratedField::OrderExprs), "options" => Ok(GeneratedField::Options), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } @@ -3501,6 +3511,7 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { let mut delimiter__ = None; let mut definition__ = None; let mut file_compression_type__ = None; + let mut order_exprs__ = None; let mut options__ = None; while let Some(k) = map.next_key()? { match k { @@ -3564,6 +3575,12 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { } file_compression_type__ = Some(map.next_value()?); } + GeneratedField::OrderExprs => { + if order_exprs__.is_some() { + return Err(serde::de::Error::duplicate_field("orderExprs")); + } + order_exprs__ = Some(map.next_value()?); + } GeneratedField::Options => { if options__.is_some() { return Err(serde::de::Error::duplicate_field("options")); @@ -3585,6 +3602,7 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { delimiter: delimiter__.unwrap_or_default(), definition: definition__.unwrap_or_default(), file_compression_type: file_compression_type__.unwrap_or_default(), + order_exprs: order_exprs__.unwrap_or_default(), options: options__.unwrap_or_default(), }) } @@ -4286,7 +4304,7 @@ impl serde::Serialize for CustomTableScanNode { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.table_name.is_empty() { + if self.table_name.is_some() { len += 1; } if self.projection.is_some() { @@ -4302,8 +4320,8 @@ impl serde::Serialize for CustomTableScanNode { len += 1; } let mut struct_ser = serializer.serialize_struct("datafusion.CustomTableScanNode", len)?; - if !self.table_name.is_empty() { - struct_ser.serialize_field("tableName", &self.table_name)?; + if let Some(v) = self.table_name.as_ref() { + struct_ser.serialize_field("tableName", v)?; } if let Some(v) = self.projection.as_ref() { struct_ser.serialize_field("projection", v)?; @@ -4399,7 +4417,7 @@ impl<'de> serde::Deserialize<'de> for CustomTableScanNode { if table_name__.is_some() { return Err(serde::de::Error::duplicate_field("tableName")); } - table_name__ = Some(map.next_value()?); + table_name__ = map.next_value()?; } GeneratedField::Projection => { if projection__.is_some() { @@ -4430,7 +4448,7 @@ impl<'de> serde::Deserialize<'de> for CustomTableScanNode { } } Ok(CustomTableScanNode { - table_name: table_name__.unwrap_or_default(), + table_name: table_name__, projection: projection__, schema: schema__, filters: filters__.unwrap_or_default(), @@ -9586,7 +9604,7 @@ impl serde::Serialize for ListingTableScanNode { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.table_name.is_empty() { + if self.table_name.is_some() { len += 1; } if !self.paths.is_empty() { @@ -9620,8 +9638,8 @@ impl serde::Serialize for ListingTableScanNode { len += 1; } let mut struct_ser = serializer.serialize_struct("datafusion.ListingTableScanNode", len)?; - if !self.table_name.is_empty() { - struct_ser.serialize_field("tableName", &self.table_name)?; + if let Some(v) = self.table_name.as_ref() { + struct_ser.serialize_field("tableName", v)?; } if !self.paths.is_empty() { struct_ser.serialize_field("paths", &self.paths)?; @@ -9779,7 +9797,7 @@ impl<'de> serde::Deserialize<'de> for ListingTableScanNode { if table_name__.is_some() { return Err(serde::de::Error::duplicate_field("tableName")); } - table_name__ = Some(map.next_value()?); + table_name__ = map.next_value()?; } GeneratedField::Paths => { if paths__.is_some() { @@ -9861,7 +9879,7 @@ impl<'de> serde::Deserialize<'de> for ListingTableScanNode { } } Ok(ListingTableScanNode { - table_name: table_name__.unwrap_or_default(), + table_name: table_name__, paths: paths__.unwrap_or_default(), file_extension: file_extension__.unwrap_or_default(), projection: projection__, @@ -20863,7 +20881,7 @@ impl serde::Serialize for ViewTableScanNode { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.table_name.is_empty() { + if self.table_name.is_some() { len += 1; } if self.input.is_some() { @@ -20879,8 +20897,8 @@ impl serde::Serialize for ViewTableScanNode { len += 1; } let mut struct_ser = serializer.serialize_struct("datafusion.ViewTableScanNode", len)?; - if !self.table_name.is_empty() { - struct_ser.serialize_field("tableName", &self.table_name)?; + if let Some(v) = self.table_name.as_ref() { + struct_ser.serialize_field("tableName", v)?; } if let Some(v) = self.input.as_ref() { struct_ser.serialize_field("input", v)?; @@ -20975,7 +20993,7 @@ impl<'de> serde::Deserialize<'de> for ViewTableScanNode { if table_name__.is_some() { return Err(serde::de::Error::duplicate_field("tableName")); } - table_name__ = Some(map.next_value()?); + table_name__ = map.next_value()?; } GeneratedField::Input => { if input__.is_some() { @@ -21004,7 +21022,7 @@ impl<'de> serde::Deserialize<'de> for ViewTableScanNode { } } Ok(ViewTableScanNode { - table_name: table_name__.unwrap_or_default(), + table_name: table_name__, input: input__, schema: schema__, projection: projection__, diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index da95fd558fef..e5b4534f60a0 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -130,8 +130,8 @@ pub struct AvroFormat {} #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ListingTableScanNode { - #[prost(string, tag = "1")] - pub table_name: ::prost::alloc::string::String, + #[prost(message, optional, tag = "14")] + pub table_name: ::core::option::Option, #[prost(string, repeated, tag = "2")] pub paths: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, #[prost(string, tag = "3")] @@ -171,8 +171,8 @@ pub mod listing_table_scan_node { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ViewTableScanNode { - #[prost(string, tag = "1")] - pub table_name: ::prost::alloc::string::String, + #[prost(message, optional, tag = "6")] + pub table_name: ::core::option::Option, #[prost(message, optional, boxed, tag = "2")] pub input: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(message, optional, tag = "3")] @@ -186,8 +186,8 @@ pub struct ViewTableScanNode { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct CustomTableScanNode { - #[prost(string, tag = "1")] - pub table_name: ::prost::alloc::string::String, + #[prost(message, optional, tag = "6")] + pub table_name: ::core::option::Option, #[prost(message, optional, tag = "2")] pub projection: ::core::option::Option, #[prost(message, optional, tag = "3")] @@ -291,6 +291,8 @@ pub struct CreateExternalTableNode { pub definition: ::prost::alloc::string::String, #[prost(string, tag = "10")] pub file_compression_type: ::prost::alloc::string::String, + #[prost(message, repeated, tag = "13")] + pub order_exprs: ::prost::alloc::vec::Vec, #[prost(map = "string, string", tag = "11")] pub options: ::std::collections::HashMap< ::prost::alloc::string::String, diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 1b704f3aa526..aa416e63b8a6 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -145,10 +145,7 @@ impl From for Column { fn from(c: protobuf::Column) -> Self { let protobuf::Column { relation, name } = c; - Self { - relation: relation.map(|r| r.relation), - name, - } + Self::new(relation.map(|r| r.relation), name) } } @@ -190,7 +187,7 @@ impl TryFrom<&protobuf::DfField> for DFField { let field = df_field.field.as_ref().required("field")?; Ok(match &df_field.qualifier { - Some(q) => DFField::from_qualified(&q.relation, field), + Some(q) => DFField::from_qualified(q.relation.clone(), field), None => DFField::from(field), }) } @@ -217,21 +214,17 @@ impl TryFrom for OwnedTableReference { match table_reference_enum { TableReferenceEnum::Bare(protobuf::BareTableReference { table }) => { - Ok(OwnedTableReference::Bare { table }) + Ok(OwnedTableReference::bare(table)) } TableReferenceEnum::Partial(protobuf::PartialTableReference { schema, table, - }) => Ok(OwnedTableReference::Partial { schema, table }), + }) => Ok(OwnedTableReference::partial(schema, table)), TableReferenceEnum::Full(protobuf::FullTableReference { catalog, schema, table, - }) => Ok(OwnedTableReference::Full { - catalog, - schema, - table, - }), + }) => Ok(OwnedTableReference::full(catalog, schema, table)), } } } diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 9c9be27328e4..2a8d24d4103b 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -398,8 +398,13 @@ impl AsLogicalPlan for LogicalPlanNode { let provider = ListingTable::try_new(config)?; + let table_name = from_owned_table_reference( + scan.table_name.as_ref(), + "ListingTableScan", + )?; + LogicalPlanBuilder::scan_with_filters( - &scan.table_name, + table_name, provider_as_source(Arc::new(provider)), projection, filters, @@ -430,8 +435,11 @@ impl AsLogicalPlan for LogicalPlanNode { ctx, )?; + let table_name = + from_owned_table_reference(scan.table_name.as_ref(), "CustomScan")?; + LogicalPlanBuilder::scan_with_filters( - &scan.table_name, + table_name, provider_as_source(provider), projection, filters, @@ -502,6 +510,12 @@ impl AsLogicalPlan for LogicalPlanNode { )))? } + let order_exprs = create_extern_table + .order_exprs + .iter() + .map(|expr| from_proto::parse_expr(expr, ctx)) + .collect::, _>>()?; + Ok(LogicalPlan::CreateExternalTable(CreateExternalTable { schema: pb_schema.try_into()?, name: from_owned_table_reference(create_extern_table.name.as_ref(), "CreateExternalTable")?, @@ -514,6 +528,7 @@ impl AsLogicalPlan for LogicalPlanNode { table_partition_cols: create_extern_table .table_partition_cols .clone(), + order_exprs, if_not_exists: create_extern_table.if_not_exists, file_compression_type: CompressionTypeVariant::from_str(&create_extern_table.file_compression_type).map_err(|_| DataFusionError::NotImplemented(format!("Unsupported file compression type {}", create_extern_table.file_compression_type)))?, definition, @@ -730,8 +745,11 @@ impl AsLogicalPlan for LogicalPlanNode { let provider = ViewTable::try_new(input, definition)?; + let table_name = + from_owned_table_reference(scan.table_name.as_ref(), "ViewScan")?; + LogicalPlanBuilder::scan( - &scan.table_name, + table_name, provider_as_source(Arc::new(provider)), projection, )? @@ -843,7 +861,7 @@ impl AsLogicalPlan for LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::ListingScan( protobuf::ListingTableScanNode { file_format_type: Some(file_format_type), - table_name: table_name.to_owned(), + table_name: Some(table_name.clone().into()), collect_stat: options.collect_stat, file_extension: options.file_extension.clone(), table_partition_cols: options @@ -868,7 +886,7 @@ impl AsLogicalPlan for LogicalPlanNode { Ok(protobuf::LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::ViewScan(Box::new( protobuf::ViewTableScanNode { - table_name: table_name.to_owned(), + table_name: Some(table_name.clone().into()), input: Some(Box::new( protobuf::LogicalPlanNode::try_from_logical_plan( view_table.logical_plan(), @@ -890,7 +908,7 @@ impl AsLogicalPlan for LogicalPlanNode { .try_encode_table_provider(provider, &mut bytes) .map_err(|e| context!("Error serializing custom table", e))?; let scan = CustomScan(CustomTableScanNode { - table_name: table_name.clone(), + table_name: Some(table_name.clone().into()), projection, schema: Some(schema), filters, @@ -1163,6 +1181,7 @@ impl AsLogicalPlan for LogicalPlanNode { if_not_exists, definition, file_compression_type, + order_exprs, options, }) => Ok(protobuf::LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::CreateExternalTable( @@ -1175,6 +1194,10 @@ impl AsLogicalPlan for LogicalPlanNode { table_partition_cols: table_partition_cols.clone(), if_not_exists: *if_not_exists, delimiter: String::from(*delimiter), + order_exprs: order_exprs + .iter() + .map(|expr| expr.try_into()) + .collect::, to_proto::Error>>()?, definition: definition.clone().unwrap_or_default(), file_compression_type: file_compression_type.to_string(), options: options.clone(), diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index c240ef853f5b..da16cd9e0275 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -240,9 +240,9 @@ impl TryFrom<&DataType> for protobuf::arrow_type::ArrowTypeEnum { impl From for protobuf::Column { fn from(c: Column) -> Self { Self { - relation: c - .relation - .map(|relation| protobuf::ColumnRelation { relation }), + relation: c.relation.map(|relation| protobuf::ColumnRelation { + relation: relation.to_string(), + }), name: c.name, } } @@ -854,10 +854,10 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { Expr::Wildcard => Self { expr_type: Some(ExprType::Wildcard(true)), }, - Expr::ScalarSubquery(_) | Expr::InSubquery { .. } | Expr::Exists { .. } => { + Expr::ScalarSubquery(_) | Expr::InSubquery { .. } | Expr::Exists { .. } | Expr::OuterReferenceColumn{..} => { // we would need to add logical plan operators to datafusion.proto to support this // see discussion in https://github.com/apache/arrow-datafusion/issues/2565 - return Err(Error::General("Proto serialization error: Expr::ScalarSubquery(_) | Expr::InSubquery { .. } | Expr::Exists { .. } not supported".to_string())); + return Err(Error::General("Proto serialization error: Expr::ScalarSubquery(_) | Expr::InSubquery { .. } | Expr::Exists { .. } | Exp:OuterReferenceColumn not supported".to_string())); } Expr::GetIndexedField(GetIndexedField { key, expr }) => Self { @@ -1321,12 +1321,14 @@ impl From for protobuf::OwnedTableReference { use protobuf::owned_table_reference::TableReferenceEnum; let table_reference_enum = match t { OwnedTableReference::Bare { table } => { - TableReferenceEnum::Bare(protobuf::BareTableReference { table }) + TableReferenceEnum::Bare(protobuf::BareTableReference { + table: table.to_string(), + }) } OwnedTableReference::Partial { schema, table } => { TableReferenceEnum::Partial(protobuf::PartialTableReference { - schema, - table, + schema: schema.to_string(), + table: table.to_string(), }) } OwnedTableReference::Full { @@ -1334,9 +1336,9 @@ impl From for protobuf::OwnedTableReference { schema, table, } => TableReferenceEnum::Full(protobuf::FullTableReference { - catalog, - schema, - table, + catalog: catalog.to_string(), + schema: schema.to_string(), + table: table.to_string(), }), }; diff --git a/datafusion/sql/Cargo.toml b/datafusion/sql/Cargo.toml index 4997ff41cd85..ac5ee38fd820 100644 --- a/datafusion/sql/Cargo.toml +++ b/datafusion/sql/Cargo.toml @@ -46,4 +46,4 @@ sqlparser = "0.32" [dev-dependencies] ctor = "0.1.22" env_logger = "0.10" -rstest = "0.16" +rstest = "0.17" diff --git a/datafusion/sql/src/expr/arrow_cast.rs b/datafusion/sql/src/expr/arrow_cast.rs index 49104ee05d1a..9a2ac28c46ab 100644 --- a/datafusion/sql/src/expr/arrow_cast.rs +++ b/datafusion/sql/src/expr/arrow_cast.rs @@ -93,7 +93,7 @@ pub fn create_arrow_cast(mut args: Vec, schema: &DFSchema) -> Result /// assert_eq!(data_type, DataType::Int32); /// ``` /// -/// Remove if added to arrow: https://github.com/apache/arrow-rs/issues/3821 +/// Remove if added to arrow: pub fn parse_data_type(val: &str) -> Result { Parser::new(val).parse() } diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 68a5df054b69..e5af0eb26976 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -48,7 +48,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // next, scalar built-in if let Ok(fun) = BuiltinScalarFunction::from_str(&name) { - let args = self.function_args_to_expr(function.args, schema)?; + let args = + self.function_args_to_expr(function.args, schema, planner_context)?; return Ok(Expr::ScalarFunction { fun, args }); }; @@ -62,7 +63,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let order_by = window .order_by .into_iter() - .map(|e| self.order_by_to_sort_expr(e, schema)) + .map(|e| self.order_by_to_sort_expr(e, schema, planner_context)) .collect::>>()?; let window_frame = window .window_frame @@ -80,8 +81,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let fun = self.find_window_func(&name)?; let expr = match fun { WindowFunction::AggregateFunction(aggregate_fun) => { - let (aggregate_fun, args) = - self.aggregate_fn_to_expr(aggregate_fun, function.args, schema)?; + let (aggregate_fun, args) = self.aggregate_fn_to_expr( + aggregate_fun, + function.args, + schema, + planner_context, + )?; Expr::WindowFunction(expr::WindowFunction::new( WindowFunction::AggregateFunction(aggregate_fun), @@ -93,7 +98,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } _ => Expr::WindowFunction(expr::WindowFunction::new( fun, - self.function_args_to_expr(function.args, schema)?, + self.function_args_to_expr(function.args, schema, planner_context)?, partition_by, order_by, window_frame, @@ -105,7 +110,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // next, aggregate built-ins if let Ok(fun) = AggregateFunction::from_str(&name) { let distinct = function.distinct; - let (fun, args) = self.aggregate_fn_to_expr(fun, function.args, schema)?; + let (fun, args) = + self.aggregate_fn_to_expr(fun, function.args, schema, planner_context)?; return Ok(Expr::AggregateFunction(expr::AggregateFunction::new( fun, args, distinct, None, ))); @@ -113,13 +119,15 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // finally, user-defined functions (UDF) and UDAF if let Some(fm) = self.schema_provider.get_function_meta(&name) { - let args = self.function_args_to_expr(function.args, schema)?; + let args = + self.function_args_to_expr(function.args, schema, planner_context)?; return Ok(Expr::ScalarUDF { fun: fm, args }); } // User defined aggregate functions if let Some(fm) = self.schema_provider.get_aggregate_meta(&name) { - let args = self.function_args_to_expr(function.args, schema)?; + let args = + self.function_args_to_expr(function.args, schema, planner_context)?; return Ok(Expr::AggregateUDF { fun: fm, args, @@ -129,7 +137,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Special case arrow_cast (as its type is dependent on its argument value) if name == ARROW_CAST_NAME { - let args = self.function_args_to_expr(function.args, schema)?; + let args = + self.function_args_to_expr(function.args, schema, planner_context)?; return super::arrow_cast::create_arrow_cast(args, schema); } @@ -189,11 +198,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { &self, args: Vec, schema: &DFSchema, + planner_context: &mut PlannerContext, ) -> Result> { args.into_iter() - .map(|a| { - self.sql_fn_arg_to_logical_expr(a, schema, &mut PlannerContext::new()) - }) + .map(|a| self.sql_fn_arg_to_logical_expr(a, schema, planner_context)) .collect::>>() } @@ -202,6 +210,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { fun: AggregateFunction, args: Vec, schema: &DFSchema, + planner_context: &mut PlannerContext, ) -> Result<(AggregateFunction, Vec)> { let args = match fun { // Special case rewrite COUNT(*) to COUNT(constant) @@ -211,14 +220,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => { Ok(Expr::Literal(COUNT_STAR_EXPANSION.clone())) } - _ => self.sql_fn_arg_to_logical_expr( - a, - schema, - &mut PlannerContext::new(), - ), + _ => self.sql_fn_arg_to_logical_expr(a, schema, planner_context), }) .collect::>>()?, - _ => self.function_args_to_expr(args, schema)?, + _ => self.function_args_to_expr(args, schema, planner_context)?, }; Ok((fun, args)) diff --git a/datafusion/sql/src/expr/identifier.rs b/datafusion/sql/src/expr/identifier.rs index a581d5f4720d..5c561270f946 100644 --- a/datafusion/sql/src/expr/identifier.rs +++ b/datafusion/sql/src/expr/identifier.rs @@ -15,18 +15,21 @@ // specific language governing permissions and limitations // under the License. -use crate::planner::{ - idents_to_table_reference, ContextProvider, PlannerContext, SqlToRel, -}; +use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use crate::utils::normalize_ident; use datafusion_common::{ - Column, DFSchema, DataFusionError, OwnedTableReference, Result, ScalarValue, + Column, DFField, DFSchema, DataFusionError, Result, ScalarValue, TableReference, }; use datafusion_expr::{Case, Expr, GetIndexedField}; use sqlparser::ast::{Expr as SQLExpr, Ident}; impl<'a, S: ContextProvider> SqlToRel<'a, S> { - pub(super) fn sql_identifier_to_expr(&self, id: Ident) -> Result { + pub(super) fn sql_identifier_to_expr( + &self, + id: Ident, + schema: &DFSchema, + planner_context: &mut PlannerContext, + ) -> Result { if id.value.starts_with('@') { // TODO: figure out if ScalarVariables should be insensitive. let var_names = vec![id.value]; @@ -44,11 +47,42 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // interpret names with '.' as if they were // compound identifiers, but this is not a compound // identifier. (e.g. it is "foo.bar" not foo.bar) - - Ok(Expr::Column(Column { - relation: None, - name: normalize_ident(id), - })) + let normalize_ident = normalize_ident(id); + match schema.field_with_unqualified_name(normalize_ident.as_str()) { + Ok(_) => { + // found a match without a qualified name, this is a inner table column + Ok(Expr::Column(Column { + relation: None, + name: normalize_ident, + })) + } + Err(_) => { + // check the outer_query_schema and try to find a match + let outer_query_schema_opt = + planner_context.outer_query_schema.as_ref(); + if let Some(outer) = outer_query_schema_opt { + match outer.field_with_unqualified_name(normalize_ident.as_str()) + { + Ok(field) => { + // found an exact match on a qualified name in the outer plan schema, so this is an outer reference column + Ok(Expr::OuterReferenceColumn( + field.data_type().clone(), + field.qualified_column(), + )) + } + Err(_) => Ok(Expr::Column(Column { + relation: None, + name: normalize_ident, + })), + } + } else { + Ok(Expr::Column(Column { + relation: None, + name: normalize_ident, + })) + } + } + } } } @@ -56,7 +90,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { &self, ids: Vec, schema: &DFSchema, + planner_context: &mut PlannerContext, ) -> Result { + if ids.len() < 2 { + return Err(DataFusionError::Internal(format!( + "Not a compound identifier: {ids:?}" + ))); + } + if ids[0].value.starts_with('@') { let var_names: Vec<_> = ids.into_iter().map(normalize_ident).collect(); let ty = self @@ -69,44 +110,97 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { })?; Ok(Expr::ScalarVariable(ty, var_names)) } else { - // only support "schema.table" type identifiers here - let (name, relation) = match idents_to_table_reference( - ids, - self.options.enable_ident_normalization, - )? { - OwnedTableReference::Partial { schema, table } => (table, schema), - r @ OwnedTableReference::Bare { .. } - | r @ OwnedTableReference::Full { .. } => { - return Err(DataFusionError::Plan(format!( - "Unsupported compound identifier '{r:?}'", - ))); - } - }; + let ids = ids + .into_iter() + .map(|id| { + if self.options.enable_ident_normalization { + normalize_ident(id) + } else { + id.value + } + }) + .collect::>(); - // Try and find the reference in schema - match schema.field_with_qualified_name(&relation, &name) { - Ok(_) => { - // found an exact match on a qualified name so this is a table.column identifier - Ok(Expr::Column(Column { - relation: Some(relation), - name, - })) + // Currently not supporting more than one nested level + // Though ideally once that support is in place, this code should work with it + // TODO: remove when can support multiple nested identifiers + if ids.len() > 5 { + return Err(DataFusionError::Internal(format!( + "Unsupported compound identifier: {ids:?}" + ))); + } + + let search_result = search_dfschema(&ids, schema); + match search_result { + // found matching field with spare identifier(s) for nested field(s) in structure + Some((field, nested_names)) if !nested_names.is_empty() => { + // TODO: remove when can support multiple nested identifiers + if nested_names.len() > 1 { + return Err(DataFusionError::Internal(format!( + "Nested identifiers not yet supported for column {}", + field.qualified_column().quoted_flat_name() + ))); + } + let nested_name = nested_names[0].to_string(); + Ok(Expr::GetIndexedField(GetIndexedField::new( + Box::new(Expr::Column(field.qualified_column())), + ScalarValue::Utf8(Some(nested_name)), + ))) } - Err(_) => { - if let Some(field) = - schema.fields().iter().find(|f| f.name().eq(&relation)) - { - // Access to a field of a column which is a structure, example: SELECT my_struct.key - Ok(Expr::GetIndexedField(GetIndexedField::new( - Box::new(Expr::Column(field.qualified_column())), - ScalarValue::Utf8(Some(name)), + // found matching field with no spare identifier(s) + Some((field, _nested_names)) => { + Ok(Expr::Column(field.qualified_column())) + } + None => { + // return default where use all identifiers to not have a nested field + // this len check is because at 5 identifiers will have to have a nested field + if ids.len() == 5 { + Err(DataFusionError::Internal(format!( + "Unsupported compound identifier: {ids:?}" ))) } else { - // table.column identifier - Ok(Expr::Column(Column { - relation: Some(relation), - name, - })) + let outer_query_schema_opt = + planner_context.outer_query_schema.as_ref(); + // check the outer_query_schema and try to find a match + if let Some(outer) = outer_query_schema_opt { + let search_result = search_dfschema(&ids, outer); + match search_result { + // found matching field with spare identifier(s) for nested field(s) in structure + Some((field, nested_names)) + if !nested_names.is_empty() => + { + // TODO: remove when can support nested identifiers for OuterReferenceColumn + Err(DataFusionError::Internal(format!( + "Nested identifiers are not yet supported for OuterReferenceColumn {}", + field.qualified_column().quoted_flat_name() + ))) + } + // found matching field with no spare identifier(s) + Some((field, _nested_names)) => { + // found an exact match on a qualified name in the outer plan schema, so this is an outer reference column + Ok(Expr::OuterReferenceColumn( + field.data_type().clone(), + field.qualified_column(), + )) + } + // found no matching field, will return a default + None => { + let s = &ids[0..ids.len()]; + // safe unwrap as s can never be empty or exceed the bounds + let (relation, column_name) = + form_identifier(s).unwrap(); + let relation = + relation.map(|r| r.to_owned_reference()); + Ok(Expr::Column(Column::new(relation, column_name))) + } + } + } else { + let s = &ids[0..ids.len()]; + // safe unwrap as s can never be empty or exceed the bounds + let (relation, column_name) = form_identifier(s).unwrap(); + let relation = relation.map(|r| r.to_owned_reference()); + Ok(Expr::Column(Column::new(relation, column_name))) + } } } } @@ -160,3 +254,244 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ))) } } + +// (relation, column name) +fn form_identifier(idents: &[String]) -> Result<(Option, &String)> { + match idents.len() { + 1 => Ok((None, &idents[0])), + 2 => Ok(( + Some(TableReference::Bare { + table: (&idents[0]).into(), + }), + &idents[1], + )), + 3 => Ok(( + Some(TableReference::Partial { + schema: (&idents[0]).into(), + table: (&idents[1]).into(), + }), + &idents[2], + )), + 4 => Ok(( + Some(TableReference::Full { + catalog: (&idents[0]).into(), + schema: (&idents[1]).into(), + table: (&idents[2]).into(), + }), + &idents[3], + )), + _ => Err(DataFusionError::Internal(format!( + "Incorrect number of identifiers: {}", + idents.len() + ))), + } +} + +fn search_dfschema<'ids, 'schema>( + ids: &'ids [String], + schema: &'schema DFSchema, +) -> Option<(&'schema DFField, &'ids [String])> { + generate_schema_search_terms(ids).find_map(|(qualifier, column, nested_names)| { + let field = schema.field_with_name(qualifier.as_ref(), column).ok(); + field.map(|f| (f, nested_names)) + }) +} + +// Possibilities we search with, in order from top to bottom for each len: +// +// len = 2: +// 1. (table.column) +// 2. (column).nested +// +// len = 3: +// 1. (schema.table.column) +// 2. (table.column).nested +// 3. (column).nested1.nested2 +// +// len = 4: +// 1. (catalog.schema.table.column) +// 2. (schema.table.column).nested1 +// 3. (table.column).nested1.nested2 +// 4. (column).nested1.nested2.nested3 +// +// len = 5: +// 1. (catalog.schema.table.column).nested +// 2. (schema.table.column).nested1.nested2 +// 3. (table.column).nested1.nested2.nested3 +// 4. (column).nested1.nested2.nested3.nested4 +// +// len > 5: +// 1. (catalog.schema.table.column).nested[.nestedN]+ +// 2. (schema.table.column).nested1.nested2[.nestedN]+ +// 3. (table.column).nested1.nested2.nested3[.nestedN]+ +// 4. (column).nested1.nested2.nested3.nested4[.nestedN]+ +fn generate_schema_search_terms( + ids: &[String], +) -> impl Iterator, &String, &[String])> { + // take at most 4 identifiers to form a Column to search with + // - 1 for the column name + // - 0 to 3 for the TableReference + let bound = ids.len().min(4); + // search terms from most specific to least specific + (0..bound).rev().map(|i| { + let nested_names_index = i + 1; + let qualifier_and_column = &ids[0..nested_names_index]; + // safe unwrap as qualifier_and_column can never be empty or exceed the bounds + let (relation, column_name) = form_identifier(qualifier_and_column).unwrap(); + (relation, column_name, &ids[nested_names_index..]) + }) +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + // testing according to documentation of generate_schema_search_terms function + // where ensure generated search terms are in correct order with correct values + fn test_generate_schema_search_terms() -> Result<()> { + type ExpectedItem = ( + Option>, + &'static str, + &'static [&'static str], + ); + fn assert_vec_eq( + expected: Vec, + actual: Vec<(Option, &String, &[String])>, + ) { + for (expected, actual) in expected.into_iter().zip(actual) { + assert_eq!(expected.0, actual.0, "qualifier"); + assert_eq!(expected.1, actual.1, "column name"); + assert_eq!(expected.2, actual.2, "nested names"); + } + } + + let actual = generate_schema_search_terms(&[]).collect::>(); + assert!(actual.is_empty()); + + let ids = vec!["a".to_string()]; + let actual = generate_schema_search_terms(&ids).collect::>(); + let expected: Vec = vec![(None, "a", &[])]; + assert_vec_eq(expected, actual); + + let ids = vec!["a".to_string(), "b".to_string()]; + let actual = generate_schema_search_terms(&ids).collect::>(); + let expected: Vec = vec![ + (Some(TableReference::bare("a")), "b", &[]), + (None, "a", &["b"]), + ]; + assert_vec_eq(expected, actual); + + let ids = vec!["a".to_string(), "b".to_string(), "c".to_string()]; + let actual = generate_schema_search_terms(&ids).collect::>(); + let expected: Vec = vec![ + (Some(TableReference::partial("a", "b")), "c", &[]), + (Some(TableReference::bare("a")), "b", &["c"]), + (None, "a", &["b", "c"]), + ]; + assert_vec_eq(expected, actual); + + let ids = vec![ + "a".to_string(), + "b".to_string(), + "c".to_string(), + "d".to_string(), + ]; + let actual = generate_schema_search_terms(&ids).collect::>(); + let expected: Vec = vec![ + (Some(TableReference::full("a", "b", "c")), "d", &[]), + (Some(TableReference::partial("a", "b")), "c", &["d"]), + (Some(TableReference::bare("a")), "b", &["c", "d"]), + (None, "a", &["b", "c", "d"]), + ]; + assert_vec_eq(expected, actual); + + let ids = vec![ + "a".to_string(), + "b".to_string(), + "c".to_string(), + "d".to_string(), + "e".to_string(), + ]; + let actual = generate_schema_search_terms(&ids).collect::>(); + let expected: Vec = vec![ + (Some(TableReference::full("a", "b", "c")), "d", &["e"]), + (Some(TableReference::partial("a", "b")), "c", &["d", "e"]), + (Some(TableReference::bare("a")), "b", &["c", "d", "e"]), + (None, "a", &["b", "c", "d", "e"]), + ]; + assert_vec_eq(expected, actual); + + let ids = vec![ + "a".to_string(), + "b".to_string(), + "c".to_string(), + "d".to_string(), + "e".to_string(), + "f".to_string(), + ]; + let actual = generate_schema_search_terms(&ids).collect::>(); + let expected: Vec = vec![ + (Some(TableReference::full("a", "b", "c")), "d", &["e", "f"]), + ( + Some(TableReference::partial("a", "b")), + "c", + &["d", "e", "f"], + ), + (Some(TableReference::bare("a")), "b", &["c", "d", "e", "f"]), + (None, "a", &["b", "c", "d", "e", "f"]), + ]; + assert_vec_eq(expected, actual); + + Ok(()) + } + + #[test] + fn test_form_identifier() -> Result<()> { + let err = form_identifier(&[]).expect_err("empty identifiers didn't fail"); + let expected = "Internal error: Incorrect number of identifiers: 0. \ + This was likely caused by a bug in DataFusion's code and we would \ + welcome that you file an bug report in our issue tracker"; + assert_eq!(err.to_string(), expected); + + let ids = vec!["a".to_string()]; + let (qualifier, column) = form_identifier(&ids)?; + assert_eq!(qualifier, None); + assert_eq!(column, "a"); + + let ids = vec!["a".to_string(), "b".to_string()]; + let (qualifier, column) = form_identifier(&ids)?; + assert_eq!(qualifier, Some(TableReference::bare("a"))); + assert_eq!(column, "b"); + + let ids = vec!["a".to_string(), "b".to_string(), "c".to_string()]; + let (qualifier, column) = form_identifier(&ids)?; + assert_eq!(qualifier, Some(TableReference::partial("a", "b"))); + assert_eq!(column, "c"); + + let ids = vec![ + "a".to_string(), + "b".to_string(), + "c".to_string(), + "d".to_string(), + ]; + let (qualifier, column) = form_identifier(&ids)?; + assert_eq!(qualifier, Some(TableReference::full("a", "b", "c"))); + assert_eq!(column, "d"); + + let err = form_identifier(&[ + "a".to_string(), + "b".to_string(), + "c".to_string(), + "d".to_string(), + "e".to_string(), + ]) + .expect_err("too many identifiers didn't fail"); + let expected = "Internal error: Incorrect number of identifiers: 5. \ + This was likely caused by a bug in DataFusion's code and we would \ + welcome that you file an bug report in our issue tracker"; + assert_eq!(err.to_string(), expected); + + Ok(()) + } +} diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index ad05fbcc16ca..441c29775c77 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -78,7 +78,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let mut expr = self.sql_expr_to_logical_expr(sql, schema, planner_context)?; expr = self.rewrite_partial_qualifier(expr, schema); self.validate_schema_satisfies_exprs(schema, &[expr.clone()])?; - let expr = infer_placeholder_types(expr, schema.clone())?; + let expr = infer_placeholder_types(expr, schema)?; Ok(expr) } @@ -93,7 +93,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .find(|field| match field.qualifier() { Some(field_q) => { field.name() == &col.name - && field_q.ends_with(&format!(".{q}")) + && field_q.to_string().ends_with(&format!(".{q}")) } _ => false, }) { @@ -144,7 +144,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { last_field, fractional_seconds_precision, ), - SQLExpr::Identifier(id) => self.sql_identifier_to_expr(id), + SQLExpr::Identifier(id) => self.sql_identifier_to_expr(id, schema, planner_context), SQLExpr::MapAccess { column, keys } => { if let SQLExpr::Identifier(id) = *column { @@ -161,7 +161,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { plan_indexed(expr, indexes) } - SQLExpr::CompoundIdentifier(ids) => self.sql_compound_identifier_to_expr(ids, schema), + SQLExpr::CompoundIdentifier(ids) => self.sql_compound_identifier_to_expr(ids, schema, planner_context), SQLExpr::Case { operand, @@ -499,36 +499,33 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } -/// Find all `PlaceHolder` tokens in a logical plan, and try to infer their type from context -fn infer_placeholder_types(expr: Expr, schema: DFSchema) -> Result { - rewrite_expr(expr, |expr| { - let expr = match expr { - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - let left = (*left).clone(); - let right = (*right).clone(); - let lt = left.get_type(&schema); - let rt = right.get_type(&schema); - let left = match (&left, rt) { - (Expr::Placeholder { id, data_type }, Ok(dt)) => Expr::Placeholder { - id: id.clone(), - data_type: Some(data_type.clone().unwrap_or(dt)), - }, - _ => left.clone(), - }; - let right = match (&right, lt) { - (Expr::Placeholder { id, data_type }, Ok(dt)) => Expr::Placeholder { - id: id.clone(), - data_type: Some(data_type.clone().unwrap_or(dt)), - }, - _ => right.clone(), - }; - Expr::BinaryExpr(BinaryExpr { - left: Box::new(left), - op, - right: Box::new(right), - }) +// modifies expr if it is a placeholder with datatype of right +fn rewrite_placeholder(expr: &mut Expr, other: &Expr, schema: &DFSchema) -> Result<()> { + if let Expr::Placeholder { id: _, data_type } = expr { + if data_type.is_none() { + let other_dt = other.get_type(schema); + match other_dt { + Err(e) => { + return Err(e.context(format!( + "Can not find type of {other} needed to infer type of {expr}" + )))?; + } + Ok(dt) => { + *data_type = Some(dt); + } } - _ => expr.clone(), + }; + } + Ok(()) +} + +/// Find all [`Expr::PlaceHolder`] tokens in a logical plan, and try to infer their type from context +fn infer_placeholder_types(expr: Expr, schema: &DFSchema) -> Result { + rewrite_expr(expr, |mut expr| { + // Default to assuming the arguments are the same type + if let Expr::BinaryExpr(BinaryExpr { left, op: _, right }) = &mut expr { + rewrite_placeholder(left.as_mut(), right.as_ref(), schema)?; + rewrite_placeholder(right.as_mut(), left.as_ref(), schema)?; }; Ok(expr) }) diff --git a/datafusion/sql/src/expr/order_by.rs b/datafusion/sql/src/expr/order_by.rs index f3a4f2b0432b..d8c0e34f3414 100644 --- a/datafusion/sql/src/expr/order_by.rs +++ b/datafusion/sql/src/expr/order_by.rs @@ -27,6 +27,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { &self, e: OrderByExpr, schema: &DFSchema, + planner_context: &mut PlannerContext, ) -> Result { let OrderByExpr { asc, @@ -55,7 +56,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let field = schema.field(field_index - 1); Expr::Column(field.qualified_column()) } - e => self.sql_expr_to_logical_expr(e, schema, &mut PlannerContext::new())?, + e => self.sql_expr_to_logical_expr(e, schema, planner_context)?, }; Ok({ let asc = asc.unwrap_or(true); diff --git a/datafusion/sql/src/expr/subquery.rs b/datafusion/sql/src/expr/subquery.rs index bb2cf30d3345..8bcb7af95911 100644 --- a/datafusion/sql/src/expr/subquery.rs +++ b/datafusion/sql/src/expr/subquery.rs @@ -30,13 +30,15 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { input_schema: &DFSchema, planner_context: &mut PlannerContext, ) -> Result { + let old_outer_query_schema = planner_context.outer_query_schema.clone(); + planner_context.outer_query_schema = Some(input_schema.clone()); + let sub_plan = self.query_to_plan(subquery, planner_context)?; + let outer_ref_columns = sub_plan.all_out_ref_exprs(); + planner_context.outer_query_schema = old_outer_query_schema; Ok(Expr::Exists { subquery: Subquery { - subquery: Arc::new(self.subquery_to_plan( - subquery, - planner_context, - input_schema, - )?), + subquery: Arc::new(sub_plan), + outer_ref_columns, }, negated, }) @@ -50,14 +52,17 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { input_schema: &DFSchema, planner_context: &mut PlannerContext, ) -> Result { + let old_outer_query_schema = planner_context.outer_query_schema.clone(); + planner_context.outer_query_schema = Some(input_schema.clone()); + let sub_plan = self.query_to_plan(subquery, planner_context)?; + let outer_ref_columns = sub_plan.all_out_ref_exprs(); + planner_context.outer_query_schema = old_outer_query_schema; + let expr = Box::new(self.sql_to_expr(expr, input_schema, planner_context)?); Ok(Expr::InSubquery { - expr: Box::new(self.sql_to_expr(expr, input_schema, planner_context)?), + expr, subquery: Subquery { - subquery: Arc::new(self.subquery_to_plan( - subquery, - planner_context, - input_schema, - )?), + subquery: Arc::new(sub_plan), + outer_ref_columns, }, negated, }) @@ -69,12 +74,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { input_schema: &DFSchema, planner_context: &mut PlannerContext, ) -> Result { + let old_outer_query_schema = planner_context.outer_query_schema.clone(); + planner_context.outer_query_schema = Some(input_schema.clone()); + let sub_plan = self.query_to_plan(subquery, planner_context)?; + let outer_ref_columns = sub_plan.all_out_ref_exprs(); + planner_context.outer_query_schema = old_outer_query_schema; Ok(Expr::ScalarSubquery(Subquery { - subquery: Arc::new(self.subquery_to_plan( - subquery, - planner_context, - input_schema, - )?), + subquery: Arc::new(sub_plan), + outer_ref_columns, })) } } diff --git a/datafusion/sql/src/parser.rs b/datafusion/sql/src/parser.rs index 521466214ff8..27bfa3250102 100644 --- a/datafusion/sql/src/parser.rs +++ b/datafusion/sql/src/parser.rs @@ -18,6 +18,7 @@ //! DataFusion SQL Parser based on [`sqlparser`] use datafusion_common::parsers::CompressionTypeVariant; +use sqlparser::ast::OrderByExpr; use sqlparser::{ ast::{ ColumnDef, ColumnOptionDef, ObjectName, Statement as SQLStatement, @@ -58,6 +59,8 @@ pub struct CreateExternalTable { pub location: String, /// Partition Columns pub table_partition_cols: Vec, + /// Ordered expressions + pub order_exprs: Vec, /// Option to not error if table already exists pub if_not_exists: bool, /// File compression type (GZIP, BZIP2, XZ) @@ -215,7 +218,7 @@ impl<'a> DFParser<'a> { })) } - /// Parse a SQL `CREATE` statementm handling `CREATE EXTERNAL TABLE` + /// Parse a SQL `CREATE` statement handling `CREATE EXTERNAL TABLE` pub fn parse_create(&mut self) -> Result { if self.parser.parse_keyword(Keyword::EXTERNAL) { self.parse_create_external_table() @@ -253,6 +256,49 @@ impl<'a> DFParser<'a> { Ok(partitions) } + /// Parse the ordering clause of a `CREATE EXTERNAL TABLE` SQL statement + pub fn parse_order_by_exprs(&mut self) -> Result, ParserError> { + let mut values = vec![]; + self.parser.expect_token(&Token::LParen)?; + loop { + values.push(self.parse_order_by_expr()?); + if !self.parser.consume_token(&Token::Comma) { + self.parser.expect_token(&Token::RParen)?; + return Ok(values); + } + } + } + + /// Parse an ORDER BY sub-expression optionally followed by ASC or DESC. + pub fn parse_order_by_expr(&mut self) -> Result { + let expr = self.parser.parse_expr()?; + + let asc = if self.parser.parse_keyword(Keyword::ASC) { + Some(true) + } else if self.parser.parse_keyword(Keyword::DESC) { + Some(false) + } else { + None + }; + + let nulls_first = if self + .parser + .parse_keywords(&[Keyword::NULLS, Keyword::FIRST]) + { + Some(true) + } else if self.parser.parse_keywords(&[Keyword::NULLS, Keyword::LAST]) { + Some(false) + } else { + None + }; + + Ok(OrderByExpr { + expr, + asc, + nulls_first, + }) + } + // This is a copy of the equivalent implementation in sqlparser. fn parse_columns( &mut self, @@ -359,6 +405,12 @@ impl<'a> DFParser<'a> { vec![] }; + let order_exprs = if self.parse_has_order() { + self.parse_order_by_exprs()? + } else { + vec![] + }; + let options = if self.parse_has_options() { self.parse_options()? } else { @@ -376,6 +428,7 @@ impl<'a> DFParser<'a> { delimiter, location, table_partition_cols, + order_exprs, if_not_exists, file_compression_type, options, @@ -458,12 +511,17 @@ impl<'a> DFParser<'a> { self.parser .parse_keywords(&[Keyword::PARTITIONED, Keyword::BY]) } + + fn parse_has_order(&mut self) -> bool { + self.parser.parse_keywords(&[Keyword::WITH, Keyword::ORDER]) + } } #[cfg(test)] mod tests { use super::*; - use sqlparser::ast::{DataType, Ident}; + use sqlparser::ast::Expr::Identifier; + use sqlparser::ast::{BinaryOperator, DataType, Expr, Ident}; use CompressionTypeVariant::UNCOMPRESSED; fn expect_parse_ok(sql: &str, expected: Statement) -> Result<(), ParserError> { @@ -520,6 +578,7 @@ mod tests { delimiter: ',', location: "foo.csv".into(), table_partition_cols: vec![], + order_exprs: vec![], if_not_exists: false, file_compression_type: UNCOMPRESSED, options: HashMap::new(), @@ -537,6 +596,7 @@ mod tests { delimiter: '|', location: "foo.csv".into(), table_partition_cols: vec![], + order_exprs: vec![], if_not_exists: false, file_compression_type: UNCOMPRESSED, options: HashMap::new(), @@ -554,6 +614,7 @@ mod tests { delimiter: ',', location: "foo.csv".into(), table_partition_cols: vec!["p1".to_string(), "p2".to_string()], + order_exprs: vec![], if_not_exists: false, file_compression_type: UNCOMPRESSED, options: HashMap::new(), @@ -574,6 +635,7 @@ mod tests { delimiter: ',', location: "foo.csv".into(), table_partition_cols: vec![], + order_exprs: vec![], if_not_exists: false, file_compression_type: UNCOMPRESSED, options: HashMap::new(), @@ -597,6 +659,7 @@ mod tests { delimiter: ',', location: "foo.csv".into(), table_partition_cols: vec![], + order_exprs: vec![], if_not_exists: false, file_compression_type: CompressionTypeVariant::from_str( file_compression_type, @@ -616,6 +679,7 @@ mod tests { delimiter: ',', location: "foo.parquet".into(), table_partition_cols: vec![], + order_exprs: vec![], if_not_exists: false, file_compression_type: UNCOMPRESSED, options: HashMap::new(), @@ -632,6 +696,7 @@ mod tests { delimiter: ',', location: "foo.parquet".into(), table_partition_cols: vec![], + order_exprs: vec![], if_not_exists: false, file_compression_type: UNCOMPRESSED, options: HashMap::new(), @@ -648,6 +713,7 @@ mod tests { delimiter: ',', location: "foo.avro".into(), table_partition_cols: vec![], + order_exprs: vec![], if_not_exists: false, file_compression_type: UNCOMPRESSED, options: HashMap::new(), @@ -665,6 +731,7 @@ mod tests { delimiter: ',', location: "foo.parquet".into(), table_partition_cols: vec![], + order_exprs: vec![], if_not_exists: true, file_compression_type: UNCOMPRESSED, options: HashMap::new(), @@ -687,6 +754,7 @@ mod tests { delimiter: ',', location: "blahblah".into(), table_partition_cols: vec![], + order_exprs: vec![], if_not_exists: false, file_compression_type: UNCOMPRESSED, options: HashMap::from([("k1".into(), "v1".into())]), @@ -704,6 +772,7 @@ mod tests { delimiter: ',', location: "blahblah".into(), table_partition_cols: vec![], + order_exprs: vec![], if_not_exists: false, file_compression_type: UNCOMPRESSED, options: HashMap::from([ @@ -713,11 +782,139 @@ mod tests { }); expect_parse_ok(sql, expected)?; + // Ordered Col + let sqls = vec!["CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV WITH ORDER (c1) LOCATION 'foo.csv'", + "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV WITH ORDER (c1 NULLS FIRST) LOCATION 'foo.csv'", + "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV WITH ORDER (c1 NULLS LAST) LOCATION 'foo.csv'", + "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV WITH ORDER (c1 ASC) LOCATION 'foo.csv'", + "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV WITH ORDER (c1 DESC) LOCATION 'foo.csv'", + "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV WITH ORDER (c1 DESC NULLS FIRST) LOCATION 'foo.csv'", + "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV WITH ORDER (c1 DESC NULLS LAST) LOCATION 'foo.csv'", + "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV WITH ORDER (c1 ASC NULLS FIRST) LOCATION 'foo.csv'", + "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV WITH ORDER (c1 ASC NULLS LAST) LOCATION 'foo.csv'"]; + let expected = vec![ + (None, None), + (None, Some(true)), + (None, Some(false)), + (Some(true), None), + (Some(false), None), + (Some(false), Some(true)), + (Some(false), Some(false)), + (Some(true), Some(true)), + (Some(true), Some(false)), + ]; + for (sql, (asc, nulls_first)) in sqls.iter().zip(expected.into_iter()) { + let expected = Statement::CreateExternalTable(CreateExternalTable { + name: "t".into(), + columns: vec![make_column_def("c1", DataType::Int(None))], + file_type: "CSV".to_string(), + has_header: false, + delimiter: ',', + location: "foo.csv".into(), + table_partition_cols: vec![], + order_exprs: vec![OrderByExpr { + expr: Identifier(Ident { + value: "c1".to_owned(), + quote_style: None, + }), + asc, + nulls_first, + }], + if_not_exists: false, + file_compression_type: UNCOMPRESSED, + options: HashMap::new(), + }); + expect_parse_ok(sql, expected)?; + } + + // Ordered Col + let sql = "CREATE EXTERNAL TABLE t(c1 int, c2 int) STORED AS CSV WITH ORDER (c1 ASC, c2 DESC NULLS FIRST) LOCATION 'foo.csv'"; + let display = None; + let expected = Statement::CreateExternalTable(CreateExternalTable { + name: "t".into(), + columns: vec![ + make_column_def("c1", DataType::Int(display)), + make_column_def("c2", DataType::Int(display)), + ], + file_type: "CSV".to_string(), + has_header: false, + delimiter: ',', + location: "foo.csv".into(), + table_partition_cols: vec![], + order_exprs: vec![ + OrderByExpr { + expr: Identifier(Ident { + value: "c1".to_owned(), + quote_style: None, + }), + asc: Some(true), + nulls_first: None, + }, + OrderByExpr { + expr: Identifier(Ident { + value: "c2".to_owned(), + quote_style: None, + }), + asc: Some(false), + nulls_first: Some(true), + }, + ], + if_not_exists: false, + file_compression_type: UNCOMPRESSED, + options: HashMap::new(), + }); + expect_parse_ok(sql, expected)?; + + // Ordered Binary op + let sql = "CREATE EXTERNAL TABLE t(c1 int, c2 int) STORED AS CSV WITH ORDER (c1 - c2 ASC) LOCATION 'foo.csv'"; + let display = None; + let expected = Statement::CreateExternalTable(CreateExternalTable { + name: "t".into(), + columns: vec![ + make_column_def("c1", DataType::Int(display)), + make_column_def("c2", DataType::Int(display)), + ], + file_type: "CSV".to_string(), + has_header: false, + delimiter: ',', + location: "foo.csv".into(), + table_partition_cols: vec![], + order_exprs: vec![OrderByExpr { + expr: Expr::BinaryOp { + left: Box::new(Identifier(Ident { + value: "c1".to_owned(), + quote_style: None, + })), + op: BinaryOperator::Minus, + right: Box::new(Identifier(Ident { + value: "c2".to_owned(), + quote_style: None, + })), + }, + asc: Some(true), + nulls_first: None, + }], + if_not_exists: false, + file_compression_type: UNCOMPRESSED, + options: HashMap::new(), + }); + expect_parse_ok(sql, expected)?; + // Error cases: partition column does not support type let sql = "CREATE EXTERNAL TABLE t STORED AS x OPTIONS ('k1' 'v1', k2 v2, k3) LOCATION 'blahblah'"; expect_parse_error(sql, "sql parser error: Expected literal string, found: )"); + // Error cases: partition column does not support type + let sql = + "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV WITH ORDER c1 LOCATION 'foo.csv'"; + expect_parse_error(sql, "sql parser error: Expected (, found: c1"); + + // Error cases: partition column does not support type + let sql = + "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV WITH ORDER (c1 LOCATION 'foo.csv'"; + expect_parse_error(sql, "sql parser error: Expected ), found: LOCATION"); + // Error case: `with header` is an invalid syntax let sql = "CREATE EXTERNAL TABLE t STORED AS CSV WITH HEADER LOCATION 'abc'"; expect_parse_error(sql, "sql parser error: Expected LOCATION, found: WITH"); diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index bf418d0b4832..87c090e3eeb7 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -21,13 +21,14 @@ use std::sync::Arc; use std::vec; use arrow_schema::*; +use datafusion_common::field_not_found; use sqlparser::ast::ExactNumberInfo; use sqlparser::ast::TimezoneInfo; use sqlparser::ast::{ColumnDef as SQLColumnDef, ColumnOption}; use sqlparser::ast::{DataType as SQLDataType, Ident, ObjectName, TableAlias}; use datafusion_common::config::ConfigOptions; -use datafusion_common::{field_not_found, DFSchema, DataFusionError, Result}; +use datafusion_common::{unqualified_field_not_found, DFSchema, DataFusionError, Result}; use datafusion_common::{OwnedTableReference, TableReference}; use datafusion_expr::logical_plan::{LogicalPlan, LogicalPlanBuilder}; use datafusion_expr::utils::find_column_exprs; @@ -69,13 +70,18 @@ impl Default for ParserOptions { } #[derive(Debug, Clone)] -/// Struct to store Common Table Expression (CTE) provided with WITH clause and -/// Parameter Data Types provided with PREPARE statement +/// Struct to store the states used by the Planner. The Planner will leverage the states to resolve +/// CTEs, Views, subqueries and PREPARE statements. The states include +/// Common Table Expression (CTE) provided with WITH clause and +/// Parameter Data Types provided with PREPARE statement and the query schema of the +/// outer query plan pub struct PlannerContext { /// Data type provided with prepare statement pub prepare_param_data_types: Vec, /// Map of CTE name to logical plan of the WITH clause pub ctes: HashMap, + /// The query schema of the outer query plan, used to resolve the columns in subquery + pub outer_query_schema: Option, } impl Default for PlannerContext { @@ -90,6 +96,7 @@ impl PlannerContext { Self { prepare_param_data_types: vec![], ctes: HashMap::new(), + outer_query_schema: None, } } @@ -100,6 +107,7 @@ impl PlannerContext { Self { prepare_param_data_types, ctes: HashMap::new(), + outer_query_schema: None, } } } @@ -201,16 +209,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { if !schema.fields_with_unqualified_name(&col.name).is_empty() { Ok(()) } else { - Err(field_not_found(None, col.name.as_str(), schema)) + Err(unqualified_field_not_found(col.name.as_str(), schema)) } } } .map_err(|_: DataFusionError| { - field_not_found( - col.relation.as_ref().map(|s| s.to_owned()), - col.name.as_str(), - schema, - ) + field_not_found(col.relation.clone(), col.name.as_str(), schema) }), _ => Err(DataFusionError::Internal("Not a column".to_string())), }) @@ -377,22 +381,18 @@ pub(crate) fn idents_to_table_reference( match taker.0.len() { 1 => { let table = taker.take(enable_normalization); - Ok(OwnedTableReference::Bare { table }) + Ok(OwnedTableReference::bare(table)) } 2 => { let table = taker.take(enable_normalization); let schema = taker.take(enable_normalization); - Ok(OwnedTableReference::Partial { schema, table }) + Ok(OwnedTableReference::partial(schema, table)) } 3 => { let table = taker.take(enable_normalization); let schema = taker.take(enable_normalization); let catalog = taker.take(enable_normalization); - Ok(OwnedTableReference::Full { - catalog, - schema, - table, - }) + Ok(OwnedTableReference::full(catalog, schema, table)) } _ => Err(DataFusionError::Plan(format!( "Unsupported compound identifier '{:?}'", diff --git a/datafusion/sql/src/query.rs b/datafusion/sql/src/query.rs index eb7ece87d6b1..df64888f5894 100644 --- a/datafusion/sql/src/query.rs +++ b/datafusion/sql/src/query.rs @@ -17,7 +17,7 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use crate::utils::normalize_ident; -use datafusion_common::{DFSchema, DataFusionError, Result, ScalarValue}; +use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_expr::{Expr, LogicalPlan, LogicalPlanBuilder}; use sqlparser::ast::{Expr as SQLExpr, Offset as SQLOffset, OrderByExpr, Query}; @@ -30,17 +30,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { query: Query, planner_context: &mut PlannerContext, ) -> Result { - self.query_to_plan_with_schema(query, planner_context, None) - } - - /// Generate a logical plan from a SQL subquery - pub(crate) fn subquery_to_plan( - &self, - query: Query, - planner_context: &mut PlannerContext, - outer_query_schema: &DFSchema, - ) -> Result { - self.query_to_plan_with_schema(query, planner_context, Some(outer_query_schema)) + self.query_to_plan_with_schema(query, planner_context) } /// Generate a logic plan from an SQL query. @@ -50,7 +40,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { &self, query: Query, planner_context: &mut PlannerContext, - outer_query_schema: Option<&DFSchema>, ) -> Result { let set_expr = query.body; if let Some(with) = query.with { @@ -82,9 +71,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { planner_context.ctes.insert(cte_name, logical_plan); } } - let plan = - self.set_expr_to_plan(*set_expr, planner_context, outer_query_schema)?; - let plan = self.order_by(plan, query.order_by)?; + let plan = self.set_expr_to_plan(*set_expr, planner_context)?; + let plan = self.order_by(plan, query.order_by, planner_context)?; self.limit(plan, query.offset, query.limit) } @@ -145,6 +133,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { &self, plan: LogicalPlan, order_by: Vec, + planner_context: &mut PlannerContext, ) -> Result { if order_by.is_empty() { return Ok(plan); @@ -152,7 +141,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let order_by_rex = order_by .into_iter() - .map(|e| self.order_by_to_sort_expr(e, plan.schema())) + .map(|e| self.order_by_to_sort_expr(e, plan.schema(), planner_context)) .collect::>>()?; LogicalPlanBuilder::from(plan).sort(order_by_rex)?.build() diff --git a/datafusion/sql/src/relation/mod.rs b/datafusion/sql/src/relation/mod.rs index 12f5bd6b49fc..1b369fedf329 100644 --- a/datafusion/sql/src/relation/mod.rs +++ b/datafusion/sql/src/relation/mod.rs @@ -37,11 +37,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ( match ( cte, - self.schema_provider.get_table_provider((&table_ref).into()), + self.schema_provider.get_table_provider(table_ref.clone()), ) { (Some(cte_plan), _) => Ok(cte_plan.clone()), (_, Ok(provider)) => { - LogicalPlanBuilder::scan(&table_name, provider, None)?.build() + LogicalPlanBuilder::scan(table_ref, provider, None)?.build() } (None, Err(e)) => Err(e), }?, diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index bdf9a7c302b1..cde4946571e8 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -20,7 +20,7 @@ use crate::utils::{ check_columns_satisfy_exprs, extract_aliases, normalize_ident, rebase_expr, resolve_aliases_to_exprs, resolve_columns, resolve_positions_to_exprs, }; -use datafusion_common::{DFSchema, DataFusionError, Result}; +use datafusion_common::{DataFusionError, Result}; use datafusion_expr::expr_rewriter::{ normalize_col, normalize_col_with_schemas_and_ambiguity_check, }; @@ -44,7 +44,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { &self, select: Select, planner_context: &mut PlannerContext, - outer_query_schema: Option<&DFSchema>, ) -> Result { // check for unsupported syntax first if !select.cluster_by.is_empty() { @@ -68,12 +67,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let empty_from = matches!(plan, LogicalPlan::EmptyRelation(_)); // process `where` clause - let plan = self.plan_selection( - select.selection, - plan, - outer_query_schema, - planner_context, - )?; + let plan = self.plan_selection(select.selection, plan, planner_context)?; // process the SELECT expressions, with wildcards expanded. let select_exprs = self.prepare_select_exprs( @@ -234,28 +228,26 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { &self, selection: Option, plan: LogicalPlan, - outer_query_schema: Option<&DFSchema>, planner_context: &mut PlannerContext, ) -> Result { match selection { Some(predicate_expr) => { - let mut join_schema = (**plan.schema()).clone(); - let fallback_schemas = plan.fallback_normalize_schemas(); - let outer_query_schema = if let Some(outer) = outer_query_schema { - join_schema.merge(outer); - vec![outer] - } else { - vec![] - }; + let outer_query_schema = planner_context.outer_query_schema.clone(); + let outer_query_schema_vec = + if let Some(outer) = outer_query_schema.as_ref() { + vec![outer] + } else { + vec![] + }; let filter_expr = - self.sql_to_expr(predicate_expr, &join_schema, planner_context)?; + self.sql_to_expr(predicate_expr, plan.schema(), planner_context)?; let mut using_columns = HashSet::new(); expr_to_columns(&filter_expr, &mut using_columns)?; let filter_expr = normalize_col_with_schemas_and_ambiguity_check( filter_expr, - &[&[plan.schema()], &fallback_schemas, &outer_query_schema], + &[&[plan.schema()], &fallback_schemas, &outer_query_schema_vec], &[using_columns], )?; diff --git a/datafusion/sql/src/set_expr.rs b/datafusion/sql/src/set_expr.rs index d870501769ee..b615f0444268 100644 --- a/datafusion/sql/src/set_expr.rs +++ b/datafusion/sql/src/set_expr.rs @@ -16,7 +16,7 @@ // under the License. use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; -use datafusion_common::{DFSchema, DataFusionError, Result}; +use datafusion_common::{DataFusionError, Result}; use datafusion_expr::{LogicalPlan, LogicalPlanBuilder}; use sqlparser::ast::{SetExpr, SetOperator, SetQuantifier}; @@ -25,12 +25,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { &self, set_expr: SetExpr, planner_context: &mut PlannerContext, - outer_query_schema: Option<&DFSchema>, ) -> Result { match set_expr { - SetExpr::Select(s) => { - self.select_to_plan(*s, planner_context, outer_query_schema) - } + SetExpr::Select(s) => self.select_to_plan(*s, planner_context), SetExpr::Values(v) => { self.sql_values_to_plan(v, &planner_context.prepare_param_data_types) } @@ -45,10 +42,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SetQuantifier::Distinct | SetQuantifier::None => false, }; - let left_plan = - self.set_expr_to_plan(*left, planner_context, outer_query_schema)?; - let right_plan = - self.set_expr_to_plan(*right, planner_context, outer_query_schema)?; + let left_plan = self.set_expr_to_plan(*left, planner_context)?; + let right_plan = self.set_expr_to_plan(*right, planner_context)?; match (op, all) { (SetOperator::Union, true) => LogicalPlanBuilder::from(left_plan) .union(right_plan)? diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index fa0eed1b9099..e32443e01a19 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -40,8 +40,8 @@ use datafusion_expr::{ }; use sqlparser::ast; use sqlparser::ast::{ - Assignment, Expr as SQLExpr, Expr, Ident, ObjectName, ObjectType, Query, SchemaName, - SetExpr, ShowCreateObject, ShowStatementFilter, Statement, TableFactor, + Assignment, Expr as SQLExpr, Expr, Ident, ObjectName, ObjectType, OrderByExpr, Query, + SchemaName, SetExpr, ShowCreateObject, ShowStatementFilter, Statement, TableFactor, TableWithJoins, UnaryOperator, Value, }; use sqlparser::parser::ParserError::ParserError; @@ -413,9 +413,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let DescribeTableStmt { table_name } = statement; let table_ref = self.object_name_to_table_reference(table_name)?; - let table_source = self - .schema_provider - .get_table_provider((&table_ref).into())?; + let table_source = self.schema_provider.get_table_provider(table_ref)?; let schema = table_source.schema(); @@ -425,6 +423,41 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { })) } + fn build_order_by( + &self, + order_exprs: Vec, + schema: &DFSchemaRef, + planner_context: &mut PlannerContext, + ) -> Result> { + // Ask user to provide a schema if schema is empty. + if !order_exprs.is_empty() && schema.fields().is_empty() { + return Err(DataFusionError::Plan( + "Provide a schema before specifying the order while creating a table." + .to_owned(), + )); + } + // Convert each OrderByExpr to a SortExpr: + let result = order_exprs + .into_iter() + .map(|e| self.order_by_to_sort_expr(e, schema, planner_context)) + .collect::>>()?; + // Verify that columns of all SortExprs exist in the schema: + for expr in result.iter() { + for column in expr.to_columns()?.iter() { + if !schema.has_column(column) { + // Return an error if any column is not in the schema: + return Err(DataFusionError::Plan(format!( + "Column {} is not in schema", + column + ))); + } + } + } + + // If all SortExprs are valid, return them as an expression vector + Ok(result) + } + /// Generate a logical plan from a CREATE EXTERNAL TABLE statement fn external_table_to_plan( &self, @@ -441,6 +474,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { table_partition_cols, if_not_exists, file_compression_type, + order_exprs, options, } = statement; @@ -461,12 +495,16 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } let schema = self.build_schema(columns)?; + let df_schema = schema.to_dfschema_ref()?; + + let ordered_exprs = + self.build_order_by(order_exprs, &df_schema, &mut PlannerContext::new())?; // External tables do not support schemas at the moment, so the name is just a table name - let name = OwnedTableReference::Bare { table: name }; + let name = OwnedTableReference::bare(name); Ok(LogicalPlan::CreateExternalTable(PlanCreateExternalTable { - schema: schema.to_dfschema_ref()?, + schema: df_schema, name, location, file_type, @@ -476,6 +514,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { if_not_exists, definition, file_compression_type, + order_exprs: ordered_exprs, options, })) } @@ -634,9 +673,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Do a table lookup to verify the table exists let table_ref = self.object_name_to_table_reference(table_name.clone())?; - let provider = self - .schema_provider - .get_table_provider((&table_ref).into())?; + let provider = self.schema_provider.get_table_provider(table_ref.clone())?; let schema = (*provider.schema()).clone(); let schema = DFSchema::try_from(schema)?; let scan = @@ -688,7 +725,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let table_name = self.object_name_to_table_reference(table_name)?; let provider = self .schema_provider - .get_table_provider((&table_name).into())?; + .get_table_provider(table_name.clone())?; let arrow_schema = (*provider.schema()).clone(); let table_schema = Arc::new(DFSchema::try_from(arrow_schema)?); let values = table_schema.fields().iter().map(|f| { @@ -790,7 +827,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let table_name = self.object_name_to_table_reference(table_name)?; let provider = self .schema_provider - .get_table_provider((&table_name).into())?; + .get_table_provider(table_name.clone())?; let arrow_schema = (*provider.schema()).clone(); let table_schema = DFSchema::try_from(arrow_schema)?; @@ -896,9 +933,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Do a table lookup to verify the table exists let table_ref = self.object_name_to_table_reference(sql_table_name)?; - let _ = self - .schema_provider - .get_table_provider((&table_ref).into())?; + let _ = self.schema_provider.get_table_provider(table_ref)?; // treat both FULL and EXTENDED as the same let select_list = if full || extended { @@ -934,9 +969,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Do a table lookup to verify the table exists let table_ref = self.object_name_to_table_reference(sql_table_name)?; - let _ = self - .schema_provider - .get_table_provider((&table_ref).into())?; + let _ = self.schema_provider.get_table_provider(table_ref)?; let query = format!( "SELECT table_catalog, table_schema, table_name, definition FROM information_schema.views WHERE {where_clause}" diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index a7f7a3dd692a..91cef6d4712e 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -361,6 +361,7 @@ where *nulls_first, ))), Expr::Column { .. } + | Expr::OuterReferenceColumn(_, _) | Expr::Literal(_) | Expr::ScalarVariable(_, _) | Expr::Exists { .. } diff --git a/datafusion/sql/tests/integration_test.rs b/datafusion/sql/tests/integration_test.rs index 124ab0c6988d..a77c1f1f9669 100644 --- a/datafusion/sql/tests/integration_test.rs +++ b/datafusion/sql/tests/integration_test.rs @@ -226,11 +226,11 @@ Dml: op=[Insert] table=[test_decimal] #[rstest] #[case::duplicate_columns( "INSERT INTO test_decimal (id, price, price) VALUES (1, 2, 3), (4, 5, 6)", - "Schema error: Schema contains duplicate unqualified field name 'price'" + "Schema error: Schema contains duplicate unqualified field name \"price\"" )] #[case::non_existing_column( "INSERT INTO test_decimal (nonexistent, price) VALUES (1, 2), (4, 5)", - "Schema error: No field named 'nonexistent'. Valid fields are 'id', 'price'." + "Schema error: No field named \"nonexistent\". Valid fields are \"id\", \"price\"." )] #[case::type_mismatch( "INSERT INTO test_decimal SELECT '2022-01-01', to_timestamp('2022-01-01T12:00:00')", @@ -515,7 +515,7 @@ fn select_with_ambiguous_column() { let sql = "SELECT id FROM person a, person b"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - "SchemaError(AmbiguousReference { qualifier: None, name: \"id\" })", + "SchemaError(AmbiguousReference { field: Column { relation: None, name: \"id\" } })", format!("{err:?}") ); } @@ -538,7 +538,7 @@ fn where_selection_with_ambiguous_column() { let sql = "SELECT * FROM person a, person b WHERE id = id + 1"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - "SchemaError(AmbiguousReference { qualifier: None, name: \"id\" })", + "SchemaError(AmbiguousReference { field: Column { relation: None, name: \"id\" } })", format!("{err:?}") ); } @@ -1121,9 +1121,9 @@ fn select_simple_aggregate_with_groupby_column_unselected() { fn select_simple_aggregate_with_groupby_and_column_in_group_by_does_not_exist() { let sql = "SELECT SUM(age) FROM person GROUP BY doesnotexist"; let err = logical_plan(sql).expect_err("query should have failed"); - assert_eq!("Schema error: No field named 'doesnotexist'. Valid fields are 'SUM(person.age)', \ - 'person'.'id', 'person'.'first_name', 'person'.'last_name', 'person'.'age', 'person'.'state', \ - 'person'.'salary', 'person'.'birth_date', 'person'.'😀'.", format!("{err}")); + assert_eq!("Schema error: No field named \"doesnotexist\". Valid fields are \"SUM(person.age)\", \ + \"person\".\"id\", \"person\".\"first_name\", \"person\".\"last_name\", \"person\".\"age\", \"person\".\"state\", \ + \"person\".\"salary\", \"person\".\"birth_date\", \"person\".\"😀\".", format!("{err}")); } #[test] @@ -1432,6 +1432,17 @@ fn select_where_with_positive_operator() { quick_test(sql, expected); } +#[test] +fn select_where_compound_identifiers() { + let sql = "SELECT aggregate_test_100.c3 \ + FROM public.aggregate_test_100 \ + WHERE aggregate_test_100.c3 > 0.1"; + let expected = "Projection: public.aggregate_test_100.c3\ + \n Filter: public.aggregate_test_100.c3 > Float64(0.1)\ + \n TableScan: public.aggregate_test_100"; + quick_test(sql, expected); +} + #[test] fn select_order_by_index() { let sql = "SELECT id FROM person ORDER BY 1"; @@ -2635,7 +2646,7 @@ fn exists_subquery() { \n Filter: EXISTS ()\ \n Subquery:\ \n Projection: person.first_name\ - \n Filter: person.last_name = p.last_name AND person.state = p.state\ + \n Filter: person.last_name = outer_ref(p.last_name) AND person.state = outer_ref(p.state)\ \n TableScan: person\ \n SubqueryAlias: p\ \n TableScan: person"; @@ -2656,7 +2667,7 @@ fn exists_subquery_schema_outer_schema_overlap() { \n Filter: person.id = p.id AND EXISTS ()\ \n Subquery:\ \n Projection: person.first_name\ - \n Filter: person.id = p2.id AND person.last_name = p.last_name AND person.state = p.state\ + \n Filter: person.id = p2.id AND person.last_name = outer_ref(p.last_name) AND person.state = outer_ref(p.state)\ \n CrossJoin:\ \n TableScan: person\ \n SubqueryAlias: p2\ @@ -2679,7 +2690,7 @@ fn exists_subquery_wildcard() { \n Filter: EXISTS ()\ \n Subquery:\ \n Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀\ - \n Filter: person.last_name = p.last_name AND person.state = p.state\ + \n Filter: person.last_name = outer_ref(p.last_name) AND person.state = outer_ref(p.state)\ \n TableScan: person\ \n SubqueryAlias: p\ \n TableScan: person"; @@ -2710,7 +2721,7 @@ fn not_in_subquery_correlated() { \n Filter: p.id NOT IN ()\ \n Subquery:\ \n Projection: person.id\ - \n Filter: person.last_name = p.last_name AND person.state = Utf8(\"CO\")\ + \n Filter: person.last_name = outer_ref(p.last_name) AND person.state = Utf8(\"CO\")\ \n TableScan: person\ \n SubqueryAlias: p\ \n TableScan: person"; @@ -2725,7 +2736,7 @@ fn scalar_subquery() { \n Subquery:\ \n Projection: MAX(person.id)\ \n Aggregate: groupBy=[[]], aggr=[[MAX(person.id)]]\ - \n Filter: person.last_name = p.last_name\ + \n Filter: person.last_name = outer_ref(p.last_name)\ \n TableScan: person\ \n SubqueryAlias: p\ \n TableScan: person"; @@ -2747,7 +2758,7 @@ fn scalar_subquery_reference_outer_field() { \n Subquery:\ \n Projection: COUNT(UInt8(1))\ \n Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\ - \n Filter: j2.j2_id = j1.j1_id AND j1.j1_id = j3.j3_id\ + \n Filter: outer_ref(j2.j2_id) = j1.j1_id AND j1.j1_id = j3.j3_id\ \n CrossJoin:\ \n TableScan: j1\ \n TableScan: j3\ @@ -2768,7 +2779,7 @@ fn subquery_references_cte() { \n Filter: EXISTS ()\ \n Subquery:\ \n Projection: cte.id, cte.first_name, cte.last_name, cte.age, cte.state, cte.salary, cte.birth_date, cte.😀\ - \n Filter: cte.id = person.id\ + \n Filter: cte.id = outer_ref(person.id)\ \n SubqueryAlias: cte\ \n Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀\ \n TableScan: person\ @@ -2964,7 +2975,7 @@ fn order_by_unaliased_name() { #[test] fn order_by_ambiguous_name() { let sql = "select * from person a join person b using (id) order by age"; - let expected = "Schema error: Ambiguous reference to unqualified field 'age'"; + let expected = "Schema error: Ambiguous reference to unqualified field \"age\""; let err = logical_plan(sql).unwrap_err(); assert_eq!(err.to_string(), expected); @@ -2973,7 +2984,7 @@ fn order_by_ambiguous_name() { #[test] fn group_by_ambiguous_name() { let sql = "select max(id) from person a join person b using (id) group by age"; - let expected = "Schema error: Ambiguous reference to unqualified field 'age'"; + let expected = "Schema error: Ambiguous reference to unqualified field \"age\""; let err = logical_plan(sql).unwrap_err(); assert_eq!(err.to_string(), expected); @@ -3241,6 +3252,17 @@ fn test_select_distinct_order_by() { assert_eq!(err.to_string(), expected); } +#[test] +fn select_order_by_with_cast() { + let sql = + "SELECT first_name AS first_name FROM (SELECT first_name AS first_name FROM person) ORDER BY CAST(first_name as INT)"; + let expected = "Sort: CAST(first_name AS first_name AS Int32) ASC NULLS LAST\ + \n Projection: first_name AS first_name\ + \n Projection: person.first_name AS first_name\ + \n TableScan: person"; + quick_test(sql, expected); +} + #[test] fn test_duplicated_left_join_key_inner_join() { // person.id * 2 happen twice in left side. @@ -3278,7 +3300,7 @@ fn test_ambiguous_column_references_in_on_join() { INNER JOIN person as p2 ON id = 1"; - let expected = "Schema error: Ambiguous reference to unqualified field 'id'"; + let expected = "Schema error: Ambiguous reference to unqualified field \"id\""; // It should return error. let result = logical_plan(sql); @@ -3875,7 +3897,7 @@ fn assert_field_not_found(err: DataFusionError, name: &str) { match err { DataFusionError::SchemaError { .. } => { let msg = format!("{err}"); - let expected = format!("Schema error: No field named '{name}'."); + let expected = format!("Schema error: No field named \"{name}\"."); if !msg.starts_with(&expected) { panic!("error [{msg}] did not start with [{expected}]"); } diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml index aba350c985e4..b132f280ddc1 100644 --- a/datafusion/substrait/Cargo.toml +++ b/datafusion/substrait/Cargo.toml @@ -34,7 +34,7 @@ datafusion = { version = "20.0.0", path = "../core" } itertools = "0.10.5" object_store = "0.5.4" prost = "0.11" -substrait = "0.4" +substrait = "0.5" tokio = "1.17" [build-dependencies] diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index cf4003c1c80e..f891cf6cc9f1 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -124,7 +124,7 @@ pub fn to_substrait_rel( }), advanced_extension: None, read_type: Some(ReadType::NamedTable(NamedTable { - names: vec![scan.table_name.clone()], + names: vec![scan.table_name.to_string()], advanced_extension: None, })), }))), diff --git a/docs/source/user-guide/example-usage.md b/docs/source/user-guide/example-usage.md index c497c66e597b..a2cd109a61ef 100644 --- a/docs/source/user-guide/example-usage.md +++ b/docs/source/user-guide/example-usage.md @@ -118,9 +118,12 @@ async fn main() -> datafusion::error::Result<()> { let ctx = SessionContext::new(); let df = ctx.read_csv("tests/data/capitalized_example.csv", CsvReadOptions::new()).await?; - let df = df.filter(col("A").lt_eq(col("c")))? - .aggregate(vec![col("A")], vec![min(col("b"))])? - .limit(0, Some(100))?; + let df = df + // col will parse the input string, hence requiring double quotes to maintain the capitalization + .filter(col("\"A\"").lt_eq(col("c")))? + // alternatively use ident to pass in an unqualified column name directly without parsing + .aggregate(vec![ident("A")], vec![min(col("b"))])? + .limit(0, Some(100))?; // execute and print results df.show().await?; diff --git a/docs/source/user-guide/sql/ddl.md b/docs/source/user-guide/sql/ddl.md index c531312b1e58..29a156bd01b1 100644 --- a/docs/source/user-guide/sql/ddl.md +++ b/docs/source/user-guide/sql/ddl.md @@ -63,6 +63,47 @@ WITH HEADER ROW LOCATION '/path/to/aggregate_test_100.csv'; ``` +When creating an output from a data source that is already ordered by an expression, you can pre-specify the order of +the data using the `WITH ORDER` clause. This applies even if the expression used for sorting is complex, +allowing for greater flexibility. + +Here's an example of how to use `WITH ORDER` query + +```sql +CREATE EXTERNAL TABLE test ( + c1 VARCHAR NOT NULL, + c2 INT NOT NULL, + c3 SMALLINT NOT NULL, + c4 SMALLINT NOT NULL, + c5 INT NOT NULL, + c6 BIGINT NOT NULL, + c7 SMALLINT NOT NULL, + c8 INT NOT NULL, + c9 BIGINT NOT NULL, + c10 VARCHAR NOT NULL, + c11 FLOAT NOT NULL, + c12 DOUBLE NOT NULL, + c13 VARCHAR NOT NULL +) +STORED AS CSV +WITH HEADER ROW +WITH ORDER (c2 ASC, c5 + c8 DESC NULL FIRST) +LOCATION '/path/to/aggregate_test_100.csv'; +``` + +where `WITH ORDER` clause specifies the sort order: + +```sql +WITH ORDER (sort_expression1 [ASC | DESC] [NULLS { FIRST | LAST }] + [, sort_expression2 [ASC | DESC] [NULLS { FIRST | LAST }] ...]) +``` + +#### Cautions When Using the WITH ORDER Clause + +- It's important to understand that using the `WITH ORDER` clause in the `CREATE EXTERNAL TABLE` statement only specifies the order in which the data should be read from the external file. If the data in the file is not already sorted according to the specified order, then the results may not be correct. + +- It's also important to note that the `WITH ORDER` clause does not affect the ordering of the data in the original external file. + If data sources are already partitioned in Hive style, `PARTITIONED BY` can be used for partition pruning. ```