Skip to content

Commit

Permalink
Improvements to sql translation
Browse files Browse the repository at this point in the history
  • Loading branch information
magbak committed Sep 2, 2024
1 parent 9f297c1 commit ba83658
Show file tree
Hide file tree
Showing 12 changed files with 157 additions and 66 deletions.
26 changes: 18 additions & 8 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 5 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ members = [
#pydf_io = { path = "../maplib/lib/pydf_io"}
#representation = { path = "../maplib/lib/representation", features = ["rdf-star"]}
#templates = { path = "../maplib/lib/templates"}
spargebra = { git = "https://github.com/DataTreehouse/maplib", rev="3bf75ac20a71c9afeab07a7b4e0196fe51e43c61", features = ["rdf-star"]}
query_processing = { git = "https://github.com/DataTreehouse/maplib", rev="3bf75ac20a71c9afeab07a7b4e0196fe51e43c61" }
pydf_io = { git = "https://github.com/DataTreehouse/maplib", rev="3bf75ac20a71c9afeab07a7b4e0196fe51e43c61" }
representation = { git = "https://github.com/DataTreehouse/maplib", rev="3bf75ac20a71c9afeab07a7b4e0196fe51e43c61", features = ["rdf-star"] }
templates = { git = "https://github.com/DataTreehouse/maplib", rev="3bf75ac20a71c9afeab07a7b4e0196fe51e43c61" }
spargebra = { git = "https://github.com/DataTreehouse/maplib", rev="07dbea46a9fed5db3eb71475996e5e1fcfec3247", features = ["rdf-star"]}
query_processing = { git = "https://github.com/DataTreehouse/maplib", rev="07dbea46a9fed5db3eb71475996e5e1fcfec3247" }
pydf_io = { git = "https://github.com/DataTreehouse/maplib", rev="07dbea46a9fed5db3eb71475996e5e1fcfec3247" }
representation = { git = "https://github.com/DataTreehouse/maplib", rev="07dbea46a9fed5db3eb71475996e5e1fcfec3247", features = ["rdf-star"] }
templates = { git = "https://github.com/DataTreehouse/maplib", rev="07dbea46a9fed5db3eb71475996e5e1fcfec3247" }


sparesults = { version = "0.2.0-alpha.5", features = ["rdf-star"] }
Expand Down
2 changes: 2 additions & 0 deletions lib/chrontext/src/constants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,7 @@ pub const HAS_DATA_POINT: &str = "https://github.com/DataTreehouse/chrontext#has
pub const HAS_VALUE: &str = "https://github.com/DataTreehouse/chrontext#hasValue";
pub const HAS_RESOURCE: &str = "https://github.com/DataTreehouse/chrontext#hasResource";
pub const HAS_EXTERNAL_ID: &str = "https://github.com/DataTreehouse/chrontext#hasExternalId";

pub const DATE_BIN:&str = "https://github.com/DataTreehouse/chrontext#dateBin";
pub const NEST: &str = "https://github.com/DataTreehouse/chrontext#nestAggregation";
pub const GROUPING_COL: &str = "grouping_col";
6 changes: 4 additions & 2 deletions lib/chrontext/src/engine.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::combiner::Combiner;
use crate::errors::ChrontextError;
use crate::preprocessing::Preprocessor;
use crate::rename_vars::rename_query_vars;
use crate::rewriting::StaticQueryRewriter;
use crate::sparql_database::sparql_embedded_oxigraph::{EmbeddedOxigraph, EmbeddedOxigraphConfig};
use crate::sparql_database::sparql_endpoint::SparqlEndpoint;
Expand All @@ -16,7 +17,6 @@ use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use virtualization::{Virtualization, VirtualizedDatabase};
use virtualized_query::pushdown_setting::PushdownSetting;
use crate::rename_vars::rename_query_vars;

pub struct EngineConfig {
pub sparql_endpoint: Option<String>,
Expand Down Expand Up @@ -120,7 +120,9 @@ impl Engine {
.map_err(|x| ChrontextError::CombinerError(x))?;
for (original, renamed) in rename_map {
if let Some(dt) = solution_mappings.rdf_node_types.remove(&renamed) {
solution_mappings.mappings = solution_mappings.mappings.rename(&[renamed], &[original.clone()]);
solution_mappings.mappings = solution_mappings
.mappings
.rename(&[renamed], &[original.clone()]);
solution_mappings.rdf_node_types.insert(original, dt);
}
}
Expand Down
2 changes: 1 addition & 1 deletion lib/chrontext/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ pub mod engine;
pub mod errors;
mod preparing;
pub mod preprocessing;
mod rename_vars;
pub mod rewriting;
pub mod sparql_database;
mod sparql_result_to_polars;
pub mod splitter;
mod rename_vars;
20 changes: 8 additions & 12 deletions lib/chrontext/src/preparing/graph_patterns/project_pattern.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,13 @@ impl TimeseriesQueryPrepper {
solution_mappings: &mut SolutionMappings,
context: &Context,
) -> GPPrepReturn {
if try_groupby_complex_query {
debug!("Encountered graph inside project, not supported for complex groupby pushdown");
return GPPrepReturn::fail_groupby_complex_query();
} else {
let inner_rewrite = self.prepare_graph_pattern(
inner,
try_groupby_complex_query,
solution_mappings,
&context.extension_with(PathEntry::ProjectInner),
);
inner_rewrite
}
let inner_context = context.extension_with(PathEntry::ProjectInner);
let mut inner_rewrite = self.prepare_graph_pattern(
inner,
try_groupby_complex_query,
solution_mappings,
&inner_context,
);
inner_rewrite
}
}
2 changes: 1 addition & 1 deletion lib/chrontext/src/rename_vars.rs
Original file line number Diff line number Diff line change
Expand Up @@ -323,5 +323,5 @@ fn create_new_variable_name() -> String {
}

fn should_rename_var(v: &Variable) -> bool {
!v.as_str().starts_with(|x:char| x.is_alphabetic())
!v.as_str().starts_with(|x: char| x.is_alphabetic())
}
103 changes: 85 additions & 18 deletions lib/virtualization/src/python/sql_translation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,22 @@ pub const PYTHON_CODE: &str = r#"
from datetime import datetime
from typing import Dict, Literal, Any, List, Union
import sqlalchemy.types as types
from sqlalchemy.dialects import postgresql
from sqlalchemy.sql.base import ColumnCollection
from sqlalchemy.sql.functions import GenericFunction
from sqlalchemy_bigquery.base import BigQueryDialect
from databricks.sqlalchemy import DatabricksDialect
from chrontext.vq import Expression, VirtualizedQuery, AggregateExpression
from chrontext.vq import Expression, VirtualizedQuery, AggregateExpression, XSDDuration
from sqlalchemy import ColumnElement, Column, Table, MetaData, Select, select, literal, DateTime, values, cast, \
BigInteger, CompoundSelect, and_, literal_column, case, func, TIMESTAMP
BigInteger, CompoundSelect, and_, literal_column, case, func, TIMESTAMP, text
XSD = "http://www.w3.org/2001/XMLSchema#"
XSD_INTEGER = "<http://www.w3.org/2001/XMLSchema#integer>"
FLOOR_DATE_TIME_TO_SECONDS_INTERVAL = "<https://github.com/DataTreehouse/chrontext#FloorDateTimeToSecondsInterval>"
XSD_INTEGER = "http://www.w3.org/2001/XMLSchema#integer"
XSD_DURATION = "http://www.w3.org/2001/XMLSchema#duration"
FLOOR_DATE_TIME_TO_SECONDS_INTERVAL = "https://github.com/DataTreehouse/chrontext#FloorDateTimeToSecondsInterval"
DATE_BIN = "https://github.com/DataTreehouse/chrontext#dateBin"
import warnings
Expand Down Expand Up @@ -102,11 +104,11 @@ class SPARQLMapper:
)
if self.dialect == "postgres" or self.dialect == "databricks":
values_sub = values(
Column("id"), Column(query.grouping_column_name),
name=self.inner_name()
).data(
query.id_grouping_tuples
)
Column("id"), Column(query.grouping_column_name),
name=self.inner_name()
).data(
query.id_grouping_tuples
)
table = values_sub.join(
table,
onclause=and_(
Expand Down Expand Up @@ -220,6 +222,8 @@ class SPARQLMapper:
return func.avg(sql_expression)
case "SUM":
return func.sum(sql_expression)
case "COUNT":
return func.count(sql_expression)
case "GROUP_CONCAT":
if aggregate_expression.separator is not None:
return func.aggregate_strings(sql_expression,
Expand All @@ -234,7 +238,7 @@ class SPARQLMapper:
def expression_to_sql(
self,
expression: Expression,
columns: ColumnCollection[str, ColumnElement]
columns: ColumnCollection[str, ColumnElement],
) -> Column | ColumnElement | int | float | bool | str:
expression_type = expression.expression_type()
match expression_type:
Expand Down Expand Up @@ -299,6 +303,9 @@ class SPARQLMapper:
if type(native) == datetime:
if native.tzinfo is not None:
return literal(native, TIMESTAMP)
elif expression.literal.datatype.iri == XSD_DURATION:
if self.dialect == "bigquery":
return bigquery_duration_literal(native)
return literal(native)
case "FunctionCall":
sql_args = []
Expand Down Expand Up @@ -337,7 +344,7 @@ class SPARQLMapper:
case "MONTH":
return func.extract("MONTH", sql_args[0])
case "YEAR":
return func.extract("YEAR", sql_args[0])
return func.extract("YEAR", sql_args[0])
case "FLOOR":
return func.floor(sql_args[0])
case "CEILING":
Expand All @@ -348,22 +355,29 @@ class SPARQLMapper:
elif IRI == FLOOR_DATE_TIME_TO_SECONDS_INTERVAL:
if self.dialect == "postgres":
return func.to_timestamp(
func.extract("EPOCH", sql_args[0]) - func.mod(
func.extract("EPOCH", sql_args[0]),
sql_args[1])
)
func.extract("EPOCH", sql_args[0]) - func.mod(
func.extract("EPOCH", sql_args[0]),
sql_args[1])
)
elif self.dialect == "databricks":
return func.TIMESTAMP_SECONDS(
func.UNIX_TIMESTAMP(sql_args[0]) - func.mod(
func.UNIX_TIMESTAMP(sql_args[0]),
sql_args[1])
func.UNIX_TIMESTAMP(sql_args[0]),
sql_args[1])
)
elif self.dialect == "bigquery":
return func.TIMESTAMP_SECONDS(
func.UNIX_SECONDS(sql_args[0]) - func.mod(
func.UNIX_SECONDS(sql_args[0]),
sql_args[1])
)
elif IRI == DATE_BIN:
if self.dialect == "bigquery":
# https://cloud.google.com/bigquery/docs/reference/standard-sql/time-series-functions#timestamp_bucket
# Duration is second arg here
print(sql_args[0])
return func.TIMESTAMP_BUCKET(sql_args[1], sql_args[0], sql_args[2])
print("Unknown function")
print(function)
assert False
Expand All @@ -372,4 +386,57 @@ class SPARQLMapper:
name = f"inner_{self.counter}"
self.counter += 1
return name
def bigquery_duration_literal(native:XSDDuration):
f = None
s = ""
last = None
if native.years > 0:
f = "YEAR"
last = "YEAR"
s += str(native.years) + "-"
if native.months > 0 or last is not None:
if f is None:
f = "MONTH"
last = "MONTH"
s += str(native.months)
if native.days > 0 or last is not None:
if f is None:
f = "DAY"
last = "DAY"
if len(s) > 0:
s += " "
s += str(native.days)
if native.hours > 0 or last is not None:
if f is None:
f = "HOUR"
last = "HOUR"
if len(s) > 0:
s += " "
s += str(native.hours)
if native.minutes > 0 or last is not None:
if f is None:
f = "MINUTE"
last = "MINUTE"
if len(s) > 0:
s += ":"
s += str(native.minutes)
whole,decimal = native.seconds
if whole > 0 or last is not None:
if f is None:
f = "SECOND"
last = "SECOND"
if len(s) > 0:
s += ":"
s += str(whole)
return text(f"INTERVAL '{s}' {f} TO {last}")
"#;
7 changes: 5 additions & 2 deletions lib/virtualized_query/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use polars::export::ahash::{HashMap, HashMapExt};
use polars::prelude::AnyValue;
use pyo3::prelude::*;
use representation::python::{PyIRI, PyLiteral, PyVariable};
use spargebra::algebra::{AggregateExpression, AggregateFunction, Expression, OrderExpression};
use spargebra::algebra::{AggregateExpression, AggregateFunction, Expression, Function, OrderExpression};
use spargebra::term::TermPattern;

#[derive(Clone)]
Expand Down Expand Up @@ -605,7 +605,10 @@ impl PyExpression {
py_expressions.push(Py::new(py, PyExpression::new(a, py)?)?);
}
PyExpression::FunctionCall {
function: function.to_string(),
function: match function {
Function::Custom(c) => {c.as_str().to_string()}
n => n.to_string()
},
arguments: py_expressions,
}
}
Expand Down
2 changes: 1 addition & 1 deletion py_chrontext/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "py_chrontext"
version = "0.9.7"
version = "0.9.8"
edition = "2021"

[dependencies]
Expand Down
Loading

0 comments on commit ba83658

Please sign in to comment.