Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

count not aggregating #3421

Open
2 tasks
andrewgazelka opened this issue Nov 25, 2024 · 8 comments
Open
2 tasks

count not aggregating #3421

andrewgazelka opened this issue Nov 25, 2024 · 8 comments
Assignees
Labels
bug Something isn't working p1 Important to tackle soon, but preemptable by p0

Comments

@andrewgazelka
Copy link
Member

andrewgazelka commented Nov 25, 2024

import daft
from daft import col, lit
# Create DataFrame from range
df = daft.from_pydict({"col": list(range(10))})

# Method 1: Using DataFrame API
result = df.agg([lit(1).count()])

# Convert to pydict and print
print(result.collect().to_pydict())

yields

{'literal': [1]}

However, I believe the correct result should be
{'literal': [10]}

This is required to work as the plan created for .count() by spark connect is a similar type of aggregation:

ExecutePlanRequest: ExecutePlanRequest {
    session_id: "0635d57d-c2d2-43e1-b1b6-022f6173b071",
    client_observed_server_side_session_id: None,
    user_context: Some(
        UserContext {
            user_id: "andrewgazelka",
            user_name: "",
            extensions: [],
        },
    ),
    operation_id: Some(
        "3518eaa1-5332-4ef3-8bdd-1b84b3b954d9",
    ),
    plan: Some(
        Plan {
            op_type: Some(
                Root(
                    Relation {
                        common: Some(
                            RelationCommon {
                                source_info: "",
                                plan_id: Some(
                                    1,
                                ),
                                origin: None,
                            },
                        ),
                        rel_type: Some(
                            Aggregate(
                                Aggregate {
                                    input: Some(
                                        Relation {
                                            common: Some(
                                                RelationCommon {
                                                    source_info: "",
                                                    plan_id: Some(
                                                        0,
                                                    ),
                                                    origin: None,
                                                },
                                            ),
                                            rel_type: Some(
                                                Range(
                                                    Range {
                                                        start: Some(
                                                            0,
                                                        ),
                                                        end: 10,
                                                        step: 1,
                                                        num_partitions: None,
                                                    },
                                                ),
                                            ),
                                        },
                                    ),
                                    group_type: Groupby,
                                    grouping_expressions: [],
                                    aggregate_expressions: [
                                        Expression {
                                            common: None,
                                            expr_type: Some(
                                                UnresolvedFunction(
                                                    UnresolvedFunction {
                                                        function_name: "count",
                                                        arguments: [
                                                            Expression {
                                                                common: None,
                                                                expr_type: Some(
                                                                    Literal(
                                                                        Literal {
                                                                            literal_type: Some(
                                                                                Integer(
                                                                                    1,
                                                                                ),
                                                                            ),
                                                                        },
                                                                    ),
                                                                ),
                                                            },
                                                        ],
                                                        is_distinct: false,
                                                        is_user_defined_function: false,
                                                    },
                                                ),
                                            ),
                                        },
                                    ],
                                    pivot: None,
                                    grouping_sets: [],
                                },
                            ),
                        ),
                    },
                ),
            ),
        },
    ),
    client_type: Some(
        "_SPARK_CONNECT_PYTHON spark/3.5.3 os/darwin python/3.12.6",
    ),
    request_options: [
        RequestOption {
            request_option: Some(
                ReattachOptions(
                    ReattachOptions {
                        reattachable: true,
                    },
                ),
            ),
        },
    ],
    tags: [],
}
  • Can expected behavior be verified?
  • Can we verify that the above pure daft test should be 1:1 behavior with the
    protobuf plan below?

want to try it yourself?

def test_count(spark_session):
# Create a range using Spark
# For example, creating a range from 0 to 9
spark_range = spark_session.range(10) # Creates DataFrame with numbers 0 to 9
# Convert to Pandas DataFrame
count = spark_range.count()
# Verify the DataFrame has expected values
assert count == 10, "DataFrame should have 10 rows"

@andrewgazelka andrewgazelka added the bug Something isn't working label Nov 25, 2024
@universalmind303
Copy link
Contributor

universalmind303 commented Nov 25, 2024

@andrewgazelka the way that we perform this kind of count in SQL is by doing a count(*). So i think if you did col("*").count() you would get the expected result.

col("*").count() under the hood will be resolved to col(<schema[0].name>).count()

@andrewgazelka
Copy link
Member Author

andrewgazelka commented Nov 25, 2024

sounds like we can fix by adding this to the optimizer.

pub fn optimize(&self) -> DaftResult<Self> {

@andrewgazelka
Copy link
Member Author

@colin-ho What are your thoughts?

@universalmind303
Copy link
Contributor

universalmind303 commented Nov 25, 2024

Idk if this should be an optimization, but instead either be special handling for spark connect since our implementations diverge, or we change our algorithm to match spark's.

If you look at other engines such as polars, it resolves to 1, but most sql based engines (duckdb, datafusion, spark, ...) will resolve to 10. So I think we need to decide if we should change our logic to match other engines, or keep it as is.

I think the quickest path for the spark-connect use case is to map df.count() to df.agg(col(<schema[0].name>).count()).

@andrewgazelka
Copy link
Member Author

Idk if this should be an optimization, but instead either be special handling for spark connect since our implementations diverge, or we change our algorithm to match spark's.

If you look at other engines such as polars, it resolves to 1, but most sql based engines (duckdb, datafusion, spark, ...) will resolve to 10. So I think we need to decide if we should change our logic to match other engines, or keep it as is.

I think the quickest path for the spark-connect use case is to map df.count() to df.agg(col(<schema[0].name>).count()).

The issue is we cannot really map df.count directly... we need to map the protobuf it generates...

potentially could see if it is 1:1 with the protobuf mentioned at the top of the issue but unsure if this is an acceptable solution.

Also for added context this is how we handle count:

pub fn handle_count(arguments: Vec<daft_dsl::ExprRef>) -> eyre::Result<daft_dsl::ExprRef> {
let arguments: [daft_dsl::ExprRef; 1] = match arguments.try_into() {
Ok(arguments) => arguments,
Err(arguments) => {
bail!("requires exactly one argument; got {arguments:?}");
}
};
let [arg] = arguments;
let count = arg.count(CountMode::All);
Ok(count)
}

@universalmind303
Copy link
Contributor

this is essentially the same as what we're doing in SQL, but instead of count(1) it's count(*). I think it's fine to have special handling handling for this case.

@colin-ho
Copy link
Contributor

this is essentially the same as what we're doing in SQL, but instead of count(1) it's count(*). I think it's fine to have special handling handling for this case.

+1 on this, makes sense to just corner case count(1) for daft-connect

@andrewgazelka
Copy link
Member Author

andrewgazelka commented Nov 25, 2024

this is essentially the same as what we're doing in SQL, but instead of count(1) it's count(*). I think it's fine to have special handling handling for this case.

+1 on this, makes sense to just corner case count(1) for daft-connect

I added a special case

https://github.com/Eventual-Inc/Daft/pull/3352/files#diff-2672bf7ca3e3abc51d6d3686d9494bc490f476a6926c3537ae577d55ea81d085R44-R50

@desmondcheongzx desmondcheongzx added p1 Important to tackle soon, but preemptable by p0 and removed needs triage labels Nov 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working p1 Important to tackle soon, but preemptable by p0
Projects
None yet
Development

No branches or pull requests

4 participants