From f324cb7dfdab58918867d3e9bcc2a25ede42a950 Mon Sep 17 00:00:00 2001
From: Bo Li <drluodian@gmail.com>
Date: Sun, 10 Dec 2023 13:08:05 +0800
Subject: [PATCH] Refactor image processing and add support for resizing images

---
 pipeline/utils/convert_to_parquet.py | 41 +++++++++++++++++++---------
 1 file changed, 28 insertions(+), 13 deletions(-)

diff --git a/pipeline/utils/convert_to_parquet.py b/pipeline/utils/convert_to_parquet.py
index 7a2ad23b..65661e16 100644
--- a/pipeline/utils/convert_to_parquet.py
+++ b/pipeline/utils/convert_to_parquet.py
@@ -5,9 +5,9 @@
 from tqdm import tqdm
 import argparse
 import orjson
+import dask.dataframe as dd
 
-
-def process_images(base64_str):
+def process_images(base64_str, resize_res=-1):
     import base64
     from PIL import Image
     from io import BytesIO
@@ -21,7 +21,10 @@ def process_images(base64_str):
         base64_str += "=" * padding_needed
 
     try:
-        img = Image.open(BytesIO(base64.urlsafe_b64decode(base64_str))).convert("RGB")
+        if resize_res == -1:
+            img = Image.open(BytesIO(base64.urlsafe_b64decode(base64_str))).convert("RGB")
+        else:
+            img = Image.open(BytesIO(base64.urlsafe_b64decode(base64_str))).convert("RGB").resize((resize_res, resize_res))
     except Exception as e:
         print(f"Warning: Failed to open image. Error: {e}")
         return None
@@ -35,23 +38,30 @@ def process_images(base64_str):
     return new_base64_str
 
 
-def convert_json_to_parquet(input_path, output_path):
+def convert_json_to_parquet(input_path, output_path, max_partition_size):
     start_time = time.time()
     with open(input_path, "rb") as f:
-        data_dict = orjson.loads(f.read())
-    # with open(input_path, "r") as f:
-    #     data_dict = json.load(f)
+        data = f.read()
+        data_dict = orjson.loads(data)
+
+    # Estimate the size of the JSON dictionary in bytes
+    total_size = len(data)
+    print(f"Total size of the JSON data: {total_size} bytes")
 
+    # Calculate the number of partitions needed
+    nparitions = max(1, total_size // max_partition_size)
+    print(f"Number of partitions: {nparitions}")
+    
     resized_data_dict = {}
     dropped_keys = []
     for key, value in tqdm(data_dict.items(), desc=f"Processing {input_path}"):
         if isinstance(value, list):
             value = value[0]
-        # resized_base64 = process_images(value)
-        resized_data_dict[key] = value
+        resized_base64 = process_images(value)
+        resized_data_dict[key] = resized_base64
 
-    df = pd.DataFrame.from_dict(resized_data_dict, orient="index", columns=["base64"])
-    df.to_parquet(output_path, engine="pyarrow")
+    ddf = dd.from_pandas(pd.DataFrame.from_dict(resized_data_dict, orient="index", columns=["base64"]), npartitions=nparitions)
+    ddf.to_parquet(output_path, engine="pyarrow")
 
     end_time = time.time()
     print(f"Converting {input_path} to parquet takes {end_time - start_time} seconds.")
@@ -62,11 +72,16 @@ def main():
     parser = argparse.ArgumentParser(description="Convert JSON to Parquet")
     parser.add_argument("--input_path", help="Path to the input JSON file")
     parser.add_argument("--output_path", help="Path for the output Parquet file")
+    parser.add_argument("--resize_res", type=int, default=-1)
+    parser.add_argument("--max_partition_size_gb", type=float, default=1.5, help="Maximum size of each partition in GB")
     args = parser.parse_args()
 
-    dropped_keys = convert_json_to_parquet(args.input_path, args.output_path)
-    print(dropped_keys)
+    # Convert GB to bytes for max_partition_size
+    max_partition_size = args.max_partition_size_gb * 1024**3
 
+    dropped_keys = convert_json_to_parquet(args.input_path, args.output_path, max_partition_size)
+    print(f"Number of dropped keys: {len(dropped_keys)}")
+    print(f"Dropped keys: {dropped_keys}")
 
 if __name__ == "__main__":
     main()