Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add compression for Pandas.to_parquet #28

Merged
merged 1 commit into from
Sep 19, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions awswrangler/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,7 @@ class InvalidSerDe(Exception):

class ApiError(Exception):
pass


class InvalidCompression(Exception):
pass
48 changes: 34 additions & 14 deletions awswrangler/glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,14 @@ def metadata_to_glue(self,
partition_cols=None,
preserve_index=True,
mode="append",
compression=None,
cast_columns=None,
extra_args=None):
schema, partition_cols_schema = Glue._build_schema(
dataframe=dataframe,
partition_cols=partition_cols,
preserve_index=preserve_index)
preserve_index=preserve_index,
cast_columns=cast_columns)
table = table if table else Glue._parse_table_name(path)
table = table.lower().replace(".", "_")
if mode == "overwrite":
Expand All @@ -151,6 +153,7 @@ def metadata_to_glue(self,
partition_cols_schema=partition_cols_schema,
path=path,
file_format=file_format,
compression=compression,
extra_args=extra_args)
if partition_cols:
partitions_tuples = Glue._parse_partitions_tuples(
Expand All @@ -159,6 +162,7 @@ def metadata_to_glue(self,
table=table,
partition_paths=partitions_tuples,
file_format=file_format,
compression=compression,
extra_args=extra_args)

def delete_table_if_exists(self, database, table):
Expand All @@ -180,16 +184,18 @@ def create_table(self,
schema,
path,
file_format,
compression,
partition_cols_schema=None,
extra_args=None):
if file_format == "parquet":
table_input = Glue.parquet_table_definition(
table, partition_cols_schema, schema, path)
table, partition_cols_schema, schema, path, compression)
elif file_format == "csv":
table_input = Glue.csv_table_definition(table,
partition_cols_schema,
schema,
path,
compression,
extra_args=extra_args)
else:
raise UnsupportedFileFormat(file_format)
Expand Down Expand Up @@ -227,15 +233,21 @@ def get_connection_details(self, name):
Name=name, HidePassword=False)["Connection"]

@staticmethod
def _extract_pyarrow_schema(dataframe, preserve_index):
def _extract_pyarrow_schema(dataframe, preserve_index, cast_columns=None):
cols = []
cols_dtypes = {}
schema = []

casted = []
if cast_columns is not None:
casted = cast_columns.keys()

for name, dtype in dataframe.dtypes.to_dict().items():
dtype = str(dtype)
if str(dtype) == "Int64":
if dtype == "Int64":
cols_dtypes[name] = "int64"
elif name in casted:
cols_dtypes[name] = cast_columns[name]
else:
cols.append(name)

Expand All @@ -252,13 +264,18 @@ def _extract_pyarrow_schema(dataframe, preserve_index):
return schema

@staticmethod
def _build_schema(dataframe, partition_cols, preserve_index):
def _build_schema(dataframe,
partition_cols,
preserve_index,
cast_columns={}):
logger.debug(f"dataframe.dtypes:\n{dataframe.dtypes}")
if not partition_cols:
partition_cols = []

pyarrow_schema = Glue._extract_pyarrow_schema(
dataframe=dataframe, preserve_index=preserve_index)
dataframe=dataframe,
preserve_index=preserve_index,
cast_columns=cast_columns)

schema_built = []
partition_cols_types = {}
Expand All @@ -285,9 +302,10 @@ def _parse_table_name(path):

@staticmethod
def csv_table_definition(table, partition_cols_schema, schema, path,
extra_args):
compression, extra_args):
if not partition_cols_schema:
partition_cols_schema = []
compressed = False if compression is None else True
sep = extra_args["sep"] if "sep" in extra_args else ","
serde = extra_args.get("serde")
if serde == "OpenCSVSerDe":
Expand Down Expand Up @@ -322,7 +340,7 @@ def csv_table_definition(table, partition_cols_schema, schema, path,
"EXTERNAL_TABLE",
"Parameters": {
"classification": "csv",
"compressionType": "none",
"compressionType": str(compression).lower(),
"typeOfData": "file",
"delimiter": sep,
"columnsOrdered": "true",
Expand All @@ -337,7 +355,7 @@ def csv_table_definition(table, partition_cols_schema, schema, path,
"InputFormat": "org.apache.hadoop.mapred.TextInputFormat",
"OutputFormat":
"org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat",
"Compressed": False,
"Compressed": True,
"NumberOfBuckets": -1,
"SerdeInfo": {
"Parameters": param,
Expand All @@ -347,7 +365,7 @@ def csv_table_definition(table, partition_cols_schema, schema, path,
"SortColumns": [],
"Parameters": {
"classification": "csv",
"compressionType": "none",
"compressionType": str(compression).lower(),
"typeOfData": "file",
"delimiter": sep,
"columnsOrdered": "true",
Expand Down Expand Up @@ -386,9 +404,11 @@ def csv_partition_definition(partition, extra_args):
}

@staticmethod
def parquet_table_definition(table, partition_cols_schema, schema, path):
def parquet_table_definition(table, partition_cols_schema, schema, path,
compression):
if not partition_cols_schema:
partition_cols_schema = []
compressed = False if compression is None else True
return {
"Name":
table,
Expand All @@ -400,7 +420,7 @@ def parquet_table_definition(table, partition_cols_schema, schema, path):
"EXTERNAL_TABLE",
"Parameters": {
"classification": "parquet",
"compressionType": "none",
"compressionType": str(compression).lower(),
"typeOfData": "file",
},
"StorageDescriptor": {
Expand All @@ -413,7 +433,7 @@ def parquet_table_definition(table, partition_cols_schema, schema, path):
"org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat",
"OutputFormat":
"org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat",
"Compressed": False,
"Compressed": compressed,
"NumberOfBuckets": -1,
"SerdeInfo": {
"SerializationLibrary":
Expand All @@ -427,7 +447,7 @@ def parquet_table_definition(table, partition_cols_schema, schema, path):
"Parameters": {
"CrawlerSchemaDeserializerVersion": "1.0",
"classification": "parquet",
"compressionType": "none",
"compressionType": str(compression).lower(),
"typeOfData": "file",
},
},
Expand Down
Loading