From 3c6b8523890d59099a2d83354f51f428f39015e9 Mon Sep 17 00:00:00 2001 From: Yoshi-Egawa Date: Tue, 25 Jun 2024 23:30:03 +0900 Subject: [PATCH] fix:copy_into_location_type --- .../_internal/analyzer/snowflake_plan_node.py | 3 +- .../snowpark/_internal/type_utils.py | 8 +++++ src/snowflake/snowpark/dataframe_writer.py | 31 ++++++++++++++----- 3 files changed, 34 insertions(+), 8 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py index 0535f360bee..3b82e5cc13e 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py @@ -12,6 +12,7 @@ PlanNodeCategory, sum_node_complexities, ) +from snowflake.snowpark._internal.type_utils import CopyOptions from snowflake.snowpark.row import Row from snowflake.snowpark.types import StructType @@ -334,7 +335,7 @@ def __init__( file_format_type: Optional[str] = None, format_type_options: Optional[Dict[str, str]] = None, header: bool = False, - copy_options: Dict[str, Any], + copy_options: CopyOptions, ) -> None: super().__init__() self.child = child diff --git a/src/snowflake/snowpark/_internal/type_utils.py b/src/snowflake/snowpark/_internal/type_utils.py index 9f2f08eee90..38b4d8bfe8d 100644 --- a/src/snowflake/snowpark/_internal/type_utils.py +++ b/src/snowflake/snowpark/_internal/type_utils.py @@ -25,6 +25,7 @@ Optional, Tuple, Type, + TypedDict, Union, get_args, get_origin, @@ -987,3 +988,10 @@ def type_string_to_type_object(type_str: str) -> DataType: ColumnOrSqlExpr = Union["snowflake.snowpark.column.Column", str] LiteralType = Union[VALID_PYTHON_TYPES_FOR_LITERAL_VALUE] ColumnOrLiteral = Union["snowflake.snowpark.column.Column", LiteralType] + +class CopyOptions(TypedDict, total=False): + overwrite: bool + single: bool + max_file_size: float + include_query_id: bool + detailed_output: bool diff --git a/src/snowflake/snowpark/dataframe_writer.py b/src/snowflake/snowpark/dataframe_writer.py index 65cf6205724..058b2e25f66 100644 --- a/src/snowflake/snowpark/dataframe_writer.py +++ b/src/snowflake/snowpark/dataframe_writer.py @@ -4,6 +4,7 @@ import sys from typing import Any, Dict, List, Literal, Optional, Union, overload +from typing_extensions import Unpack import snowflake.snowpark # for forward references of type hints from snowflake.snowpark._internal.analyzer.snowflake_plan_node import ( @@ -17,7 +18,7 @@ add_api_call, dfw_collect_api_telemetry, ) -from snowflake.snowpark._internal.type_utils import ColumnOrName, ColumnOrSqlExpr +from snowflake.snowpark._internal.type_utils import ColumnOrName, ColumnOrSqlExpr, CopyOptions from snowflake.snowpark._internal.utils import ( SUPPORTED_TABLE_TYPES, get_aliased_option_name, @@ -293,7 +294,7 @@ def copy_into_location( header: bool = False, statement_params: Optional[Dict[str, str]] = None, block: Literal[True] = True, - **copy_options: Optional[Dict[str, Any]], + **copy_options: Unpack[CopyOptions], ) -> List[Row]: ... # pragma: no cover @@ -309,10 +310,26 @@ def copy_into_location( header: bool = False, statement_params: Optional[Dict[str, str]] = None, block: Literal[False] = False, - **copy_options: Optional[Dict[str, Any]], + **copy_options: Unpack[CopyOptions], ) -> AsyncJob: ... # pragma: no cover + @overload + def copy_into_location( + self, + location: str, + *, + partition_by: Optional[ColumnOrSqlExpr] = None, + file_format_name: Optional[str] = None, + file_format_type: Optional[str] = None, + format_type_options: Optional[Dict[str, str]] = None, + header: bool = False, + statement_params: Optional[Dict[str, str]] = None, + block: bool = True, + **copy_options: Unpack[CopyOptions], + ) -> Union[List[Row], AsyncJob]: + ... # pragma: no cover + def copy_into_location( self, location: str, @@ -324,7 +341,7 @@ def copy_into_location( header: bool = False, statement_params: Optional[Dict[str, str]] = None, block: bool = True, - **copy_options: Optional[Dict[str, Any]], + **copy_options: Unpack[CopyOptions], ) -> Union[List[Row], AsyncJob]: """Executes a `COPY INTO `__ to unload data from a ``DataFrame`` into one or more files in a stage or external stage. @@ -412,7 +429,7 @@ def csv( header: bool = False, statement_params: Optional[Dict[str, str]] = None, block: bool = True, - **copy_options: Optional[str], + **copy_options: Unpack[CopyOptions], ) -> Union[List[Row], AsyncJob]: """Executes internally a `COPY INTO `__ to unload data from a ``DataFrame`` into one or more CSV files in a stage or external stage. @@ -458,7 +475,7 @@ def json( header: bool = False, statement_params: Optional[Dict[str, str]] = None, block: bool = True, - **copy_options: Optional[str], + **copy_options: Unpack[CopyOptions], ) -> Union[List[Row], AsyncJob]: """Executes internally a `COPY INTO `__ to unload data from a ``DataFrame`` into a JSON file in a stage or external stage. @@ -505,7 +522,7 @@ def parquet( header: bool = False, statement_params: Optional[Dict[str, str]] = None, block: bool = True, - **copy_options: Optional[str], + **copy_options: Unpack[CopyOptions], ) -> Union[List[Row], AsyncJob]: """Executes internally a `COPY INTO `__ to unload data from a ``DataFrame`` into a PARQUET file in a stage or external stage.