From ca7000631be45f4033863ca71e009704596b5200 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Fri, 1 Dec 2023 15:49:58 -0800 Subject: [PATCH 01/62] initial commit --- scripts/data_prep/convert_delta_to_json.py | 91 ++++++++++++++++++++++ 1 file changed, 91 insertions(+) create mode 100644 scripts/data_prep/convert_delta_to_json.py diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py new file mode 100644 index 0000000000..e59daeee5d --- /dev/null +++ b/scripts/data_prep/convert_delta_to_json.py @@ -0,0 +1,91 @@ +import os +import logging +import requests +import json +import boto3 +from botocore.config import Config +import argparse +import tempfile +from urllib.parse import urlparse +from pyspark.sql import SparkSession +#from deltalake import DeltaTable +import time +import pandas as pd +from databricks import sql + +log = logging.getLogger(__name__) + +def delta_to_json(spark, delta_table_path, json_output_path): + log.info('Convert table {delta_table_path} to json and saving to {json_output_path}') + obj = urlparse(delta_table_path) + scheme, bucket, path = obj.scheme, obj.netloc, obj.path + if scheme == '' and bucket == '' and path == '': + raise FileNotFoundError( + f'Check data availability! local index {url[0]} is not accessible.' + + f'remote index {url[1]} does not have a valid url format') + + if scheme == '': # local + #delta_df = read_table_from_local(path) + delta_df = read_table_from_uc(delta_table_path) + #elif scheme == 'dbfs': # uc table format: dbfs:/Volumes////path/to/folder + # if path.startswith('/Volumes'): + # delta_df = read_table_from_uc(delta_table_path) + else: + log.warning(f"Support of scheme {scheme} has not implemented!") + raise NotImplementedError + + delta_df.write.mode("overwrite").format("json").save(json_output_path) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Download delta table from UC and convert to json to save local") + parser.add_argument("--delta_table_name", required=True, type=str, help="UC table of format ..") + parser.add_argument("--json_output_path", required=True, type=str, help="Local path to save the converted json") + parser.add_argument("--debug", type=bool, required=False, default=False) + args = parser.parse_args() + + # Note: delta-io has been renamed to delta-spark after 3.0.0 + # For spark < 3.5.0, update the configuration from + # https://docs.delta.io/latest/quick-start.html + spark = SparkSession.builder \ + .appName("Delta to JSON Test") \ + .config("spark.jars.packages", "io.delta:delta-spark_2.12:3.0.0") \ + .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") \ + .config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog") \ + .getOrCreate() + + if 0: # args.debug == True: + # Test Local + sample_data = [("Alice", 1), ("Bob", 2)] + columns = ["Name", "Id"] + spark_df = spark.createDataFrame(sample_data, schema=columns) + + #with tempfile.TemporaryDirectory() as temp_dir: + temp_dir = '/tmp/test_delta_to_json' + print('temp_dir = ', temp_dir) + delta_table_path = os.path.join(temp_dir, "delta_table") + spark_df.write.format("delta").save(delta_table_path) + json_output_path = os.path.join(temp_dir, "json_output") + + delta_to_json(spark, delta_table_path, json_output_path) + + + #connection = sql.connect( + # server_hostname=os.getenv("DATABRICKS_HOST"), + # http_path="sql/protocolv1/o/7395834863327820/1116-234530-6seh113n", # from compute.JDBC + # access_token=os.getenv("DATABRICKS_TOKEN") + # ) + connection = sql.connect( + server_hostname="e2-dogfood.staging.cloud.databricks.com", + http_path="/sql/1.0/warehouses/7e083095329f3ca5", + access_token="dapi18a0a6fa53b5fb1afbf1101c93eee31f" + ) + cursor = connection.cursor() + cursor.execute(f"USE CATALOG main;") + cursor.execute(f"USE SCHEMA streaming;") + cursor.execute(f"SELECT * FROM dummy_table") + ans = cursor.fetchall() + connection.commit() + + result = [ row.asDict() for row in ans ] + df = pd.DataFrame.from_dict(result) + df.to_json(args.json_output_path) From 34b1bb97b43cceee4db6200a81feb4f4a13ad562 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Fri, 1 Dec 2023 17:00:01 -0800 Subject: [PATCH 02/62] use databricks-sql to read delta table and convert to json --- scripts/data_prep/convert_delta_to_json.py | 161 +++++++++++---------- setup.py | 1 + 2 files changed, 87 insertions(+), 75 deletions(-) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index e59daeee5d..d6e0c2fc27 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -1,91 +1,102 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + import os -import logging -import requests -import json -import boto3 -from botocore.config import Config import argparse -import tempfile -from urllib.parse import urlparse -from pyspark.sql import SparkSession -#from deltalake import DeltaTable -import time +import logging import pandas as pd from databricks import sql log = logging.getLogger(__name__) -def delta_to_json(spark, delta_table_path, json_output_path): - log.info('Convert table {delta_table_path} to json and saving to {json_output_path}') - obj = urlparse(delta_table_path) - scheme, bucket, path = obj.scheme, obj.netloc, obj.path - if scheme == '' and bucket == '' and path == '': - raise FileNotFoundError( - f'Check data availability! local index {url[0]} is not accessible.' + - f'remote index {url[1]} does not have a valid url format') - - if scheme == '': # local - #delta_df = read_table_from_local(path) - delta_df = read_table_from_uc(delta_table_path) - #elif scheme == 'dbfs': # uc table format: dbfs:/Volumes////path/to/folder - # if path.startswith('/Volumes'): - # delta_df = read_table_from_uc(delta_table_path) - else: - log.warning(f"Support of scheme {scheme} has not implemented!") - raise NotImplementedError - - delta_df.write.mode("overwrite").format("json").save(json_output_path) +""" +Sample tables are created here + + - E2-dogfood: https://e2-dogfood.staging.cloud.databricks.com/?o=6051921418418893#notebook/3642707736157009/command/551761898400018 + - Data Force One: https://dbc-559ffd80-2bfc.cloud.databricks.com/?o=7395834863327820#notebook/2500382962301597/command/2500382962301599 + +The script can be called as: + + - python scripts/data_prep/convert_delta_to_json.py --delta_table_name 'main.streaming.dummy_table' --json_output_path /tmp/delta2json2 --debug False --http_path 'sql/protocolv1/o/7395834863327820/1116-234530-6seh113n' + + - python scripts/data_prep/convert_delta_to_json.py --delta_table_name 'main.streaming.dummy_table' --json_output_path /tmp/delta2json2 --debug False --http_path /sql/1.0/warehouses/7e083095329f3ca5 --DATABRICKS_HOST e2-dogfood.staging.cloud.databricks.com --DATABRICKS_TOKEN dapi18a0a6fa53b5fb1afbf1101c93eee31f + +""" + +def stream_delta_to_json(connection, tablename, json_output_folder, key = 'name', batch_size=3): + + cursor = connection.cursor() + cursor.execute(f"USE CATALOG main;") + cursor.execute(f"USE SCHEMA streaming;") + cursor.execute(f"SELECT COUNT(*) FROM {tablename}") + ans = cursor.fetchall() + + total_rows = [ row.asDict() for row in ans ][0].popitem()[1] + print('total_rows = ', total_rows) + + cursor.execute(f"SHOW COLUMNS IN {tablename}") + ans = cursor.fetchall() + + order_by = [ row.asDict() for row in ans ][0].popitem()[1] + print('order by column ', order_by) + + for start in range(0, total_rows, batch_size): + end = min(start + batch_size, total_rows) + query = f""" + WITH NumberedRows AS ( + SELECT + *, + ROW_NUMBER() OVER (ORDER BY {key}) AS rn + FROM + {tablename} + ) + SELECT * + FROM NumberedRows + WHERE rn BETWEEN {start+1} AND {end}""" + cursor.execute(query) + ans = cursor.fetchall() + + result = [ row.asDict() for row in ans ] + print(result) + df = pd.DataFrame.from_dict(result) + df.to_json(os.path.join(args.json_output_path, f'shard_{start+1}_{end}.json')) + + cursor.close() + connection.close() + if __name__ == "__main__": + print(f"Start .... Convert delta to json") parser = argparse.ArgumentParser(description="Download delta table from UC and convert to json to save local") parser.add_argument("--delta_table_name", required=True, type=str, help="UC table of format ..
") parser.add_argument("--json_output_path", required=True, type=str, help="Local path to save the converted json") + parser.add_argument("--DATABRICKS_HOST", required=False, type=str, help="DATABRICKS_HOST") + parser.add_argument("--DATABRICKS_TOKEN", required=False, type=str, help="DATABRICKS_TOKEN") + parser.add_argument("--http_path", required=True, type=str, help="http_path from either dedicated cluster or serverless sql warehouse") parser.add_argument("--debug", type=bool, required=False, default=False) args = parser.parse_args() - # Note: delta-io has been renamed to delta-spark after 3.0.0 - # For spark < 3.5.0, update the configuration from - # https://docs.delta.io/latest/quick-start.html - spark = SparkSession.builder \ - .appName("Delta to JSON Test") \ - .config("spark.jars.packages", "io.delta:delta-spark_2.12:3.0.0") \ - .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") \ - .config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog") \ - .getOrCreate() - - if 0: # args.debug == True: - # Test Local - sample_data = [("Alice", 1), ("Bob", 2)] - columns = ["Name", "Id"] - spark_df = spark.createDataFrame(sample_data, schema=columns) - - #with tempfile.TemporaryDirectory() as temp_dir: - temp_dir = '/tmp/test_delta_to_json' - print('temp_dir = ', temp_dir) - delta_table_path = os.path.join(temp_dir, "delta_table") - spark_df.write.format("delta").save(delta_table_path) - json_output_path = os.path.join(temp_dir, "json_output") - - delta_to_json(spark, delta_table_path, json_output_path) - - - #connection = sql.connect( - # server_hostname=os.getenv("DATABRICKS_HOST"), - # http_path="sql/protocolv1/o/7395834863327820/1116-234530-6seh113n", # from compute.JDBC - # access_token=os.getenv("DATABRICKS_TOKEN") - # ) - connection = sql.connect( - server_hostname="e2-dogfood.staging.cloud.databricks.com", - http_path="/sql/1.0/warehouses/7e083095329f3ca5", - access_token="dapi18a0a6fa53b5fb1afbf1101c93eee31f" - ) - cursor = connection.cursor() - cursor.execute(f"USE CATALOG main;") - cursor.execute(f"USE SCHEMA streaming;") - cursor.execute(f"SELECT * FROM dummy_table") - ans = cursor.fetchall() - connection.commit() + server_hostname = args.DATABRICKS_HOST if args.DATABRICKS_HOST else os.getenv("DATABRICKS_HOST") + access_token = args.DATABRICKS_TOKEN if args.DATABRICKS_TOKEN else os.getenv("DATABRICKS_TOKEN") + http_path= args.http_path # " + + try: + connection = sql.connect( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + ) + except Exception as e: + raise RuntimeError("Failed to create sql connection to db workspace. Check {server_hostname} and {http_path} and access token!") from exc + + #if os.path.exists(args.json_output_path): + # if not os.path.isdir(args.json_output_path) or os.listdir(args.json_output_path): + # raise RuntimeError(f"A file or a folder {args.json_output_path} already exists and is not empty. Remove it and retry!") + + os.makedirs(args.json_output_path, exist_ok=True) + log.info(f"Directory {args.json_output_path} created.") + + stream_delta_to_json(connection, args.delta_table_name, args.json_output_path) - result = [ row.asDict() for row in ans ] - df = pd.DataFrame.from_dict(result) - df.to_json(args.json_output_path) + print(f"Convert delta to json is done. check {args.json_output_path}.") + log.info(f"Convert delta to json is done. check {args.json_output_path}.") diff --git a/setup.py b/setup.py index 152c682a2d..de1d14ccda 100644 --- a/setup.py +++ b/setup.py @@ -85,6 +85,7 @@ extra_deps['databricks'] = [ 'mosaicml[databricks]>=0.17.2,<0.18', + 'databricks-sql-connector>=3, <4', ] extra_deps['tensorboard'] = [ From 5944314819ec3af52017295f0d56356151a635b8 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Fri, 1 Dec 2023 21:45:18 -0800 Subject: [PATCH 03/62] update --- scripts/data_prep/convert_delta_to_json.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index d6e0c2fc27..300c04fe9d 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -19,11 +19,11 @@ - python scripts/data_prep/convert_delta_to_json.py --delta_table_name 'main.streaming.dummy_table' --json_output_path /tmp/delta2json2 --debug False --http_path 'sql/protocolv1/o/7395834863327820/1116-234530-6seh113n' - - python scripts/data_prep/convert_delta_to_json.py --delta_table_name 'main.streaming.dummy_table' --json_output_path /tmp/delta2json2 --debug False --http_path /sql/1.0/warehouses/7e083095329f3ca5 --DATABRICKS_HOST e2-dogfood.staging.cloud.databricks.com --DATABRICKS_TOKEN dapi18a0a6fa53b5fb1afbf1101c93eee31f + - python scripts/data_prep/convert_delta_to_json.py --delta_table_name 'main.streaming.dummy_table' --json_output_path /tmp/delta2json2 --debug False --http_path {http_path} --DATABRICKS_HOST {your host} --DATABRICKS_TOKEN {your token} """ -def stream_delta_to_json(connection, tablename, json_output_folder, key = 'name', batch_size=3): +def stream_delta_to_json(connection, tablename, json_output_folder, key = 'name', batch_size=1<<20): cursor = connection.cursor() cursor.execute(f"USE CATALOG main;") @@ -32,16 +32,17 @@ def stream_delta_to_json(connection, tablename, json_output_folder, key = 'name' ans = cursor.fetchall() total_rows = [ row.asDict() for row in ans ][0].popitem()[1] - print('total_rows = ', total_rows) + log.info(f'total_rows = {total_rows}') cursor.execute(f"SHOW COLUMNS IN {tablename}") ans = cursor.fetchall() order_by = [ row.asDict() for row in ans ][0].popitem()[1] - print('order by column ', order_by) + log.info(f'order by column {order_by}') for start in range(0, total_rows, batch_size): end = min(start + batch_size, total_rows) + query = f""" WITH NumberedRows AS ( SELECT @@ -53,6 +54,7 @@ def stream_delta_to_json(connection, tablename, json_output_folder, key = 'name' SELECT * FROM NumberedRows WHERE rn BETWEEN {start+1} AND {end}""" + cursor.execute(query) ans = cursor.fetchall() @@ -78,7 +80,7 @@ def stream_delta_to_json(connection, tablename, json_output_folder, key = 'name' server_hostname = args.DATABRICKS_HOST if args.DATABRICKS_HOST else os.getenv("DATABRICKS_HOST") access_token = args.DATABRICKS_TOKEN if args.DATABRICKS_TOKEN else os.getenv("DATABRICKS_TOKEN") - http_path= args.http_path # " + http_path= args.http_path try: connection = sql.connect( @@ -89,9 +91,9 @@ def stream_delta_to_json(connection, tablename, json_output_folder, key = 'name' except Exception as e: raise RuntimeError("Failed to create sql connection to db workspace. Check {server_hostname} and {http_path} and access token!") from exc - #if os.path.exists(args.json_output_path): - # if not os.path.isdir(args.json_output_path) or os.listdir(args.json_output_path): - # raise RuntimeError(f"A file or a folder {args.json_output_path} already exists and is not empty. Remove it and retry!") + if os.path.exists(args.json_output_path): + if not os.path.isdir(args.json_output_path) or os.listdir(args.json_output_path): + raise RuntimeError(f"A file or a folder {args.json_output_path} already exists and is not empty. Remove it and retry!") os.makedirs(args.json_output_path, exist_ok=True) log.info(f"Directory {args.json_output_path} created.") From 3cc004e66725cbcd5cc8614dd9267bc1269a0535 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Fri, 1 Dec 2023 21:50:22 -0800 Subject: [PATCH 04/62] update --- scripts/data_prep/convert_delta_to_json.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index 300c04fe9d..7c5d19c81b 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -9,20 +9,6 @@ log = logging.getLogger(__name__) -""" -Sample tables are created here - - - E2-dogfood: https://e2-dogfood.staging.cloud.databricks.com/?o=6051921418418893#notebook/3642707736157009/command/551761898400018 - - Data Force One: https://dbc-559ffd80-2bfc.cloud.databricks.com/?o=7395834863327820#notebook/2500382962301597/command/2500382962301599 - -The script can be called as: - - - python scripts/data_prep/convert_delta_to_json.py --delta_table_name 'main.streaming.dummy_table' --json_output_path /tmp/delta2json2 --debug False --http_path 'sql/protocolv1/o/7395834863327820/1116-234530-6seh113n' - - - python scripts/data_prep/convert_delta_to_json.py --delta_table_name 'main.streaming.dummy_table' --json_output_path /tmp/delta2json2 --debug False --http_path {http_path} --DATABRICKS_HOST {your host} --DATABRICKS_TOKEN {your token} - -""" - def stream_delta_to_json(connection, tablename, json_output_folder, key = 'name', batch_size=1<<20): cursor = connection.cursor() From b22c4bbd72cf1b2b0b39027ff971d4d9538e96c6 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Fri, 1 Dec 2023 22:11:38 -0800 Subject: [PATCH 05/62] update --- scripts/data_prep/convert_delta_to_json.py | 44 ++++++++++++---------- 1 file changed, 25 insertions(+), 19 deletions(-) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index 7c5d19c81b..ed91bc5c1a 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -9,7 +9,25 @@ log = logging.getLogger(__name__) -def stream_delta_to_json(connection, tablename, json_output_folder, key = 'name', batch_size=1<<20): +def stream_delta_to_json(args: argparse.Namespace, + batch_size:int =1<<20): + """Read UC delta table and convert it to json. Save json files to local. + In the case of table has more than batch_size rows, read the table batch_size rows a time + """ + server_hostname = args.DATABRICKS_HOST if args.DATABRICKS_HOST else os.getenv("DATABRICKS_HOST") + access_token = args.DATABRICKS_TOKEN if args.DATABRICKS_TOKEN else os.getenv("DATABRICKS_TOKEN") + http_path= args.http_path + tablename = args.delta_table_name + json_output_path = args.json_output_path + + try: + connection = sql.connect( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + ) + except Exception as e: + raise RuntimeError("Failed to create sql connection to db workspace. Check {server_hostname} and {http_path} and access token!") from exc cursor = connection.cursor() cursor.execute(f"USE CATALOG main;") @@ -23,6 +41,7 @@ def stream_delta_to_json(connection, tablename, json_output_folder, key = 'name' cursor.execute(f"SHOW COLUMNS IN {tablename}") ans = cursor.fetchall() + # Get the first column to order by. can be any column order_by = [ row.asDict() for row in ans ][0].popitem()[1] log.info(f'order by column {order_by}') @@ -33,7 +52,7 @@ def stream_delta_to_json(connection, tablename, json_output_folder, key = 'name' WITH NumberedRows AS ( SELECT *, - ROW_NUMBER() OVER (ORDER BY {key}) AS rn + ROW_NUMBER() OVER (ORDER BY {order_by}) AS rn FROM {tablename} ) @@ -45,16 +64,15 @@ def stream_delta_to_json(connection, tablename, json_output_folder, key = 'name' ans = cursor.fetchall() result = [ row.asDict() for row in ans ] - print(result) df = pd.DataFrame.from_dict(result) - df.to_json(os.path.join(args.json_output_path, f'shard_{start+1}_{end}.json')) + df.to_json(os.path.join(json_output_path, f'shard_{start+1}_{end}.json')) cursor.close() connection.close() if __name__ == "__main__": - print(f"Start .... Convert delta to json") + log.info(f"Start .... Convert delta to json") parser = argparse.ArgumentParser(description="Download delta table from UC and convert to json to save local") parser.add_argument("--delta_table_name", required=True, type=str, help="UC table of format ..
") parser.add_argument("--json_output_path", required=True, type=str, help="Local path to save the converted json") @@ -64,27 +82,15 @@ def stream_delta_to_json(connection, tablename, json_output_folder, key = 'name' parser.add_argument("--debug", type=bool, required=False, default=False) args = parser.parse_args() - server_hostname = args.DATABRICKS_HOST if args.DATABRICKS_HOST else os.getenv("DATABRICKS_HOST") - access_token = args.DATABRICKS_TOKEN if args.DATABRICKS_TOKEN else os.getenv("DATABRICKS_TOKEN") - http_path= args.http_path - - try: - connection = sql.connect( - server_hostname=server_hostname, - http_path=http_path, - access_token=access_token, - ) - except Exception as e: - raise RuntimeError("Failed to create sql connection to db workspace. Check {server_hostname} and {http_path} and access token!") from exc - if os.path.exists(args.json_output_path): if not os.path.isdir(args.json_output_path) or os.listdir(args.json_output_path): raise RuntimeError(f"A file or a folder {args.json_output_path} already exists and is not empty. Remove it and retry!") os.makedirs(args.json_output_path, exist_ok=True) + log.info(f"Directory {args.json_output_path} created.") - stream_delta_to_json(connection, args.delta_table_name, args.json_output_path) + stream_delta_to_json(args) print(f"Convert delta to json is done. check {args.json_output_path}.") log.info(f"Convert delta to json is done. check {args.json_output_path}.") From b0dd43ade6075e81b1f2f29994872244b506fd42 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Fri, 1 Dec 2023 22:38:25 -0800 Subject: [PATCH 06/62] add mocked unittest --- tests/test_convert_delta_to_json.py | 53 +++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 tests/test_convert_delta_to_json.py diff --git a/tests/test_convert_delta_to_json.py b/tests/test_convert_delta_to_json.py new file mode 100644 index 0000000000..2280093f5c --- /dev/null +++ b/tests/test_convert_delta_to_json.py @@ -0,0 +1,53 @@ +# copyright 2022 mosaicml llm foundry authors +# spdx-license-identifier: apache-2.0 + +import pytest +import os +import sys +import warnings + +# Add repo root to path so we can import scripts and test it +repo_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) +sys.path.append(repo_dir) + +import unittest +from unittest.mock import patch, MagicMock, mock_open +from scripts.data_prep.convert_delta_to_json import stream_delta_to_json + +class TestStreamDeltaToJson(unittest.TestCase): + + @patch('scripts.data_prep.convert_delta_to_json.sql.connect') + @patch('scripts.data_prep.convert_delta_to_json.pd.DataFrame.to_json') + def test_stream_delta_to_json(self, mock_to_json, mock_connect): + mock_args = MagicMock() + mock_args.DATABRICKS_HOST = 'test_host' + mock_args.DATABRICKS_TOKEN = 'test_token' + mock_args.http_path = 'test_http_path' + mock_args.delta_table_name = 'test_table' + mock_args.json_output_path = 'test_output_path' + + # Mock database connection and cursor + mock_cursor = MagicMock() + mock_connection = MagicMock() + mock_connection.cursor.return_value = mock_cursor + mock_connect.return_value = mock_connection + + # Mock fetchall response + count_response = MagicMock() + count_response.asDict.return_value = {'COUNT(*)': 3} + column_response_item = MagicMock() + column_response_item.asDict.return_value = {'COLUMN_NAME': 'name'} # Assuming SHOW COLUMNS query returns this format + data_response_item = MagicMock() + data_response_item.asDict.return_value = {'name': 'test', 'id': 1} # Assuming SELECT query returns this format + mock_cursor.fetchall.side_effect = [[count_response], [column_response_item], [data_response_item]] + + stream_delta_to_json(mock_args) + + mock_connect.assert_called_once_with(server_hostname='test_host', http_path='test_http_path', access_token='test_token') + mock_to_json.assert_called() + mock_cursor.close.assert_called() + mock_connection.close.assert_called() + +if __name__ == '__main__': + unittest.main() + From 286650e11dd2bb3bad49075ba23cd07ec48ad2ae Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Fri, 1 Dec 2023 22:58:30 -0800 Subject: [PATCH 07/62] Fix lints --- scripts/data_prep/convert_delta_to_json.py | 101 ++++++++++++++------- tests/test_convert_delta_to_json.py | 30 ++++-- 2 files changed, 88 insertions(+), 43 deletions(-) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index ed91bc5c1a..3f98049ceb 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -1,48 +1,55 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -import os import argparse import logging +import os + import pandas as pd from databricks import sql log = logging.getLogger(__name__) -def stream_delta_to_json(args: argparse.Namespace, - batch_size:int =1<<20): - """Read UC delta table and convert it to json. Save json files to local. - In the case of table has more than batch_size rows, read the table batch_size rows a time + +def stream_delta_to_json(args: argparse.Namespace, batch_size: int = 1 << 20): + """Read UC delta table and convert it to json. + + Save json files to local. In the case of table has more than batch_size + rows, read the table batch_size rows a time """ - server_hostname = args.DATABRICKS_HOST if args.DATABRICKS_HOST else os.getenv("DATABRICKS_HOST") - access_token = args.DATABRICKS_TOKEN if args.DATABRICKS_TOKEN else os.getenv("DATABRICKS_TOKEN") - http_path= args.http_path + server_hostname = args.DATABRICKS_HOST if args.DATABRICKS_HOST else os.getenv( + 'DATABRICKS_HOST') + access_token = args.DATABRICKS_TOKEN if args.DATABRICKS_TOKEN else os.getenv( + 'DATABRICKS_TOKEN') + http_path = args.http_path tablename = args.delta_table_name json_output_path = args.json_output_path try: connection = sql.connect( - server_hostname=server_hostname, - http_path=http_path, - access_token=access_token, - ) + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + ) except Exception as e: - raise RuntimeError("Failed to create sql connection to db workspace. Check {server_hostname} and {http_path} and access token!") from exc + raise RuntimeError( + 'Failed to create sql connection to db workspace. Check {server_hostname} and {http_path} and access token!' + ) from e cursor = connection.cursor() - cursor.execute(f"USE CATALOG main;") - cursor.execute(f"USE SCHEMA streaming;") - cursor.execute(f"SELECT COUNT(*) FROM {tablename}") + cursor.execute(f'USE CATALOG main;') + cursor.execute(f'USE SCHEMA streaming;') + cursor.execute(f'SELECT COUNT(*) FROM {tablename}') ans = cursor.fetchall() - total_rows = [ row.asDict() for row in ans ][0].popitem()[1] + total_rows = [row.asDict() for row in ans][0].popitem()[1] log.info(f'total_rows = {total_rows}') - cursor.execute(f"SHOW COLUMNS IN {tablename}") + cursor.execute(f'SHOW COLUMNS IN {tablename}') ans = cursor.fetchall() # Get the first column to order by. can be any column - order_by = [ row.asDict() for row in ans ][0].popitem()[1] + order_by = [row.asDict() for row in ans][0].popitem()[1] log.info(f'order by column {order_by}') for start in range(0, total_rows, batch_size): @@ -63,34 +70,58 @@ def stream_delta_to_json(args: argparse.Namespace, cursor.execute(query) ans = cursor.fetchall() - result = [ row.asDict() for row in ans ] + result = [row.asDict() for row in ans] df = pd.DataFrame.from_dict(result) - df.to_json(os.path.join(json_output_path, f'shard_{start+1}_{end}.json')) + df.to_json(os.path.join(json_output_path, + f'shard_{start+1}_{end}.json')) cursor.close() connection.close() -if __name__ == "__main__": - log.info(f"Start .... Convert delta to json") - parser = argparse.ArgumentParser(description="Download delta table from UC and convert to json to save local") - parser.add_argument("--delta_table_name", required=True, type=str, help="UC table of format ..
") - parser.add_argument("--json_output_path", required=True, type=str, help="Local path to save the converted json") - parser.add_argument("--DATABRICKS_HOST", required=False, type=str, help="DATABRICKS_HOST") - parser.add_argument("--DATABRICKS_TOKEN", required=False, type=str, help="DATABRICKS_TOKEN") - parser.add_argument("--http_path", required=True, type=str, help="http_path from either dedicated cluster or serverless sql warehouse") - parser.add_argument("--debug", type=bool, required=False, default=False) +if __name__ == '__main__': + log.info(f'Start .... Convert delta to json') + parser = argparse.ArgumentParser( + description= + 'Download delta table from UC and convert to json to save local') + parser.add_argument( + '--delta_table_name', + required=True, + type=str, + help='UC table of format ..
') + parser.add_argument('--json_output_path', + required=True, + type=str, + help='Local path to save the converted json') + parser.add_argument('--DATABRICKS_HOST', + required=False, + type=str, + help='DATABRICKS_HOST') + parser.add_argument('--DATABRICKS_TOKEN', + required=False, + type=str, + help='DATABRICKS_TOKEN') + parser.add_argument( + '--http_path', + required=True, + type=str, + help= + 'http_path from either dedicated cluster or serverless sql warehouse') + parser.add_argument('--debug', type=bool, required=False, default=False) args = parser.parse_args() if os.path.exists(args.json_output_path): - if not os.path.isdir(args.json_output_path) or os.listdir(args.json_output_path): - raise RuntimeError(f"A file or a folder {args.json_output_path} already exists and is not empty. Remove it and retry!") + if not os.path.isdir(args.json_output_path) or os.listdir( + args.json_output_path): + raise RuntimeError( + f'A file or a folder {args.json_output_path} already exists and is not empty. Remove it and retry!' + ) os.makedirs(args.json_output_path, exist_ok=True) - log.info(f"Directory {args.json_output_path} created.") + log.info(f'Directory {args.json_output_path} created.') stream_delta_to_json(args) - print(f"Convert delta to json is done. check {args.json_output_path}.") - log.info(f"Convert delta to json is done. check {args.json_output_path}.") + print(f'Convert delta to json is done. check {args.json_output_path}.') + log.info(f'Convert delta to json is done. check {args.json_output_path}.') diff --git a/tests/test_convert_delta_to_json.py b/tests/test_convert_delta_to_json.py index 2280093f5c..6773e2e5d8 100644 --- a/tests/test_convert_delta_to_json.py +++ b/tests/test_convert_delta_to_json.py @@ -1,19 +1,24 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + # copyright 2022 mosaicml llm foundry authors # spdx-license-identifier: apache-2.0 -import pytest import os import sys -import warnings + +import pytest # Add repo root to path so we can import scripts and test it repo_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) sys.path.append(repo_dir) import unittest -from unittest.mock import patch, MagicMock, mock_open +from unittest.mock import MagicMock, mock_open, patch + from scripts.data_prep.convert_delta_to_json import stream_delta_to_json + class TestStreamDeltaToJson(unittest.TestCase): @patch('scripts.data_prep.convert_delta_to_json.sql.connect') @@ -36,18 +41,27 @@ def test_stream_delta_to_json(self, mock_to_json, mock_connect): count_response = MagicMock() count_response.asDict.return_value = {'COUNT(*)': 3} column_response_item = MagicMock() - column_response_item.asDict.return_value = {'COLUMN_NAME': 'name'} # Assuming SHOW COLUMNS query returns this format + column_response_item.asDict.return_value = { + 'COLUMN_NAME': 'name' + } # Assuming SHOW COLUMNS query returns this format data_response_item = MagicMock() - data_response_item.asDict.return_value = {'name': 'test', 'id': 1} # Assuming SELECT query returns this format - mock_cursor.fetchall.side_effect = [[count_response], [column_response_item], [data_response_item]] + data_response_item.asDict.return_value = { + 'name': 'test', + 'id': 1 + } # Assuming SELECT query returns this format + mock_cursor.fetchall.side_effect = [[count_response], + [column_response_item], + [data_response_item]] stream_delta_to_json(mock_args) - mock_connect.assert_called_once_with(server_hostname='test_host', http_path='test_http_path', access_token='test_token') + mock_connect.assert_called_once_with(server_hostname='test_host', + http_path='test_http_path', + access_token='test_token') mock_to_json.assert_called() mock_cursor.close.assert_called() mock_connection.close.assert_called() + if __name__ == '__main__': unittest.main() - From 8a45576fb4793c9138379f01cea0dcb1299efc3b Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Fri, 1 Dec 2023 23:07:50 -0800 Subject: [PATCH 08/62] update --- tests/test_convert_delta_to_json.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/tests/test_convert_delta_to_json.py b/tests/test_convert_delta_to_json.py index 6773e2e5d8..a9bc00e7e2 100644 --- a/tests/test_convert_delta_to_json.py +++ b/tests/test_convert_delta_to_json.py @@ -6,24 +6,21 @@ import os import sys - -import pytest +from typing import Any # Add repo root to path so we can import scripts and test it repo_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) sys.path.append(repo_dir) import unittest -from unittest.mock import MagicMock, mock_open, patch - +from unittest.mock import MagicMock, patch from scripts.data_prep.convert_delta_to_json import stream_delta_to_json - class TestStreamDeltaToJson(unittest.TestCase): @patch('scripts.data_prep.convert_delta_to_json.sql.connect') @patch('scripts.data_prep.convert_delta_to_json.pd.DataFrame.to_json') - def test_stream_delta_to_json(self, mock_to_json, mock_connect): + def test_stream_delta_to_json(self, mock_to_json:Any, mock_connect:Any): mock_args = MagicMock() mock_args.DATABRICKS_HOST = 'test_host' mock_args.DATABRICKS_TOKEN = 'test_token' From b452ec30ac4278003e32107b8d136225da3ee172 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Fri, 1 Dec 2023 23:22:13 -0800 Subject: [PATCH 09/62] update --- tests/test_convert_delta_to_json.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_convert_delta_to_json.py b/tests/test_convert_delta_to_json.py index a9bc00e7e2..bd67d2e1c1 100644 --- a/tests/test_convert_delta_to_json.py +++ b/tests/test_convert_delta_to_json.py @@ -14,13 +14,15 @@ import unittest from unittest.mock import MagicMock, patch + from scripts.data_prep.convert_delta_to_json import stream_delta_to_json + class TestStreamDeltaToJson(unittest.TestCase): @patch('scripts.data_prep.convert_delta_to_json.sql.connect') @patch('scripts.data_prep.convert_delta_to_json.pd.DataFrame.to_json') - def test_stream_delta_to_json(self, mock_to_json:Any, mock_connect:Any): + def test_stream_delta_to_json(self, mock_to_json: Any, mock_connect: Any): mock_args = MagicMock() mock_args.DATABRICKS_HOST = 'test_host' mock_args.DATABRICKS_TOKEN = 'test_token' From c902b67f13db7a9e8be42735c9ffa4ceac50a63d Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Mon, 4 Dec 2023 14:49:44 -0800 Subject: [PATCH 10/62] restructure code --- scripts/data_prep/convert_delta_to_json.py | 50 +++++++++++-------- .../data_prep}/test_convert_delta_to_json.py | 21 ++++---- 2 files changed, 37 insertions(+), 34 deletions(-) rename tests/{ => a_scripts/data_prep}/test_convert_delta_to_json.py (78%) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index 3f98049ceb..c82e930580 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -11,19 +11,29 @@ log = logging.getLogger(__name__) -def stream_delta_to_json(args: argparse.Namespace, batch_size: int = 1 << 20): +def stream_delta_to_json(server_hostname: str, + access_token: str, + http_path: str, + tablename: str, + json_output_path: str, + batch_size: int = 1 << 20): """Read UC delta table and convert it to json. Save json files to local. In the case of table has more than batch_size rows, read the table batch_size rows a time """ - server_hostname = args.DATABRICKS_HOST if args.DATABRICKS_HOST else os.getenv( - 'DATABRICKS_HOST') - access_token = args.DATABRICKS_TOKEN if args.DATABRICKS_TOKEN else os.getenv( - 'DATABRICKS_TOKEN') - http_path = args.http_path - tablename = args.delta_table_name - json_output_path = args.json_output_path + log.info(f'Start .... Convert delta to json') + + if os.path.exists(json_output_path): + if not os.path.isdir(json_output_path) or os.listdir( + json_output_path): + raise RuntimeError( + f'A file or a folder {json_output_path} already exists and is not empty. Remove it and retry!' + ) + + os.makedirs(json_output_path, exist_ok=True) + + log.info(f'Directory {json_output_path} created.') try: connection = sql.connect( @@ -78,9 +88,10 @@ def stream_delta_to_json(args: argparse.Namespace, batch_size: int = 1 << 20): cursor.close() connection.close() + print(f'Convert delta to json is done. check {json_output_path}.') + log.info(f'Convert delta to json is done. check {json_output_path}.') if __name__ == '__main__': - log.info(f'Start .... Convert delta to json') parser = argparse.ArgumentParser( description= 'Download delta table from UC and convert to json to save local') @@ -110,18 +121,13 @@ def stream_delta_to_json(args: argparse.Namespace, batch_size: int = 1 << 20): parser.add_argument('--debug', type=bool, required=False, default=False) args = parser.parse_args() - if os.path.exists(args.json_output_path): - if not os.path.isdir(args.json_output_path) or os.listdir( - args.json_output_path): - raise RuntimeError( - f'A file or a folder {args.json_output_path} already exists and is not empty. Remove it and retry!' - ) - - os.makedirs(args.json_output_path, exist_ok=True) - - log.info(f'Directory {args.json_output_path} created.') + server_hostname = args.DATABRICKS_HOST if args.DATABRICKS_HOST is not None else os.getenv( + 'DATABRICKS_HOST') + access_token = args.DATABRICKS_TOKEN if args.DATABRICKS_TOKEN is not None else os.getenv( + 'DATABRICKS_TOKEN') + http_path = args.http_path + tablename = args.delta_table_name + json_output_path = args.json_output_path - stream_delta_to_json(args) + stream_delta_to_json(server_hostname, access_token, http_path, tablename, json_output_path) - print(f'Convert delta to json is done. check {args.json_output_path}.') - log.info(f'Convert delta to json is done. check {args.json_output_path}.') diff --git a/tests/test_convert_delta_to_json.py b/tests/a_scripts/data_prep/test_convert_delta_to_json.py similarity index 78% rename from tests/test_convert_delta_to_json.py rename to tests/a_scripts/data_prep/test_convert_delta_to_json.py index bd67d2e1c1..f759913d47 100644 --- a/tests/test_convert_delta_to_json.py +++ b/tests/a_scripts/data_prep/test_convert_delta_to_json.py @@ -8,27 +8,17 @@ import sys from typing import Any -# Add repo root to path so we can import scripts and test it -repo_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) -sys.path.append(repo_dir) - import unittest from unittest.mock import MagicMock, patch from scripts.data_prep.convert_delta_to_json import stream_delta_to_json -class TestStreamDeltaToJson(unittest.TestCase): +class TestStreamDeltaToJson(): @patch('scripts.data_prep.convert_delta_to_json.sql.connect') @patch('scripts.data_prep.convert_delta_to_json.pd.DataFrame.to_json') def test_stream_delta_to_json(self, mock_to_json: Any, mock_connect: Any): - mock_args = MagicMock() - mock_args.DATABRICKS_HOST = 'test_host' - mock_args.DATABRICKS_TOKEN = 'test_token' - mock_args.http_path = 'test_http_path' - mock_args.delta_table_name = 'test_table' - mock_args.json_output_path = 'test_output_path' # Mock database connection and cursor mock_cursor = MagicMock() @@ -52,7 +42,14 @@ def test_stream_delta_to_json(self, mock_to_json: Any, mock_connect: Any): [column_response_item], [data_response_item]] - stream_delta_to_json(mock_args) + stream_delta_to_json( + server_hostname = 'test_host', + access_token = 'test_token', + http_path = 'test_http_path', + tablename = 'test_table', + json_output_path = 'test_output_path' + ) + mock_connect.assert_called_once_with(server_hostname='test_host', http_path='test_http_path', From f57c21ce05d36be798b80a1f16fccd738646b9eb Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Mon, 4 Dec 2023 14:54:02 -0800 Subject: [PATCH 11/62] Add timer for optimizing --- scripts/data_prep/convert_delta_to_json.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index c82e930580..d89ab84f1a 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -4,6 +4,7 @@ import argparse import logging import os +import time import pandas as pd from databricks import sql @@ -129,5 +130,10 @@ def stream_delta_to_json(server_hostname: str, tablename = args.delta_table_name json_output_path = args.json_output_path + tik = time.time() + print("start timer", tik) + stream_delta_to_json(server_hostname, access_token, http_path, tablename, json_output_path) + print("end timer", time.time() - tik) + From 7ff9e5a898c5bf5080245c1e7573393bba27c2c8 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Tue, 5 Dec 2023 10:38:08 -0800 Subject: [PATCH 12/62] Add db-connect --- scripts/data_prep/convert_delta_to_json.py | 106 +++++++++++++++++---- 1 file changed, 89 insertions(+), 17 deletions(-) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index d89ab84f1a..ce6467acd6 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -11,17 +11,89 @@ log = logging.getLogger(__name__) +def fetch_DT(*args: Any, **kwargs: Any): + r"""Fetch Delta Table from UC and save to local + + This can be called as + + ``` + fetch_DT(server_hostname: str, + access_token: str, + tablename: str, + json_output_path: str, + batch_size: int = 1 << 20) + or + + fetch_DT(server_hostname: str, + access_token: str, + http_path: str, + tablename: str, + json_output_path: str, + batch_size: int = 1 << 20) + ``` + Based on the arguments, the call is redirected to either fetch_DT_with_dbconnect or fetch_DT_with_dbsql + """ + if 'http_path' not in args and 'http_path' not in kwargs: + return fetch_DT_with_dbconnect(*args, **kwargs) + else: + return fetch_DT_with_dbsql(*args, **kargs) + +def fetch_DT_with_dbconnect(server_hostname: str, + access_token: str, + tablename: str, + json_output_path: str, + batch_size: int = 1 << 20): + """Fetch UC delta table with databricks-connnect and convert them to json. + In the case when table is very large, we fetch batch_size rows a time. + Compared to fetch_DT_with_dbsql, this function does not need http_path. + """ + from databricks.connect import DatabricksSession + from uuid import uuid4 + + session_id = str(uuid4()) + spark = DatabricksSession.builder.host("https://e2-dogfood.staging.cloud.databricks.com/").token("TOKEN").header("x-databricks-session-id", session_id).getOrCreate() + + try: + ans = spark.sql(f"SELECT COUNT(*) FROM {tablename}").collect() + total_rows = [row.asDict() for row in ans][0].popitem()[1] + + ans = spark.sql(f"SHOW COLUMNS IN {tablename}").collect() + order_by = [row.asDict() for row in ans][0].popitem()[1] + + log.info(f'total_rows = {total_rows}') + log.info(f'order by column {order_by}') + except e: + raise RuntimeError(f"Error in get total rows / columns from {tablename}. Restart sparksession and try again") from e + + for start in range(0, total_rows, batch_size): + end = min(start + batch_size, total_rows) + + query = f""" + WITH NumberedRows AS ( + SELECT + *, + ROW_NUMBER() OVER (ORDER BY {order_by}) AS rn + FROM + {tablename} + ) + SELECT * + FROM NumberedRows + WHERE rn BETWEEN {start+1} AND {end}""" + + ans = spark.sql(query).collect() + df = spark.createDataFrame(ans).collect() + shard = os.path.join(json_output_path, f'shard_{start+1}_{end}.json') + shard.write.format('json').mode('overwrite').option('header', 'true').save('/tmp/new') -def stream_delta_to_json(server_hostname: str, - access_token: str, - http_path: str, - tablename: str, - json_output_path: str, - batch_size: int = 1 << 20): - """Read UC delta table and convert it to json. - Save json files to local. In the case of table has more than batch_size - rows, read the table batch_size rows a time +def fetch_DT_with_dbsql(server_hostname: str, + access_token: str, + http_path: str, + tablename: str, + json_output_path: str, + batch_size: int = 1 << 20): + """Fetch UC delta table locally as dataframes and convert them to json. + In the case when table is very large, we fetch batch_size rows a time. """ log.info(f'Start .... Convert delta to json') @@ -122,18 +194,18 @@ def stream_delta_to_json(server_hostname: str, parser.add_argument('--debug', type=bool, required=False, default=False) args = parser.parse_args() - server_hostname = args.DATABRICKS_HOST if args.DATABRICKS_HOST is not None else os.getenv( - 'DATABRICKS_HOST') - access_token = args.DATABRICKS_TOKEN if args.DATABRICKS_TOKEN is not None else os.getenv( - 'DATABRICKS_TOKEN') - http_path = args.http_path - tablename = args.delta_table_name - json_output_path = args.json_output_path + #server_hostname = args.DATABRICKS_HOST if args.DATABRICKS_HOST is not None else os.getenv( + # 'DATABRICKS_HOST') + #access_token = args.DATABRICKS_TOKEN if args.DATABRICKS_TOKEN is not None else os.getenv( + # 'DATABRICKS_TOKEN') + #http_path = args.http_path + #tablename = args.delta_table_name + #json_output_path = args.json_output_path tik = time.time() print("start timer", tik) - stream_delta_to_json(server_hostname, access_token, http_path, tablename, json_output_path) + fetch_DT(*args) print("end timer", time.time() - tik) From 6d17cefb2fdf507f44c8517f36efe33f23677f26 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Tue, 5 Dec 2023 12:06:54 -0800 Subject: [PATCH 13/62] add wrapper --- scripts/data_prep/convert_delta_to_json.py | 376 ++++++++++++++------- 1 file changed, 255 insertions(+), 121 deletions(-) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index ce6467acd6..ccd971fa70 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -8,62 +8,62 @@ import pandas as pd from databricks import sql +from typing import Any, Optional, List +from databricks.connect import DatabricksSession +from uuid import uuid4 +from pyspark.sql.types import Row log = logging.getLogger(__name__) -def fetch_DT(*args: Any, **kwargs: Any): - r"""Fetch Delta Table from UC and save to local - - This can be called as - - ``` - fetch_DT(server_hostname: str, - access_token: str, - tablename: str, - json_output_path: str, - batch_size: int = 1 << 20) - or - - fetch_DT(server_hostname: str, - access_token: str, - http_path: str, - tablename: str, - json_output_path: str, - batch_size: int = 1 << 20) - ``` - Based on the arguments, the call is redirected to either fetch_DT_with_dbconnect or fetch_DT_with_dbsql - """ - if 'http_path' not in args and 'http_path' not in kwargs: - return fetch_DT_with_dbconnect(*args, **kwargs) - else: - return fetch_DT_with_dbsql(*args, **kargs) - -def fetch_DT_with_dbconnect(server_hostname: str, - access_token: str, - tablename: str, - json_output_path: str, - batch_size: int = 1 << 20): +def run_query(q:str, method:str, cursor=None, spark=None, collect=True) -> Optional[List[Row]]: + if not q: + return + + if method == 'dbsql': + if cursor is None: + raise ValueError(f"cursor cannot be None if using method dbsql") + cursor.execute(q) + if collect: + return cursor.fetchall() + + if method == 'dbconnect': + if spark == None: + raise ValueError(f"sparksession is required for dbconnect") + df = spark.sql(q) + if collect: + return df.collect() + return df + + return None + +def fetch(method, + tablename: str, + json_output_path: str, + batch_size: int = 1 << 20, + sparksession = None, + dbsql = None, + ): """Fetch UC delta table with databricks-connnect and convert them to json. In the case when table is very large, we fetch batch_size rows a time. Compared to fetch_DT_with_dbsql, this function does not need http_path. """ - from databricks.connect import DatabricksSession - from uuid import uuid4 - - session_id = str(uuid4()) - spark = DatabricksSession.builder.host("https://e2-dogfood.staging.cloud.databricks.com/").token("TOKEN").header("x-databricks-session-id", session_id).getOrCreate() + cursor = dbsql.cursor() if dbsql is not None else None try: - ans = spark.sql(f"SELECT COUNT(*) FROM {tablename}").collect() + ans = run_query(f"SELECT COUNT(*) FROM {tablename}", method, cursor, sparksession) total_rows = [row.asDict() for row in ans][0].popitem()[1] - - ans = spark.sql(f"SHOW COLUMNS IN {tablename}").collect() - order_by = [row.asDict() for row in ans][0].popitem()[1] - log.info(f'total_rows = {total_rows}') + except Exception as e: + raise RuntimeError(f"Error in get total rows from {tablename}. Restart sparksession and try again") from e + + try: + ans = run_query(f"SHOW COLUMNS IN {tablename}", method, cursor, sparksession) + columns = [row.asDict().popitem()[1] for row in ans] + order_by = columns[0] + columns_str = ','.join(columns) log.info(f'order by column {order_by}') - except e: - raise RuntimeError(f"Error in get total rows / columns from {tablename}. Restart sparksession and try again") from e + except Exception as e: + raise RuntimeError(f"Error in get columns from {tablename}. Restart sparksession and try again") from e for start in range(0, total_rows, batch_size): end = min(start + batch_size, total_rows) @@ -76,93 +76,222 @@ def fetch_DT_with_dbconnect(server_hostname: str, FROM {tablename} ) - SELECT * + SELECT {columns_str} FROM NumberedRows WHERE rn BETWEEN {start+1} AND {end}""" - ans = spark.sql(query).collect() - df = spark.createDataFrame(ans).collect() - shard = os.path.join(json_output_path, f'shard_{start+1}_{end}.json') - shard.write.format('json').mode('overwrite').option('header', 'true').save('/tmp/new') + if method == 'dbconnect': + pdf = run_query(query, method, cursor, sparksession, collect=False).toPandas() + elif method == 'dbsql': + ans = run_query(query, method, cursor, sparksession, collect=True) + pdf = pd.DataFrame.from_dict([row.asDict() for row in ans]) + + pdf.to_json(os.path.join(json_output_path, + f'part_{start+1}_{end}.json')) + + if cursor is not None: + cursor.close() + +#def fetch_DT_with_dbconnect(server_hostname: str, +# access_token: str, +# tablename: str, +# json_output_path: str, +# batch_size: int = 1 << 20): +# """Fetch UC delta table with databricks-connnect and convert them to json. +# In the case when table is very large, we fetch batch_size rows a time. +# Compared to fetch_DT_with_dbsql, this function does not need http_path. +# """ +# from databricks.connect import DatabricksSession +# from uuid import uuid4 +# +# session_id = str(uuid4()) +# spark = DatabricksSession.builder.host(server_hostname).token(access_token).header("x-databricks-session-id", session_id).getOrCreate() +# +# try: +# ans = spark.sql(f"SELECT COUNT(*) FROM {tablename}").collect() +# total_rows = [row.asDict() for row in ans][0].popitem()[1] +# log.info(f'total_rows = {total_rows}') +# except Exception as e: +# raise RuntimeError(f"Error in get total rows from {tablename}. Restart sparksession and try again") from e +# +# try: +# ans = spark.sql(f"SHOW COLUMNS IN {tablename}").collect() +# columns = [row.asDict().popitem()[1] for row in ans] +# order_by = columns[0] +# columns_str = ','.join(columns) +# log.info(f'order by column {order_by}') +# except Exception as e: +# raise RuntimeError(f"Error in get columns from {tablename}. Restart sparksession and try again") from e +# +# for start in range(0, total_rows, batch_size): +# end = min(start + batch_size, total_rows) +# +# query = f""" +# WITH NumberedRows AS ( +# SELECT +# *, +# ROW_NUMBER() OVER (ORDER BY {order_by}) AS rn +# FROM +# {tablename} +# ) +# SELECT {columns_str} +# FROM NumberedRows +# WHERE rn BETWEEN {start+1} AND {end}""" +# +# df = spark.sql(query) +# pdf = df.toPandas() +# shard_file = os.path.join(json_output_path, f'shard_{start+1}_{end}.json') +# pdf.to_json(shard_file) +# +# +#def fetch_DT_with_dbsql(server_hostname: str, +# access_token: str, +# http_path: str, +# tablename: str, +# json_output_path: str, +# batch_size: int = 1 << 20): +# print("use dbsql") +# """Fetch UC delta table locally as dataframes and convert them to json. +# In the case when table is very large, we fetch batch_size rows a time. +# """ +# log.info(f'Start .... Convert delta to json') +# +# try: +# connection = sql.connect( +# server_hostname=server_hostname, +# http_path=http_path, +# access_token=access_token, +# ) +# except Exception as e: +# raise RuntimeError( +# 'Failed to create sql connection to db workspace. Check {server_hostname} and {http_path} and access token!' +# ) from e +# +# cursor = connection.cursor() +# cursor.execute(f'USE CATALOG main;') +# cursor.execute(f'USE SCHEMA streaming;') +# cursor.execute(f'SELECT COUNT(*) FROM {tablename}') +# ans = cursor.fetchall() +# +# total_rows = [row.asDict() for row in ans][0].popitem()[1] +# log.info(f'total_rows = {total_rows}') +# +# cursor.execute(f'SHOW COLUMNS IN {tablename}') +# ans = cursor.fetchall() +# +# # Get the first column to order by. can be any column +# columns = [row.asDict().popitem()[1] for row in ans] +# order_by = columns[0] +# columns_str = ','.join(columns) +# log.info(f'order by column {order_by}') +# print(order_by, columns_str) +# +# for start in range(0, total_rows, batch_size): +# end = min(start + batch_size, total_rows) +# +# query = f""" +# WITH NumberedRows AS ( +# SELECT +# *, +# ROW_NUMBER() OVER (ORDER BY {order_by}) AS rn +# FROM +# {tablename} +# ) +# SELECT {columns_str} +# FROM NumberedRows +# WHERE rn BETWEEN {start+1} AND {end}""" +# +# cursor.execute(query) +# ans = cursor.fetchall() +# +# result = [row.asDict() for row in ans] +# df = pd.DataFrame.from_dict(result) +# df.to_json(os.path.join(json_output_path, +# f'shard_{start+1}_{end}.json')) +# +# cursor.close() +# connection.close() +# +# print(f'Convert delta to json is done. check {json_output_path}.') +# log.info(f'Convert delta to json is done. check {json_output_path}.') -def fetch_DT_with_dbsql(server_hostname: str, - access_token: str, - http_path: str, - tablename: str, - json_output_path: str, - batch_size: int = 1 << 20): - """Fetch UC delta table locally as dataframes and convert them to json. - In the case when table is very large, we fetch batch_size rows a time. +def fetch_DT(*args: Any, **kwargs: Any): + r"""Fetch Delta Table from UC and save to local + + This can be called as + + ``` + fetch_DT(server_hostname: str, + access_token: str, + tablename: str, + json_output_path: str, + batch_size: int = 1 << 20) + or + + fetch_DT(server_hostname: str, + access_token: str, + http_path: str, + tablename: str, + json_output_path: str, + batch_size: int = 1 << 20) + ``` + Based on the arguments, the call is redirected to either fetch_DT_with_dbconnect or fetch_DT_with_dbsql """ + print('args = ', args) + print(type(args)) + args = args[0] log.info(f'Start .... Convert delta to json') - if os.path.exists(json_output_path): - if not os.path.isdir(json_output_path) or os.listdir( - json_output_path): + if os.path.exists(args.json_output_path): + if not os.path.isdir(args.json_output_path) or os.listdir( + args.json_output_path): raise RuntimeError( - f'A file or a folder {json_output_path} already exists and is not empty. Remove it and retry!' + f'A file or a folder {args.json_output_path} already exists and is not empty. Remove it and retry!' ) - os.makedirs(json_output_path, exist_ok=True) + os.makedirs(args.json_output_path, exist_ok=True) - log.info(f'Directory {json_output_path} created.') + log.info(f'Directory {args.json_output_path} created.') - try: - connection = sql.connect( - server_hostname=server_hostname, - http_path=http_path, - access_token=access_token, - ) - except Exception as e: - raise RuntimeError( - 'Failed to create sql connection to db workspace. Check {server_hostname} and {http_path} and access token!' - ) from e - - cursor = connection.cursor() - cursor.execute(f'USE CATALOG main;') - cursor.execute(f'USE SCHEMA streaming;') - cursor.execute(f'SELECT COUNT(*) FROM {tablename}') - ans = cursor.fetchall() - - total_rows = [row.asDict() for row in ans][0].popitem()[1] - log.info(f'total_rows = {total_rows}') - - cursor.execute(f'SHOW COLUMNS IN {tablename}') - ans = cursor.fetchall() - - # Get the first column to order by. can be any column - order_by = [row.asDict() for row in ans][0].popitem()[1] - log.info(f'order by column {order_by}') - - for start in range(0, total_rows, batch_size): - end = min(start + batch_size, total_rows) + method = '' + dbsql = None + sparkSession = None - query = f""" - WITH NumberedRows AS ( - SELECT - *, - ROW_NUMBER() OVER (ORDER BY {order_by}) AS rn - FROM - {tablename} - ) - SELECT * - FROM NumberedRows - WHERE rn BETWEEN {start+1} AND {end}""" - - cursor.execute(query) - ans = cursor.fetchall() + if hasattr(args, 'http_path') and args.http_path: + method = 'dbsql' + try: + dbsql = sql.connect( + server_hostname=args.DATABRICKS_HOST, + http_path=args.http_path, + access_token=args.DATABRICKS_TOKEN, + ) + except Exception as e: + raise RuntimeError( + 'Failed to create sql connection to db workspace. Check {server_hostname} and {http_path} and access token!' + ) from e + + #return fetch_DT_with_dbsql(args.DATABRICKS_HOST, + # args.DATABRICKS_TOKEN, + # args.http_path, + # args.delta_table_name, + # args.json_output_path, + # args.batch_size) + else: + method = 'dbconnect' + session_id = str(uuid4()) + sparkSession = DatabricksSession.builder.host(args.DATABRICKS_HOST).token(args.DATABRICKS_TOKEN).header("x-databricks-session-id", session_id).getOrCreate() - result = [row.asDict() for row in ans] - df = pd.DataFrame.from_dict(result) - df.to_json(os.path.join(json_output_path, - f'shard_{start+1}_{end}.json')) + #return fetch_DT_with_dbconnect(args.DATABRICKS_HOST, + # args.DATABRICKS_TOKEN, + # args.delta_table_name, + # args.json_output_path, + # args.batch_size) - cursor.close() - connection.close() + fetch(method, args.delta_table_name, args.json_output_path, args.batch_size, sparkSession, dbsql) - print(f'Convert delta to json is done. check {json_output_path}.') - log.info(f'Convert delta to json is done. check {json_output_path}.') + if dbsql is not None: + dbsql.close() if __name__ == '__main__': parser = argparse.ArgumentParser( @@ -185,12 +314,17 @@ def fetch_DT_with_dbsql(server_hostname: str, required=False, type=str, help='DATABRICKS_TOKEN') - parser.add_argument( - '--http_path', - required=True, - type=str, - help= - 'http_path from either dedicated cluster or serverless sql warehouse') + parser.add_argument('--http_path', + required=False, + type=str, + help= + 'http_path from either dedicated cluster or serverless sql warehouse') + parser.add_argument('--batch_size', + required=False, + type=int, + default=1<<20, + help= + 'chunk of rows to transmit a time') parser.add_argument('--debug', type=bool, required=False, default=False) args = parser.parse_args() @@ -205,7 +339,7 @@ def fetch_DT_with_dbsql(server_hostname: str, tik = time.time() print("start timer", tik) - fetch_DT(*args) + fetch_DT(args) print("end timer", time.time() - tik) From 8116804dd0bdc17bc359958fc33de1d380347393 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Tue, 5 Dec 2023 12:45:22 -0800 Subject: [PATCH 14/62] update --- scripts/data_prep/convert_delta_to_json.py | 199 ++------------------- 1 file changed, 19 insertions(+), 180 deletions(-) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index ccd971fa70..03d1d9ef67 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -12,6 +12,7 @@ from databricks.connect import DatabricksSession from uuid import uuid4 from pyspark.sql.types import Row +import concurrent.futures log = logging.getLogger(__name__) @@ -28,7 +29,7 @@ def run_query(q:str, method:str, cursor=None, spark=None, collect=True) -> Optio if method == 'dbconnect': if spark == None: - raise ValueError(f"sparksession is required for dbconnect") + raise ValueError(f"sparkSession is required for dbconnect") df = spark.sql(q) if collect: return df.collect() @@ -40,7 +41,7 @@ def fetch(method, tablename: str, json_output_path: str, batch_size: int = 1 << 20, - sparksession = None, + sparkSession = None, dbsql = None, ): """Fetch UC delta table with databricks-connnect and convert them to json. @@ -50,24 +51,22 @@ def fetch(method, cursor = dbsql.cursor() if dbsql is not None else None try: - ans = run_query(f"SELECT COUNT(*) FROM {tablename}", method, cursor, sparksession) + ans = run_query(f"SELECT COUNT(*) FROM {tablename}", method, cursor, sparkSession) total_rows = [row.asDict() for row in ans][0].popitem()[1] log.info(f'total_rows = {total_rows}') except Exception as e: - raise RuntimeError(f"Error in get total rows from {tablename}. Restart sparksession and try again") from e + raise RuntimeError(f"Error in get total rows from {tablename}. Restart sparkSession and try again") from e try: - ans = run_query(f"SHOW COLUMNS IN {tablename}", method, cursor, sparksession) + ans = run_query(f"SHOW COLUMNS IN {tablename}", method, cursor, sparkSession) columns = [row.asDict().popitem()[1] for row in ans] order_by = columns[0] columns_str = ','.join(columns) log.info(f'order by column {order_by}') except Exception as e: - raise RuntimeError(f"Error in get columns from {tablename}. Restart sparksession and try again") from e - - for start in range(0, total_rows, batch_size): - end = min(start + batch_size, total_rows) + raise RuntimeError(f"Error in get columns from {tablename}. Restart sparkSession and try again") from e + def fetch_data(s, e, order_by, tablename, json_output_path): query = f""" WITH NumberedRows AS ( SELECT @@ -81,165 +80,29 @@ def fetch(method, WHERE rn BETWEEN {start+1} AND {end}""" if method == 'dbconnect': - pdf = run_query(query, method, cursor, sparksession, collect=False).toPandas() + pdf = run_query(query, method, cursor, sparkSession, collect=False).toPandas() elif method == 'dbsql': - ans = run_query(query, method, cursor, sparksession, collect=True) + ans = run_query(query, method, cursor, sparkSession, collect=True) pdf = pd.DataFrame.from_dict([row.asDict() for row in ans]) pdf.to_json(os.path.join(json_output_path, - f'part_{start+1}_{end}.json')) + f'part_{s+1}_{e}.json')) + + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [] + for start in range(0, total_rows, batch_size): + end = min(start + batch_size, total_rows) + futures.append(executor.submit(fetch_data, start, end, order_by, tablename, json_output_path)) + if cursor is not None: cursor.close() -#def fetch_DT_with_dbconnect(server_hostname: str, -# access_token: str, -# tablename: str, -# json_output_path: str, -# batch_size: int = 1 << 20): -# """Fetch UC delta table with databricks-connnect and convert them to json. -# In the case when table is very large, we fetch batch_size rows a time. -# Compared to fetch_DT_with_dbsql, this function does not need http_path. -# """ -# from databricks.connect import DatabricksSession -# from uuid import uuid4 -# -# session_id = str(uuid4()) -# spark = DatabricksSession.builder.host(server_hostname).token(access_token).header("x-databricks-session-id", session_id).getOrCreate() -# -# try: -# ans = spark.sql(f"SELECT COUNT(*) FROM {tablename}").collect() -# total_rows = [row.asDict() for row in ans][0].popitem()[1] -# log.info(f'total_rows = {total_rows}') -# except Exception as e: -# raise RuntimeError(f"Error in get total rows from {tablename}. Restart sparksession and try again") from e -# -# try: -# ans = spark.sql(f"SHOW COLUMNS IN {tablename}").collect() -# columns = [row.asDict().popitem()[1] for row in ans] -# order_by = columns[0] -# columns_str = ','.join(columns) -# log.info(f'order by column {order_by}') -# except Exception as e: -# raise RuntimeError(f"Error in get columns from {tablename}. Restart sparksession and try again") from e -# -# for start in range(0, total_rows, batch_size): -# end = min(start + batch_size, total_rows) -# -# query = f""" -# WITH NumberedRows AS ( -# SELECT -# *, -# ROW_NUMBER() OVER (ORDER BY {order_by}) AS rn -# FROM -# {tablename} -# ) -# SELECT {columns_str} -# FROM NumberedRows -# WHERE rn BETWEEN {start+1} AND {end}""" -# -# df = spark.sql(query) -# pdf = df.toPandas() -# shard_file = os.path.join(json_output_path, f'shard_{start+1}_{end}.json') -# pdf.to_json(shard_file) -# -# -#def fetch_DT_with_dbsql(server_hostname: str, -# access_token: str, -# http_path: str, -# tablename: str, -# json_output_path: str, -# batch_size: int = 1 << 20): -# print("use dbsql") -# """Fetch UC delta table locally as dataframes and convert them to json. -# In the case when table is very large, we fetch batch_size rows a time. -# """ -# log.info(f'Start .... Convert delta to json') -# -# try: -# connection = sql.connect( -# server_hostname=server_hostname, -# http_path=http_path, -# access_token=access_token, -# ) -# except Exception as e: -# raise RuntimeError( -# 'Failed to create sql connection to db workspace. Check {server_hostname} and {http_path} and access token!' -# ) from e -# -# cursor = connection.cursor() -# cursor.execute(f'USE CATALOG main;') -# cursor.execute(f'USE SCHEMA streaming;') -# cursor.execute(f'SELECT COUNT(*) FROM {tablename}') -# ans = cursor.fetchall() -# -# total_rows = [row.asDict() for row in ans][0].popitem()[1] -# log.info(f'total_rows = {total_rows}') -# -# cursor.execute(f'SHOW COLUMNS IN {tablename}') -# ans = cursor.fetchall() -# -# # Get the first column to order by. can be any column -# columns = [row.asDict().popitem()[1] for row in ans] -# order_by = columns[0] -# columns_str = ','.join(columns) -# log.info(f'order by column {order_by}') -# print(order_by, columns_str) -# -# for start in range(0, total_rows, batch_size): -# end = min(start + batch_size, total_rows) -# -# query = f""" -# WITH NumberedRows AS ( -# SELECT -# *, -# ROW_NUMBER() OVER (ORDER BY {order_by}) AS rn -# FROM -# {tablename} -# ) -# SELECT {columns_str} -# FROM NumberedRows -# WHERE rn BETWEEN {start+1} AND {end}""" -# -# cursor.execute(query) -# ans = cursor.fetchall() -# -# result = [row.asDict() for row in ans] -# df = pd.DataFrame.from_dict(result) -# df.to_json(os.path.join(json_output_path, -# f'shard_{start+1}_{end}.json')) -# -# cursor.close() -# connection.close() -# -# print(f'Convert delta to json is done. check {json_output_path}.') -# log.info(f'Convert delta to json is done. check {json_output_path}.') - def fetch_DT(*args: Any, **kwargs: Any): r"""Fetch Delta Table from UC and save to local - - This can be called as - - ``` - fetch_DT(server_hostname: str, - access_token: str, - tablename: str, - json_output_path: str, - batch_size: int = 1 << 20) - or - - fetch_DT(server_hostname: str, - access_token: str, - http_path: str, - tablename: str, - json_output_path: str, - batch_size: int = 1 << 20) - ``` Based on the arguments, the call is redirected to either fetch_DT_with_dbconnect or fetch_DT_with_dbsql """ - print('args = ', args) - print(type(args)) args = args[0] log.info(f'Start .... Convert delta to json') @@ -270,24 +133,11 @@ def fetch_DT(*args: Any, **kwargs: Any): raise RuntimeError( 'Failed to create sql connection to db workspace. Check {server_hostname} and {http_path} and access token!' ) from e - - #return fetch_DT_with_dbsql(args.DATABRICKS_HOST, - # args.DATABRICKS_TOKEN, - # args.http_path, - # args.delta_table_name, - # args.json_output_path, - # args.batch_size) else: method = 'dbconnect' session_id = str(uuid4()) sparkSession = DatabricksSession.builder.host(args.DATABRICKS_HOST).token(args.DATABRICKS_TOKEN).header("x-databricks-session-id", session_id).getOrCreate() - #return fetch_DT_with_dbconnect(args.DATABRICKS_HOST, - # args.DATABRICKS_TOKEN, - # args.delta_table_name, - # args.json_output_path, - # args.batch_size) - fetch(method, args.delta_table_name, args.json_output_path, args.batch_size, sparkSession, dbsql) if dbsql is not None: @@ -328,18 +178,7 @@ def fetch_DT(*args: Any, **kwargs: Any): parser.add_argument('--debug', type=bool, required=False, default=False) args = parser.parse_args() - #server_hostname = args.DATABRICKS_HOST if args.DATABRICKS_HOST is not None else os.getenv( - # 'DATABRICKS_HOST') - #access_token = args.DATABRICKS_TOKEN if args.DATABRICKS_TOKEN is not None else os.getenv( - # 'DATABRICKS_TOKEN') - #http_path = args.http_path - #tablename = args.delta_table_name - #json_output_path = args.json_output_path - tik = time.time() - print("start timer", tik) - fetch_DT(args) - - print("end timer", time.time() - tik) + print("Elapsed time", time.time() - tik) From 92cfeb2b46ea55530ec9f0a4750b1029aca39792 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Tue, 5 Dec 2023 12:46:54 -0800 Subject: [PATCH 15/62] add install dbconnect --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index de1d14ccda..4ffbd593a2 100644 --- a/setup.py +++ b/setup.py @@ -86,6 +86,7 @@ extra_deps['databricks'] = [ 'mosaicml[databricks]>=0.17.2,<0.18', 'databricks-sql-connector>=3, <4', + 'databricks-connect==14.0.0' ] extra_deps['tensorboard'] = [ From fa9743b965955c718cb00eb04edee5a10ebeacd4 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Tue, 5 Dec 2023 15:55:05 -0800 Subject: [PATCH 16/62] update --- scripts/data_prep/convert_delta_to_json.py | 87 +++++++++++++++------- 1 file changed, 59 insertions(+), 28 deletions(-) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index 03d1d9ef67..178888de14 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -6,13 +6,16 @@ import os import time +import urllib.parse import pandas as pd from databricks import sql -from typing import Any, Optional, List +from typing import Any, Optional, List, Tuple from databricks.connect import DatabricksSession from uuid import uuid4 from pyspark.sql.types import Row import concurrent.futures +from multiprocessing import Pool +import subprocess log = logging.getLogger(__name__) @@ -37,10 +40,38 @@ def run_query(q:str, method:str, cursor=None, spark=None, collect=True) -> Optio return None + +def fetch_data_starargs(args: Tuple): + return fetch_data(*args) + +def fetch_data(method, cursor, sparkSession, s, e, order_by, tablename, columns_str, json_output_path): + query = f""" + WITH NumberedRows AS ( + SELECT + *, + ROW_NUMBER() OVER (ORDER BY {order_by}) AS rn + FROM + {tablename} + ) + SELECT {columns_str} + FROM NumberedRows + WHERE rn BETWEEN {s+1} AND {e}""" + + if method == 'dbconnect': + pdf = run_query(query, method, cursor, sparkSession, collect=False).toPandas() + elif method == 'dbsql': + ans = run_query(query, method, cursor, sparkSession, collect=True) + pdf = pd.DataFrame.from_dict([row.asDict() for row in ans]) + + pdf.to_json(os.path.join(json_output_path, + f'part_{s+1}_{e}.json')) + + def fetch(method, tablename: str, json_output_path: str, batch_size: int = 1 << 20, + processes = 1, sparkSession = None, dbsql = None, ): @@ -66,34 +97,24 @@ def fetch(method, except Exception as e: raise RuntimeError(f"Error in get columns from {tablename}. Restart sparkSession and try again") from e - def fetch_data(s, e, order_by, tablename, json_output_path): - query = f""" - WITH NumberedRows AS ( - SELECT - *, - ROW_NUMBER() OVER (ORDER BY {order_by}) AS rn - FROM - {tablename} - ) - SELECT {columns_str} - FROM NumberedRows - WHERE rn BETWEEN {start+1} AND {end}""" - - if method == 'dbconnect': - pdf = run_query(query, method, cursor, sparkSession, collect=False).toPandas() - elif method == 'dbsql': - ans = run_query(query, method, cursor, sparkSession, collect=True) - pdf = pd.DataFrame.from_dict([row.asDict() for row in ans]) - - pdf.to_json(os.path.join(json_output_path, - f'part_{s+1}_{e}.json')) - - with concurrent.futures.ThreadPoolExecutor() as executor: - futures = [] + obj = urllib.parse.urlparse(json_output_path) + + if method == 'dbconnect': + df = run_query(f"SELECT * FROM {tablename}", method, cursor, sparkSession, collect=False) + print('processes = ', processes) + + dbfs_cache = 'dbfs:/' + json_output_path.lstrip('/') + df.repartition(processes).write.mode("overwrite").json(dbfs_cache) + print(f"downloading from {dbfs_cache} to {json_output_path}") + subprocess.run(f"databricks fs cp -r {dbfs_cache} {json_output_path}", shell=True, capture_output=True, text=True) + subprocess.run(f"databricks fs rm -r {dbfs_cache}", shell=True, capture_output=True, text=True) + + elif method == 'dbsql': + ans = run_query(query, method, cursor, sparkSession, collect=True) + pdf = pd.DataFrame.from_dict([row.asDict() for row in ans]) for start in range(0, total_rows, batch_size): end = min(start + batch_size, total_rows) - futures.append(executor.submit(fetch_data, start, end, order_by, tablename, json_output_path)) - + fetch_data(method, cursor, sparkSession, start, end, order_by, tablename, columns_str, json_output_path) if cursor is not None: cursor.close() @@ -106,6 +127,10 @@ def fetch_DT(*args: Any, **kwargs: Any): args = args[0] log.info(f'Start .... Convert delta to json') + obj = urllib.parse.urlparse(args.json_output_path) + if obj.scheme != '': + raise ValueError(f"We don't support writing to remote yet in this script!") + if os.path.exists(args.json_output_path): if not os.path.isdir(args.json_output_path) or os.listdir( args.json_output_path): @@ -138,7 +163,7 @@ def fetch_DT(*args: Any, **kwargs: Any): session_id = str(uuid4()) sparkSession = DatabricksSession.builder.host(args.DATABRICKS_HOST).token(args.DATABRICKS_TOKEN).header("x-databricks-session-id", session_id).getOrCreate() - fetch(method, args.delta_table_name, args.json_output_path, args.batch_size, sparkSession, dbsql) + fetch(method, args.delta_table_name, args.json_output_path, args.batch_size, args.processes, sparkSession, dbsql) if dbsql is not None: dbsql.close() @@ -175,6 +200,12 @@ def fetch_DT(*args: Any, **kwargs: Any): default=1<<20, help= 'chunk of rows to transmit a time') + parser.add_argument('--processes', + required=False, + type=int, + default=1, + help= + 'number of processes allowed to use') parser.add_argument('--debug', type=bool, required=False, default=False) args = parser.parse_args() From 00c3f737035512023ac0aff20c06d363b6f1620f Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Tue, 5 Dec 2023 15:57:23 -0800 Subject: [PATCH 17/62] update --- scripts/data_prep/convert_delta_to_json.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index 178888de14..fbf43815a6 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -71,7 +71,7 @@ def fetch(method, tablename: str, json_output_path: str, batch_size: int = 1 << 20, - processes = 1, + partitions = 1, sparkSession = None, dbsql = None, ): @@ -101,10 +101,10 @@ def fetch(method, if method == 'dbconnect': df = run_query(f"SELECT * FROM {tablename}", method, cursor, sparkSession, collect=False) - print('processes = ', processes) + print('partitions = ', partitions) dbfs_cache = 'dbfs:/' + json_output_path.lstrip('/') - df.repartition(processes).write.mode("overwrite").json(dbfs_cache) + df.repartition(partitions).write.mode("overwrite").json(dbfs_cache) print(f"downloading from {dbfs_cache} to {json_output_path}") subprocess.run(f"databricks fs cp -r {dbfs_cache} {json_output_path}", shell=True, capture_output=True, text=True) subprocess.run(f"databricks fs rm -r {dbfs_cache}", shell=True, capture_output=True, text=True) @@ -163,7 +163,7 @@ def fetch_DT(*args: Any, **kwargs: Any): session_id = str(uuid4()) sparkSession = DatabricksSession.builder.host(args.DATABRICKS_HOST).token(args.DATABRICKS_TOKEN).header("x-databricks-session-id", session_id).getOrCreate() - fetch(method, args.delta_table_name, args.json_output_path, args.batch_size, args.processes, sparkSession, dbsql) + fetch(method, args.delta_table_name, args.json_output_path, args.batch_size, args.partitions, sparkSession, dbsql) if dbsql is not None: dbsql.close() @@ -200,12 +200,12 @@ def fetch_DT(*args: Any, **kwargs: Any): default=1<<20, help= 'chunk of rows to transmit a time') - parser.add_argument('--processes', + parser.add_argument('--partitions', required=False, type=int, default=1, help= - 'number of processes allowed to use') + 'number of partitions allowed to use') parser.add_argument('--debug', type=bool, required=False, default=False) args = parser.parse_args() From 7304c4c64ea0f67b162b01764fd4a49391aaacb0 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Fri, 8 Dec 2023 09:58:21 -0800 Subject: [PATCH 18/62] patch dbconnect to allow multiple return formats --- scripts/data_prep/convert_delta_to_json.py | 30 +++++++- scripts/data_prep/patch.py | 86 ++++++++++++++++++++++ 2 files changed, 113 insertions(+), 3 deletions(-) create mode 100644 scripts/data_prep/patch.py diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index fbf43815a6..2b3975a319 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -13,10 +13,14 @@ from databricks.connect import DatabricksSession from uuid import uuid4 from pyspark.sql.types import Row -import concurrent.futures +from concurrent.futures import ProcessPoolExecutor from multiprocessing import Pool import subprocess +import patch # Monkey Patching for SparkConnectClient +import requests + + log = logging.getLogger(__name__) def run_query(q:str, method:str, cursor=None, spark=None, collect=True) -> Optional[List[Row]]: @@ -41,6 +45,17 @@ def run_query(q:str, method:str, cursor=None, spark=None, collect=True) -> Optio return None +def get_args(signed, json_output_path): + for i, r in enumerate(signed): + yield (i, r.url, json_output_path) + +def download_json(i, url, json_output_path): + data =requests.get(url).json() + pd.DataFrame.from_dict(data).to_json(os.path.join(json_output_path, 'part_'+str(i)+'.json')) + +def download_json_starargs(args: Tuple): + return download_json(*args) + def fetch_data_starargs(args: Tuple): return fetch_data(*args) @@ -106,8 +121,16 @@ def fetch(method, dbfs_cache = 'dbfs:/' + json_output_path.lstrip('/') df.repartition(partitions).write.mode("overwrite").json(dbfs_cache) print(f"downloading from {dbfs_cache} to {json_output_path}") - subprocess.run(f"databricks fs cp -r {dbfs_cache} {json_output_path}", shell=True, capture_output=True, text=True) - subprocess.run(f"databricks fs rm -r {dbfs_cache}", shell=True, capture_output=True, text=True) + #subprocess.run(f"databricks fs cp -r {dbfs_cache} {json_output_path}", shell=True, capture_output=True, text=True) + #subprocess.run(f"databricks fs rm -r {dbfs_cache}", shell=True, capture_output=True, text=True) + signed, rows, overflow = df.collect_cf("json") + print(len(signed)) + print(rows) + print(overflow) + + args = get_args(signed, json_output_path) + with ProcessPoolExecutor(max_workers=partitions) as executor: + list(executor.map(download_json_starargs, args)) elif method == 'dbsql': ans = run_query(query, method, cursor, sparkSession, collect=True) @@ -120,6 +143,7 @@ def fetch(method, cursor.close() + def fetch_DT(*args: Any, **kwargs: Any): r"""Fetch Delta Table from UC and save to local Based on the arguments, the call is redirected to either fetch_DT_with_dbconnect or fetch_DT_with_dbsql diff --git a/scripts/data_prep/patch.py b/scripts/data_prep/patch.py new file mode 100644 index 0000000000..30299fc1bf --- /dev/null +++ b/scripts/data_prep/patch.py @@ -0,0 +1,86 @@ +# This file is a monkey patch on top of the DB Connect package that allows +# the client to fetch the results in different formats from the server. To be +# able to use the code make sure to first import this module before importing +# the DB Connect classes. +from typing import Tuple, List + +from pyspark.sql.connect.dataframe import DataFrame +from pyspark.sql.connect.client.core import SparkConnectClient +from pyspark.sql.connect.client.reattach import ExecutePlanResponseReattachableIterator + +# PB2 stuff +import pyspark.sql.connect.proto.cloud_pb2 as cloud_pb2 +import pyspark.sql.connect.proto as pb2 +import google.protobuf.any_pb2 as any_pb2 +from collections import namedtuple + +Result = namedtuple("Result", ["url", "row_count", "compressed_size", "uncompressed_size"]) + +# Monkey Patching for SparkConnectClient +def to_cf(self: SparkConnectClient, plan: pb2.Plan, type: str = "json"): + """ + Executes a given plan object and returns the results as cloud fetch + presigned URLS. It can handle the current outptu formats that are + supported by the server. + + In contrast to the regular API methods of the client, this method + does not return the schema and drops all other responses. + """ + req = self._execute_plan_request_with_metadata() + req.plan.CopyFrom(plan) + + # Add the request options + if type == "json": + format = cloud_pb2.ResultOptions.CloudOptions.FORMAT_JSON + elif type == "csv": + format = cloud_pb2.ResultOptions.CloudOptions.FORMAT_CSV + elif type == "arrow": + format = cloud_pb2.ResultOptions.CloudOptions.FORMAT_ARROW + else: + raise Exception("Invalid type") + + ro = cloud_pb2.ResultOptions( + type=cloud_pb2.ResultOptions.TYPE_CLOUD, + cloudOptions=cloud_pb2.ResultOptions.CloudOptions( + format=format, + useCompression=False, + )) + cloud_option = any_pb2.Any() + cloud_option.Pack(ro) + req.request_options.append(pb2.ExecutePlanRequest.RequestOption(extension=cloud_option)) + + # Create the iterator + iterator = ExecutePlanResponseReattachableIterator(req, self._stub, + self._retry_policy, + self._builder.metadata()) + # Iterate over the response + + result = [] + row_count = 0 + is_overflow = False + + for response in iterator: + if response.HasField("extension") and response.extension.Is( + cloud_pb2.CloudResultBatch.DESCRIPTOR): + batch = cloud_pb2.CloudResultBatch() + assert response.extension.Is(cloud_pb2.CloudResultBatch.DESCRIPTOR) + response.extension.Unpack(batch) + result += [Result(b.url, b.row_count, b.compressed_size, b.uncompressed_size) for b in batch.results] + row_count += sum(result.row_count for result in batch.results) + is_overflow |= batch.truncated + return result, row_count, is_overflow + + +SparkConnectClient.to_cf = to_cf + + +# Monkey Patching for DataFrame + +def collect_as_cf(self: DataFrame, type: str = "json") -> Tuple[List[Result], int, bool]: + query = self._plan.to_proto(self._session.client) + results, row_count, is_overflow = self._session.client.to_cf( + query, type) + return results, row_count, is_overflow + + +DataFrame.collect_cf = collect_as_cf From 9a1916bf8010ada0ef52597929119fd353090af7 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Fri, 8 Dec 2023 11:42:49 -0800 Subject: [PATCH 19/62] update --- scripts/data_prep/convert_delta_to_json.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index 2b3975a319..234070417f 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -115,16 +115,18 @@ def fetch(method, obj = urllib.parse.urlparse(json_output_path) if method == 'dbconnect': - df = run_query(f"SELECT * FROM {tablename}", method, cursor, sparkSession, collect=False) + #df = run_query(f"SELECT * FROM {tablename}", method, cursor, sparkSession, collect=False) print('partitions = ', partitions) + df = sparkSession.table("main.tpcds_sf100_delta.store_sales") - dbfs_cache = 'dbfs:/' + json_output_path.lstrip('/') - df.repartition(partitions).write.mode("overwrite").json(dbfs_cache) - print(f"downloading from {dbfs_cache} to {json_output_path}") + #dbfs_cache = 'dbfs:/' + json_output_path.lstrip('/') + #df.repartition(partitions).write.mode("overwrite").json(dbfs_cache) + #print(f"downloading from {dbfs_cache} to {json_output_path}") #subprocess.run(f"databricks fs cp -r {dbfs_cache} {json_output_path}", shell=True, capture_output=True, text=True) #subprocess.run(f"databricks fs rm -r {dbfs_cache}", shell=True, capture_output=True, text=True) signed, rows, overflow = df.collect_cf("json") print(len(signed)) + print(signed) print(rows) print(overflow) From d3defc14ac4526e4af5e9c1c2d229c1cee7e3e42 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Sat, 9 Dec 2023 00:11:12 -0800 Subject: [PATCH 20/62] add arrow --- scripts/data_prep/convert_delta_to_json.py | 87 +++++++++++++++++----- setup.py | 3 +- 2 files changed, 72 insertions(+), 18 deletions(-) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index 234070417f..3179762ccf 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -5,6 +5,7 @@ import logging import os import time +import random import urllib.parse import pandas as pd @@ -19,6 +20,8 @@ import patch # Monkey Patching for SparkConnectClient import requests +import pyarrow as pa +import lz4.frame log = logging.getLogger(__name__) @@ -50,12 +53,64 @@ def get_args(signed, json_output_path): yield (i, r.url, json_output_path) def download_json(i, url, json_output_path): - data =requests.get(url).json() - pd.DataFrame.from_dict(data).to_json(os.path.join(json_output_path, 'part_'+str(i)+'.json')) + for attempt in range(3): + try: + response = requests.get(url) + if response.status_code == 200: + data = response.json() + pd.DataFrame.from_dict(data).to_json(os.path.join(json_output_path, 'part_'+str(i)+'.json')) + break # Break the loop if the request is successful + else: + print(f"Attempt {attempt + 1} failed with status code {response.status_code}") + except requests.RequestException as e: + print(f"An error occurred: {e}") + + time.sleep(random.randint(1, 5)) + + if attempt == retries - 1: + raise RuntimeError("Failed to download after 3 attempts") + def download_json_starargs(args: Tuple): return download_json(*args) + +def download_arrow(i, url, json_output_path): + resp = requests.get(url) + if resp.status_code == 200: + try: + # Debug: Check the first few bytes of the response + print("First 100 bytes of response:", resp.content[:100]) + + # Decompress the data + decompressed_data = lz4.frame.decompress(resp.content) + + except RuntimeError as e: + print("Decompression error:", e) + print("Response headers:", resp.headers) + # Optionally, write the raw response to a file for analysis + with open("raw_response_data.bin", "wb") as file: + file.write(resp.content) + raise + +#def download_arrow(i, url, json_output_path): +# resp = requests.get(url) +# if resp.status_code == 200: +# # The data is lz4 compressed arrow format. +# # Decompress the data +# decompressed_data = lz4.frame.decompress(resp.content) +# +# # Convert the decompressed data into a PyArrow table +# reader = pa.ipc.open_stream(decompressed_data) +# table = reader.read_all() +# +# # Convert the PyArrow table into a pandas DataFrame +# df = table.to_pandas() +# df.to_json(os.path.join(json_output_path, 'part_'+str(i)+'.json')) + +def download_arrow_starargs(args: Tuple): + return download_arrow(*args) + def fetch_data_starargs(args: Tuple): return fetch_data(*args) @@ -115,24 +170,20 @@ def fetch(method, obj = urllib.parse.urlparse(json_output_path) if method == 'dbconnect': - #df = run_query(f"SELECT * FROM {tablename}", method, cursor, sparkSession, collect=False) print('partitions = ', partitions) - df = sparkSession.table("main.tpcds_sf100_delta.store_sales") - - #dbfs_cache = 'dbfs:/' + json_output_path.lstrip('/') - #df.repartition(partitions).write.mode("overwrite").json(dbfs_cache) - #print(f"downloading from {dbfs_cache} to {json_output_path}") - #subprocess.run(f"databricks fs cp -r {dbfs_cache} {json_output_path}", shell=True, capture_output=True, text=True) - #subprocess.run(f"databricks fs rm -r {dbfs_cache}", shell=True, capture_output=True, text=True) - signed, rows, overflow = df.collect_cf("json") - print(len(signed)) - print(signed) - print(rows) - print(overflow) + df = sparkSession.table(tablename) # "main.tpcds_sf100_delta.store_sales") + # Running the query and collecting the data as arrow. + signed, rows, overflow = df.collect_cf("json") # "arrow") + print(f"len(signed) = {len(signed)}") args = get_args(signed, json_output_path) - with ProcessPoolExecutor(max_workers=partitions) as executor: - list(executor.map(download_json_starargs, args)) + # Stopping the SparkSession to avoid spilling connection state into the subprocesses. + sparkSession.stop() + #with ProcessPoolExecutor(max_workers=partitions) as executor: + #list(executor.map(download_arrow_starargs, args)) + #list(executor.map(download_json_starargs, args)) + with Pool(partitions) as p: + p.map(download_json_starargs, args) elif method == 'dbsql': ans = run_query(query, method, cursor, sparkSession, collect=True) @@ -187,7 +238,9 @@ def fetch_DT(*args: Any, **kwargs: Any): else: method = 'dbconnect' session_id = str(uuid4()) + #sparkSession = DatabricksSession.builder.host(args.DATABRICKS_HOST).token(args.DATABRICKS_TOKEN)._cluster_id("0704-124501-tsc2fxq").getOrCreate() sparkSession = DatabricksSession.builder.host(args.DATABRICKS_HOST).token(args.DATABRICKS_TOKEN).header("x-databricks-session-id", session_id).getOrCreate() + #sparkSession = DatabricksSession.builder.remote(host =args.DATABRICKS_HOST, token =args.DATABRICKS_TOKEN, cluster_id ="0704-124501-tsc2fxq").getOrCreate() fetch(method, args.delta_table_name, args.json_output_path, args.batch_size, args.partitions, sparkSession, dbsql) diff --git a/setup.py b/setup.py index 4ffbd593a2..25d68085a5 100644 --- a/setup.py +++ b/setup.py @@ -86,7 +86,8 @@ extra_deps['databricks'] = [ 'mosaicml[databricks]>=0.17.2,<0.18', 'databricks-sql-connector>=3, <4', - 'databricks-connect==14.0.0' + 'databricks-connect==14.1.0', + 'lz4>=4,<5', ] extra_deps['tensorboard'] = [ From da820ebf7bfa1ef962a6a4b52edd97fa257b4355 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Mon, 11 Dec 2023 12:35:42 -0800 Subject: [PATCH 21/62] use compression --- scripts/data_prep/convert_delta_to_json.py | 62 +++++++++++----------- scripts/data_prep/patch.py | 2 +- 2 files changed, 32 insertions(+), 32 deletions(-) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index 3179762ccf..c4c86dcd8a 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -75,38 +75,38 @@ def download_json_starargs(args: Tuple): return download_json(*args) -def download_arrow(i, url, json_output_path): - resp = requests.get(url) - if resp.status_code == 200: - try: - # Debug: Check the first few bytes of the response - print("First 100 bytes of response:", resp.content[:100]) - - # Decompress the data - decompressed_data = lz4.frame.decompress(resp.content) - - except RuntimeError as e: - print("Decompression error:", e) - print("Response headers:", resp.headers) - # Optionally, write the raw response to a file for analysis - with open("raw_response_data.bin", "wb") as file: - file.write(resp.content) - raise - #def download_arrow(i, url, json_output_path): # resp = requests.get(url) # if resp.status_code == 200: -# # The data is lz4 compressed arrow format. -# # Decompress the data -# decompressed_data = lz4.frame.decompress(resp.content) +# try: +# # Debug: Check the first few bytes of the response +# print("First 100 bytes of response:", resp.content[:100]) # -# # Convert the decompressed data into a PyArrow table -# reader = pa.ipc.open_stream(decompressed_data) -# table = reader.read_all() +# # Decompress the data +# decompressed_data = lz4.frame.decompress(resp.content) # -# # Convert the PyArrow table into a pandas DataFrame -# df = table.to_pandas() -# df.to_json(os.path.join(json_output_path, 'part_'+str(i)+'.json')) +# except RuntimeError as e: +# print("Decompression error:", e) +# print("Response headers:", resp.headers) +# # Optionally, write the raw response to a file for analysis +# with open("raw_response_data.bin", "wb") as file: +# file.write(resp.content) +# raise + +def download_arrow(i, url, json_output_path): + resp = requests.get(url) + if resp.status_code == 200: + # The data is lz4 compressed arrow format. + # Decompress the data + decompressed_data = lz4.frame.decompress(resp.content) + + # Convert the decompressed data into a PyArrow table + reader = pa.ipc.open_stream(decompressed_data) + table = reader.read_all() + + # Convert the PyArrow table into a pandas DataFrame + df = table.to_pandas() + df.to_json(os.path.join(json_output_path, 'part_'+str(i)+'.json')) def download_arrow_starargs(args: Tuple): return download_arrow(*args) @@ -179,11 +179,11 @@ def fetch(method, args = get_args(signed, json_output_path) # Stopping the SparkSession to avoid spilling connection state into the subprocesses. sparkSession.stop() - #with ProcessPoolExecutor(max_workers=partitions) as executor: - #list(executor.map(download_arrow_starargs, args)) + with ProcessPoolExecutor(max_workers=partitions) as executor: + list(executor.map(download_arrow_starargs, args)) #list(executor.map(download_json_starargs, args)) - with Pool(partitions) as p: - p.map(download_json_starargs, args) + #with Pool(partitions) as p: + # p.map(download_json_starargs, args) elif method == 'dbsql': ans = run_query(query, method, cursor, sparkSession, collect=True) diff --git a/scripts/data_prep/patch.py b/scripts/data_prep/patch.py index 30299fc1bf..1c9e835d4c 100644 --- a/scripts/data_prep/patch.py +++ b/scripts/data_prep/patch.py @@ -43,7 +43,7 @@ def to_cf(self: SparkConnectClient, plan: pb2.Plan, type: str = "json"): type=cloud_pb2.ResultOptions.TYPE_CLOUD, cloudOptions=cloud_pb2.ResultOptions.CloudOptions( format=format, - useCompression=False, + useCompression=True, )) cloud_option = any_pb2.Any() cloud_option.Pack(ro) From 60a4be745bcd00af1ecbde68957d1ee1eb242ec5 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Tue, 12 Dec 2023 08:01:11 -0800 Subject: [PATCH 22/62] clean up --- scripts/data_prep/convert_delta_to_json.py | 95 +++++++++++++--------- 1 file changed, 57 insertions(+), 38 deletions(-) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index c4c86dcd8a..c3761c0b79 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -4,6 +4,7 @@ import argparse import logging import os +import json import time import random @@ -26,6 +27,28 @@ log = logging.getLogger(__name__) +def iterative_combine_jsons(json_directory, output_file): + """Combine json files in json_directory into one big jsonl file + Args: + json_directory(str): directory containing the JSON files + output_file(str): output JSONL file + """ + json_files = [f for f in os.listdir(json_directory) if f.endswith('.json')] + + def read_json(file_path): + with open(file_path, 'r', encoding='utf-8') as file: + return json.load(file) + + with open(output_file, 'w', encoding='utf-8') as outfile: + for json_file in json_files: + full_path = os.path.join(json_directory, json_file) + json_obj = read_json(full_path) + json.dump(json_obj, outfile, ensure_ascii=False) + outfile.write('\n') # Write a newline character after each JSON object + + print('JSON files have been combined into a JSONL file.') + + def run_query(q:str, method:str, cursor=None, spark=None, collect=True) -> Optional[List[Row]]: if not q: return @@ -52,8 +75,8 @@ def get_args(signed, json_output_path): for i, r in enumerate(signed): yield (i, r.url, json_output_path) -def download_json(i, url, json_output_path): - for attempt in range(3): +def download_json(i, url, json_output_path, max_retry=3): + for attempt in range(max_retry): try: response = requests.get(url) if response.status_code == 200: @@ -68,31 +91,11 @@ def download_json(i, url, json_output_path): time.sleep(random.randint(1, 5)) if attempt == retries - 1: - raise RuntimeError("Failed to download after 3 attempts") - + raise RuntimeError(f"Failed to download after {max_retry} attempts") def download_json_starargs(args: Tuple): return download_json(*args) - -#def download_arrow(i, url, json_output_path): -# resp = requests.get(url) -# if resp.status_code == 200: -# try: -# # Debug: Check the first few bytes of the response -# print("First 100 bytes of response:", resp.content[:100]) -# -# # Decompress the data -# decompressed_data = lz4.frame.decompress(resp.content) -# -# except RuntimeError as e: -# print("Decompression error:", e) -# print("Response headers:", resp.headers) -# # Optionally, write the raw response to a file for analysis -# with open("raw_response_data.bin", "wb") as file: -# file.write(resp.content) -# raise - def download_arrow(i, url, json_output_path): resp = requests.get(url) if resp.status_code == 200: @@ -111,9 +114,6 @@ def download_arrow(i, url, json_output_path): def download_arrow_starargs(args: Tuple): return download_arrow(*args) -def fetch_data_starargs(args: Tuple): - return fetch_data(*args) - def fetch_data(method, cursor, sparkSession, s, e, order_by, tablename, columns_str, json_output_path): query = f""" WITH NumberedRows AS ( @@ -145,9 +145,15 @@ def fetch(method, sparkSession = None, dbsql = None, ): - """Fetch UC delta table with databricks-connnect and convert them to json. - In the case when table is very large, we fetch batch_size rows a time. - Compared to fetch_DT_with_dbsql, this function does not need http_path. + """Fetch UC delta table with databricks-connnect and convert them to a number of json files. + In the case when table is very large, we fetch batch_size rows a time. + Args: + method (str): dbconnect or dbsql + tablename (str): catalog.scheme.tablename on UC + batch_size (int): the number of rows that each time to fetch + processes (int): max number of processes to use to parallelize the fetch + sparkSession (pyspark.sql.sparksession): spark session + dbsql (databricks.sql.connect): dbsql session """ cursor = dbsql.cursor() if dbsql is not None else None @@ -174,14 +180,16 @@ def fetch(method, df = sparkSession.table(tablename) # "main.tpcds_sf100_delta.store_sales") # Running the query and collecting the data as arrow. - signed, rows, overflow = df.collect_cf("json") # "arrow") + signed, rows, overflow = df.collect_cf("arrow") print(f"len(signed) = {len(signed)}") + args = get_args(signed, json_output_path) + # Stopping the SparkSession to avoid spilling connection state into the subprocesses. sparkSession.stop() + with ProcessPoolExecutor(max_workers=partitions) as executor: list(executor.map(download_arrow_starargs, args)) - #list(executor.map(download_json_starargs, args)) #with Pool(partitions) as p: # p.map(download_json_starargs, args) @@ -198,15 +206,15 @@ def fetch(method, def fetch_DT(*args: Any, **kwargs: Any): - r"""Fetch Delta Table from UC and save to local - Based on the arguments, the call is redirected to either fetch_DT_with_dbconnect or fetch_DT_with_dbsql + """Fetch UC Delta Table to local as json files and combined into one big jsonl + By default, databricks-connect is used. Only when ``http_path`` is present in the argument, use dbsql. """ args = args[0] log.info(f'Start .... Convert delta to json') obj = urllib.parse.urlparse(args.json_output_path) if obj.scheme != '': - raise ValueError(f"We don't support writing to remote yet in this script!") + raise ValueError(f"Check the json_output_path and verify it is a local path!") if os.path.exists(args.json_output_path): if not os.path.isdir(args.json_output_path) or os.listdir( @@ -237,16 +245,21 @@ def fetch_DT(*args: Any, **kwargs: Any): ) from e else: method = 'dbconnect' - session_id = str(uuid4()) - #sparkSession = DatabricksSession.builder.host(args.DATABRICKS_HOST).token(args.DATABRICKS_TOKEN)._cluster_id("0704-124501-tsc2fxq").getOrCreate() - sparkSession = DatabricksSession.builder.host(args.DATABRICKS_HOST).token(args.DATABRICKS_TOKEN).header("x-databricks-session-id", session_id).getOrCreate() - #sparkSession = DatabricksSession.builder.remote(host =args.DATABRICKS_HOST, token =args.DATABRICKS_TOKEN, cluster_id ="0704-124501-tsc2fxq").getOrCreate() + if not args.cluster_id: + session_id = str(uuid4()) + sparkSession = DatabricksSession.builder.host(args.DATABRICKS_HOST).token(args.DATABRICKS_TOKEN).header("x-databricks-session-id", session_id).getOrCreate() + else: + # IMPORTANT: make sure cluster has runtime newer than 14.1.0, the databricks-connect client version. + compute_id = args.cluster_id # "1115-130834-ms4m0yv" + sparkSession = DatabricksSession.builder.remote(host =args.DATABRICKS_HOST, token =args.DATABRICKS_TOKEN, cluster_id = compute_id).getOrCreate() fetch(method, args.delta_table_name, args.json_output_path, args.batch_size, args.partitions, sparkSession, dbsql) if dbsql is not None: dbsql.close() + iterative_combine_jsons(args.json_output_path, os.path.join(args.json_output_path, 'combined.jsonl')) + if __name__ == '__main__': parser = argparse.ArgumentParser( description= @@ -285,6 +298,12 @@ def fetch_DT(*args: Any, **kwargs: Any): default=1, help= 'number of partitions allowed to use') + parser.add_argument('--cluster_id', + required=False, + type=str, + default=None, + help= + 'Use serverless if not present. IMPORTANT! make sure cluster has runtime newer than 14.1.0, the databricks-connect client version') parser.add_argument('--debug', type=bool, required=False, default=False) args = parser.parse_args() From 382a8fad9ce49d4bcd0872f59be202191c576504 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Tue, 12 Dec 2023 15:09:19 -0800 Subject: [PATCH 23/62] Add cluster rt check --- scripts/data_prep/convert_delta_to_json.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index c3761c0b79..bede21277f 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -13,6 +13,7 @@ from databricks import sql from typing import Any, Optional, List, Tuple from databricks.connect import DatabricksSession +from databricks.sdk import WorkspaceClient from uuid import uuid4 from pyspark.sql.types import Row from concurrent.futures import ProcessPoolExecutor @@ -24,6 +25,7 @@ import pyarrow as pa import lz4.frame +MINIUM_DBR_VERSION = 14.1.0 log = logging.getLogger(__name__) @@ -251,6 +253,10 @@ def fetch_DT(*args: Any, **kwargs: Any): else: # IMPORTANT: make sure cluster has runtime newer than 14.1.0, the databricks-connect client version. compute_id = args.cluster_id # "1115-130834-ms4m0yv" + w = WorkspaceClient() + res = w.clusters.get(cluster_id="0704-124501-tsc2fxq") + runtime_version = res.spark_version.split('-scala')[0].replace('x-snapshot', '0') + assert version.parse(runtime_version) >= version.parse(MINIUM_DBR_VERSION), "You need at least {MINIMUM_DBR_VERSION} to use Databricks-connect to read delta table for FT API" sparkSession = DatabricksSession.builder.remote(host =args.DATABRICKS_HOST, token =args.DATABRICKS_TOKEN, cluster_id = compute_id).getOrCreate() fetch(method, args.delta_table_name, args.json_output_path, args.batch_size, args.partitions, sparkSession, dbsql) @@ -299,7 +305,7 @@ def fetch_DT(*args: Any, **kwargs: Any): help= 'number of partitions allowed to use') parser.add_argument('--cluster_id', - required=False, + required=True, type=str, default=None, help= From f0286ddf61b11e7f31cb07868d5f362253820b92 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Tue, 12 Dec 2023 16:43:21 -0800 Subject: [PATCH 24/62] Fix lints --- scripts/data_prep/convert_delta_to_json.py | 274 ++++++++++-------- scripts/data_prep/patch.py | 69 +++-- .../data_prep/test_convert_delta_to_json.py | 23 +- 3 files changed, 206 insertions(+), 160 deletions(-) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index bede21277f..85e0e602c0 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -1,43 +1,46 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -import argparse +import json import logging import os -import json -import time import random - +import time import urllib.parse +from argparse import ArgumentParser, Namespace +from concurrent.futures import ProcessPoolExecutor +from typing import List, Optional, Tuple, Union +from uuid import uuid4 + +import lz4.frame import pandas as pd +# Monkey Patching for SparkConnectClient +import patch # pyright: ignore +import pyarrow as pa +import requests from databricks import sql -from typing import Any, Optional, List, Tuple from databricks.connect import DatabricksSession from databricks.sdk import WorkspaceClient -from uuid import uuid4 +from packaging import version +from pyspark.sql import SparkSession +from pyspark.sql.dataframe import DataFrame from pyspark.sql.types import Row -from concurrent.futures import ProcessPoolExecutor -from multiprocessing import Pool -import subprocess -import patch # Monkey Patching for SparkConnectClient -import requests -import pyarrow as pa -import lz4.frame - -MINIUM_DBR_VERSION = 14.1.0 +MINIUM_DBR_VERSION = '14.1.0' log = logging.getLogger(__name__) -def iterative_combine_jsons(json_directory, output_file): - """Combine json files in json_directory into one big jsonl file + +def iterative_combine_jsons(json_directory: str, output_file: str): + """Combine json files in json_directory into one big jsonl file. + Args: json_directory(str): directory containing the JSON files output_file(str): output JSONL file """ json_files = [f for f in os.listdir(json_directory) if f.endswith('.json')] - def read_json(file_path): + def read_json(file_path: str): with open(file_path, 'r', encoding='utf-8') as file: return json.load(file) @@ -46,25 +49,29 @@ def read_json(file_path): full_path = os.path.join(json_directory, json_file) json_obj = read_json(full_path) json.dump(json_obj, outfile, ensure_ascii=False) - outfile.write('\n') # Write a newline character after each JSON object + outfile.write( + '\n') # Write a newline character after each JSON object print('JSON files have been combined into a JSONL file.') -def run_query(q:str, method:str, cursor=None, spark=None, collect=True) -> Optional[List[Row]]: - if not q: - return +def run_query(q: str, + method: str, + cursor: Optional[sql.client.Cursor] = None, + spark: Optional[SparkSession] = None, + collect: bool = True) -> Optional[Union[List[Row], DataFrame]]: + assert method in ['dbsql', 'dbconnect'], f'Unrecognized method: {method}' if method == 'dbsql': if cursor is None: - raise ValueError(f"cursor cannot be None if using method dbsql") + raise ValueError(f'cursor cannot be None if using method dbsql') cursor.execute(q) if collect: return cursor.fetchall() - if method == 'dbconnect': + elif method == 'dbconnect': if spark == None: - raise ValueError(f"sparkSession is required for dbconnect") + raise ValueError(f'sparkSession is required for dbconnect') df = spark.sql(q) if collect: return df.collect() @@ -73,32 +80,42 @@ def run_query(q:str, method:str, cursor=None, spark=None, collect=True) -> Optio return None -def get_args(signed, json_output_path): +def get_args(signed: List, json_output_path: str): for i, r in enumerate(signed): yield (i, r.url, json_output_path) -def download_json(i, url, json_output_path, max_retry=3): + +def download_json(ipart: int, + url: str, + json_output_path: str, + max_retry: int = 3): for attempt in range(max_retry): try: response = requests.get(url) if response.status_code == 200: data = response.json() - pd.DataFrame.from_dict(data).to_json(os.path.join(json_output_path, 'part_'+str(i)+'.json')) + pd.DataFrame.from_dict(data).to_json( + os.path.join(json_output_path, + 'part_' + str(ipart) + '.json')) break # Break the loop if the request is successful else: - print(f"Attempt {attempt + 1} failed with status code {response.status_code}") + print( + f'Attempt {attempt + 1} failed with status code {response.status_code}' + ) except requests.RequestException as e: - print(f"An error occurred: {e}") + print(f'An error occurred: {e}') time.sleep(random.randint(1, 5)) - if attempt == retries - 1: - raise RuntimeError(f"Failed to download after {max_retry} attempts") + if attempt == max_retry - 1: + raise RuntimeError(f'Failed to download after {max_retry} attempts') + def download_json_starargs(args: Tuple): return download_json(*args) -def download_arrow(i, url, json_output_path): + +def download_arrow(ipart: int, url: str, json_output_path: str): resp = requests.get(url) if resp.status_code == 200: # The data is lz4 compressed arrow format. @@ -111,12 +128,18 @@ def download_arrow(i, url, json_output_path): # Convert the PyArrow table into a pandas DataFrame df = table.to_pandas() - df.to_json(os.path.join(json_output_path, 'part_'+str(i)+'.json')) + df.to_json( + os.path.join(json_output_path, 'part_' + str(ipart) + '.json')) + def download_arrow_starargs(args: Tuple): return download_arrow(*args) -def fetch_data(method, cursor, sparkSession, s, e, order_by, tablename, columns_str, json_output_path): + +def fetch_data(method: str, cursor: Optional[sql.client.Cursor], + sparkSession: Optional[SparkSession], s: int, e: int, + order_by: str, tablename: str, columns_str: str, + json_output_path: str): query = f""" WITH NumberedRows AS ( SELECT @@ -130,60 +153,73 @@ def fetch_data(method, cursor, sparkSession, s, e, order_by, tablename, columns_ WHERE rn BETWEEN {s+1} AND {e}""" if method == 'dbconnect': - pdf = run_query(query, method, cursor, sparkSession, collect=False).toPandas() - elif method == 'dbsql': + spark_df = run_query(query, method, cursor, sparkSession, collect=False) + if not spark_df: + raise RuntimeError( + f'Expect spark dataframe with {query} but got None') + pdf = spark_df.toPandas() # pyright: ignore + else: # method == 'dbsql': ans = run_query(query, method, cursor, sparkSession, collect=True) - pdf = pd.DataFrame.from_dict([row.asDict() for row in ans]) - - pdf.to_json(os.path.join(json_output_path, - f'part_{s+1}_{e}.json')) - - -def fetch(method, - tablename: str, - json_output_path: str, - batch_size: int = 1 << 20, - partitions = 1, - sparkSession = None, - dbsql = None, - ): - """Fetch UC delta table with databricks-connnect and convert them to a number of json files. - In the case when table is very large, we fetch batch_size rows a time. - Args: - method (str): dbconnect or dbsql - tablename (str): catalog.scheme.tablename on UC - batch_size (int): the number of rows that each time to fetch - processes (int): max number of processes to use to parallelize the fetch - sparkSession (pyspark.sql.sparksession): spark session - dbsql (databricks.sql.connect): dbsql session + if not ans: + raise RuntimeError(f'Got empty results with {query}') + records = [r.asDict() for r in ans] # pyright: ignore + pdf = pd.DataFrame.from_dict(records) + + pdf.to_json(os.path.join(json_output_path, f'part_{s+1}_{e}.json')) + + +def fetch( + method: str, + tablename: str, + json_output_path: str, + batch_size: int = 1 << 30, + partitions: int = 1, + sparkSession: Optional[SparkSession] = None, + dbsql: Optional[sql.Client.Connection] = None, +): + """Fetch UC delta table with databricks-connnect as JSONL. + + Args: + method (str): dbconnect or dbsql + tablename (str): catalog.scheme.tablename on UC + json_output_path (str): path to write the result json file to + batch_size (int): number of rows that fetch each time to avoid OOM + partitions (int): max number of processes to use to parallelize the fetch + sparkSession (pyspark.sql.sparksession): spark session + dbsql (databricks.sql.connect): dbsql session """ cursor = dbsql.cursor() if dbsql is not None else None try: - ans = run_query(f"SELECT COUNT(*) FROM {tablename}", method, cursor, sparkSession) - total_rows = [row.asDict() for row in ans][0].popitem()[1] - log.info(f'total_rows = {total_rows}') + ans = run_query(f'SELECT COUNT(*) FROM {tablename}', method, cursor, + sparkSession) + nrows = [row.asDict() for row in ans][0].popitem()[1] # pyright: ignore + log.info(f'total_rows = {nrows}') except Exception as e: - raise RuntimeError(f"Error in get total rows from {tablename}. Restart sparkSession and try again") from e + raise RuntimeError( + f'Error in get total rows from {tablename}. Restart sparkSession and try again' + ) from e try: - ans = run_query(f"SHOW COLUMNS IN {tablename}", method, cursor, sparkSession) - columns = [row.asDict().popitem()[1] for row in ans] + ans = run_query(f'SHOW COLUMNS IN {tablename}', method, cursor, + sparkSession) + columns = [row.asDict().popitem()[1] for row in ans] # pyright: ignore order_by = columns[0] columns_str = ','.join(columns) log.info(f'order by column {order_by}') except Exception as e: - raise RuntimeError(f"Error in get columns from {tablename}. Restart sparkSession and try again") from e + raise RuntimeError( + f'Error in get columns from {tablename}. Restart sparkSession and try again' + ) from e - obj = urllib.parse.urlparse(json_output_path) - - if method == 'dbconnect': + if method == 'dbconnect' and sparkSession: print('partitions = ', partitions) - df = sparkSession.table(tablename) # "main.tpcds_sf100_delta.store_sales") + df = sparkSession.table( + tablename) # "main.tpcds_sf100_delta.store_sales") # Running the query and collecting the data as arrow. - signed, rows, overflow = df.collect_cf("arrow") - print(f"len(signed) = {len(signed)}") + signed, _, _ = df.collect_cf('arrow') # pyright: ignore + print(f'len(signed) = {len(signed)}') args = get_args(signed, json_output_path) @@ -192,31 +228,25 @@ def fetch(method, with ProcessPoolExecutor(max_workers=partitions) as executor: list(executor.map(download_arrow_starargs, args)) - #with Pool(partitions) as p: - # p.map(download_json_starargs, args) - elif method == 'dbsql': - ans = run_query(query, method, cursor, sparkSession, collect=True) - pdf = pd.DataFrame.from_dict([row.asDict() for row in ans]) - for start in range(0, total_rows, batch_size): - end = min(start + batch_size, total_rows) - fetch_data(method, cursor, sparkSession, start, end, order_by, tablename, columns_str, json_output_path) + elif method == 'dbsql' and cursor: + for start in range(0, nrows, batch_size): + end = min(start + batch_size, nrows) + fetch_data(method, cursor, sparkSession, start, end, order_by, + tablename, columns_str, json_output_path) if cursor is not None: cursor.close() - -def fetch_DT(*args: Any, **kwargs: Any): - """Fetch UC Delta Table to local as json files and combined into one big jsonl - By default, databricks-connect is used. Only when ``http_path`` is present in the argument, use dbsql. - """ - args = args[0] +def fetch_DT(args: Namespace): + """Fetch UC Delta Table to local as jsonl.""" log.info(f'Start .... Convert delta to json') obj = urllib.parse.urlparse(args.json_output_path) if obj.scheme != '': - raise ValueError(f"Check the json_output_path and verify it is a local path!") + raise ValueError( + f'Check the json_output_path and verify it is a local path!') if os.path.exists(args.json_output_path): if not os.path.isdir(args.json_output_path) or os.listdir( @@ -249,25 +279,37 @@ def fetch_DT(*args: Any, **kwargs: Any): method = 'dbconnect' if not args.cluster_id: session_id = str(uuid4()) - sparkSession = DatabricksSession.builder.host(args.DATABRICKS_HOST).token(args.DATABRICKS_TOKEN).header("x-databricks-session-id", session_id).getOrCreate() + sparkSession = DatabricksSession.builder.host( + args.DATABRICKS_HOST).token(args.DATABRICKS_TOKEN).header( + 'x-databricks-session-id', session_id).getOrCreate() else: # IMPORTANT: make sure cluster has runtime newer than 14.1.0, the databricks-connect client version. - compute_id = args.cluster_id # "1115-130834-ms4m0yv" + compute_id = args.cluster_id # "1115-130834-ms4m0yv" w = WorkspaceClient() - res = w.clusters.get(cluster_id="0704-124501-tsc2fxq") - runtime_version = res.spark_version.split('-scala')[0].replace('x-snapshot', '0') - assert version.parse(runtime_version) >= version.parse(MINIUM_DBR_VERSION), "You need at least {MINIMUM_DBR_VERSION} to use Databricks-connect to read delta table for FT API" - sparkSession = DatabricksSession.builder.remote(host =args.DATABRICKS_HOST, token =args.DATABRICKS_TOKEN, cluster_id = compute_id).getOrCreate() - - fetch(method, args.delta_table_name, args.json_output_path, args.batch_size, args.partitions, sparkSession, dbsql) + res = w.clusters.get(cluster_id='0704-124501-tsc2fxq') + runtime_version = res.spark_version.split('-scala')[0].replace( + 'x-snapshot', '0') + assert version.parse(runtime_version) >= version.parse( + MINIUM_DBR_VERSION + ), 'You need at least {MINIMUM_DBR_VERSION} to use Databricks-connect to read delta table for FT API' + sparkSession = DatabricksSession.builder.remote( + host=args.DATABRICKS_HOST, + token=args.DATABRICKS_TOKEN, + cluster_id=compute_id).getOrCreate() + + fetch(method, args.delta_table_name, args.json_output_path, args.batch_size, + args.partitions, sparkSession, dbsql) if dbsql is not None: dbsql.close() - iterative_combine_jsons(args.json_output_path, os.path.join(args.json_output_path, 'combined.jsonl')) + iterative_combine_jsons( + args.json_output_path, + os.path.join(args.json_output_path, 'combined.jsonl')) + if __name__ == '__main__': - parser = argparse.ArgumentParser( + parser = ArgumentParser( description= 'Download delta table from UC and convert to json to save local') parser.add_argument( @@ -287,33 +329,33 @@ def fetch_DT(*args: Any, **kwargs: Any): required=False, type=str, help='DATABRICKS_TOKEN') - parser.add_argument('--http_path', - required=False, - type=str, - help= - 'http_path from either dedicated cluster or serverless sql warehouse') + parser.add_argument( + '--http_path', + required=False, + type=str, + help= + 'http_path from either dedicated cluster or serverless sql warehouse') parser.add_argument('--batch_size', required=False, type=int, - default=1<<20, - help= - 'chunk of rows to transmit a time') + default=1 << 20, + help='chunk of rows to transmit a time') parser.add_argument('--partitions', required=False, type=int, default=1, - help= - 'number of partitions allowed to use') - parser.add_argument('--cluster_id', - required=True, - type=str, - default=None, - help= - 'Use serverless if not present. IMPORTANT! make sure cluster has runtime newer than 14.1.0, the databricks-connect client version') + help='number of partitions allowed to use') + parser.add_argument( + '--cluster_id', + required=True, + type=str, + default=None, + help= + 'Use serverless if not present. IMPORTANT! make sure cluster has runtime newer than 14.1.0, the databricks-connect client version' + ) parser.add_argument('--debug', type=bool, required=False, default=False) args = parser.parse_args() tik = time.time() fetch_DT(args) - print("Elapsed time", time.time() - tik) - + print('Elapsed time', time.time() - tik) diff --git a/scripts/data_prep/patch.py b/scripts/data_prep/patch.py index 1c9e835d4c..80b6a788e2 100644 --- a/scripts/data_prep/patch.py +++ b/scripts/data_prep/patch.py @@ -1,43 +1,47 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + # This file is a monkey patch on top of the DB Connect package that allows # the client to fetch the results in different formats from the server. To be # able to use the code make sure to first import this module before importing # the DB Connect classes. -from typing import Tuple, List - -from pyspark.sql.connect.dataframe import DataFrame -from pyspark.sql.connect.client.core import SparkConnectClient -from pyspark.sql.connect.client.reattach import ExecutePlanResponseReattachableIterator +from collections import namedtuple +from typing import List, Tuple +import google.protobuf.any_pb2 as any_pb2 +import pyspark.sql.connect.proto as pb2 # PB2 stuff import pyspark.sql.connect.proto.cloud_pb2 as cloud_pb2 -import pyspark.sql.connect.proto as pb2 -import google.protobuf.any_pb2 as any_pb2 -from collections import namedtuple +from pyspark.sql.connect.client.core import SparkConnectClient +from pyspark.sql.connect.client.reattach import \ + ExecutePlanResponseReattachableIterator +from pyspark.sql.connect.dataframe import DataFrame + +Result = namedtuple( + 'Result', ['url', 'row_count', 'compressed_size', 'uncompressed_size' + ]) # pyright: ignore -Result = namedtuple("Result", ["url", "row_count", "compressed_size", "uncompressed_size"]) # Monkey Patching for SparkConnectClient -def to_cf(self: SparkConnectClient, plan: pb2.Plan, type: str = "json"): - """ - Executes a given plan object and returns the results as cloud fetch - presigned URLS. It can handle the current outptu formats that are - supported by the server. +def to_cf(self: SparkConnectClient, plan: pb2.Plan, type: str = 'json'): + """Executes plan object return as cloud fetch presigned URLS. - In contrast to the regular API methods of the client, this method - does not return the schema and drops all other responses. + It can handle the current outptu formats that are supported by the server. + In contrast to the regular API methods of the client, this method does not + return the schema and drops all other responses. """ req = self._execute_plan_request_with_metadata() req.plan.CopyFrom(plan) # Add the request options - if type == "json": + if type == 'json': format = cloud_pb2.ResultOptions.CloudOptions.FORMAT_JSON - elif type == "csv": + elif type == 'csv': format = cloud_pb2.ResultOptions.CloudOptions.FORMAT_CSV - elif type == "arrow": + elif type == 'arrow': format = cloud_pb2.ResultOptions.CloudOptions.FORMAT_ARROW else: - raise Exception("Invalid type") + raise Exception('Invalid type') ro = cloud_pb2.ResultOptions( type=cloud_pb2.ResultOptions.TYPE_CLOUD, @@ -47,7 +51,8 @@ def to_cf(self: SparkConnectClient, plan: pb2.Plan, type: str = "json"): )) cloud_option = any_pb2.Any() cloud_option.Pack(ro) - req.request_options.append(pb2.ExecutePlanRequest.RequestOption(extension=cloud_option)) + req.request_options.append( + pb2.ExecutePlanRequest.RequestOption(extension=cloud_option)) # Create the iterator iterator = ExecutePlanResponseReattachableIterator(req, self._stub, @@ -60,27 +65,29 @@ def to_cf(self: SparkConnectClient, plan: pb2.Plan, type: str = "json"): is_overflow = False for response in iterator: - if response.HasField("extension") and response.extension.Is( + if response.HasField('extension') and response.extension.Is( cloud_pb2.CloudResultBatch.DESCRIPTOR): batch = cloud_pb2.CloudResultBatch() assert response.extension.Is(cloud_pb2.CloudResultBatch.DESCRIPTOR) response.extension.Unpack(batch) - result += [Result(b.url, b.row_count, b.compressed_size, b.uncompressed_size) for b in batch.results] + result += [ + Result(b.url, b.row_count, b.compressed_size, + b.uncompressed_size) for b in batch.results + ] row_count += sum(result.row_count for result in batch.results) is_overflow |= batch.truncated return result, row_count, is_overflow -SparkConnectClient.to_cf = to_cf - +SparkConnectClient.to_cf = to_cf # pyright: ignore # Monkey Patching for DataFrame -def collect_as_cf(self: DataFrame, type: str = "json") -> Tuple[List[Result], int, bool]: - query = self._plan.to_proto(self._session.client) - results, row_count, is_overflow = self._session.client.to_cf( - query, type) - return results, row_count, is_overflow + +def collect_as_cf(self: DataFrame, + type: str = 'json') -> Tuple[List[Result], int, bool]: + query = self._plan.to_proto(self._session.client) # pyright: ignore + return self._session.client.to_cf(query, type) # pyright: ignore -DataFrame.collect_cf = collect_as_cf +DataFrame.collect_cf = collect_as_cf # pyright: ignore diff --git a/tests/a_scripts/data_prep/test_convert_delta_to_json.py b/tests/a_scripts/data_prep/test_convert_delta_to_json.py index f759913d47..54ce4b44be 100644 --- a/tests/a_scripts/data_prep/test_convert_delta_to_json.py +++ b/tests/a_scripts/data_prep/test_convert_delta_to_json.py @@ -4,14 +4,12 @@ # copyright 2022 mosaicml llm foundry authors # spdx-license-identifier: apache-2.0 -import os -import sys -from typing import Any - import unittest +from argparse import Namespace +from typing import Any from unittest.mock import MagicMock, patch -from scripts.data_prep.convert_delta_to_json import stream_delta_to_json +from scripts.data_prep.convert_delta_to_json import fetch_DT class TestStreamDeltaToJson(): @@ -42,15 +40,14 @@ def test_stream_delta_to_json(self, mock_to_json: Any, mock_connect: Any): [column_response_item], [data_response_item]] - stream_delta_to_json( - server_hostname = 'test_host', - access_token = 'test_token', - http_path = 'test_http_path', - tablename = 'test_table', - json_output_path = 'test_output_path' - ) - + args = Namespace(DATABRICKS_HOST='test_host', + DATABRICKS_TOKEN='test_token', + http_path='test_http_path', + tablename='test_table', + json_output_path='test_output_path', + cluster_id='test_cluster_id') + fetch_DT(args) mock_connect.assert_called_once_with(server_hostname='test_host', http_path='test_http_path', access_token='test_token') From 09efbb0e1324d9c7b94f6694183b0338489a14da Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Tue, 12 Dec 2023 16:53:16 -0800 Subject: [PATCH 25/62] remove patch.py for CI --- scripts/data_prep/convert_delta_to_json.py | 86 +++++++++++++++++ scripts/data_prep/patch.py | 93 ------------------- .../data_prep/test_convert_delta_to_json.py | 5 - 3 files changed, 86 insertions(+), 98 deletions(-) delete mode 100644 scripts/data_prep/patch.py diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index 85e0e602c0..109624c57b 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -8,21 +8,30 @@ import time import urllib.parse from argparse import ArgumentParser, Namespace +from collections import namedtuple from concurrent.futures import ProcessPoolExecutor from typing import List, Optional, Tuple, Union from uuid import uuid4 +import google.protobuf.any_pb2 as any_pb2 import lz4.frame import pandas as pd # Monkey Patching for SparkConnectClient import patch # pyright: ignore import pyarrow as pa +import pyspark.sql.connect.proto as pb2 +# PB2 stuff +import pyspark.sql.connect.proto.cloud_pb2 as cloud_pb2 import requests from databricks import sql from databricks.connect import DatabricksSession from databricks.sdk import WorkspaceClient from packaging import version from pyspark.sql import SparkSession +from pyspark.sql.connect.client.core import SparkConnectClient +from pyspark.sql.connect.client.reattach import \ + ExecutePlanResponseReattachableIterator +from pyspark.sql.connect.dataframe import DataFrame from pyspark.sql.dataframe import DataFrame from pyspark.sql.types import Row @@ -30,6 +39,83 @@ log = logging.getLogger(__name__) +Result = namedtuple( + 'Result', ['url', 'row_count', 'compressed_size', 'uncompressed_size' + ]) # pyright: ignore + + +# Monkey Patching for SparkConnectClient +def to_cf(self: SparkConnectClient, plan: pb2.Plan, type: str = 'json'): + """Executes plan object return as cloud fetch presigned URLS. + + It can handle the current outptu formats that are supported by the server. + In contrast to the regular API methods of the client, this method does not + return the schema and drops all other responses. + """ + req = self._execute_plan_request_with_metadata() + req.plan.CopyFrom(plan) + + # Add the request options + if type == 'json': + format = cloud_pb2.ResultOptions.CloudOptions.FORMAT_JSON + elif type == 'csv': + format = cloud_pb2.ResultOptions.CloudOptions.FORMAT_CSV + elif type == 'arrow': + format = cloud_pb2.ResultOptions.CloudOptions.FORMAT_ARROW + else: + raise Exception('Invalid type') + + ro = cloud_pb2.ResultOptions( + type=cloud_pb2.ResultOptions.TYPE_CLOUD, + cloudOptions=cloud_pb2.ResultOptions.CloudOptions( + format=format, + useCompression=True, + )) + cloud_option = any_pb2.Any() + cloud_option.Pack(ro) + req.request_options.append( + pb2.ExecutePlanRequest.RequestOption(extension=cloud_option)) + + # Create the iterator + iterator = ExecutePlanResponseReattachableIterator(req, self._stub, + self._retry_policy, + self._builder.metadata()) + # Iterate over the response + + result = [] + row_count = 0 + is_overflow = False + + for response in iterator: + if response.HasField('extension') and response.extension.Is( + cloud_pb2.CloudResultBatch.DESCRIPTOR): + batch = cloud_pb2.CloudResultBatch() + assert response.extension.Is(cloud_pb2.CloudResultBatch.DESCRIPTOR) + response.extension.Unpack(batch) + result += [ + Result(b.url, b.row_count, b.compressed_size, + b.uncompressed_size) for b in batch.results + ] + row_count += sum(result.row_count for result in batch.results) + is_overflow |= batch.truncated + return result, row_count, is_overflow + + +SparkConnectClient.to_cf = to_cf # pyright: ignore + +# This is a monkey patch on top of the DB Connect package that allows +# the client to fetch the results in different formats from the server. To be +# able to use the code make sure this module is not overriden by DB Connect classes. + + +def collect_as_cf(self: DataFrame, + type: str = 'json') -> Tuple[List[Result], int, bool]: + query = self._plan.to_proto(self._session.client) # pyright: ignore + return self._session.client.to_cf(query, type) # pyright: ignore + + +DataFrame.collect_cf = collect_as_cf # pyright: ignore + def iterative_combine_jsons(json_directory: str, output_file: str): """Combine json files in json_directory into one big jsonl file. diff --git a/scripts/data_prep/patch.py b/scripts/data_prep/patch.py deleted file mode 100644 index 80b6a788e2..0000000000 --- a/scripts/data_prep/patch.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright 2022 MosaicML LLM Foundry authors -# SPDX-License-Identifier: Apache-2.0 - -# This file is a monkey patch on top of the DB Connect package that allows -# the client to fetch the results in different formats from the server. To be -# able to use the code make sure to first import this module before importing -# the DB Connect classes. -from collections import namedtuple -from typing import List, Tuple - -import google.protobuf.any_pb2 as any_pb2 -import pyspark.sql.connect.proto as pb2 -# PB2 stuff -import pyspark.sql.connect.proto.cloud_pb2 as cloud_pb2 -from pyspark.sql.connect.client.core import SparkConnectClient -from pyspark.sql.connect.client.reattach import \ - ExecutePlanResponseReattachableIterator -from pyspark.sql.connect.dataframe import DataFrame - -Result = namedtuple( - 'Result', ['url', 'row_count', 'compressed_size', 'uncompressed_size' - ]) # pyright: ignore - - -# Monkey Patching for SparkConnectClient -def to_cf(self: SparkConnectClient, plan: pb2.Plan, type: str = 'json'): - """Executes plan object return as cloud fetch presigned URLS. - - It can handle the current outptu formats that are supported by the server. - In contrast to the regular API methods of the client, this method does not - return the schema and drops all other responses. - """ - req = self._execute_plan_request_with_metadata() - req.plan.CopyFrom(plan) - - # Add the request options - if type == 'json': - format = cloud_pb2.ResultOptions.CloudOptions.FORMAT_JSON - elif type == 'csv': - format = cloud_pb2.ResultOptions.CloudOptions.FORMAT_CSV - elif type == 'arrow': - format = cloud_pb2.ResultOptions.CloudOptions.FORMAT_ARROW - else: - raise Exception('Invalid type') - - ro = cloud_pb2.ResultOptions( - type=cloud_pb2.ResultOptions.TYPE_CLOUD, - cloudOptions=cloud_pb2.ResultOptions.CloudOptions( - format=format, - useCompression=True, - )) - cloud_option = any_pb2.Any() - cloud_option.Pack(ro) - req.request_options.append( - pb2.ExecutePlanRequest.RequestOption(extension=cloud_option)) - - # Create the iterator - iterator = ExecutePlanResponseReattachableIterator(req, self._stub, - self._retry_policy, - self._builder.metadata()) - # Iterate over the response - - result = [] - row_count = 0 - is_overflow = False - - for response in iterator: - if response.HasField('extension') and response.extension.Is( - cloud_pb2.CloudResultBatch.DESCRIPTOR): - batch = cloud_pb2.CloudResultBatch() - assert response.extension.Is(cloud_pb2.CloudResultBatch.DESCRIPTOR) - response.extension.Unpack(batch) - result += [ - Result(b.url, b.row_count, b.compressed_size, - b.uncompressed_size) for b in batch.results - ] - row_count += sum(result.row_count for result in batch.results) - is_overflow |= batch.truncated - return result, row_count, is_overflow - - -SparkConnectClient.to_cf = to_cf # pyright: ignore - -# Monkey Patching for DataFrame - - -def collect_as_cf(self: DataFrame, - type: str = 'json') -> Tuple[List[Result], int, bool]: - query = self._plan.to_proto(self._session.client) # pyright: ignore - return self._session.client.to_cf(query, type) # pyright: ignore - - -DataFrame.collect_cf = collect_as_cf # pyright: ignore diff --git a/tests/a_scripts/data_prep/test_convert_delta_to_json.py b/tests/a_scripts/data_prep/test_convert_delta_to_json.py index 54ce4b44be..8c6c2ad9bb 100644 --- a/tests/a_scripts/data_prep/test_convert_delta_to_json.py +++ b/tests/a_scripts/data_prep/test_convert_delta_to_json.py @@ -4,7 +4,6 @@ # copyright 2022 mosaicml llm foundry authors # spdx-license-identifier: apache-2.0 -import unittest from argparse import Namespace from typing import Any from unittest.mock import MagicMock, patch @@ -54,7 +53,3 @@ def test_stream_delta_to_json(self, mock_to_json: Any, mock_connect: Any): mock_to_json.assert_called() mock_cursor.close.assert_called() mock_connection.close.assert_called() - - -if __name__ == '__main__': - unittest.main() From 64214de14b5a8900e259c849b864796679761cef Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Tue, 12 Dec 2023 22:39:59 -0800 Subject: [PATCH 26/62] update --- scripts/data_prep/convert_delta_to_json.py | 2 -- tests/a_scripts/data_prep/test_convert_delta_to_json.py | 3 +++ 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index 109624c57b..826217a29a 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -16,8 +16,6 @@ import google.protobuf.any_pb2 as any_pb2 import lz4.frame import pandas as pd -# Monkey Patching for SparkConnectClient -import patch # pyright: ignore import pyarrow as pa import pyspark.sql.connect.proto as pb2 # PB2 stuff diff --git a/tests/a_scripts/data_prep/test_convert_delta_to_json.py b/tests/a_scripts/data_prep/test_convert_delta_to_json.py index 8c6c2ad9bb..227a93f4a8 100644 --- a/tests/a_scripts/data_prep/test_convert_delta_to_json.py +++ b/tests/a_scripts/data_prep/test_convert_delta_to_json.py @@ -53,3 +53,6 @@ def test_stream_delta_to_json(self, mock_to_json: Any, mock_connect: Any): mock_to_json.assert_called() mock_cursor.close.assert_called() mock_connection.close.assert_called() + +if __name__ == '__main__': + unittest.main() From 13ce55b1eda4e1a8a3e66f82b1f61d4677c7de86 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Tue, 12 Dec 2023 22:55:43 -0800 Subject: [PATCH 27/62] update --- tests/a_scripts/data_prep/test_convert_delta_to_json.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/a_scripts/data_prep/test_convert_delta_to_json.py b/tests/a_scripts/data_prep/test_convert_delta_to_json.py index 227a93f4a8..b10fe327f2 100644 --- a/tests/a_scripts/data_prep/test_convert_delta_to_json.py +++ b/tests/a_scripts/data_prep/test_convert_delta_to_json.py @@ -4,6 +4,7 @@ # copyright 2022 mosaicml llm foundry authors # spdx-license-identifier: apache-2.0 +import unittest from argparse import Namespace from typing import Any from unittest.mock import MagicMock, patch From 5edc497248749726ebfabfc880cb131c39973ed0 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Tue, 12 Dec 2023 23:02:58 -0800 Subject: [PATCH 28/62] updat --- scripts/data_prep/convert_delta_to_json.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index 826217a29a..ea330a9f1a 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -141,7 +141,7 @@ def read_json(file_path: str): def run_query(q: str, method: str, - cursor: Optional[sql.client.Cursor] = None, + cursor: Optional[sql.Client.Cursor] = None, spark: Optional[SparkSession] = None, collect: bool = True) -> Optional[Union[List[Row], DataFrame]]: @@ -220,7 +220,7 @@ def download_arrow_starargs(args: Tuple): return download_arrow(*args) -def fetch_data(method: str, cursor: Optional[sql.client.Cursor], +def fetch_data(method: str, cursor: Optional[sql.Client.Cursor], sparkSession: Optional[SparkSession], s: int, e: int, order_by: str, tablename: str, columns_str: str, json_output_path: str): From ad062110a6215b9d699663966f7a4b26f041f842 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Tue, 12 Dec 2023 23:17:54 -0800 Subject: [PATCH 29/62] update --- scripts/data_prep/convert_delta_to_json.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index ea330a9f1a..2b3b5ac635 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -141,7 +141,7 @@ def read_json(file_path: str): def run_query(q: str, method: str, - cursor: Optional[sql.Client.Cursor] = None, + cursor: Optional[sql.client.Cursor] = None, spark: Optional[SparkSession] = None, collect: bool = True) -> Optional[Union[List[Row], DataFrame]]: @@ -220,7 +220,7 @@ def download_arrow_starargs(args: Tuple): return download_arrow(*args) -def fetch_data(method: str, cursor: Optional[sql.Client.Cursor], +def fetch_data(method: str, cursor: Optional[sql.client.Cursor], sparkSession: Optional[SparkSession], s: int, e: int, order_by: str, tablename: str, columns_str: str, json_output_path: str): @@ -259,7 +259,7 @@ def fetch( batch_size: int = 1 << 30, partitions: int = 1, sparkSession: Optional[SparkSession] = None, - dbsql: Optional[sql.Client.Connection] = None, + dbsql: Optional[sql.client.Connection] = None, ): """Fetch UC delta table with databricks-connnect as JSONL. From 7a18e96ac82d843b86e4f61982435a080ed21e7c Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Tue, 12 Dec 2023 23:55:08 -0800 Subject: [PATCH 30/62] fix tests --- scripts/data_prep/convert_delta_to_json.py | 9 ++- .../data_prep/test_convert_delta_to_json.py | 57 +++++++------------ 2 files changed, 26 insertions(+), 40 deletions(-) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index 2b3b5ac635..a3e3d19969 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -21,6 +21,7 @@ # PB2 stuff import pyspark.sql.connect.proto.cloud_pb2 as cloud_pb2 import requests +import databricks from databricks import sql from databricks.connect import DatabricksSession from databricks.sdk import WorkspaceClient @@ -32,6 +33,8 @@ from pyspark.sql.connect.dataframe import DataFrame from pyspark.sql.dataframe import DataFrame from pyspark.sql.types import Row +from databricks.sql.client import Connection as Connection +from databricks.sql.client import Cursor as Cursor MINIUM_DBR_VERSION = '14.1.0' @@ -141,7 +144,7 @@ def read_json(file_path: str): def run_query(q: str, method: str, - cursor: Optional[sql.client.Cursor] = None, + cursor: Optional[Cursor] = None, spark: Optional[SparkSession] = None, collect: bool = True) -> Optional[Union[List[Row], DataFrame]]: @@ -220,7 +223,7 @@ def download_arrow_starargs(args: Tuple): return download_arrow(*args) -def fetch_data(method: str, cursor: Optional[sql.client.Cursor], +def fetch_data(method: str, cursor: Optional[Cursor], sparkSession: Optional[SparkSession], s: int, e: int, order_by: str, tablename: str, columns_str: str, json_output_path: str): @@ -259,7 +262,7 @@ def fetch( batch_size: int = 1 << 30, partitions: int = 1, sparkSession: Optional[SparkSession] = None, - dbsql: Optional[sql.client.Connection] = None, + dbsql: Optional[Connection] = None, ): """Fetch UC delta table with databricks-connnect as JSONL. diff --git a/tests/a_scripts/data_prep/test_convert_delta_to_json.py b/tests/a_scripts/data_prep/test_convert_delta_to_json.py index b10fe327f2..67ca22e6c7 100644 --- a/tests/a_scripts/data_prep/test_convert_delta_to_json.py +++ b/tests/a_scripts/data_prep/test_convert_delta_to_json.py @@ -15,45 +15,28 @@ class TestStreamDeltaToJson(): @patch('scripts.data_prep.convert_delta_to_json.sql.connect') - @patch('scripts.data_prep.convert_delta_to_json.pd.DataFrame.to_json') - def test_stream_delta_to_json(self, mock_to_json: Any, mock_connect: Any): - - # Mock database connection and cursor - mock_cursor = MagicMock() - mock_connection = MagicMock() - mock_connection.cursor.return_value = mock_cursor - mock_connect.return_value = mock_connection - - # Mock fetchall response - count_response = MagicMock() - count_response.asDict.return_value = {'COUNT(*)': 3} - column_response_item = MagicMock() - column_response_item.asDict.return_value = { - 'COLUMN_NAME': 'name' - } # Assuming SHOW COLUMNS query returns this format - data_response_item = MagicMock() - data_response_item.asDict.return_value = { - 'name': 'test', - 'id': 1 - } # Assuming SELECT query returns this format - mock_cursor.fetchall.side_effect = [[count_response], - [column_response_item], - [data_response_item]] - - args = Namespace(DATABRICKS_HOST='test_host', - DATABRICKS_TOKEN='test_token', - http_path='test_http_path', - tablename='test_table', - json_output_path='test_output_path', - cluster_id='test_cluster_id') + @patch('scripts.data_prep.convert_delta_to_json.os.makedirs') + @patch('scripts.data_prep.convert_delta_to_json.iterative_combine_jsons') + @patch('scripts.data_prep.convert_delta_to_json.fetch') + def test_stream_delta_to_json(self, mock_fetch, mock_combine_jsons, mock_makedirs, mock_sql_connect): + + args = MagicMock() + args.delta_table_name = 'test_table' + args.json_output_path = '/path/to/json' + args.DATABRICKS_HOST = 'test_host' + args.DATABRICKS_TOKEN = 'test_token' + args.http_path = 'test_path' + args.batch_size = 1000 + args.partitions = 1 + args.cluster_id = None + args.debug = False fetch_DT(args) - mock_connect.assert_called_once_with(server_hostname='test_host', - http_path='test_http_path', - access_token='test_token') - mock_to_json.assert_called() - mock_cursor.close.assert_called() - mock_connection.close.assert_called() + mock_sql_connect.assert_called_once_with(server_hostname='test_host', http_path='test_path', access_token='test_token') + mock_makedirs.assert_called_once_with('/path/to/json', exist_ok=True) + mock_fetch.assert_called_once() + mock_combine_jsons.assert_called_once_with('/path/to/json', '/path/to/json/combined.jsonl') + if __name__ == '__main__': unittest.main() From 89e658a5a15146f756cd7722d82e8f99eb0dbb54 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Wed, 13 Dec 2023 00:08:36 -0800 Subject: [PATCH 31/62] fix lint --- .../data_prep/test_convert_delta_to_json.py | 47 +++++++++---------- 1 file changed, 22 insertions(+), 25 deletions(-) diff --git a/tests/a_scripts/data_prep/test_convert_delta_to_json.py b/tests/a_scripts/data_prep/test_convert_delta_to_json.py index 67ca22e6c7..0c82fc5302 100644 --- a/tests/a_scripts/data_prep/test_convert_delta_to_json.py +++ b/tests/a_scripts/data_prep/test_convert_delta_to_json.py @@ -11,31 +11,28 @@ from scripts.data_prep.convert_delta_to_json import fetch_DT - -class TestStreamDeltaToJson(): - - @patch('scripts.data_prep.convert_delta_to_json.sql.connect') - @patch('scripts.data_prep.convert_delta_to_json.os.makedirs') - @patch('scripts.data_prep.convert_delta_to_json.iterative_combine_jsons') - @patch('scripts.data_prep.convert_delta_to_json.fetch') - def test_stream_delta_to_json(self, mock_fetch, mock_combine_jsons, mock_makedirs, mock_sql_connect): - - args = MagicMock() - args.delta_table_name = 'test_table' - args.json_output_path = '/path/to/json' - args.DATABRICKS_HOST = 'test_host' - args.DATABRICKS_TOKEN = 'test_token' - args.http_path = 'test_path' - args.batch_size = 1000 - args.partitions = 1 - args.cluster_id = None - args.debug = False - - fetch_DT(args) - mock_sql_connect.assert_called_once_with(server_hostname='test_host', http_path='test_path', access_token='test_token') - mock_makedirs.assert_called_once_with('/path/to/json', exist_ok=True) - mock_fetch.assert_called_once() - mock_combine_jsons.assert_called_once_with('/path/to/json', '/path/to/json/combined.jsonl') +@patch('scripts.data_prep.convert_delta_to_json.sql.connect') +@patch('scripts.data_prep.convert_delta_to_json.os.makedirs') +@patch('scripts.data_prep.convert_delta_to_json.iterative_combine_jsons') +@patch('scripts.data_prep.convert_delta_to_json.fetch') +def test_stream_delta_to_json(self, mock_fetch: Any, mock_combine_jsons: Any, mock_makedirs: Any, mock_sql_connect: Any): + + args = MagicMock() + args.delta_table_name = 'test_table' + args.json_output_path = '/path/to/json' + args.DATABRICKS_HOST = 'test_host' + args.DATABRICKS_TOKEN = 'test_token' + args.http_path = 'test_path' + args.batch_size = 1000 + args.partitions = 1 + args.cluster_id = None + args.debug = False + + fetch_DT(args) + mock_sql_connect.assert_called_once_with(server_hostname='test_host', http_path='test_path', access_token='test_token') + mock_makedirs.assert_called_once_with('/path/to/json', exist_ok=True) + mock_fetch.assert_called_once() + mock_combine_jsons.assert_called_once_with('/path/to/json', '/path/to/json/combined.jsonl') if __name__ == '__main__': From acef77c2ea15aef407fd67a06377c01c65b853fa Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Wed, 13 Dec 2023 00:09:24 -0800 Subject: [PATCH 32/62] update --- scripts/data_prep/convert_delta_to_json.py | 6 +++--- .../data_prep/test_convert_delta_to_json.py | 12 ++++++++---- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index a3e3d19969..c2d9876e6d 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -13,6 +13,7 @@ from typing import List, Optional, Tuple, Union from uuid import uuid4 +import databricks import google.protobuf.any_pb2 as any_pb2 import lz4.frame import pandas as pd @@ -21,10 +22,11 @@ # PB2 stuff import pyspark.sql.connect.proto.cloud_pb2 as cloud_pb2 import requests -import databricks from databricks import sql from databricks.connect import DatabricksSession from databricks.sdk import WorkspaceClient +from databricks.sql.client import Connection as Connection +from databricks.sql.client import Cursor as Cursor from packaging import version from pyspark.sql import SparkSession from pyspark.sql.connect.client.core import SparkConnectClient @@ -33,8 +35,6 @@ from pyspark.sql.connect.dataframe import DataFrame from pyspark.sql.dataframe import DataFrame from pyspark.sql.types import Row -from databricks.sql.client import Connection as Connection -from databricks.sql.client import Cursor as Cursor MINIUM_DBR_VERSION = '14.1.0' diff --git a/tests/a_scripts/data_prep/test_convert_delta_to_json.py b/tests/a_scripts/data_prep/test_convert_delta_to_json.py index 0c82fc5302..6cd6799dcf 100644 --- a/tests/a_scripts/data_prep/test_convert_delta_to_json.py +++ b/tests/a_scripts/data_prep/test_convert_delta_to_json.py @@ -5,17 +5,18 @@ # spdx-license-identifier: apache-2.0 import unittest -from argparse import Namespace from typing import Any from unittest.mock import MagicMock, patch from scripts.data_prep.convert_delta_to_json import fetch_DT + @patch('scripts.data_prep.convert_delta_to_json.sql.connect') @patch('scripts.data_prep.convert_delta_to_json.os.makedirs') @patch('scripts.data_prep.convert_delta_to_json.iterative_combine_jsons') @patch('scripts.data_prep.convert_delta_to_json.fetch') -def test_stream_delta_to_json(self, mock_fetch: Any, mock_combine_jsons: Any, mock_makedirs: Any, mock_sql_connect: Any): +def test_stream_delta_to_json(mock_fetch: Any, mock_combine_jsons: Any, + mock_makedirs: Any, mock_sql_connect: Any): args = MagicMock() args.delta_table_name = 'test_table' @@ -29,10 +30,13 @@ def test_stream_delta_to_json(self, mock_fetch: Any, mock_combine_jsons: Any, mo args.debug = False fetch_DT(args) - mock_sql_connect.assert_called_once_with(server_hostname='test_host', http_path='test_path', access_token='test_token') + mock_sql_connect.assert_called_once_with(server_hostname='test_host', + http_path='test_path', + access_token='test_token') mock_makedirs.assert_called_once_with('/path/to/json', exist_ok=True) mock_fetch.assert_called_once() - mock_combine_jsons.assert_called_once_with('/path/to/json', '/path/to/json/combined.jsonl') + mock_combine_jsons.assert_called_once_with('/path/to/json', + '/path/to/json/combined.jsonl') if __name__ == '__main__': From 3e1b55f0a207fea8cabe1d4caf323b04d335ef4c Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Wed, 13 Dec 2023 00:10:01 -0800 Subject: [PATCH 33/62] update --- scripts/data_prep/convert_delta_to_json.py | 1 - 1 file changed, 1 deletion(-) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index c2d9876e6d..d098ff58a9 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -13,7 +13,6 @@ from typing import List, Optional, Tuple, Union from uuid import uuid4 -import databricks import google.protobuf.any_pb2 as any_pb2 import lz4.frame import pandas as pd From 6bd32a5ca51e6a6bd8489103db1a606f17c69202 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Wed, 13 Dec 2023 07:53:09 -0800 Subject: [PATCH 34/62] Add more tests --- .../data_prep/test_convert_delta_to_json.py | 147 ++++++++++++++---- 1 file changed, 116 insertions(+), 31 deletions(-) diff --git a/tests/a_scripts/data_prep/test_convert_delta_to_json.py b/tests/a_scripts/data_prep/test_convert_delta_to_json.py index 6cd6799dcf..c61758b97b 100644 --- a/tests/a_scripts/data_prep/test_convert_delta_to_json.py +++ b/tests/a_scripts/data_prep/test_convert_delta_to_json.py @@ -6,37 +6,122 @@ import unittest from typing import Any -from unittest.mock import MagicMock, patch - -from scripts.data_prep.convert_delta_to_json import fetch_DT - - -@patch('scripts.data_prep.convert_delta_to_json.sql.connect') -@patch('scripts.data_prep.convert_delta_to_json.os.makedirs') -@patch('scripts.data_prep.convert_delta_to_json.iterative_combine_jsons') -@patch('scripts.data_prep.convert_delta_to_json.fetch') -def test_stream_delta_to_json(mock_fetch: Any, mock_combine_jsons: Any, - mock_makedirs: Any, mock_sql_connect: Any): - - args = MagicMock() - args.delta_table_name = 'test_table' - args.json_output_path = '/path/to/json' - args.DATABRICKS_HOST = 'test_host' - args.DATABRICKS_TOKEN = 'test_token' - args.http_path = 'test_path' - args.batch_size = 1000 - args.partitions = 1 - args.cluster_id = None - args.debug = False - - fetch_DT(args) - mock_sql_connect.assert_called_once_with(server_hostname='test_host', - http_path='test_path', - access_token='test_token') - mock_makedirs.assert_called_once_with('/path/to/json', exist_ok=True) - mock_fetch.assert_called_once() - mock_combine_jsons.assert_called_once_with('/path/to/json', - '/path/to/json/combined.jsonl') +from unittest.mock import MagicMock, mock_open, patch + +from scripts.data_prep.convert_delta_to_json import (download_json, fetch_DT, + iterative_combine_jsons, + run_query) + + +class TestConverDeltaToJsonl(unittest.TestCase): + + @patch('scripts.data_prep.convert_delta_to_json.sql.connect') + @patch('scripts.data_prep.convert_delta_to_json.os.makedirs') + @patch('scripts.data_prep.convert_delta_to_json.iterative_combine_jsons') + @patch('scripts.data_prep.convert_delta_to_json.fetch') + def test_stream_delta_to_json(self, mock_fetch: Any, + mock_combine_jsons: Any, mock_makedirs: Any, + mock_sql_connect: Any): + + args = MagicMock() + args.delta_table_name = 'test_table' + args.json_output_path = '/path/to/json' + args.DATABRICKS_HOST = 'test_host' + args.DATABRICKS_TOKEN = 'test_token' + args.http_path = 'test_path' + args.batch_size = 1000 + args.partitions = 1 + args.cluster_id = None + args.debug = False + + fetch_DT(args) + mock_sql_connect.assert_called_once_with(server_hostname='test_host', + http_path='test_path', + access_token='test_token') + mock_makedirs.assert_called_once_with('/path/to/json', exist_ok=True) + mock_fetch.assert_called_once() + mock_combine_jsons.assert_called_once_with( + '/path/to/json', '/path/to/json/combined.jsonl') + + @patch('scripts.data_prep.convert_delta_to_json.os.listdir') + @patch('builtins.open', + new_callable=mock_open, + read_data='{"key": "value"}') + def test_iterative_combine_jsons(self, mock_file: Any, mock_listdir: Any): + mock_listdir.return_value = ['file1.json', 'file2.json'] + json_directory = '/fake/dir' + output_file = '/fake/output.jsonl' + + iterative_combine_jsons(json_directory, output_file) + + mock_listdir.assert_called_once_with(json_directory) + mock_file.assert_called() + """ + Diagnostic print + for call_args in mock_file().write.call_args_list: + print(call_args) + -------------------- + call('{') + call('"key"') + call(': ') + call('"value"') + call('}') + call('\n') + call('{') + call('"key"') + call(': ') + call('"value"') + call('}') + call('\n') + -------------------- + """ + self.assertEqual(mock_file().write.call_count, 12) + + @patch('scripts.data_prep.convert_delta_to_json.SparkSession') + def test_run_query_dbconnect(self, mock_spark: Any): + method = 'dbconnect' + mock_cursor = None + mock_spark.sql.return_value.collect.return_value = 'result' + + result = run_query('SELECT * FROM table', + method, + cursor=mock_cursor, + spark=mock_spark) + + mock_spark.sql.assert_called_once_with('SELECT * FROM table') + self.assertEqual(result, 'result') + + @patch('scripts.data_prep.convert_delta_to_json.Cursor') + def test_run_query_dbsql(self, mock_cursor: Any): + method = 'dbsql' + mock_cursor.fetchall.return_value = 'result' + mock_spark = None + + result = run_query('SELECT * FROM table', + method, + cursor=mock_cursor, + spark=mock_spark) + + mock_cursor.execute.assert_called_once_with('SELECT * FROM table') + self.assertEqual(result, 'result') + + @patch('scripts.data_prep.convert_delta_to_json.requests.get') + @patch('scripts.data_prep.convert_delta_to_json.pd.DataFrame.from_dict') + @patch('scripts.data_prep.convert_delta_to_json.os.path.join', + return_value='/fake/path/part_1.json') + @patch('scripts.data_prep.convert_delta_to_json.time.sleep' + ) # Mock sleep to speed up the test + def test_download_json_success(self, mock_sleep: Any, mock_join: Any, + mock_from_dict: Any, mock_get: Any): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {'data': 'test'} + mock_get.return_value = mock_response + + download_json(1, 'http://fakeurl.com/data', '/fake/path') + + mock_get.assert_called_once_with('http://fakeurl.com/data') + mock_from_dict.assert_called_once_with({'data': 'test'}) if __name__ == '__main__': From 720506376d5ac35cdfd8de0806d901a6b4cdc3b5 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Wed, 13 Dec 2023 10:08:16 -0800 Subject: [PATCH 35/62] update --- scripts/data_prep/convert_delta_to_json.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index d098ff58a9..7b83d0d9f8 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -35,7 +35,7 @@ from pyspark.sql.dataframe import DataFrame from pyspark.sql.types import Row -MINIUM_DBR_VERSION = '14.1.0' +MINIMUM_DBR_VERSION = '14.1.0' log = logging.getLogger(__name__) @@ -370,14 +370,14 @@ def fetch_DT(args: Namespace): 'x-databricks-session-id', session_id).getOrCreate() else: # IMPORTANT: make sure cluster has runtime newer than 14.1.0, the databricks-connect client version. - compute_id = args.cluster_id # "1115-130834-ms4m0yv" + compute_id = args.cluster_id # "1115-130834-ms4m0yv" - valid 14.1.0 w = WorkspaceClient() - res = w.clusters.get(cluster_id='0704-124501-tsc2fxq') + res = w.clusters.get(cluster_id=comput_id)# '0704-124501-tsc2fxq' - invalid 12.2.x runtime_version = res.spark_version.split('-scala')[0].replace( - 'x-snapshot', '0') + 'x-snapshot', '0').replace('x', '0') assert version.parse(runtime_version) >= version.parse( - MINIUM_DBR_VERSION - ), 'You need at least {MINIMUM_DBR_VERSION} to use Databricks-connect to read delta table for FT API' + MINIMUM_DBR_VERSION + ), f'You need at least {MINIMUM_DBR_VERSION} to use Databricks-connect to read delta table for FT API but got {res.spark_version}' sparkSession = DatabricksSession.builder.remote( host=args.DATABRICKS_HOST, token=args.DATABRICKS_TOKEN, From 284b52f5ae1a4bcd7bae05223467fcf02e44b1f1 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Wed, 13 Dec 2023 10:13:03 -0800 Subject: [PATCH 36/62] update --- scripts/data_prep/convert_delta_to_json.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index 7b83d0d9f8..c6e45070c2 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -372,7 +372,7 @@ def fetch_DT(args: Namespace): # IMPORTANT: make sure cluster has runtime newer than 14.1.0, the databricks-connect client version. compute_id = args.cluster_id # "1115-130834-ms4m0yv" - valid 14.1.0 w = WorkspaceClient() - res = w.clusters.get(cluster_id=comput_id)# '0704-124501-tsc2fxq' - invalid 12.2.x + res = w.clusters.get(cluster_id=compute_id)# '0704-124501-tsc2fxq' - invalid 12.2.x runtime_version = res.spark_version.split('-scala')[0].replace( 'x-snapshot', '0').replace('x', '0') assert version.parse(runtime_version) >= version.parse( From 6f8957a2ff0beb5e35cdcc4dff03889926c9782b Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Wed, 13 Dec 2023 10:13:54 -0800 Subject: [PATCH 37/62] update --- scripts/data_prep/convert_delta_to_json.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index c6e45070c2..39960fda79 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -35,7 +35,7 @@ from pyspark.sql.dataframe import DataFrame from pyspark.sql.types import Row -MINIMUM_DBR_VERSION = '14.1.0' +MINIMUM_DBR_VERSION = '14.0.0' log = logging.getLogger(__name__) From 7ce02e188d8d9f51f26b31dbcea2e3ff8beb70f9 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Wed, 13 Dec 2023 23:00:45 -0800 Subject: [PATCH 38/62] change to download_json --- scripts/data_prep/convert_delta_to_json.py | 38 ++++++++++--------- .../data_prep/test_convert_delta_to_json.py | 15 +++++--- 2 files changed, 31 insertions(+), 22 deletions(-) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index 39960fda79..324bbd246f 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -13,12 +13,12 @@ from typing import List, Optional, Tuple, Union from uuid import uuid4 +# PB2 stuff import google.protobuf.any_pb2 as any_pb2 import lz4.frame import pandas as pd import pyarrow as pa import pyspark.sql.connect.proto as pb2 -# PB2 stuff import pyspark.sql.connect.proto.cloud_pb2 as cloud_pb2 import requests from databricks import sql @@ -32,7 +32,7 @@ from pyspark.sql.connect.client.reattach import \ ExecutePlanResponseReattachableIterator from pyspark.sql.connect.dataframe import DataFrame -from pyspark.sql.dataframe import DataFrame +from pyspark.sql.dataframe import DataFrame as SparkDataFrame from pyspark.sql.types import Row MINIMUM_DBR_VERSION = '14.0.0' @@ -141,11 +141,13 @@ def read_json(file_path: str): print('JSON files have been combined into a JSONL file.') -def run_query(q: str, - method: str, - cursor: Optional[Cursor] = None, - spark: Optional[SparkSession] = None, - collect: bool = True) -> Optional[Union[List[Row], DataFrame]]: +def run_query( + q: str, + method: str, + cursor: Optional[Cursor] = None, + spark: Optional[SparkSession] = None, + collect: bool = True +) -> Optional[Union[List[Row], DataFrame, SparkDataFrame]]: assert method in ['dbsql', 'dbconnect'], f'Unrecognized method: {method}' if method == 'dbsql': @@ -166,23 +168,24 @@ def run_query(q: str, return None -def get_args(signed: List, json_output_path: str): +def get_args(signed: List, json_output_path: str, columns: List): for i, r in enumerate(signed): - yield (i, r.url, json_output_path) + yield (i, r.url, json_output_path, columns) def download_json(ipart: int, url: str, json_output_path: str, + columns: List, max_retry: int = 3): for attempt in range(max_retry): try: response = requests.get(url) if response.status_code == 200: data = response.json() - pd.DataFrame.from_dict(data).to_json( - os.path.join(json_output_path, - 'part_' + str(ipart) + '.json')) + pd.DataFrame(data, columns=columns).to_json(os.path.join( + json_output_path, 'part_' + str(ipart) + '.json'), + orient='records') break # Break the loop if the request is successful else: print( @@ -304,16 +307,16 @@ def fetch( tablename) # "main.tpcds_sf100_delta.store_sales") # Running the query and collecting the data as arrow. - signed, _, _ = df.collect_cf('arrow') # pyright: ignore + signed, _, _ = df.collect_cf('json') # pyright: ignore print(f'len(signed) = {len(signed)}') - args = get_args(signed, json_output_path) + args = get_args(signed, json_output_path, columns) # Stopping the SparkSession to avoid spilling connection state into the subprocesses. sparkSession.stop() with ProcessPoolExecutor(max_workers=partitions) as executor: - list(executor.map(download_arrow_starargs, args)) + list(executor.map(download_json_starargs, args)) elif method == 'dbsql' and cursor: for start in range(0, nrows, batch_size): @@ -372,7 +375,8 @@ def fetch_DT(args: Namespace): # IMPORTANT: make sure cluster has runtime newer than 14.1.0, the databricks-connect client version. compute_id = args.cluster_id # "1115-130834-ms4m0yv" - valid 14.1.0 w = WorkspaceClient() - res = w.clusters.get(cluster_id=compute_id)# '0704-124501-tsc2fxq' - invalid 12.2.x + res = w.clusters.get( + cluster_id=compute_id) # '0704-124501-tsc2fxq' - invalid 12.2.x runtime_version = res.spark_version.split('-scala')[0].replace( 'x-snapshot', '0').replace('x', '0') assert version.parse(runtime_version) >= version.parse( @@ -433,7 +437,7 @@ def fetch_DT(args: Namespace): help='number of partitions allowed to use') parser.add_argument( '--cluster_id', - required=True, + required=False, type=str, default=None, help= diff --git a/tests/a_scripts/data_prep/test_convert_delta_to_json.py b/tests/a_scripts/data_prep/test_convert_delta_to_json.py index c61758b97b..8eba3a9c7b 100644 --- a/tests/a_scripts/data_prep/test_convert_delta_to_json.py +++ b/tests/a_scripts/data_prep/test_convert_delta_to_json.py @@ -106,22 +106,27 @@ def test_run_query_dbsql(self, mock_cursor: Any): self.assertEqual(result, 'result') @patch('scripts.data_prep.convert_delta_to_json.requests.get') - @patch('scripts.data_prep.convert_delta_to_json.pd.DataFrame.from_dict') + @patch('scripts.data_prep.convert_delta_to_json.pd.DataFrame.to_json') @patch('scripts.data_prep.convert_delta_to_json.os.path.join', return_value='/fake/path/part_1.json') @patch('scripts.data_prep.convert_delta_to_json.time.sleep' ) # Mock sleep to speed up the test def test_download_json_success(self, mock_sleep: Any, mock_join: Any, - mock_from_dict: Any, mock_get: Any): + mock_to_json: Any, mock_get: Any): mock_response = MagicMock() mock_response.status_code = 200 - mock_response.json.return_value = {'data': 'test'} + mock_response.json.return_value = [['val1.1', 'val1.2'], + ['val2.1', 'val2.2']] mock_get.return_value = mock_response - download_json(1, 'http://fakeurl.com/data', '/fake/path') + download_json(1, 'http://fakeurl.com/data', '/fake/path', ['A', 'B']) + + mock_get.assert_called_with('http://fakeurl.com/data') + mock_join.assert_called_with('/fake/path', 'part_1.json') + mock_to_json.assert_called_with('/fake/path/part_1.json', + orient='records') mock_get.assert_called_once_with('http://fakeurl.com/data') - mock_from_dict.assert_called_once_with({'data': 'test'}) if __name__ == '__main__': From aba0f4db03a410a5034ea341130905eb85234594 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Wed, 13 Dec 2023 23:40:52 -0800 Subject: [PATCH 39/62] update --- scripts/data_prep/convert_delta_to_json.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index 324bbd246f..adfbfd9145 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -44,7 +44,10 @@ ]) # pyright: ignore -# Monkey Patching for SparkConnectClient +# This is a monkey patch on top of the DB Connect package that allows +# the client to fetch the results in different formats from the server. To be +# able to use the code make sure this module is not overriden by DB Connect classes. + def to_cf(self: SparkConnectClient, plan: pb2.Plan, type: str = 'json'): """Executes plan object return as cloud fetch presigned URLS. @@ -103,17 +106,11 @@ def to_cf(self: SparkConnectClient, plan: pb2.Plan, type: str = 'json'): SparkConnectClient.to_cf = to_cf # pyright: ignore -# This is a monkey patch on top of the DB Connect package that allows -# the client to fetch the results in different formats from the server. To be -# able to use the code make sure this module is not overriden by DB Connect classes. - - def collect_as_cf(self: DataFrame, type: str = 'json') -> Tuple[List[Result], int, bool]: query = self._plan.to_proto(self._session.client) # pyright: ignore return self._session.client.to_cf(query, type) # pyright: ignore - DataFrame.collect_cf = collect_as_cf # pyright: ignore @@ -224,7 +221,6 @@ def download_arrow(ipart: int, url: str, json_output_path: str): def download_arrow_starargs(args: Tuple): return download_arrow(*args) - def fetch_data(method: str, cursor: Optional[Cursor], sparkSession: Optional[SparkSession], s: int, e: int, order_by: str, tablename: str, columns_str: str, @@ -272,7 +268,7 @@ def fetch( method (str): dbconnect or dbsql tablename (str): catalog.scheme.tablename on UC json_output_path (str): path to write the result json file to - batch_size (int): number of rows that fetch each time to avoid OOM + batch_size (int): number of rows that dbsql fetches each time to avoid OOM partitions (int): max number of processes to use to parallelize the fetch sparkSession (pyspark.sql.sparksession): spark session dbsql (databricks.sql.connect): dbsql session @@ -437,7 +433,7 @@ def fetch_DT(args: Namespace): help='number of partitions allowed to use') parser.add_argument( '--cluster_id', - required=False, + required=True, type=str, default=None, help= From 040620f07c13e0025f6a0b4cb288500bacfbe701 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Thu, 14 Dec 2023 00:01:50 -0800 Subject: [PATCH 40/62] fix lints --- scripts/data_prep/convert_delta_to_json.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index adfbfd9145..1bb2098118 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -43,11 +43,11 @@ 'Result', ['url', 'row_count', 'compressed_size', 'uncompressed_size' ]) # pyright: ignore - # This is a monkey patch on top of the DB Connect package that allows # the client to fetch the results in different formats from the server. To be # able to use the code make sure this module is not overriden by DB Connect classes. + def to_cf(self: SparkConnectClient, plan: pb2.Plan, type: str = 'json'): """Executes plan object return as cloud fetch presigned URLS. @@ -106,11 +106,13 @@ def to_cf(self: SparkConnectClient, plan: pb2.Plan, type: str = 'json'): SparkConnectClient.to_cf = to_cf # pyright: ignore + def collect_as_cf(self: DataFrame, type: str = 'json') -> Tuple[List[Result], int, bool]: query = self._plan.to_proto(self._session.client) # pyright: ignore return self._session.client.to_cf(query, type) # pyright: ignore + DataFrame.collect_cf = collect_as_cf # pyright: ignore @@ -221,6 +223,7 @@ def download_arrow(ipart: int, url: str, json_output_path: str): def download_arrow_starargs(args: Tuple): return download_arrow(*args) + def fetch_data(method: str, cursor: Optional[Cursor], sparkSession: Optional[SparkSession], s: int, e: int, order_by: str, tablename: str, columns_str: str, From 51522f89b4b11e4040b90710f31244f553595256 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Fri, 15 Dec 2023 09:56:24 -0800 Subject: [PATCH 41/62] Add decompressed option for arrow --- scripts/data_prep/convert_delta_to_json.py | 58 +++++++++++----------- 1 file changed, 30 insertions(+), 28 deletions(-) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index 1bb2098118..1d58dc9ede 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -72,7 +72,7 @@ def to_cf(self: SparkConnectClient, plan: pb2.Plan, type: str = 'json'): type=cloud_pb2.ResultOptions.TYPE_CLOUD, cloudOptions=cloud_pb2.ResultOptions.CloudOptions( format=format, - useCompression=True, + useCompression=False, )) cloud_option = any_pb2.Any() cloud_option.Pack(ro) @@ -203,15 +203,20 @@ def download_json_starargs(args: Tuple): return download_json(*args) -def download_arrow(ipart: int, url: str, json_output_path: str): +def download_arrow(ipart: int, + url: str, + json_output_path: str, + compressed: bool = False): resp = requests.get(url) if resp.status_code == 200: - # The data is lz4 compressed arrow format. - # Decompress the data - decompressed_data = lz4.frame.decompress(resp.content) - - # Convert the decompressed data into a PyArrow table - reader = pa.ipc.open_stream(decompressed_data) + if compressed: + # The data is lz4 compressed arrow format. + # Decompress the data + decompressed_data = lz4.frame.decompress(resp.content) + # Convert the decompressed data into a PyArrow table + reader = pa.ipc.open_stream(decompressed_data) + else: + reader = pa.ipc.open_stream(resp.content) table = reader.read_all() # Convert the PyArrow table into a pandas DataFrame @@ -261,7 +266,7 @@ def fetch( tablename: str, json_output_path: str, batch_size: int = 1 << 30, - partitions: int = 1, + processes: int = 1, sparkSession: Optional[SparkSession] = None, dbsql: Optional[Connection] = None, ): @@ -272,7 +277,7 @@ def fetch( tablename (str): catalog.scheme.tablename on UC json_output_path (str): path to write the result json file to batch_size (int): number of rows that dbsql fetches each time to avoid OOM - partitions (int): max number of processes to use to parallelize the fetch + processes (int): max number of processes to use to parallelize the fetch sparkSession (pyspark.sql.sparksession): spark session dbsql (databricks.sql.connect): dbsql session """ @@ -301,7 +306,7 @@ def fetch( ) from e if method == 'dbconnect' and sparkSession: - print('partitions = ', partitions) + print('processes = ', processes) df = sparkSession.table( tablename) # "main.tpcds_sf100_delta.store_sales") @@ -314,7 +319,7 @@ def fetch( # Stopping the SparkSession to avoid spilling connection state into the subprocesses. sparkSession.stop() - with ProcessPoolExecutor(max_workers=partitions) as executor: + with ProcessPoolExecutor(max_workers=processes) as executor: list(executor.map(download_json_starargs, args)) elif method == 'dbsql' and cursor: @@ -387,7 +392,7 @@ def fetch_DT(args: Namespace): cluster_id=compute_id).getOrCreate() fetch(method, args.delta_table_name, args.json_output_path, args.batch_size, - args.partitions, sparkSession, dbsql) + args.processes, sparkSession, dbsql) if dbsql is not None: dbsql.close() @@ -401,11 +406,10 @@ def fetch_DT(args: Namespace): parser = ArgumentParser( description= 'Download delta table from UC and convert to json to save local') - parser.add_argument( - '--delta_table_name', - required=True, - type=str, - help='UC table of format ..
') + parser.add_argument('--delta_table_name', + required=True, + type=str, + help='UC table ..
') parser.add_argument('--json_output_path', required=True, type=str, @@ -418,22 +422,20 @@ def fetch_DT(args: Namespace): required=False, type=str, help='DATABRICKS_TOKEN') - parser.add_argument( - '--http_path', - required=False, - type=str, - help= - 'http_path from either dedicated cluster or serverless sql warehouse') + parser.add_argument('--http_path', + required=False, + type=str, + help='http_path is set then dbsql method is used') parser.add_argument('--batch_size', required=False, type=int, - default=1 << 20, - help='chunk of rows to transmit a time') - parser.add_argument('--partitions', + default=1 << 30, + help='row chunks to transmit a time to avoid OOM') + parser.add_argument('--processes', required=False, type=int, default=1, - help='number of partitions allowed to use') + help='number of processes allowed to use') parser.add_argument( '--cluster_id', required=True, From 7ae65e311d4729436385e853d1f1907a3cecb760 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Fri, 15 Dec 2023 13:22:29 -0800 Subject: [PATCH 42/62] format json to jsonl --- scripts/data_prep/convert_delta_to_json.py | 88 ++++++------------- .../data_prep/test_convert_delta_to_json.py | 8 +- 2 files changed, 31 insertions(+), 65 deletions(-) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index 1d58dc9ede..ea6e757cd5 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -117,26 +117,18 @@ def collect_as_cf(self: DataFrame, def iterative_combine_jsons(json_directory: str, output_file: str): - """Combine json files in json_directory into one big jsonl file. + """Combine jsonl files in json_directory into one big jsonl file. Args: json_directory(str): directory containing the JSON files output_file(str): output JSONL file """ - json_files = [f for f in os.listdir(json_directory) if f.endswith('.json')] - - def read_json(file_path: str): - with open(file_path, 'r', encoding='utf-8') as file: - return json.load(file) - - with open(output_file, 'w', encoding='utf-8') as outfile: - for json_file in json_files: - full_path = os.path.join(json_directory, json_file) - json_obj = read_json(full_path) - json.dump(json_obj, outfile, ensure_ascii=False) - outfile.write( - '\n') # Write a newline character after each JSON object - + json_files = [f for f in os.listdir(json_directory) if f.endswith('.jsonl')] + with open(output_file, 'w') as outfile: + for file_name in json_files: + with open(os.path.join(json_directory, file_name), 'r') as infile: + for line in infile: + outfile.write(line) print('JSON files have been combined into a JSONL file.') @@ -171,44 +163,20 @@ def get_args(signed: List, json_output_path: str, columns: List): for i, r in enumerate(signed): yield (i, r.url, json_output_path, columns) - -def download_json(ipart: int, - url: str, - json_output_path: str, - columns: List, - max_retry: int = 3): - for attempt in range(max_retry): - try: - response = requests.get(url) - if response.status_code == 200: - data = response.json() - pd.DataFrame(data, columns=columns).to_json(os.path.join( - json_output_path, 'part_' + str(ipart) + '.json'), - orient='records') - break # Break the loop if the request is successful - else: - print( - f'Attempt {attempt + 1} failed with status code {response.status_code}' - ) - except requests.RequestException as e: - print(f'An error occurred: {e}') - - time.sleep(random.randint(1, 5)) - - if attempt == max_retry - 1: - raise RuntimeError(f'Failed to download after {max_retry} attempts') - - -def download_json_starargs(args: Tuple): - return download_json(*args) - - -def download_arrow(ipart: int, - url: str, - json_output_path: str, - compressed: bool = False): +def download(ipart: int, + url: str, + json_output_path: str, + columns: Optional[List] = None, + resp_format: str = "arrow", + compressed: bool = False): resp = requests.get(url) if resp.status_code == 200: + if resp_format == "json": + data = resp.json() + pd.DataFrame(data, columns=columns).to_json(os.path.join( + json_output_path, 'part_' + str(ipart) + '.jsonl'), orient='records', lines=True) + return + if compressed: # The data is lz4 compressed arrow format. # Decompress the data @@ -217,17 +185,15 @@ def download_arrow(ipart: int, reader = pa.ipc.open_stream(decompressed_data) else: reader = pa.ipc.open_stream(resp.content) + print("I am here") table = reader.read_all() # Convert the PyArrow table into a pandas DataFrame df = table.to_pandas() - df.to_json( - os.path.join(json_output_path, 'part_' + str(ipart) + '.json')) - - -def download_arrow_starargs(args: Tuple): - return download_arrow(*args) + df.to_json(os.path.join(json_output_path, 'part_' + str(ipart) + '.jsonl'), orient='records', lines=True, force_ascii=False) +def download_starargs(args: Tuple): + return download(*args) def fetch_data(method: str, cursor: Optional[Cursor], sparkSession: Optional[SparkSession], s: int, e: int, @@ -258,7 +224,7 @@ def fetch_data(method: str, cursor: Optional[Cursor], records = [r.asDict() for r in ans] # pyright: ignore pdf = pd.DataFrame.from_dict(records) - pdf.to_json(os.path.join(json_output_path, f'part_{s+1}_{e}.json')) + pdf.to_json(os.path.join(json_output_path, f'part_{s+1}_{e}.jsonl')) def fetch( @@ -310,8 +276,8 @@ def fetch( df = sparkSession.table( tablename) # "main.tpcds_sf100_delta.store_sales") - # Running the query and collecting the data as arrow. - signed, _, _ = df.collect_cf('json') # pyright: ignore + # Running the query and collecting the data as arrow or json. + signed, _, _ = df.collect_cf('arrow') # pyright: ignore print(f'len(signed) = {len(signed)}') args = get_args(signed, json_output_path, columns) @@ -320,7 +286,7 @@ def fetch( sparkSession.stop() with ProcessPoolExecutor(max_workers=processes) as executor: - list(executor.map(download_json_starargs, args)) + list(executor.map(download_starargs, args)) elif method == 'dbsql' and cursor: for start in range(0, nrows, batch_size): diff --git a/tests/a_scripts/data_prep/test_convert_delta_to_json.py b/tests/a_scripts/data_prep/test_convert_delta_to_json.py index 8eba3a9c7b..072a16281b 100644 --- a/tests/a_scripts/data_prep/test_convert_delta_to_json.py +++ b/tests/a_scripts/data_prep/test_convert_delta_to_json.py @@ -8,7 +8,7 @@ from typing import Any from unittest.mock import MagicMock, mock_open, patch -from scripts.data_prep.convert_delta_to_json import (download_json, fetch_DT, +from scripts.data_prep.convert_delta_to_json import (download, fetch_DT, iterative_combine_jsons, run_query) @@ -111,7 +111,7 @@ def test_run_query_dbsql(self, mock_cursor: Any): return_value='/fake/path/part_1.json') @patch('scripts.data_prep.convert_delta_to_json.time.sleep' ) # Mock sleep to speed up the test - def test_download_json_success(self, mock_sleep: Any, mock_join: Any, + def test_download_success(self, mock_sleep: Any, mock_join: Any, mock_to_json: Any, mock_get: Any): mock_response = MagicMock() mock_response.status_code = 200 @@ -119,12 +119,12 @@ def test_download_json_success(self, mock_sleep: Any, mock_join: Any, ['val2.1', 'val2.2']] mock_get.return_value = mock_response - download_json(1, 'http://fakeurl.com/data', '/fake/path', ['A', 'B']) + download(1, 'http://fakeurl.com/data', '/fake/path', ['A', 'B'], resp_format='json') mock_get.assert_called_with('http://fakeurl.com/data') mock_join.assert_called_with('/fake/path', 'part_1.json') mock_to_json.assert_called_with('/fake/path/part_1.json', - orient='records') + orient='records', lines=True) mock_get.assert_called_once_with('http://fakeurl.com/data') From 7f027fc2ff26018c2444261689464c1162c7a2c8 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Sat, 16 Dec 2023 15:31:55 -0800 Subject: [PATCH 43/62] Add comments --- scripts/data_prep/convert_delta_to_json.py | 58 ++++++++++++++----- .../data_prep/test_convert_delta_to_json.py | 30 +++++----- 2 files changed, 57 insertions(+), 31 deletions(-) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index ea6e757cd5..739039b4b1 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -1,16 +1,14 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -import json import logging import os -import random import time import urllib.parse from argparse import ArgumentParser, Namespace from collections import namedtuple from concurrent.futures import ProcessPoolExecutor -from typing import List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Tuple, Union from uuid import uuid4 # PB2 stuff @@ -48,7 +46,9 @@ # able to use the code make sure this module is not overriden by DB Connect classes. -def to_cf(self: SparkConnectClient, plan: pb2.Plan, type: str = 'json'): +def to_cf(self: SparkConnectClient, + plan: pb2.Plan, + type: str = 'json') -> Tuple[List[Result], int, bool]: """Executes plan object return as cloud fetch presigned URLS. It can handle the current outptu formats that are supported by the server. @@ -116,7 +116,7 @@ def collect_as_cf(self: DataFrame, DataFrame.collect_cf = collect_as_cf # pyright: ignore -def iterative_combine_jsons(json_directory: str, output_file: str): +def iterative_combine_jsons(json_directory: str, output_file: str) -> None: """Combine jsonl files in json_directory into one big jsonl file. Args: @@ -139,7 +139,15 @@ def run_query( spark: Optional[SparkSession] = None, collect: bool = True ) -> Optional[Union[List[Row], DataFrame, SparkDataFrame]]: + """Run SQL query via databricks-connect or databricks-sql. + Args: + q (str): sql query + method (str): select from dbsql and dbconnect + cursor (Cursor): connection.cursor + spark (SparkSession): spark session + collect (bool): whether to get the underlying data from spark dataframe + """ assert method in ['dbsql', 'dbconnect'], f'Unrecognized method: {method}' if method == 'dbsql': if cursor is None: @@ -159,22 +167,35 @@ def run_query( return None -def get_args(signed: List, json_output_path: str, columns: List): +def get_args(signed: List, json_output_path: str, columns: List) -> Iterable: for i, r in enumerate(signed): yield (i, r.url, json_output_path, columns) + def download(ipart: int, url: str, json_output_path: str, columns: Optional[List] = None, - resp_format: str = "arrow", - compressed: bool = False): + resp_format: str = 'arrow', + compressed: bool = False) -> None: + """Thread download presigned url and save to jsonl locally. + + Args: + ipart (int): presigned url id + url (str): presigned url + json_output_path (str): directory to save the ipart_th segment of dataframe + columns (list): schema to save to json + resp_format (str): whether to use arrow or json when collect + compressed (bool): if data is compressed before downloading. Need decompress if compressed=True. + """ resp = requests.get(url) if resp.status_code == 200: - if resp_format == "json": + if resp_format == 'json': data = resp.json() pd.DataFrame(data, columns=columns).to_json(os.path.join( - json_output_path, 'part_' + str(ipart) + '.jsonl'), orient='records', lines=True) + json_output_path, 'part_' + str(ipart) + '.jsonl'), + orient='records', + lines=True) return if compressed: @@ -185,20 +206,25 @@ def download(ipart: int, reader = pa.ipc.open_stream(decompressed_data) else: reader = pa.ipc.open_stream(resp.content) - print("I am here") table = reader.read_all() # Convert the PyArrow table into a pandas DataFrame df = table.to_pandas() - df.to_json(os.path.join(json_output_path, 'part_' + str(ipart) + '.jsonl'), orient='records', lines=True, force_ascii=False) + df.to_json(os.path.join(json_output_path, + 'part_' + str(ipart) + '.jsonl'), + orient='records', + lines=True, + force_ascii=False) -def download_starargs(args: Tuple): + +def download_starargs(args: Tuple) -> None: return download(*args) + def fetch_data(method: str, cursor: Optional[Cursor], sparkSession: Optional[SparkSession], s: int, e: int, order_by: str, tablename: str, columns_str: str, - json_output_path: str): + json_output_path: str) -> None: query = f""" WITH NumberedRows AS ( SELECT @@ -235,7 +261,7 @@ def fetch( processes: int = 1, sparkSession: Optional[SparkSession] = None, dbsql: Optional[Connection] = None, -): +) -> None: """Fetch UC delta table with databricks-connnect as JSONL. Args: @@ -298,7 +324,7 @@ def fetch( cursor.close() -def fetch_DT(args: Namespace): +def fetch_DT(args: Namespace) -> None: """Fetch UC Delta Table to local as jsonl.""" log.info(f'Start .... Convert delta to json') diff --git a/tests/a_scripts/data_prep/test_convert_delta_to_json.py b/tests/a_scripts/data_prep/test_convert_delta_to_json.py index 072a16281b..532bf22164 100644 --- a/tests/a_scripts/data_prep/test_convert_delta_to_json.py +++ b/tests/a_scripts/data_prep/test_convert_delta_to_json.py @@ -25,7 +25,7 @@ def test_stream_delta_to_json(self, mock_fetch: Any, args = MagicMock() args.delta_table_name = 'test_table' - args.json_output_path = '/path/to/json' + args.json_output_path = '/path/to/jsonl' args.DATABRICKS_HOST = 'test_host' args.DATABRICKS_TOKEN = 'test_token' args.http_path = 'test_path' @@ -38,17 +38,17 @@ def test_stream_delta_to_json(self, mock_fetch: Any, mock_sql_connect.assert_called_once_with(server_hostname='test_host', http_path='test_path', access_token='test_token') - mock_makedirs.assert_called_once_with('/path/to/json', exist_ok=True) + mock_makedirs.assert_called_once_with('/path/to/jsonl', exist_ok=True) mock_fetch.assert_called_once() mock_combine_jsons.assert_called_once_with( - '/path/to/json', '/path/to/json/combined.jsonl') + '/path/to/jsonl', '/path/to/jsonl/combined.jsonl') @patch('scripts.data_prep.convert_delta_to_json.os.listdir') @patch('builtins.open', new_callable=mock_open, read_data='{"key": "value"}') def test_iterative_combine_jsons(self, mock_file: Any, mock_listdir: Any): - mock_listdir.return_value = ['file1.json', 'file2.json'] + mock_listdir.return_value = ['file1.jsonl', 'file2.jsonl'] json_directory = '/fake/dir' output_file = '/fake/output.jsonl' @@ -75,7 +75,7 @@ def test_iterative_combine_jsons(self, mock_file: Any, mock_listdir: Any): call('\n') -------------------- """ - self.assertEqual(mock_file().write.call_count, 12) + self.assertEqual(mock_file().write.call_count, 2) @patch('scripts.data_prep.convert_delta_to_json.SparkSession') def test_run_query_dbconnect(self, mock_spark: Any): @@ -108,26 +108,26 @@ def test_run_query_dbsql(self, mock_cursor: Any): @patch('scripts.data_prep.convert_delta_to_json.requests.get') @patch('scripts.data_prep.convert_delta_to_json.pd.DataFrame.to_json') @patch('scripts.data_prep.convert_delta_to_json.os.path.join', - return_value='/fake/path/part_1.json') + return_value='/fake/path/part_1.jsonl') @patch('scripts.data_prep.convert_delta_to_json.time.sleep' ) # Mock sleep to speed up the test def test_download_success(self, mock_sleep: Any, mock_join: Any, - mock_to_json: Any, mock_get: Any): + mock_to_json: Any, mock_get: Any): mock_response = MagicMock() mock_response.status_code = 200 mock_response.json.return_value = [['val1.1', 'val1.2'], ['val2.1', 'val2.2']] mock_get.return_value = mock_response - download(1, 'http://fakeurl.com/data', '/fake/path', ['A', 'B'], resp_format='json') + download(1, + 'http://fakeurl.com/data', + '/fake/path', ['A', 'B'], + resp_format='json') mock_get.assert_called_with('http://fakeurl.com/data') - mock_join.assert_called_with('/fake/path', 'part_1.json') - mock_to_json.assert_called_with('/fake/path/part_1.json', - orient='records', lines=True) + mock_join.assert_called_with('/fake/path', 'part_1.jsonl') + mock_to_json.assert_called_with('/fake/path/part_1.jsonl', + orient='records', + lines=True) mock_get.assert_called_once_with('http://fakeurl.com/data') - - -if __name__ == '__main__': - unittest.main() From 34c5b72f7e37510c5437e5b7786c181d79503001 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Sat, 16 Dec 2023 15:36:51 -0800 Subject: [PATCH 44/62] Make cf_collect_type global option --- scripts/data_prep/convert_delta_to_json.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index 739039b4b1..b82dd658b4 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -41,6 +41,8 @@ 'Result', ['url', 'row_count', 'compressed_size', 'uncompressed_size' ]) # pyright: ignore +cf_collect_type = 'arrow' # optionally change to json if arrow fails + # This is a monkey patch on top of the DB Connect package that allows # the client to fetch the results in different formats from the server. To be # able to use the code make sure this module is not overriden by DB Connect classes. @@ -167,9 +169,10 @@ def run_query( return None -def get_args(signed: List, json_output_path: str, columns: List) -> Iterable: +def get_args(signed: List, json_output_path: str, columns: List, + cf_collect_type: str) -> Iterable: for i, r in enumerate(signed): - yield (i, r.url, json_output_path, columns) + yield (i, r.url, json_output_path, columns, cf_collect_type) def download(ipart: int, @@ -306,7 +309,7 @@ def fetch( signed, _, _ = df.collect_cf('arrow') # pyright: ignore print(f'len(signed) = {len(signed)}') - args = get_args(signed, json_output_path, columns) + args = get_args(signed, json_output_path, columns, cf_collect_type) # Stopping the SparkSession to avoid spilling connection state into the subprocesses. sparkSession.stop() From 0e9c0296ef249f8678debbc34f8037e3a6dcc01e Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Sat, 16 Dec 2023 15:40:32 -0800 Subject: [PATCH 45/62] fix comments --- scripts/data_prep/convert_delta_to_json.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index b82dd658b4..82916419f5 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -150,7 +150,9 @@ def run_query( spark (SparkSession): spark session collect (bool): whether to get the underlying data from spark dataframe """ - assert method in ['dbsql', 'dbconnect'], f'Unrecognized method: {method}' + if method not in ['dbsql', 'dbconnect']: + raise ValueError(f'Unrecognized method: {method}') + if method == 'dbsql': if cursor is None: raise ValueError(f'cursor cannot be None if using method dbsql') @@ -378,9 +380,11 @@ def fetch_DT(args: Namespace) -> None: cluster_id=compute_id) # '0704-124501-tsc2fxq' - invalid 12.2.x runtime_version = res.spark_version.split('-scala')[0].replace( 'x-snapshot', '0').replace('x', '0') - assert version.parse(runtime_version) >= version.parse( - MINIMUM_DBR_VERSION - ), f'You need at least {MINIMUM_DBR_VERSION} to use Databricks-connect to read delta table for FT API but got {res.spark_version}' + if version.parse(runtime_version) < version.parse( + MINIMUM_DBR_VERSION): + raise RuntimeError( + f'You need at least {MINIMUM_DBR_VERSION} to use Databricks-connect to read delta table for FT API but got {res.spark_version}' + ) sparkSession = DatabricksSession.builder.remote( host=args.DATABRICKS_HOST, token=args.DATABRICKS_TOKEN, From 20076887553246c9bc46f36d22c94ab09379a075 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Thu, 21 Dec 2023 22:10:04 -0800 Subject: [PATCH 46/62] fix lints --- scripts/data_prep/convert_delta_to_json.py | 1 - 1 file changed, 1 deletion(-) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index 82916419f5..dfc65c9b77 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -11,7 +11,6 @@ from typing import Iterable, List, Optional, Tuple, Union from uuid import uuid4 -# PB2 stuff import google.protobuf.any_pb2 as any_pb2 import lz4.frame import pandas as pd From 10fd85a18314ec7f03f5fcde9b5191915dd5678b Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Fri, 22 Dec 2023 07:56:24 -0800 Subject: [PATCH 47/62] fix comments --- scripts/data_prep/convert_delta_to_json.py | 138 ++++++++++++++------- setup.py | 4 +- 2 files changed, 94 insertions(+), 48 deletions(-) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index dfc65c9b77..22eb58d017 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -32,19 +32,18 @@ from pyspark.sql.dataframe import DataFrame as SparkDataFrame from pyspark.sql.types import Row -MINIMUM_DBR_VERSION = '14.0.0' +MINIMUM_DBR_VERSION = '14.1.0' log = logging.getLogger(__name__) +log.setLevel(logging.INFO) Result = namedtuple( 'Result', ['url', 'row_count', 'compressed_size', 'uncompressed_size' ]) # pyright: ignore -cf_collect_type = 'arrow' # optionally change to json if arrow fails - -# This is a monkey patch on top of the DB Connect package that allows -# the client to fetch the results in different formats from the server. To be -# able to use the code make sure this module is not overriden by DB Connect classes. +# ``collect_as_cf`` is an addon new feature monkey patch on top of the DB Connect package. +# It allows the client to fetch the results in different formats from the server. +# To be able to use the code make sure this module is not overriden by DB Connect classes. def to_cf(self: SparkConnectClient, @@ -52,9 +51,19 @@ def to_cf(self: SparkConnectClient, type: str = 'json') -> Tuple[List[Result], int, bool]: """Executes plan object return as cloud fetch presigned URLS. - It can handle the current outptu formats that are supported by the server. + It can handle the current output formats that are supported by the server. In contrast to the regular API methods of the client, this method does not return the schema and drops all other responses. + + Args: + plan (pb2.Plan): The plan object to be executed by spark. + type (str): The output format of the result, supported formats are 'json', 'csv', and 'arrow'. + Returns: + Tuple[List[Result], int, bool]: A tuple containing: + - A list of Result namedtuples, each containing a URL, row count, compressed size, + and uncompressed size of the part of the result. + - Total row count of all parts of the result. + - A boolean indicating whether the result is truncated or overflowed. """ req = self._execute_plan_request_with_metadata() req.plan.CopyFrom(plan) @@ -67,7 +76,7 @@ def to_cf(self: SparkConnectClient, elif type == 'arrow': format = cloud_pb2.ResultOptions.CloudOptions.FORMAT_ARROW else: - raise Exception('Invalid type') + raise Exception(f'Only formats json, csv, and arrow are supported. Got invalid type {type}') ro = cloud_pb2.ResultOptions( type=cloud_pb2.ResultOptions.TYPE_CLOUD, @@ -85,7 +94,6 @@ def to_cf(self: SparkConnectClient, self._retry_policy, self._builder.metadata()) # Iterate over the response - result = [] row_count = 0 is_overflow = False @@ -94,7 +102,8 @@ def to_cf(self: SparkConnectClient, if response.HasField('extension') and response.extension.Is( cloud_pb2.CloudResultBatch.DESCRIPTOR): batch = cloud_pb2.CloudResultBatch() - assert response.extension.Is(cloud_pb2.CloudResultBatch.DESCRIPTOR) + if not response.extension.Is(cloud_pb2.CloudResultBatch.DESCRIPTOR): + raise ValueError("Response extension is not of type CloudResultBatch.") response.extension.Unpack(batch) result += [ Result(b.url, b.row_count, b.compressed_size, @@ -110,6 +119,22 @@ def to_cf(self: SparkConnectClient, def collect_as_cf(self: DataFrame, type: str = 'json') -> Tuple[List[Result], int, bool]: + """Collects the result of the DataFrame's execution plan as cloud fetch presigned URLs. + + This method is a wrapper around the `to_cf` method of SparkConnectClient. It takes the + execution plan of the current DataFrame, converts it to a protocol buffer format, and then + uses the `to_cf` method to execute the plan and fetch results as presigned URLs. + + Args: + type (str): The output format of the result, supported formats are 'json', 'csv', and 'arrow'. + + Returns: + Tuple[List[Result], int, bool]: A tuple containing: + - A list of Result namedtuples, each containing a URL, row count, compressed size, + and uncompressed size of the part of the result. + - Total row count of all parts of the result. + - A boolean indicating whether the result is truncated or overflowed. + """ query = self._plan.to_proto(self._session.client) # pyright: ignore return self._session.client.to_cf(query, type) # pyright: ignore @@ -120,9 +145,11 @@ def collect_as_cf(self: DataFrame, def iterative_combine_jsons(json_directory: str, output_file: str) -> None: """Combine jsonl files in json_directory into one big jsonl file. + This function does not work for nested subdirectories. + Args: - json_directory(str): directory containing the JSON files - output_file(str): output JSONL file + json_directory(str): directory containing the JSONL files + output_file(str): path to the output combined JSONL file """ json_files = [f for f in os.listdir(json_directory) if f.endswith('.jsonl')] with open(output_file, 'w') as outfile: @@ -130,50 +157,44 @@ def iterative_combine_jsons(json_directory: str, output_file: str) -> None: with open(os.path.join(json_directory, file_name), 'r') as infile: for line in infile: outfile.write(line) - print('JSON files have been combined into a JSONL file.') + log.info('JSON files have been combined into a JSONL file.') def run_query( - q: str, + query: str, method: str, cursor: Optional[Cursor] = None, spark: Optional[SparkSession] = None, collect: bool = True -) -> Optional[Union[List[Row], DataFrame, SparkDataFrame]]: +) -> Union[List[Row], DataFrame, SparkDataFrame]: """Run SQL query via databricks-connect or databricks-sql. Args: - q (str): sql query + query (str): sql query method (str): select from dbsql and dbconnect - cursor (Cursor): connection.cursor - spark (SparkSession): spark session + cursor (Optional[Cursor]): connection.cursor + spark (Optional[SparkSession]): spark session collect (bool): whether to get the underlying data from spark dataframe """ - if method not in ['dbsql', 'dbconnect']: - raise ValueError(f'Unrecognized method: {method}') - if method == 'dbsql': if cursor is None: raise ValueError(f'cursor cannot be None if using method dbsql') - cursor.execute(q) + cursor.execute(query) if collect: return cursor.fetchall() - elif method == 'dbconnect': if spark == None: raise ValueError(f'sparkSession is required for dbconnect') - df = spark.sql(q) + df = spark.sql(query) if collect: return df.collect() return df + else: + raise ValueError(f'Unrecognized method: {method}') - return None - - -def get_args(signed: List, json_output_path: str, columns: List, - cf_collect_type: str) -> Iterable: +def get_args(signed: List, json_output_path: str, columns: List) -> Iterable: for i, r in enumerate(signed): - yield (i, r.url, json_output_path, columns, cf_collect_type) + yield (i, r.url, json_output_path, columns) def download(ipart: int, @@ -202,6 +223,7 @@ def download(ipart: int, lines=True) return + # When resp_format is arrow: if compressed: # The data is lz4 compressed arrow format. # Decompress the data @@ -225,10 +247,34 @@ def download_starargs(args: Tuple) -> None: return download(*args) -def fetch_data(method: str, cursor: Optional[Cursor], - sparkSession: Optional[SparkSession], s: int, e: int, - order_by: str, tablename: str, columns_str: str, +def fetch_data(method: str, + cursor: Optional[Cursor], + sparkSession: Optional[SparkSession], + start: int, + end: int, + order_by: str, + tablename: str, + columns_str: str, json_output_path: str) -> None: + """Fetches a specified range of rows from a given table and writes the result to a json file. + + This function executes a SQL query to retrieve a range of rows, determined by 'start' and 'end' indexes, + from a specified table and column set. The fetched data is then exported as a JSON file. + + Args: + method (str): The method to use for fetching data, either 'dbconnect' or 'dbsql'. + cursor (Optional[Cursor]): The cursor object for executing queries in 'dbsql' method. + sparkSession (Optional[SparkSession]): The Spark session object for executing queries in 'dbconnect' method. + start (int): The starting index for row fetching. + end (int): The ending index for row fetching. + order_by (str): The column name to use for ordering the rows. + tablename (str): The name of the table from which to fetch the data. + columns_str (str): The string representation of the columns to select from the table. + json_output_path (str): The file path where the resulting JSON file will be saved. + + Returns: + None: The function doesn't return any value, but writes the result to a JSONL file. + """ query = f""" WITH NumberedRows AS ( SELECT @@ -239,17 +285,17 @@ def fetch_data(method: str, cursor: Optional[Cursor], ) SELECT {columns_str} FROM NumberedRows - WHERE rn BETWEEN {s+1} AND {e}""" + WHERE rn BETWEEN {start+1} AND {end}""" if method == 'dbconnect': spark_df = run_query(query, method, cursor, sparkSession, collect=False) - if not spark_df: + if spark_df is None: raise RuntimeError( f'Expect spark dataframe with {query} but got None') pdf = spark_df.toPandas() # pyright: ignore else: # method == 'dbsql': ans = run_query(query, method, cursor, sparkSession, collect=True) - if not ans: + if ans is None: raise RuntimeError(f'Got empty results with {query}') records = [r.asDict() for r in ans] # pyright: ignore pdf = pd.DataFrame.from_dict(records) @@ -283,7 +329,7 @@ def fetch( ans = run_query(f'SELECT COUNT(*) FROM {tablename}', method, cursor, sparkSession) nrows = [row.asDict() for row in ans][0].popitem()[1] # pyright: ignore - log.info(f'total_rows = {nrows}') + log.debug(f'total_rows = {nrows}') except Exception as e: raise RuntimeError( f'Error in get total rows from {tablename}. Restart sparkSession and try again' @@ -301,16 +347,16 @@ def fetch( f'Error in get columns from {tablename}. Restart sparkSession and try again' ) from e - if method == 'dbconnect' and sparkSession: - print('processes = ', processes) + if method == 'dbconnect' and sparkSession is not None: + log.info('processes = ', processes) df = sparkSession.table( tablename) # "main.tpcds_sf100_delta.store_sales") # Running the query and collecting the data as arrow or json. signed, _, _ = df.collect_cf('arrow') # pyright: ignore - print(f'len(signed) = {len(signed)}') + log.info(f'len(signed) = {len(signed)}') - args = get_args(signed, json_output_path, columns, cf_collect_type) + args = get_args(signed, json_output_path, columns) # Stopping the SparkSession to avoid spilling connection state into the subprocesses. sparkSession.stop() @@ -318,7 +364,7 @@ def fetch( with ProcessPoolExecutor(max_workers=processes) as executor: list(executor.map(download_starargs, args)) - elif method == 'dbsql' and cursor: + elif method == 'dbsql' and cursor is not None: for start in range(0, nrows, batch_size): end = min(start + batch_size, nrows) fetch_data(method, cursor, sparkSession, start, end, order_by, @@ -382,7 +428,8 @@ def fetch_DT(args: Namespace) -> None: if version.parse(runtime_version) < version.parse( MINIMUM_DBR_VERSION): raise RuntimeError( - f'You need at least {MINIMUM_DBR_VERSION} to use Databricks-connect to read delta table for FT API but got {res.spark_version}' + f'You need at least {MINIMUM_DBR_VERSION} to use Databricks-connect' + + ' to read delta table for FT API but got {res.spark_version}' ) sparkSession = DatabricksSession.builder.remote( host=args.DATABRICKS_HOST, @@ -432,7 +479,7 @@ def fetch_DT(args: Namespace) -> None: parser.add_argument('--processes', required=False, type=int, - default=1, + default=os.cpu_count(), help='number of processes allowed to use') parser.add_argument( '--cluster_id', @@ -442,9 +489,8 @@ def fetch_DT(args: Namespace) -> None: help= 'Use serverless if not present. IMPORTANT! make sure cluster has runtime newer than 14.1.0, the databricks-connect client version' ) - parser.add_argument('--debug', type=bool, required=False, default=False) args = parser.parse_args() tik = time.time() fetch_DT(args) - print('Elapsed time', time.time() - tik) + log.info('Elapsed time', time.time() - tik) diff --git a/setup.py b/setup.py index 25d68085a5..dcf7b70434 100644 --- a/setup.py +++ b/setup.py @@ -84,8 +84,8 @@ ] extra_deps['databricks'] = [ - 'mosaicml[databricks]>=0.17.2,<0.18', - 'databricks-sql-connector>=3, <4', + 'mosaicml[databricks]>=0.17.1,<0.18', + 'databricks-sql-connector>=3,<4', 'databricks-connect==14.1.0', 'lz4>=4,<5', ] From 462f9b71c82209eb5f00c41b586947b66d585e42 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Sun, 7 Jan 2024 22:12:26 -0800 Subject: [PATCH 48/62] Fix lints --- scripts/data_prep/convert_delta_to_json.py | 31 +++++++++++----------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index 22eb58d017..0ff72a02ad 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -58,6 +58,7 @@ def to_cf(self: SparkConnectClient, Args: plan (pb2.Plan): The plan object to be executed by spark. type (str): The output format of the result, supported formats are 'json', 'csv', and 'arrow'. + Returns: Tuple[List[Result], int, bool]: A tuple containing: - A list of Result namedtuples, each containing a URL, row count, compressed size, @@ -76,7 +77,9 @@ def to_cf(self: SparkConnectClient, elif type == 'arrow': format = cloud_pb2.ResultOptions.CloudOptions.FORMAT_ARROW else: - raise Exception(f'Only formats json, csv, and arrow are supported. Got invalid type {type}') + raise Exception( + f'Only formats json, csv, and arrow are supported. Got invalid type {type}' + ) ro = cloud_pb2.ResultOptions( type=cloud_pb2.ResultOptions.TYPE_CLOUD, @@ -103,7 +106,8 @@ def to_cf(self: SparkConnectClient, cloud_pb2.CloudResultBatch.DESCRIPTOR): batch = cloud_pb2.CloudResultBatch() if not response.extension.Is(cloud_pb2.CloudResultBatch.DESCRIPTOR): - raise ValueError("Response extension is not of type CloudResultBatch.") + raise ValueError( + 'Response extension is not of type CloudResultBatch.') response.extension.Unpack(batch) result += [ Result(b.url, b.row_count, b.compressed_size, @@ -119,7 +123,7 @@ def to_cf(self: SparkConnectClient, def collect_as_cf(self: DataFrame, type: str = 'json') -> Tuple[List[Result], int, bool]: - """Collects the result of the DataFrame's execution plan as cloud fetch presigned URLs. + """Collects DataFrame execution plan as presigned URLs. This method is a wrapper around the `to_cf` method of SparkConnectClient. It takes the execution plan of the current DataFrame, converts it to a protocol buffer format, and then @@ -166,7 +170,7 @@ def run_query( cursor: Optional[Cursor] = None, spark: Optional[SparkSession] = None, collect: bool = True -) -> Union[List[Row], DataFrame, SparkDataFrame]: +) -> Optional[Union[List[Row], DataFrame, SparkDataFrame]]: """Run SQL query via databricks-connect or databricks-sql. Args: @@ -192,6 +196,7 @@ def run_query( else: raise ValueError(f'Unrecognized method: {method}') + def get_args(signed: List, json_output_path: str, columns: List) -> Iterable: for i, r in enumerate(signed): yield (i, r.url, json_output_path, columns) @@ -247,16 +252,11 @@ def download_starargs(args: Tuple) -> None: return download(*args) -def fetch_data(method: str, - cursor: Optional[Cursor], - sparkSession: Optional[SparkSession], - start: int, - end: int, - order_by: str, - tablename: str, - columns_str: str, +def fetch_data(method: str, cursor: Optional[Cursor], + sparkSession: Optional[SparkSession], start: int, end: int, + order_by: str, tablename: str, columns_str: str, json_output_path: str) -> None: - """Fetches a specified range of rows from a given table and writes the result to a json file. + """Fetches a specified range of rows from a given table to a json file. This function executes a SQL query to retrieve a range of rows, determined by 'start' and 'end' indexes, from a specified table and column set. The fetched data is then exported as a JSON file. @@ -300,7 +300,7 @@ def fetch_data(method: str, records = [r.asDict() for r in ans] # pyright: ignore pdf = pd.DataFrame.from_dict(records) - pdf.to_json(os.path.join(json_output_path, f'part_{s+1}_{e}.jsonl')) + pdf.to_json(os.path.join(json_output_path, f'part_{start+1}_{end}.jsonl')) def fetch( @@ -428,7 +428,8 @@ def fetch_DT(args: Namespace) -> None: if version.parse(runtime_version) < version.parse( MINIMUM_DBR_VERSION): raise RuntimeError( - f'You need at least {MINIMUM_DBR_VERSION} to use Databricks-connect' + + f'You need at least {MINIMUM_DBR_VERSION} to use Databricks-connect' + + ' to read delta table for FT API but got {res.spark_version}' ) sparkSession = DatabricksSession.builder.remote( From b7d9b04be6123466379e43eddafb239d2f4c0bd8 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Sun, 7 Jan 2024 22:36:49 -0800 Subject: [PATCH 49/62] change to use workspaceclient --- scripts/data_prep/convert_delta_to_json.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index 0ff72a02ad..c11a095095 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -460,14 +460,6 @@ def fetch_DT(args: Namespace) -> None: required=True, type=str, help='Local path to save the converted json') - parser.add_argument('--DATABRICKS_HOST', - required=False, - type=str, - help='DATABRICKS_HOST') - parser.add_argument('--DATABRICKS_TOKEN', - required=False, - type=str, - help='DATABRICKS_TOKEN') parser.add_argument('--http_path', required=False, type=str, @@ -492,6 +484,11 @@ def fetch_DT(args: Namespace) -> None: ) args = parser.parse_args() + from databricks.sdk import WorkspaceClient + w = WorkspaceClient() + args.DATABRICKS_HOST = w.config.host + args.DATABRICKS_TOKEN = w.config.token + tik = time.time() fetch_DT(args) log.info('Elapsed time', time.time() - tik) From 626eeb1602143ec92b59c0b66c9da7c543a066b6 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Tue, 9 Jan 2024 13:43:29 -0800 Subject: [PATCH 50/62] Add CPT support --- scripts/data_prep/convert_delta_to_json.py | 228 +++++++++++++++++---- setup.py | 1 + 2 files changed, 190 insertions(+), 39 deletions(-) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index c11a095095..48655b04d8 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -1,16 +1,16 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import glob import logging import os -import time import urllib.parse from argparse import ArgumentParser, Namespace from collections import namedtuple from concurrent.futures import ProcessPoolExecutor -from typing import Iterable, List, Optional, Tuple, Union -from uuid import uuid4 +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +import datasets as hf_datasets import google.protobuf.any_pb2 as any_pb2 import lz4.frame import pandas as pd @@ -31,6 +31,10 @@ from pyspark.sql.connect.dataframe import DataFrame from pyspark.sql.dataframe import DataFrame as SparkDataFrame from pyspark.sql.types import Row +from streaming.base.converters import dataframe_to_mds +from transformers import AutoTokenizer + +from llmfoundry.data import ConcatTokensDataset MINIMUM_DBR_VERSION = '14.1.0' @@ -374,6 +378,94 @@ def fetch( cursor.close() +def pandas_processing_fn(df: pd.DataFrame, + **args: Any) -> Iterable[Dict[str, bytes]]: + """Tokenize helper function for dataframe_to_mds. + + Args: + df (pandas.DataFrame): The input pandas DataFrame that needs to be processed. + **args : Additional arguments to be passed to the 'process_some_data' function during processing. + + Returns: + iterable obj + """ + hf_dataset = hf_datasets.Dataset.from_pandas(df=df) + tokenizer = AutoTokenizer.from_pretrained(args['tokenizer']) + tokenizer.model_max_length = 5000000000 # Hack to prevent warnings from HuggingFace + dataset = ConcatTokensDataset( + hf_dataset=hf_dataset, + max_length=args.get('concat_tokens', None), + tokenizer=tokenizer, + eos_text=args.get('eos_text', None), + bos_text=args.get('bos_text', None), + no_wrap=args.get('no_wrap', None), + ) + + for sample in dataset: # pyright: ignore + yield sample + + +def json_to_mds(json_output_path: str, + mds_output_path: str, + concat_tokens: int, + tokenizer_name: str, + eos_text: str = '<|endoftext|>', + compression: str = 'zstd', + no_wrap: bool = False, + bos_text: str = '') -> None: + """Convert a local folder of jsonl files into Streaming MDS dataset. + + Args: + json_output_path (str): Folder that contains jsonl files to process + mds_output_path (str): Folder to write MDS shards to + concat_tokens (int): Concantenate up to this many tokens + tokenizer_name (str): Name of tokenizer to use + eos_text (str): Textend to append to each example to separate concatenated samples + compression (str): The compression algorithm to use for MDS writing + no_wrap: (bool): Whether to let text examples wrap across multiple training examples + bos_text (str): Text to prepend to each example to separate concatenated samples + """ + spark = SparkSession.builder.getOrCreate() # pyright: ignore + file_paths = glob.glob(json_output_path + '/*.jsonl') + + if len(file_paths) == 0: + raise FileNotFoundError(f'No jsonl files found in {json_output_path}') + + df = spark.read.json(file_paths) + mds_kwargs = { + 'out': mds_output_path, + 'columns': { + 'tokens': 'bytes' + }, + 'keep_local': True + } + udf_kwargs = { + 'concat_tokens': concat_tokens, + 'tokenizer': tokenizer_name, + 'eos_text': eos_text, + 'compression': compression, + 'no_wrap': no_wrap, + 'bos_text': bos_text, + } + + dataframe_to_mds(df, + merge_index=True, + mds_kwargs=mds_kwargs, + udf_iterable=pandas_processing_fn, + udf_kwargs=udf_kwargs) + + # Sanity Check + import numpy as np + from streaming import StreamingDataset + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + tokenizer.model_max_length = 5000000000 # Hack to prevent warnings from HuggingFace + dataset = StreamingDataset(local=mds_output_path, shuffle=False) + for i in range(5): + l = np.frombuffer(dataset[i]['tokens'], dtype=np.int64) + print(''.join(tokenizer.decode(l))) + print() + + def fetch_DT(args: Namespace) -> None: """Fetch UC Delta Table to local as jsonl.""" log.info(f'Start .... Convert delta to json') @@ -382,6 +474,8 @@ def fetch_DT(args: Namespace) -> None: if obj.scheme != '': raise ValueError( f'Check the json_output_path and verify it is a local path!') + if args.task_type == 'CONTINUED_PRETRAIN' and args.mds_output_path is None: + raise ValueError(f'Need to specify mds_output_path along with CPT') if os.path.exists(args.json_output_path): if not os.path.isdir(args.json_output_path) or os.listdir( @@ -394,12 +488,29 @@ def fetch_DT(args: Namespace) -> None: log.info(f'Directory {args.json_output_path} created.') - method = '' + method = 'dbsql' dbsql = None sparkSession = None - if hasattr(args, 'http_path') and args.http_path: - method = 'dbsql' + if args.http_path is None and args.cluster_id is not None: + w = WorkspaceClient() + res = w.clusters.get(cluster_id=args.cluster_id) + runtime_version = res.spark_version.split('-scala')[0].replace( + 'x-snapshot', '0').replace('x', '0') + if version.parse(runtime_version) >= version.parse(MINIMUM_DBR_VERSION): + method = 'dbconnect' + + if method == 'dbconnect': + try: + sparkSession = DatabricksSession.builder.remote( + host=args.DATABRICKS_HOST, + token=args.DATABRICKS_TOKEN, + cluster_id=args.cluster_id).getOrCreate() + except Exception as e: + raise RuntimeError( + 'Failed to create databricks connection. Check hostname and access token!' + ) from e + else: try: dbsql = sql.connect( server_hostname=args.DATABRICKS_HOST, @@ -408,34 +519,8 @@ def fetch_DT(args: Namespace) -> None: ) except Exception as e: raise RuntimeError( - 'Failed to create sql connection to db workspace. Check {server_hostname} and {http_path} and access token!' + 'Failed to create sql connection to db workspace. Check server_hostname and http_path and access token!' ) from e - else: - method = 'dbconnect' - if not args.cluster_id: - session_id = str(uuid4()) - sparkSession = DatabricksSession.builder.host( - args.DATABRICKS_HOST).token(args.DATABRICKS_TOKEN).header( - 'x-databricks-session-id', session_id).getOrCreate() - else: - # IMPORTANT: make sure cluster has runtime newer than 14.1.0, the databricks-connect client version. - compute_id = args.cluster_id # "1115-130834-ms4m0yv" - valid 14.1.0 - w = WorkspaceClient() - res = w.clusters.get( - cluster_id=compute_id) # '0704-124501-tsc2fxq' - invalid 12.2.x - runtime_version = res.spark_version.split('-scala')[0].replace( - 'x-snapshot', '0').replace('x', '0') - if version.parse(runtime_version) < version.parse( - MINIMUM_DBR_VERSION): - raise RuntimeError( - f'You need at least {MINIMUM_DBR_VERSION} to use Databricks-connect' - + - ' to read delta table for FT API but got {res.spark_version}' - ) - sparkSession = DatabricksSession.builder.remote( - host=args.DATABRICKS_HOST, - token=args.DATABRICKS_TOKEN, - cluster_id=compute_id).getOrCreate() fetch(method, args.delta_table_name, args.json_output_path, args.batch_size, args.processes, sparkSession, dbsql) @@ -443,9 +528,21 @@ def fetch_DT(args: Namespace) -> None: if dbsql is not None: dbsql.close() - iterative_combine_jsons( - args.json_output_path, - os.path.join(args.json_output_path, 'combined.jsonl')) + if args.task_type == 'INSTRUCTION_FINETUNE': + # combine downloaded jsonl into one big jsonl for IFT + iterative_combine_jsons( + args.json_output_path, + os.path.join(args.json_output_path, 'combined.jsonl')) + else: + # convert downloaded jsonl into a MDS dataset for CPT + json_to_mds(json_output_path=args.json_output_path, + mds_output_path=args.mds_output_path, + concat_tokens=args.concat_tokens, + tokenizer_name=args.tokenizer, + eos_text=args.eos_text, + compression=args.compression, + no_wrap=args.no_wrap, + bos_text=args.bos_text) if __name__ == '__main__': @@ -480,15 +577,68 @@ def fetch_DT(args: Namespace) -> None: type=str, default=None, help= - 'Use serverless if not present. IMPORTANT! make sure cluster has runtime newer than 14.1.0, the databricks-connect client version' + 'Use serverless if not present. IMPORTANT! make sure cluster has runtime newer than 14.1.0 to use databricks-connect' + ) + parser.add_argument('--task_type', + required=True, + type=str, + default='INSTRUCTION_FINETUNE', + help='INSTRUCTION_FINETUNE or CONTINUED_PRETRAIN') + parser.add_argument('--mds_output_path', + required=False, + type=str, + help='local or remote paths to save MDS dataset') + parser.add_argument( + '--compression', + type=str, + default='zstd', + help='The compression algorithm to use for MDS writing', + ) + parser.add_argument( + '--concat_tokens', + required=True, + type=int, + help='Convert text to tokens and concatenate up to this many tokens', + ) + parser.add_argument( + '--tokenizer', + required=False, + type=str, + help='The name of the tokenizer to use', + ) + parser.add_argument( + '--bos_text', + type=str, + required=False, + default=None, + help= + 'The text to prepend to each example to separate concatenated examples', + ) + parser.add_argument( + '--eos_text', + type=str, + required=False, + default=None, + help= + 'The text to append to each example to separate concatenated examples', + ) + parser.add_argument( + '--no_wrap', + default=False, + action='store_true', + help= + 'Whether to let text examples wrap across multiple training examples', ) args = parser.parse_args() + if args.bos_text is None: + args.bos_text = '' + if args.eos_text is None: + args.eos_text = '' + from databricks.sdk import WorkspaceClient w = WorkspaceClient() args.DATABRICKS_HOST = w.config.host args.DATABRICKS_TOKEN = w.config.token - tik = time.time() fetch_DT(args) - log.info('Elapsed time', time.time() - tik) diff --git a/setup.py b/setup.py index dcf7b70434..ff5da2542a 100644 --- a/setup.py +++ b/setup.py @@ -88,6 +88,7 @@ 'databricks-sql-connector>=3,<4', 'databricks-connect==14.1.0', 'lz4>=4,<5', + 'pyspark>=3,<4', ] extra_deps['tensorboard'] = [ From 85cc21d46d12026433861e6f0b53ff3736086250 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Tue, 9 Jan 2024 15:09:05 -0800 Subject: [PATCH 51/62] Rewire method assignment logic --- scripts/data_prep/convert_delta_to_json.py | 194 ++------------------- 1 file changed, 18 insertions(+), 176 deletions(-) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index 48655b04d8..c5f72273c7 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -1,16 +1,15 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -import glob import logging import os +import time import urllib.parse from argparse import ArgumentParser, Namespace from collections import namedtuple from concurrent.futures import ProcessPoolExecutor -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Tuple, Union -import datasets as hf_datasets import google.protobuf.any_pb2 as any_pb2 import lz4.frame import pandas as pd @@ -31,10 +30,6 @@ from pyspark.sql.connect.dataframe import DataFrame from pyspark.sql.dataframe import DataFrame as SparkDataFrame from pyspark.sql.types import Row -from streaming.base.converters import dataframe_to_mds -from transformers import AutoTokenizer - -from llmfoundry.data import ConcatTokensDataset MINIMUM_DBR_VERSION = '14.1.0' @@ -370,6 +365,7 @@ def fetch( elif method == 'dbsql' and cursor is not None: for start in range(0, nrows, batch_size): + print('start = ', start) end = min(start + batch_size, nrows) fetch_data(method, cursor, sparkSession, start, end, order_by, tablename, columns_str, json_output_path) @@ -378,94 +374,6 @@ def fetch( cursor.close() -def pandas_processing_fn(df: pd.DataFrame, - **args: Any) -> Iterable[Dict[str, bytes]]: - """Tokenize helper function for dataframe_to_mds. - - Args: - df (pandas.DataFrame): The input pandas DataFrame that needs to be processed. - **args : Additional arguments to be passed to the 'process_some_data' function during processing. - - Returns: - iterable obj - """ - hf_dataset = hf_datasets.Dataset.from_pandas(df=df) - tokenizer = AutoTokenizer.from_pretrained(args['tokenizer']) - tokenizer.model_max_length = 5000000000 # Hack to prevent warnings from HuggingFace - dataset = ConcatTokensDataset( - hf_dataset=hf_dataset, - max_length=args.get('concat_tokens', None), - tokenizer=tokenizer, - eos_text=args.get('eos_text', None), - bos_text=args.get('bos_text', None), - no_wrap=args.get('no_wrap', None), - ) - - for sample in dataset: # pyright: ignore - yield sample - - -def json_to_mds(json_output_path: str, - mds_output_path: str, - concat_tokens: int, - tokenizer_name: str, - eos_text: str = '<|endoftext|>', - compression: str = 'zstd', - no_wrap: bool = False, - bos_text: str = '') -> None: - """Convert a local folder of jsonl files into Streaming MDS dataset. - - Args: - json_output_path (str): Folder that contains jsonl files to process - mds_output_path (str): Folder to write MDS shards to - concat_tokens (int): Concantenate up to this many tokens - tokenizer_name (str): Name of tokenizer to use - eos_text (str): Textend to append to each example to separate concatenated samples - compression (str): The compression algorithm to use for MDS writing - no_wrap: (bool): Whether to let text examples wrap across multiple training examples - bos_text (str): Text to prepend to each example to separate concatenated samples - """ - spark = SparkSession.builder.getOrCreate() # pyright: ignore - file_paths = glob.glob(json_output_path + '/*.jsonl') - - if len(file_paths) == 0: - raise FileNotFoundError(f'No jsonl files found in {json_output_path}') - - df = spark.read.json(file_paths) - mds_kwargs = { - 'out': mds_output_path, - 'columns': { - 'tokens': 'bytes' - }, - 'keep_local': True - } - udf_kwargs = { - 'concat_tokens': concat_tokens, - 'tokenizer': tokenizer_name, - 'eos_text': eos_text, - 'compression': compression, - 'no_wrap': no_wrap, - 'bos_text': bos_text, - } - - dataframe_to_mds(df, - merge_index=True, - mds_kwargs=mds_kwargs, - udf_iterable=pandas_processing_fn, - udf_kwargs=udf_kwargs) - - # Sanity Check - import numpy as np - from streaming import StreamingDataset - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) - tokenizer.model_max_length = 5000000000 # Hack to prevent warnings from HuggingFace - dataset = StreamingDataset(local=mds_output_path, shuffle=False) - for i in range(5): - l = np.frombuffer(dataset[i]['tokens'], dtype=np.int64) - print(''.join(tokenizer.decode(l))) - print() - - def fetch_DT(args: Namespace) -> None: """Fetch UC Delta Table to local as jsonl.""" log.info(f'Start .... Convert delta to json') @@ -474,8 +382,6 @@ def fetch_DT(args: Namespace) -> None: if obj.scheme != '': raise ValueError( f'Check the json_output_path and verify it is a local path!') - if args.task_type == 'CONTINUED_PRETRAIN' and args.mds_output_path is None: - raise ValueError(f'Need to specify mds_output_path along with CPT') if os.path.exists(args.json_output_path): if not os.path.isdir(args.json_output_path) or os.listdir( @@ -492,13 +398,13 @@ def fetch_DT(args: Namespace) -> None: dbsql = None sparkSession = None - if args.http_path is None and args.cluster_id is not None: - w = WorkspaceClient() - res = w.clusters.get(cluster_id=args.cluster_id) - runtime_version = res.spark_version.split('-scala')[0].replace( - 'x-snapshot', '0').replace('x', '0') - if version.parse(runtime_version) >= version.parse(MINIMUM_DBR_VERSION): - method = 'dbconnect' + w = WorkspaceClient() + res = w.clusters.get(cluster_id=args.cluster_id) + runtime_version = res.spark_version.split('-scala')[0].replace( + 'x-snapshot', '0').replace('x', '0') + if args.http_path is None and version.parse( + runtime_version) >= version.parse(MINIMUM_DBR_VERSION): + method = 'dbconnect' if method == 'dbconnect': try: @@ -513,13 +419,13 @@ def fetch_DT(args: Namespace) -> None: else: try: dbsql = sql.connect( - server_hostname=args.DATABRICKS_HOST, + server_hostname=args.DATABRICKS_HOST.lstrip('https://'), http_path=args.http_path, access_token=args.DATABRICKS_TOKEN, ) except Exception as e: raise RuntimeError( - 'Failed to create sql connection to db workspace. Check server_hostname and http_path and access token!' + 'Failed to create sql connection to db workspace. To use sql connect, you need to provide http_path and cluster_id!' ) from e fetch(method, args.delta_table_name, args.json_output_path, args.batch_size, @@ -528,21 +434,10 @@ def fetch_DT(args: Namespace) -> None: if dbsql is not None: dbsql.close() - if args.task_type == 'INSTRUCTION_FINETUNE': - # combine downloaded jsonl into one big jsonl for IFT - iterative_combine_jsons( - args.json_output_path, - os.path.join(args.json_output_path, 'combined.jsonl')) - else: - # convert downloaded jsonl into a MDS dataset for CPT - json_to_mds(json_output_path=args.json_output_path, - mds_output_path=args.mds_output_path, - concat_tokens=args.concat_tokens, - tokenizer_name=args.tokenizer, - eos_text=args.eos_text, - compression=args.compression, - no_wrap=args.no_wrap, - bos_text=args.bos_text) + # combine downloaded jsonl into one big jsonl for IFT + iterative_combine_jsons( + args.json_output_path, + os.path.join(args.json_output_path, 'combined.jsonl')) if __name__ == '__main__': @@ -579,66 +474,13 @@ def fetch_DT(args: Namespace) -> None: help= 'Use serverless if not present. IMPORTANT! make sure cluster has runtime newer than 14.1.0 to use databricks-connect' ) - parser.add_argument('--task_type', - required=True, - type=str, - default='INSTRUCTION_FINETUNE', - help='INSTRUCTION_FINETUNE or CONTINUED_PRETRAIN') - parser.add_argument('--mds_output_path', - required=False, - type=str, - help='local or remote paths to save MDS dataset') - parser.add_argument( - '--compression', - type=str, - default='zstd', - help='The compression algorithm to use for MDS writing', - ) - parser.add_argument( - '--concat_tokens', - required=True, - type=int, - help='Convert text to tokens and concatenate up to this many tokens', - ) - parser.add_argument( - '--tokenizer', - required=False, - type=str, - help='The name of the tokenizer to use', - ) - parser.add_argument( - '--bos_text', - type=str, - required=False, - default=None, - help= - 'The text to prepend to each example to separate concatenated examples', - ) - parser.add_argument( - '--eos_text', - type=str, - required=False, - default=None, - help= - 'The text to append to each example to separate concatenated examples', - ) - parser.add_argument( - '--no_wrap', - default=False, - action='store_true', - help= - 'Whether to let text examples wrap across multiple training examples', - ) args = parser.parse_args() - if args.bos_text is None: - args.bos_text = '' - if args.eos_text is None: - args.eos_text = '' - from databricks.sdk import WorkspaceClient w = WorkspaceClient() args.DATABRICKS_HOST = w.config.host args.DATABRICKS_TOKEN = w.config.token + tik = time.time() fetch_DT(args) + log.info('Elapsed time', time.time() - tik) From d8c28acbef120aaac48a24a72661fec2c38942e9 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Tue, 9 Jan 2024 16:33:57 -0800 Subject: [PATCH 52/62] Fix bug in stripping https --- scripts/data_prep/convert_delta_to_json.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index c5f72273c7..932ddebe88 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -3,6 +3,7 @@ import logging import os +import re import time import urllib.parse from argparse import ArgumentParser, Namespace @@ -419,7 +420,7 @@ def fetch_DT(args: Namespace) -> None: else: try: dbsql = sql.connect( - server_hostname=args.DATABRICKS_HOST.lstrip('https://'), + server_hostname=re.compile(r"^https?://").sub('', args.DATABRICKS_HOST).strip(), # sqlconnect hangs if hostname starts with https http_path=args.http_path, access_token=args.DATABRICKS_TOKEN, ) From 3ff5d0e667d24a949e1f3d26a237bd7c225bd7e7 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Tue, 9 Jan 2024 16:34:18 -0800 Subject: [PATCH 53/62] Add tests for rewired method assignment logic --- .../data_prep/test_convert_delta_to_json.py | 147 +++++++++++++++++- 1 file changed, 145 insertions(+), 2 deletions(-) diff --git a/tests/a_scripts/data_prep/test_convert_delta_to_json.py b/tests/a_scripts/data_prep/test_convert_delta_to_json.py index 532bf22164..d7302bd07e 100644 --- a/tests/a_scripts/data_prep/test_convert_delta_to_json.py +++ b/tests/a_scripts/data_prep/test_convert_delta_to_json.py @@ -11,7 +11,11 @@ from scripts.data_prep.convert_delta_to_json import (download, fetch_DT, iterative_combine_jsons, run_query) +from argparse import Namespace +from packaging import version +def mock_cluster_get_response(cluster_id, spark_version): + return {'cluster_id': cluster_id, 'spark_version': spark_version} class TestConverDeltaToJsonl(unittest.TestCase): @@ -19,7 +23,9 @@ class TestConverDeltaToJsonl(unittest.TestCase): @patch('scripts.data_prep.convert_delta_to_json.os.makedirs') @patch('scripts.data_prep.convert_delta_to_json.iterative_combine_jsons') @patch('scripts.data_prep.convert_delta_to_json.fetch') - def test_stream_delta_to_json(self, mock_fetch: Any, + @patch('scripts.data_prep.convert_delta_to_json.WorkspaceClient') + def test_stream_delta_to_json(self, mock_workspace_client: Any, + mock_fetch: Any, mock_combine_jsons: Any, mock_makedirs: Any, mock_sql_connect: Any): @@ -31,9 +37,13 @@ def test_stream_delta_to_json(self, mock_fetch: Any, args.http_path = 'test_path' args.batch_size = 1000 args.partitions = 1 - args.cluster_id = None + args.cluster_id = '1234' args.debug = False + mock_cluster_get = MagicMock() + mock_cluster_get.return_value = MagicMock(spark_version='14.1.0-scala2.12') + mock_workspace_client.return_value.clusters.get = mock_cluster_get + fetch_DT(args) mock_sql_connect.assert_called_once_with(server_hostname='test_host', http_path='test_path', @@ -131,3 +141,136 @@ def test_download_success(self, mock_sleep: Any, mock_join: Any, lines=True) mock_get.assert_called_once_with('http://fakeurl.com/data') + + + @patch('scripts.data_prep.convert_delta_to_json.sql.connect') + @patch('scripts.data_prep.convert_delta_to_json.DatabricksSession') + @patch('scripts.data_prep.convert_delta_to_json.WorkspaceClient') + @patch('scripts.data_prep.convert_delta_to_json.os.makedirs') + @patch('scripts.data_prep.convert_delta_to_json.iterative_combine_jsons') + @patch('scripts.data_prep.convert_delta_to_json.fetch') + def test_dbconnect_called(self, mock_fetch: Any, + mock_combine_jsons: Any, + mock_makedirs: Any, + mock_workspace_client: Any, + mock_databricks_session: Any, + mock_sql_connect: Any): + + args = MagicMock() + + args.delta_table_name = 'test_table' + args.json_output_path = '/path/to/jsonl' + # Execute function with http_path=None (should use dbconnect) + args.http_path=None + args.cluster_id='1234' + args.DATABRICKS_HOST='host' + args.DATABRICKS_TOKEN='token' + + mock_cluster_response = Namespace(spark_version='14.1.0-scala2.12') + mock_workspace_client.return_value.clusters.get.return_value = mock_cluster_response + + mock_remote = MagicMock() + mock_remote.getOrCreate.return_value = MagicMock() # Mock return value for getOrCreate + mock_databricks_session.builder.remote.return_value = mock_remote + + fetch_DT(args) + mock_databricks_session.builder.remote.assert_called_once_with(host=args.DATABRICKS_HOST, + token=args.DATABRICKS_TOKEN, + cluster_id=args.cluster_id) + + @patch('scripts.data_prep.convert_delta_to_json.sql.connect') + @patch('scripts.data_prep.convert_delta_to_json.DatabricksSession') + @patch('scripts.data_prep.convert_delta_to_json.WorkspaceClient') + @patch('scripts.data_prep.convert_delta_to_json.os.makedirs') + @patch('scripts.data_prep.convert_delta_to_json.iterative_combine_jsons') + @patch('scripts.data_prep.convert_delta_to_json.fetch') + def test_sqlconnect_called_dbr13(self, mock_fetch: Any, + mock_combine_jsons: Any, + mock_makedirs: Any, + mock_workspace_client: Any, + mock_databricks_session: Any, + mock_sql_connect: Any): + + args = MagicMock() + + args.delta_table_name = 'test_table' + args.json_output_path = '/path/to/jsonl' + # Execute function with http_path=None (should use dbconnect) + args.http_path='test_path' + args.cluster_id='1234' + args.DATABRICKS_HOST='host' + args.DATABRICKS_TOKEN='token' + + mock_cluster_response = Namespace(spark_version='13.0.0-scala2.12') + mock_workspace_client.return_value.clusters.get.return_value = mock_cluster_response + + fetch_DT(args) + mock_sql_connect.assert_called_once_with(server_hostname=args.DATABRICKS_HOST, + http_path=args.http_path, + access_token=args.DATABRICKS_TOKEN) + + + @patch('scripts.data_prep.convert_delta_to_json.sql.connect') + @patch('scripts.data_prep.convert_delta_to_json.DatabricksSession') + @patch('scripts.data_prep.convert_delta_to_json.WorkspaceClient') + @patch('scripts.data_prep.convert_delta_to_json.os.makedirs') + @patch('scripts.data_prep.convert_delta_to_json.iterative_combine_jsons') + @patch('scripts.data_prep.convert_delta_to_json.fetch') + def test_sqlconnect_called_dbr14(self, mock_fetch: Any, + mock_combine_jsons: Any, + mock_makedirs: Any, + mock_workspace_client: Any, + mock_databricks_session: Any, + mock_sql_connect: Any): + + args = MagicMock() + + args.delta_table_name = 'test_table' + args.json_output_path = '/path/to/jsonl' + # Execute function with http_path=None (should use dbconnect) + args.http_path='test_path' + args.cluster_id='1234' + args.DATABRICKS_HOST='host' + args.DATABRICKS_TOKEN='token' + + mock_cluster_response = Namespace(spark_version='14.2.0-scala2.12') + mock_workspace_client.return_value.clusters.get.return_value = mock_cluster_response + + fetch_DT(args) + mock_sql_connect.assert_called_once_with(server_hostname=args.DATABRICKS_HOST, + http_path=args.http_path, + access_token=args.DATABRICKS_TOKEN) + + + @patch('scripts.data_prep.convert_delta_to_json.sql.connect') + @patch('scripts.data_prep.convert_delta_to_json.DatabricksSession') + @patch('scripts.data_prep.convert_delta_to_json.WorkspaceClient') + @patch('scripts.data_prep.convert_delta_to_json.os.makedirs') + @patch('scripts.data_prep.convert_delta_to_json.iterative_combine_jsons') + @patch('scripts.data_prep.convert_delta_to_json.fetch') + def test_sqlconnect_called_https(self, mock_fetch: Any, + mock_combine_jsons: Any, + mock_makedirs: Any, + mock_workspace_client: Any, + mock_databricks_session: Any, + mock_sql_connect: Any): + + args = MagicMock() + + args.delta_table_name = 'test_table' + args.json_output_path = '/path/to/jsonl' + # Execute function with http_path=None (should use dbconnect) + args.http_path='test_path' + args.cluster_id='1234' + args.DATABRICKS_HOST='https://test-host' + args.DATABRICKS_TOKEN='token' + + mock_cluster_response = Namespace(spark_version='14.2.0-scala2.12') + mock_workspace_client.return_value.clusters.get.return_value = mock_cluster_response + + fetch_DT(args) + mock_sql_connect.assert_called_once_with(server_hostname="test-host", + http_path=args.http_path, + access_token=args.DATABRICKS_TOKEN) + + From fdc7b14fe33df8dbb4678667e70a6fe1458ae8cd Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Tue, 9 Jan 2024 22:15:37 -0800 Subject: [PATCH 54/62] Fix lints --- scripts/data_prep/convert_delta_to_json.py | 8 +- .../data_prep/test_convert_delta_to_json.py | 87 +++++++++---------- 2 files changed, 47 insertions(+), 48 deletions(-) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index 932ddebe88..5a0bcb932d 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -300,7 +300,9 @@ def fetch_data(method: str, cursor: Optional[Cursor], records = [r.asDict() for r in ans] # pyright: ignore pdf = pd.DataFrame.from_dict(records) - pdf.to_json(os.path.join(json_output_path, f'part_{start+1}_{end}.jsonl')) + pdf.to_json(os.path.join(json_output_path, f'part_{start+1}_{end}.jsonl'), + orient='records', + lines=True) def fetch( @@ -420,7 +422,9 @@ def fetch_DT(args: Namespace) -> None: else: try: dbsql = sql.connect( - server_hostname=re.compile(r"^https?://").sub('', args.DATABRICKS_HOST).strip(), # sqlconnect hangs if hostname starts with https + server_hostname=re.compile(r'^https?://').sub( + '', args.DATABRICKS_HOST).strip( + ), # sqlconnect hangs if hostname starts with https http_path=args.http_path, access_token=args.DATABRICKS_TOKEN, ) diff --git a/tests/a_scripts/data_prep/test_convert_delta_to_json.py b/tests/a_scripts/data_prep/test_convert_delta_to_json.py index d7302bd07e..3b6a9d6b45 100644 --- a/tests/a_scripts/data_prep/test_convert_delta_to_json.py +++ b/tests/a_scripts/data_prep/test_convert_delta_to_json.py @@ -5,17 +5,14 @@ # spdx-license-identifier: apache-2.0 import unittest +from argparse import Namespace from typing import Any from unittest.mock import MagicMock, mock_open, patch from scripts.data_prep.convert_delta_to_json import (download, fetch_DT, iterative_combine_jsons, run_query) -from argparse import Namespace -from packaging import version -def mock_cluster_get_response(cluster_id, spark_version): - return {'cluster_id': cluster_id, 'spark_version': spark_version} class TestConverDeltaToJsonl(unittest.TestCase): @@ -25,9 +22,8 @@ class TestConverDeltaToJsonl(unittest.TestCase): @patch('scripts.data_prep.convert_delta_to_json.fetch') @patch('scripts.data_prep.convert_delta_to_json.WorkspaceClient') def test_stream_delta_to_json(self, mock_workspace_client: Any, - mock_fetch: Any, - mock_combine_jsons: Any, mock_makedirs: Any, - mock_sql_connect: Any): + mock_fetch: Any, mock_combine_jsons: Any, + mock_makedirs: Any, mock_sql_connect: Any): args = MagicMock() args.delta_table_name = 'test_table' @@ -41,7 +37,8 @@ def test_stream_delta_to_json(self, mock_workspace_client: Any, args.debug = False mock_cluster_get = MagicMock() - mock_cluster_get.return_value = MagicMock(spark_version='14.1.0-scala2.12') + mock_cluster_get.return_value = MagicMock( + spark_version='14.1.0-scala2.12') mock_workspace_client.return_value.clusters.get = mock_cluster_get fetch_DT(args) @@ -142,17 +139,14 @@ def test_download_success(self, mock_sleep: Any, mock_join: Any, mock_get.assert_called_once_with('http://fakeurl.com/data') - @patch('scripts.data_prep.convert_delta_to_json.sql.connect') @patch('scripts.data_prep.convert_delta_to_json.DatabricksSession') @patch('scripts.data_prep.convert_delta_to_json.WorkspaceClient') @patch('scripts.data_prep.convert_delta_to_json.os.makedirs') @patch('scripts.data_prep.convert_delta_to_json.iterative_combine_jsons') @patch('scripts.data_prep.convert_delta_to_json.fetch') - def test_dbconnect_called(self, mock_fetch: Any, - mock_combine_jsons: Any, - mock_makedirs: Any, - mock_workspace_client: Any, + def test_dbconnect_called(self, mock_fetch: Any, mock_combine_jsons: Any, + mock_makedirs: Any, mock_workspace_client: Any, mock_databricks_session: Any, mock_sql_connect: Any): @@ -161,22 +155,24 @@ def test_dbconnect_called(self, mock_fetch: Any, args.delta_table_name = 'test_table' args.json_output_path = '/path/to/jsonl' # Execute function with http_path=None (should use dbconnect) - args.http_path=None - args.cluster_id='1234' - args.DATABRICKS_HOST='host' - args.DATABRICKS_TOKEN='token' + args.http_path = None + args.cluster_id = '1234' + args.DATABRICKS_HOST = 'host' + args.DATABRICKS_TOKEN = 'token' mock_cluster_response = Namespace(spark_version='14.1.0-scala2.12') mock_workspace_client.return_value.clusters.get.return_value = mock_cluster_response mock_remote = MagicMock() - mock_remote.getOrCreate.return_value = MagicMock() # Mock return value for getOrCreate + mock_remote.getOrCreate.return_value = MagicMock( + ) # Mock return value for getOrCreate mock_databricks_session.builder.remote.return_value = mock_remote fetch_DT(args) - mock_databricks_session.builder.remote.assert_called_once_with(host=args.DATABRICKS_HOST, - token=args.DATABRICKS_TOKEN, - cluster_id=args.cluster_id) + mock_databricks_session.builder.remote.assert_called_once_with( + host=args.DATABRICKS_HOST, + token=args.DATABRICKS_TOKEN, + cluster_id=args.cluster_id) @patch('scripts.data_prep.convert_delta_to_json.sql.connect') @patch('scripts.data_prep.convert_delta_to_json.DatabricksSession') @@ -196,19 +192,19 @@ def test_sqlconnect_called_dbr13(self, mock_fetch: Any, args.delta_table_name = 'test_table' args.json_output_path = '/path/to/jsonl' # Execute function with http_path=None (should use dbconnect) - args.http_path='test_path' - args.cluster_id='1234' - args.DATABRICKS_HOST='host' - args.DATABRICKS_TOKEN='token' + args.http_path = 'test_path' + args.cluster_id = '1234' + args.DATABRICKS_HOST = 'host' + args.DATABRICKS_TOKEN = 'token' mock_cluster_response = Namespace(spark_version='13.0.0-scala2.12') mock_workspace_client.return_value.clusters.get.return_value = mock_cluster_response fetch_DT(args) - mock_sql_connect.assert_called_once_with(server_hostname=args.DATABRICKS_HOST, - http_path=args.http_path, - access_token=args.DATABRICKS_TOKEN) - + mock_sql_connect.assert_called_once_with( + server_hostname=args.DATABRICKS_HOST, + http_path=args.http_path, + access_token=args.DATABRICKS_TOKEN) @patch('scripts.data_prep.convert_delta_to_json.sql.connect') @patch('scripts.data_prep.convert_delta_to_json.DatabricksSession') @@ -228,19 +224,19 @@ def test_sqlconnect_called_dbr14(self, mock_fetch: Any, args.delta_table_name = 'test_table' args.json_output_path = '/path/to/jsonl' # Execute function with http_path=None (should use dbconnect) - args.http_path='test_path' - args.cluster_id='1234' - args.DATABRICKS_HOST='host' - args.DATABRICKS_TOKEN='token' + args.http_path = 'test_path' + args.cluster_id = '1234' + args.DATABRICKS_HOST = 'host' + args.DATABRICKS_TOKEN = 'token' mock_cluster_response = Namespace(spark_version='14.2.0-scala2.12') mock_workspace_client.return_value.clusters.get.return_value = mock_cluster_response fetch_DT(args) - mock_sql_connect.assert_called_once_with(server_hostname=args.DATABRICKS_HOST, - http_path=args.http_path, - access_token=args.DATABRICKS_TOKEN) - + mock_sql_connect.assert_called_once_with( + server_hostname=args.DATABRICKS_HOST, + http_path=args.http_path, + access_token=args.DATABRICKS_TOKEN) @patch('scripts.data_prep.convert_delta_to_json.sql.connect') @patch('scripts.data_prep.convert_delta_to_json.DatabricksSession') @@ -260,17 +256,16 @@ def test_sqlconnect_called_https(self, mock_fetch: Any, args.delta_table_name = 'test_table' args.json_output_path = '/path/to/jsonl' # Execute function with http_path=None (should use dbconnect) - args.http_path='test_path' - args.cluster_id='1234' - args.DATABRICKS_HOST='https://test-host' - args.DATABRICKS_TOKEN='token' + args.http_path = 'test_path' + args.cluster_id = '1234' + args.DATABRICKS_HOST = 'https://test-host' + args.DATABRICKS_TOKEN = 'token' mock_cluster_response = Namespace(spark_version='14.2.0-scala2.12') mock_workspace_client.return_value.clusters.get.return_value = mock_cluster_response fetch_DT(args) - mock_sql_connect.assert_called_once_with(server_hostname="test-host", - http_path=args.http_path, - access_token=args.DATABRICKS_TOKEN) - - + mock_sql_connect.assert_called_once_with( + server_hostname='test-host', + http_path=args.http_path, + access_token=args.DATABRICKS_TOKEN) From a543b8092bf4ee6bd06e0ebc1fca2315294ca197 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Tue, 9 Jan 2024 22:45:37 -0800 Subject: [PATCH 55/62] Fix lints --- scripts/data_prep/convert_delta_to_json.py | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index 5a0bcb932d..bb47b728fa 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -489,3 +489,4 @@ def fetch_DT(args: Namespace) -> None: tik = time.time() fetch_DT(args) log.info('Elapsed time', time.time() - tik) + print('Elapsed time', time.time() - tik) From ba2dbc23242050c4840c9058464c0ed7aac4b146 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Wed, 10 Jan 2024 10:24:18 -0800 Subject: [PATCH 56/62] Removed logger set_level --- scripts/data_prep/convert_delta_to_json.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index bb47b728fa..6c0373d3f6 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -35,7 +35,6 @@ MINIMUM_DBR_VERSION = '14.1.0' log = logging.getLogger(__name__) -log.setLevel(logging.INFO) Result = namedtuple( 'Result', ['url', 'row_count', 'compressed_size', 'uncompressed_size' @@ -49,7 +48,7 @@ def to_cf(self: SparkConnectClient, plan: pb2.Plan, type: str = 'json') -> Tuple[List[Result], int, bool]: - """Executes plan object return as cloud fetch presigned URLS. + """Executes the query plans and return as presigned URLS for cloud fetch. It can handle the current output formats that are supported by the server. In contrast to the regular API methods of the client, this method does not @@ -77,7 +76,7 @@ def to_cf(self: SparkConnectClient, elif type == 'arrow': format = cloud_pb2.ResultOptions.CloudOptions.FORMAT_ARROW else: - raise Exception( + raise ValueError( f'Only formats json, csv, and arrow are supported. Got invalid type {type}' ) @@ -368,7 +367,7 @@ def fetch( elif method == 'dbsql' and cursor is not None: for start in range(0, nrows, batch_size): - print('start = ', start) + log.info('start = ', start) end = min(start + batch_size, nrows) fetch_data(method, cursor, sparkSession, start, end, order_by, tablename, columns_str, json_output_path) From 88cad19545d4b0f5632ea51e115ecd56ea353f77 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Wed, 10 Jan 2024 10:39:21 -0800 Subject: [PATCH 57/62] Remove pyspark. It conflicts with databricks-connect --- scripts/data_prep/convert_delta_to_json.py | 1 - setup.py | 1 - 2 files changed, 2 deletions(-) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index 6c0373d3f6..bf0ddc5f0a 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -488,4 +488,3 @@ def fetch_DT(args: Namespace) -> None: tik = time.time() fetch_DT(args) log.info('Elapsed time', time.time() - tik) - print('Elapsed time', time.time() - tik) diff --git a/setup.py b/setup.py index ff5da2542a..dcf7b70434 100644 --- a/setup.py +++ b/setup.py @@ -88,7 +88,6 @@ 'databricks-sql-connector>=3,<4', 'databricks-connect==14.1.0', 'lz4>=4,<5', - 'pyspark>=3,<4', ] extra_deps['tensorboard'] = [ From 21efd2272bedfa842b7ae54a4be3538668ca716f Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Wed, 10 Jan 2024 21:57:20 -0800 Subject: [PATCH 58/62] Update the comment --- scripts/data_prep/convert_delta_to_json.py | 33 ++++++++++++++++------ 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index bf0ddc5f0a..ad2a1939b6 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -10,6 +10,7 @@ from collections import namedtuple from concurrent.futures import ProcessPoolExecutor from typing import Iterable, List, Optional, Tuple, Union +from uuid import uuid4 import google.protobuf.any_pb2 as any_pb2 import lz4.frame @@ -32,7 +33,8 @@ from pyspark.sql.dataframe import DataFrame as SparkDataFrame from pyspark.sql.types import Row -MINIMUM_DBR_VERSION = '14.1.0' +MINIMUM_DB_CONNECT_DBR_VERSION = '14.1.0' +MINIMUM_SQ_CONNECT_DBR_VERSION = '12.2.0' log = logging.getLogger(__name__) @@ -63,7 +65,7 @@ def to_cf(self: SparkConnectClient, - A list of Result namedtuples, each containing a URL, row count, compressed size, and uncompressed size of the part of the result. - Total row count of all parts of the result. - - A boolean indicating whether the result is truncated or overflowed. + - A boolean indicating whether the result has been truncated. """ req = self._execute_plan_request_with_metadata() req.plan.CopyFrom(plan) @@ -350,8 +352,7 @@ def fetch( if method == 'dbconnect' and sparkSession is not None: log.info('processes = ', processes) - df = sparkSession.table( - tablename) # "main.tpcds_sf100_delta.store_sales") + df = sparkSession.table(tablename) # Running the query and collecting the data as arrow or json. signed, _, _ = df.collect_cf('arrow') # pyright: ignore @@ -404,16 +405,30 @@ def fetch_DT(args: Namespace) -> None: res = w.clusters.get(cluster_id=args.cluster_id) runtime_version = res.spark_version.split('-scala')[0].replace( 'x-snapshot', '0').replace('x', '0') + if version.parse(runtime_version) < version.parse( + MINIMUM_SQ_CONNECT_DBR_VERSION): + raise ValueError( + f'The minium DBR version required is {MINIMUM_SQ_CONNECT_DBR_VERSION} but got {version.parse(runtime_version)}' + ) + if args.http_path is None and version.parse( - runtime_version) >= version.parse(MINIMUM_DBR_VERSION): + runtime_version) >= version.parse(MINIMUM_DB_CONNECT_DBR_VERSION): method = 'dbconnect' if method == 'dbconnect': try: - sparkSession = DatabricksSession.builder.remote( - host=args.DATABRICKS_HOST, - token=args.DATABRICKS_TOKEN, - cluster_id=args.cluster_id).getOrCreate() + if args.cluster_id == 'serverless': + session_id = str(uuid4()) + sparkSession = DatabricksSession.builder.host( + args.DATABRICKS_HOST).token(args.DATABRICKS_TOKEN).header( + 'x-databricks-session-id', session_id).getOrCreate() + + else: + sparkSession = DatabricksSession.builder.remote( + host=args.DATABRICKS_HOST, + token=args.DATABRICKS_TOKEN, + cluster_id=args.cluster_id).getOrCreate() + except Exception as e: raise RuntimeError( 'Failed to create databricks connection. Check hostname and access token!' From 90d77fd1f3227956899b6c22d14ed7a4bb099a4d Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Wed, 10 Jan 2024 22:33:40 -0800 Subject: [PATCH 59/62] skip cluster version check when cluster_id is serverless --- scripts/data_prep/convert_delta_to_json.py | 29 ++++++++++++---------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index ad2a1939b6..f19b7479cd 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -368,7 +368,7 @@ def fetch( elif method == 'dbsql' and cursor is not None: for start in range(0, nrows, batch_size): - log.info('start = ', start) + log.warning(f"batch {start}") end = min(start + batch_size, nrows) fetch_data(method, cursor, sparkSession, start, end, order_by, tablename, columns_str, json_output_path) @@ -401,19 +401,22 @@ def fetch_DT(args: Namespace) -> None: dbsql = None sparkSession = None - w = WorkspaceClient() - res = w.clusters.get(cluster_id=args.cluster_id) - runtime_version = res.spark_version.split('-scala')[0].replace( - 'x-snapshot', '0').replace('x', '0') - if version.parse(runtime_version) < version.parse( - MINIMUM_SQ_CONNECT_DBR_VERSION): - raise ValueError( - f'The minium DBR version required is {MINIMUM_SQ_CONNECT_DBR_VERSION} but got {version.parse(runtime_version)}' - ) - - if args.http_path is None and version.parse( - runtime_version) >= version.parse(MINIMUM_DB_CONNECT_DBR_VERSION): + if args.cluster_id == "serverless": method = 'dbconnect' + else: + w = WorkspaceClient() + res = w.clusters.get(cluster_id=args.cluster_id) + runtime_version = res.spark_version.split('-scala')[0].replace( + 'x-snapshot', '0').replace('x', '0') + if version.parse(runtime_version) < version.parse( + MINIMUM_SQ_CONNECT_DBR_VERSION): + raise ValueError( + f'The minium DBR version required is {MINIMUM_SQ_CONNECT_DBR_VERSION} but got {version.parse(runtime_version)}' + ) + + if args.http_path is None and version.parse( + runtime_version) >= version.parse(MINIMUM_DB_CONNECT_DBR_VERSION): + method = 'dbconnect' if method == 'dbconnect': try: From c8f83076090fe9f3821b21d6fa3ca142118c15a1 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Wed, 10 Jan 2024 23:16:22 -0800 Subject: [PATCH 60/62] Add use_serverless flag --- scripts/data_prep/convert_delta_to_json.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index f19b7479cd..8986849a42 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -332,7 +332,7 @@ def fetch( ans = run_query(f'SELECT COUNT(*) FROM {tablename}', method, cursor, sparkSession) nrows = [row.asDict() for row in ans][0].popitem()[1] # pyright: ignore - log.debug(f'total_rows = {nrows}') + log.info(f'total_rows = {nrows}') except Exception as e: raise RuntimeError( f'Error in get total rows from {tablename}. Restart sparkSession and try again' @@ -368,7 +368,7 @@ def fetch( elif method == 'dbsql' and cursor is not None: for start in range(0, nrows, batch_size): - log.warning(f"batch {start}") + log.warning(f'batch {start}') end = min(start + batch_size, nrows) fetch_data(method, cursor, sparkSession, start, end, order_by, tablename, columns_str, json_output_path) @@ -401,7 +401,7 @@ def fetch_DT(args: Namespace) -> None: dbsql = None sparkSession = None - if args.cluster_id == "serverless": + if args.use_serverless: method = 'dbconnect' else: w = WorkspaceClient() @@ -415,12 +415,13 @@ def fetch_DT(args: Namespace) -> None: ) if args.http_path is None and version.parse( - runtime_version) >= version.parse(MINIMUM_DB_CONNECT_DBR_VERSION): + runtime_version) >= version.parse( + MINIMUM_DB_CONNECT_DBR_VERSION): method = 'dbconnect' if method == 'dbconnect': try: - if args.cluster_id == 'serverless': + if args.use_serverless: session_id = str(uuid4()) sparkSession = DatabricksSession.builder.host( args.DATABRICKS_HOST).token(args.DATABRICKS_TOKEN).header( @@ -494,7 +495,15 @@ def fetch_DT(args: Namespace) -> None: type=str, default=None, help= - 'Use serverless if not present. IMPORTANT! make sure cluster has runtime newer than 14.1.0 to use databricks-connect' + 'cluster id has runtime newer than 14.1.0 and access mode of either assigned or shared can use databricks-connect.' + ) + parser.add_argument( + '--use_serverless', + required=False, + type=bool, + default=False, + help= + 'Use serverless or not. Make sure the workspace is entitled with serverless' ) args = parser.parse_args() From 9a524be7f9294f3b781aed2a8b587b0f79f5dafa Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Thu, 11 Jan 2024 00:38:55 -0800 Subject: [PATCH 61/62] update tests with use_serverless flag --- .../data_prep/test_convert_delta_to_json.py | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/tests/a_scripts/data_prep/test_convert_delta_to_json.py b/tests/a_scripts/data_prep/test_convert_delta_to_json.py index 3b6a9d6b45..28eb630a35 100644 --- a/tests/a_scripts/data_prep/test_convert_delta_to_json.py +++ b/tests/a_scripts/data_prep/test_convert_delta_to_json.py @@ -35,6 +35,7 @@ def test_stream_delta_to_json(self, mock_workspace_client: Any, args.partitions = 1 args.cluster_id = '1234' args.debug = False + args.use_serverless = False mock_cluster_get = MagicMock() mock_cluster_get.return_value = MagicMock( @@ -159,6 +160,7 @@ def test_dbconnect_called(self, mock_fetch: Any, mock_combine_jsons: Any, args.cluster_id = '1234' args.DATABRICKS_HOST = 'host' args.DATABRICKS_TOKEN = 'token' + args.use_serverless = False mock_cluster_response = Namespace(spark_version='14.1.0-scala2.12') mock_workspace_client.return_value.clusters.get.return_value = mock_cluster_response @@ -196,6 +198,7 @@ def test_sqlconnect_called_dbr13(self, mock_fetch: Any, args.cluster_id = '1234' args.DATABRICKS_HOST = 'host' args.DATABRICKS_TOKEN = 'token' + args.use_serverless = False mock_cluster_response = Namespace(spark_version='13.0.0-scala2.12') mock_workspace_client.return_value.clusters.get.return_value = mock_cluster_response @@ -228,6 +231,7 @@ def test_sqlconnect_called_dbr14(self, mock_fetch: Any, args.cluster_id = '1234' args.DATABRICKS_HOST = 'host' args.DATABRICKS_TOKEN = 'token' + args.use_serverless = False mock_cluster_response = Namespace(spark_version='14.2.0-scala2.12') mock_workspace_client.return_value.clusters.get.return_value = mock_cluster_response @@ -260,6 +264,7 @@ def test_sqlconnect_called_https(self, mock_fetch: Any, args.cluster_id = '1234' args.DATABRICKS_HOST = 'https://test-host' args.DATABRICKS_TOKEN = 'token' + args.use_serverless = False mock_cluster_response = Namespace(spark_version='14.2.0-scala2.12') mock_workspace_client.return_value.clusters.get.return_value = mock_cluster_response @@ -269,3 +274,36 @@ def test_sqlconnect_called_https(self, mock_fetch: Any, server_hostname='test-host', http_path=args.http_path, access_token=args.DATABRICKS_TOKEN) + + + @patch('scripts.data_prep.convert_delta_to_json.sql.connect') + @patch('scripts.data_prep.convert_delta_to_json.DatabricksSession') + @patch('scripts.data_prep.convert_delta_to_json.WorkspaceClient') + @patch('scripts.data_prep.convert_delta_to_json.os.makedirs') + @patch('scripts.data_prep.convert_delta_to_json.iterative_combine_jsons') + @patch('scripts.data_prep.convert_delta_to_json.fetch') + def test_serverless(self, mock_fetch: Any, + mock_combine_jsons: Any, + mock_makedirs: Any, + mock_workspace_client: Any, + mock_databricks_session: Any, + mock_sql_connect: Any): + + args = MagicMock() + + args.delta_table_name = 'test_table' + args.json_output_path = '/path/to/jsonl' + # Execute function with http_path=None (should use dbconnect) + args.http_path = 'test_path' + args.cluster_id = '1234' + args.DATABRICKS_HOST = 'https://test-host' + args.DATABRICKS_TOKEN = 'token' + args.use_serverless = True + + mock_cluster_response = Namespace(spark_version='14.2.0-scala2.12') + mock_workspace_client.return_value.clusters.get.return_value = mock_cluster_response + + fetch_DT(args) + assert not mock_sql_connect.called + assert not mock_databricks_session.builder.remote.called + From 00ca45734fa5ef8e929f7b35e014c7f862746844 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Thu, 11 Jan 2024 01:11:09 -0800 Subject: [PATCH 62/62] Fix lints --- .../a_scripts/data_prep/test_convert_delta_to_json.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/tests/a_scripts/data_prep/test_convert_delta_to_json.py b/tests/a_scripts/data_prep/test_convert_delta_to_json.py index 28eb630a35..39bc5d8099 100644 --- a/tests/a_scripts/data_prep/test_convert_delta_to_json.py +++ b/tests/a_scripts/data_prep/test_convert_delta_to_json.py @@ -275,19 +275,15 @@ def test_sqlconnect_called_https(self, mock_fetch: Any, http_path=args.http_path, access_token=args.DATABRICKS_TOKEN) - @patch('scripts.data_prep.convert_delta_to_json.sql.connect') @patch('scripts.data_prep.convert_delta_to_json.DatabricksSession') @patch('scripts.data_prep.convert_delta_to_json.WorkspaceClient') @patch('scripts.data_prep.convert_delta_to_json.os.makedirs') @patch('scripts.data_prep.convert_delta_to_json.iterative_combine_jsons') @patch('scripts.data_prep.convert_delta_to_json.fetch') - def test_serverless(self, mock_fetch: Any, - mock_combine_jsons: Any, - mock_makedirs: Any, - mock_workspace_client: Any, - mock_databricks_session: Any, - mock_sql_connect: Any): + def test_serverless(self, mock_fetch: Any, mock_combine_jsons: Any, + mock_makedirs: Any, mock_workspace_client: Any, + mock_databricks_session: Any, mock_sql_connect: Any): args = MagicMock() @@ -306,4 +302,3 @@ def test_serverless(self, mock_fetch: Any, fetch_DT(args) assert not mock_sql_connect.called assert not mock_databricks_session.builder.remote.called -