diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py new file mode 100644 index 0000000000..8986849a42 --- /dev/null +++ b/scripts/data_prep/convert_delta_to_json.py @@ -0,0 +1,517 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import logging +import os +import re +import time +import urllib.parse +from argparse import ArgumentParser, Namespace +from collections import namedtuple +from concurrent.futures import ProcessPoolExecutor +from typing import Iterable, List, Optional, Tuple, Union +from uuid import uuid4 + +import google.protobuf.any_pb2 as any_pb2 +import lz4.frame +import pandas as pd +import pyarrow as pa +import pyspark.sql.connect.proto as pb2 +import pyspark.sql.connect.proto.cloud_pb2 as cloud_pb2 +import requests +from databricks import sql +from databricks.connect import DatabricksSession +from databricks.sdk import WorkspaceClient +from databricks.sql.client import Connection as Connection +from databricks.sql.client import Cursor as Cursor +from packaging import version +from pyspark.sql import SparkSession +from pyspark.sql.connect.client.core import SparkConnectClient +from pyspark.sql.connect.client.reattach import \ + ExecutePlanResponseReattachableIterator +from pyspark.sql.connect.dataframe import DataFrame +from pyspark.sql.dataframe import DataFrame as SparkDataFrame +from pyspark.sql.types import Row + +MINIMUM_DB_CONNECT_DBR_VERSION = '14.1.0' +MINIMUM_SQ_CONNECT_DBR_VERSION = '12.2.0' + +log = logging.getLogger(__name__) + +Result = namedtuple( + 'Result', ['url', 'row_count', 'compressed_size', 'uncompressed_size' + ]) # pyright: ignore + +# ``collect_as_cf`` is an addon new feature monkey patch on top of the DB Connect package. +# It allows the client to fetch the results in different formats from the server. +# To be able to use the code make sure this module is not overriden by DB Connect classes. + + +def to_cf(self: SparkConnectClient, + plan: pb2.Plan, + type: str = 'json') -> Tuple[List[Result], int, bool]: + """Executes the query plans and return as presigned URLS for cloud fetch. + + It can handle the current output formats that are supported by the server. + In contrast to the regular API methods of the client, this method does not + return the schema and drops all other responses. + + Args: + plan (pb2.Plan): The plan object to be executed by spark. + type (str): The output format of the result, supported formats are 'json', 'csv', and 'arrow'. + + Returns: + Tuple[List[Result], int, bool]: A tuple containing: + - A list of Result namedtuples, each containing a URL, row count, compressed size, + and uncompressed size of the part of the result. + - Total row count of all parts of the result. + - A boolean indicating whether the result has been truncated. + """ + req = self._execute_plan_request_with_metadata() + req.plan.CopyFrom(plan) + + # Add the request options + if type == 'json': + format = cloud_pb2.ResultOptions.CloudOptions.FORMAT_JSON + elif type == 'csv': + format = cloud_pb2.ResultOptions.CloudOptions.FORMAT_CSV + elif type == 'arrow': + format = cloud_pb2.ResultOptions.CloudOptions.FORMAT_ARROW + else: + raise ValueError( + f'Only formats json, csv, and arrow are supported. Got invalid type {type}' + ) + + ro = cloud_pb2.ResultOptions( + type=cloud_pb2.ResultOptions.TYPE_CLOUD, + cloudOptions=cloud_pb2.ResultOptions.CloudOptions( + format=format, + useCompression=False, + )) + cloud_option = any_pb2.Any() + cloud_option.Pack(ro) + req.request_options.append( + pb2.ExecutePlanRequest.RequestOption(extension=cloud_option)) + + # Create the iterator + iterator = ExecutePlanResponseReattachableIterator(req, self._stub, + self._retry_policy, + self._builder.metadata()) + # Iterate over the response + result = [] + row_count = 0 + is_overflow = False + + for response in iterator: + if response.HasField('extension') and response.extension.Is( + cloud_pb2.CloudResultBatch.DESCRIPTOR): + batch = cloud_pb2.CloudResultBatch() + if not response.extension.Is(cloud_pb2.CloudResultBatch.DESCRIPTOR): + raise ValueError( + 'Response extension is not of type CloudResultBatch.') + response.extension.Unpack(batch) + result += [ + Result(b.url, b.row_count, b.compressed_size, + b.uncompressed_size) for b in batch.results + ] + row_count += sum(result.row_count for result in batch.results) + is_overflow |= batch.truncated + return result, row_count, is_overflow + + +SparkConnectClient.to_cf = to_cf # pyright: ignore + + +def collect_as_cf(self: DataFrame, + type: str = 'json') -> Tuple[List[Result], int, bool]: + """Collects DataFrame execution plan as presigned URLs. + + This method is a wrapper around the `to_cf` method of SparkConnectClient. It takes the + execution plan of the current DataFrame, converts it to a protocol buffer format, and then + uses the `to_cf` method to execute the plan and fetch results as presigned URLs. + + Args: + type (str): The output format of the result, supported formats are 'json', 'csv', and 'arrow'. + + Returns: + Tuple[List[Result], int, bool]: A tuple containing: + - A list of Result namedtuples, each containing a URL, row count, compressed size, + and uncompressed size of the part of the result. + - Total row count of all parts of the result. + - A boolean indicating whether the result is truncated or overflowed. + """ + query = self._plan.to_proto(self._session.client) # pyright: ignore + return self._session.client.to_cf(query, type) # pyright: ignore + + +DataFrame.collect_cf = collect_as_cf # pyright: ignore + + +def iterative_combine_jsons(json_directory: str, output_file: str) -> None: + """Combine jsonl files in json_directory into one big jsonl file. + + This function does not work for nested subdirectories. + + Args: + json_directory(str): directory containing the JSONL files + output_file(str): path to the output combined JSONL file + """ + json_files = [f for f in os.listdir(json_directory) if f.endswith('.jsonl')] + with open(output_file, 'w') as outfile: + for file_name in json_files: + with open(os.path.join(json_directory, file_name), 'r') as infile: + for line in infile: + outfile.write(line) + log.info('JSON files have been combined into a JSONL file.') + + +def run_query( + query: str, + method: str, + cursor: Optional[Cursor] = None, + spark: Optional[SparkSession] = None, + collect: bool = True +) -> Optional[Union[List[Row], DataFrame, SparkDataFrame]]: + """Run SQL query via databricks-connect or databricks-sql. + + Args: + query (str): sql query + method (str): select from dbsql and dbconnect + cursor (Optional[Cursor]): connection.cursor + spark (Optional[SparkSession]): spark session + collect (bool): whether to get the underlying data from spark dataframe + """ + if method == 'dbsql': + if cursor is None: + raise ValueError(f'cursor cannot be None if using method dbsql') + cursor.execute(query) + if collect: + return cursor.fetchall() + elif method == 'dbconnect': + if spark == None: + raise ValueError(f'sparkSession is required for dbconnect') + df = spark.sql(query) + if collect: + return df.collect() + return df + else: + raise ValueError(f'Unrecognized method: {method}') + + +def get_args(signed: List, json_output_path: str, columns: List) -> Iterable: + for i, r in enumerate(signed): + yield (i, r.url, json_output_path, columns) + + +def download(ipart: int, + url: str, + json_output_path: str, + columns: Optional[List] = None, + resp_format: str = 'arrow', + compressed: bool = False) -> None: + """Thread download presigned url and save to jsonl locally. + + Args: + ipart (int): presigned url id + url (str): presigned url + json_output_path (str): directory to save the ipart_th segment of dataframe + columns (list): schema to save to json + resp_format (str): whether to use arrow or json when collect + compressed (bool): if data is compressed before downloading. Need decompress if compressed=True. + """ + resp = requests.get(url) + if resp.status_code == 200: + if resp_format == 'json': + data = resp.json() + pd.DataFrame(data, columns=columns).to_json(os.path.join( + json_output_path, 'part_' + str(ipart) + '.jsonl'), + orient='records', + lines=True) + return + + # When resp_format is arrow: + if compressed: + # The data is lz4 compressed arrow format. + # Decompress the data + decompressed_data = lz4.frame.decompress(resp.content) + # Convert the decompressed data into a PyArrow table + reader = pa.ipc.open_stream(decompressed_data) + else: + reader = pa.ipc.open_stream(resp.content) + table = reader.read_all() + + # Convert the PyArrow table into a pandas DataFrame + df = table.to_pandas() + df.to_json(os.path.join(json_output_path, + 'part_' + str(ipart) + '.jsonl'), + orient='records', + lines=True, + force_ascii=False) + + +def download_starargs(args: Tuple) -> None: + return download(*args) + + +def fetch_data(method: str, cursor: Optional[Cursor], + sparkSession: Optional[SparkSession], start: int, end: int, + order_by: str, tablename: str, columns_str: str, + json_output_path: str) -> None: + """Fetches a specified range of rows from a given table to a json file. + + This function executes a SQL query to retrieve a range of rows, determined by 'start' and 'end' indexes, + from a specified table and column set. The fetched data is then exported as a JSON file. + + Args: + method (str): The method to use for fetching data, either 'dbconnect' or 'dbsql'. + cursor (Optional[Cursor]): The cursor object for executing queries in 'dbsql' method. + sparkSession (Optional[SparkSession]): The Spark session object for executing queries in 'dbconnect' method. + start (int): The starting index for row fetching. + end (int): The ending index for row fetching. + order_by (str): The column name to use for ordering the rows. + tablename (str): The name of the table from which to fetch the data. + columns_str (str): The string representation of the columns to select from the table. + json_output_path (str): The file path where the resulting JSON file will be saved. + + Returns: + None: The function doesn't return any value, but writes the result to a JSONL file. + """ + query = f""" + WITH NumberedRows AS ( + SELECT + *, + ROW_NUMBER() OVER (ORDER BY {order_by}) AS rn + FROM + {tablename} + ) + SELECT {columns_str} + FROM NumberedRows + WHERE rn BETWEEN {start+1} AND {end}""" + + if method == 'dbconnect': + spark_df = run_query(query, method, cursor, sparkSession, collect=False) + if spark_df is None: + raise RuntimeError( + f'Expect spark dataframe with {query} but got None') + pdf = spark_df.toPandas() # pyright: ignore + else: # method == 'dbsql': + ans = run_query(query, method, cursor, sparkSession, collect=True) + if ans is None: + raise RuntimeError(f'Got empty results with {query}') + records = [r.asDict() for r in ans] # pyright: ignore + pdf = pd.DataFrame.from_dict(records) + + pdf.to_json(os.path.join(json_output_path, f'part_{start+1}_{end}.jsonl'), + orient='records', + lines=True) + + +def fetch( + method: str, + tablename: str, + json_output_path: str, + batch_size: int = 1 << 30, + processes: int = 1, + sparkSession: Optional[SparkSession] = None, + dbsql: Optional[Connection] = None, +) -> None: + """Fetch UC delta table with databricks-connnect as JSONL. + + Args: + method (str): dbconnect or dbsql + tablename (str): catalog.scheme.tablename on UC + json_output_path (str): path to write the result json file to + batch_size (int): number of rows that dbsql fetches each time to avoid OOM + processes (int): max number of processes to use to parallelize the fetch + sparkSession (pyspark.sql.sparksession): spark session + dbsql (databricks.sql.connect): dbsql session + """ + cursor = dbsql.cursor() if dbsql is not None else None + + try: + ans = run_query(f'SELECT COUNT(*) FROM {tablename}', method, cursor, + sparkSession) + nrows = [row.asDict() for row in ans][0].popitem()[1] # pyright: ignore + log.info(f'total_rows = {nrows}') + except Exception as e: + raise RuntimeError( + f'Error in get total rows from {tablename}. Restart sparkSession and try again' + ) from e + + try: + ans = run_query(f'SHOW COLUMNS IN {tablename}', method, cursor, + sparkSession) + columns = [row.asDict().popitem()[1] for row in ans] # pyright: ignore + order_by = columns[0] + columns_str = ','.join(columns) + log.info(f'order by column {order_by}') + except Exception as e: + raise RuntimeError( + f'Error in get columns from {tablename}. Restart sparkSession and try again' + ) from e + + if method == 'dbconnect' and sparkSession is not None: + log.info('processes = ', processes) + df = sparkSession.table(tablename) + + # Running the query and collecting the data as arrow or json. + signed, _, _ = df.collect_cf('arrow') # pyright: ignore + log.info(f'len(signed) = {len(signed)}') + + args = get_args(signed, json_output_path, columns) + + # Stopping the SparkSession to avoid spilling connection state into the subprocesses. + sparkSession.stop() + + with ProcessPoolExecutor(max_workers=processes) as executor: + list(executor.map(download_starargs, args)) + + elif method == 'dbsql' and cursor is not None: + for start in range(0, nrows, batch_size): + log.warning(f'batch {start}') + end = min(start + batch_size, nrows) + fetch_data(method, cursor, sparkSession, start, end, order_by, + tablename, columns_str, json_output_path) + + if cursor is not None: + cursor.close() + + +def fetch_DT(args: Namespace) -> None: + """Fetch UC Delta Table to local as jsonl.""" + log.info(f'Start .... Convert delta to json') + + obj = urllib.parse.urlparse(args.json_output_path) + if obj.scheme != '': + raise ValueError( + f'Check the json_output_path and verify it is a local path!') + + if os.path.exists(args.json_output_path): + if not os.path.isdir(args.json_output_path) or os.listdir( + args.json_output_path): + raise RuntimeError( + f'A file or a folder {args.json_output_path} already exists and is not empty. Remove it and retry!' + ) + + os.makedirs(args.json_output_path, exist_ok=True) + + log.info(f'Directory {args.json_output_path} created.') + + method = 'dbsql' + dbsql = None + sparkSession = None + + if args.use_serverless: + method = 'dbconnect' + else: + w = WorkspaceClient() + res = w.clusters.get(cluster_id=args.cluster_id) + runtime_version = res.spark_version.split('-scala')[0].replace( + 'x-snapshot', '0').replace('x', '0') + if version.parse(runtime_version) < version.parse( + MINIMUM_SQ_CONNECT_DBR_VERSION): + raise ValueError( + f'The minium DBR version required is {MINIMUM_SQ_CONNECT_DBR_VERSION} but got {version.parse(runtime_version)}' + ) + + if args.http_path is None and version.parse( + runtime_version) >= version.parse( + MINIMUM_DB_CONNECT_DBR_VERSION): + method = 'dbconnect' + + if method == 'dbconnect': + try: + if args.use_serverless: + session_id = str(uuid4()) + sparkSession = DatabricksSession.builder.host( + args.DATABRICKS_HOST).token(args.DATABRICKS_TOKEN).header( + 'x-databricks-session-id', session_id).getOrCreate() + + else: + sparkSession = DatabricksSession.builder.remote( + host=args.DATABRICKS_HOST, + token=args.DATABRICKS_TOKEN, + cluster_id=args.cluster_id).getOrCreate() + + except Exception as e: + raise RuntimeError( + 'Failed to create databricks connection. Check hostname and access token!' + ) from e + else: + try: + dbsql = sql.connect( + server_hostname=re.compile(r'^https?://').sub( + '', args.DATABRICKS_HOST).strip( + ), # sqlconnect hangs if hostname starts with https + http_path=args.http_path, + access_token=args.DATABRICKS_TOKEN, + ) + except Exception as e: + raise RuntimeError( + 'Failed to create sql connection to db workspace. To use sql connect, you need to provide http_path and cluster_id!' + ) from e + + fetch(method, args.delta_table_name, args.json_output_path, args.batch_size, + args.processes, sparkSession, dbsql) + + if dbsql is not None: + dbsql.close() + + # combine downloaded jsonl into one big jsonl for IFT + iterative_combine_jsons( + args.json_output_path, + os.path.join(args.json_output_path, 'combined.jsonl')) + + +if __name__ == '__main__': + parser = ArgumentParser( + description= + 'Download delta table from UC and convert to json to save local') + parser.add_argument('--delta_table_name', + required=True, + type=str, + help='UC table ..') + parser.add_argument('--json_output_path', + required=True, + type=str, + help='Local path to save the converted json') + parser.add_argument('--http_path', + required=False, + type=str, + help='http_path is set then dbsql method is used') + parser.add_argument('--batch_size', + required=False, + type=int, + default=1 << 30, + help='row chunks to transmit a time to avoid OOM') + parser.add_argument('--processes', + required=False, + type=int, + default=os.cpu_count(), + help='number of processes allowed to use') + parser.add_argument( + '--cluster_id', + required=True, + type=str, + default=None, + help= + 'cluster id has runtime newer than 14.1.0 and access mode of either assigned or shared can use databricks-connect.' + ) + parser.add_argument( + '--use_serverless', + required=False, + type=bool, + default=False, + help= + 'Use serverless or not. Make sure the workspace is entitled with serverless' + ) + args = parser.parse_args() + + from databricks.sdk import WorkspaceClient + w = WorkspaceClient() + args.DATABRICKS_HOST = w.config.host + args.DATABRICKS_TOKEN = w.config.token + + tik = time.time() + fetch_DT(args) + log.info('Elapsed time', time.time() - tik) diff --git a/setup.py b/setup.py index 3de80f2292..5444352cf7 100644 --- a/setup.py +++ b/setup.py @@ -84,7 +84,10 @@ ] extra_deps['databricks'] = [ - 'mosaicml[databricks]>=0.17.2,<0.18', + 'mosaicml[databricks]>=0.17.1,<0.18', + 'databricks-sql-connector>=3,<4', + 'databricks-connect==14.1.0', + 'lz4>=4,<5', ] extra_deps['tensorboard'] = [ diff --git a/tests/a_scripts/data_prep/test_convert_delta_to_json.py b/tests/a_scripts/data_prep/test_convert_delta_to_json.py new file mode 100644 index 0000000000..39bc5d8099 --- /dev/null +++ b/tests/a_scripts/data_prep/test_convert_delta_to_json.py @@ -0,0 +1,304 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +# copyright 2022 mosaicml llm foundry authors +# spdx-license-identifier: apache-2.0 + +import unittest +from argparse import Namespace +from typing import Any +from unittest.mock import MagicMock, mock_open, patch + +from scripts.data_prep.convert_delta_to_json import (download, fetch_DT, + iterative_combine_jsons, + run_query) + + +class TestConverDeltaToJsonl(unittest.TestCase): + + @patch('scripts.data_prep.convert_delta_to_json.sql.connect') + @patch('scripts.data_prep.convert_delta_to_json.os.makedirs') + @patch('scripts.data_prep.convert_delta_to_json.iterative_combine_jsons') + @patch('scripts.data_prep.convert_delta_to_json.fetch') + @patch('scripts.data_prep.convert_delta_to_json.WorkspaceClient') + def test_stream_delta_to_json(self, mock_workspace_client: Any, + mock_fetch: Any, mock_combine_jsons: Any, + mock_makedirs: Any, mock_sql_connect: Any): + + args = MagicMock() + args.delta_table_name = 'test_table' + args.json_output_path = '/path/to/jsonl' + args.DATABRICKS_HOST = 'test_host' + args.DATABRICKS_TOKEN = 'test_token' + args.http_path = 'test_path' + args.batch_size = 1000 + args.partitions = 1 + args.cluster_id = '1234' + args.debug = False + args.use_serverless = False + + mock_cluster_get = MagicMock() + mock_cluster_get.return_value = MagicMock( + spark_version='14.1.0-scala2.12') + mock_workspace_client.return_value.clusters.get = mock_cluster_get + + fetch_DT(args) + mock_sql_connect.assert_called_once_with(server_hostname='test_host', + http_path='test_path', + access_token='test_token') + mock_makedirs.assert_called_once_with('/path/to/jsonl', exist_ok=True) + mock_fetch.assert_called_once() + mock_combine_jsons.assert_called_once_with( + '/path/to/jsonl', '/path/to/jsonl/combined.jsonl') + + @patch('scripts.data_prep.convert_delta_to_json.os.listdir') + @patch('builtins.open', + new_callable=mock_open, + read_data='{"key": "value"}') + def test_iterative_combine_jsons(self, mock_file: Any, mock_listdir: Any): + mock_listdir.return_value = ['file1.jsonl', 'file2.jsonl'] + json_directory = '/fake/dir' + output_file = '/fake/output.jsonl' + + iterative_combine_jsons(json_directory, output_file) + + mock_listdir.assert_called_once_with(json_directory) + mock_file.assert_called() + """ + Diagnostic print + for call_args in mock_file().write.call_args_list: + print(call_args) + -------------------- + call('{') + call('"key"') + call(': ') + call('"value"') + call('}') + call('\n') + call('{') + call('"key"') + call(': ') + call('"value"') + call('}') + call('\n') + -------------------- + """ + self.assertEqual(mock_file().write.call_count, 2) + + @patch('scripts.data_prep.convert_delta_to_json.SparkSession') + def test_run_query_dbconnect(self, mock_spark: Any): + method = 'dbconnect' + mock_cursor = None + mock_spark.sql.return_value.collect.return_value = 'result' + + result = run_query('SELECT * FROM table', + method, + cursor=mock_cursor, + spark=mock_spark) + + mock_spark.sql.assert_called_once_with('SELECT * FROM table') + self.assertEqual(result, 'result') + + @patch('scripts.data_prep.convert_delta_to_json.Cursor') + def test_run_query_dbsql(self, mock_cursor: Any): + method = 'dbsql' + mock_cursor.fetchall.return_value = 'result' + mock_spark = None + + result = run_query('SELECT * FROM table', + method, + cursor=mock_cursor, + spark=mock_spark) + + mock_cursor.execute.assert_called_once_with('SELECT * FROM table') + self.assertEqual(result, 'result') + + @patch('scripts.data_prep.convert_delta_to_json.requests.get') + @patch('scripts.data_prep.convert_delta_to_json.pd.DataFrame.to_json') + @patch('scripts.data_prep.convert_delta_to_json.os.path.join', + return_value='/fake/path/part_1.jsonl') + @patch('scripts.data_prep.convert_delta_to_json.time.sleep' + ) # Mock sleep to speed up the test + def test_download_success(self, mock_sleep: Any, mock_join: Any, + mock_to_json: Any, mock_get: Any): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = [['val1.1', 'val1.2'], + ['val2.1', 'val2.2']] + mock_get.return_value = mock_response + + download(1, + 'http://fakeurl.com/data', + '/fake/path', ['A', 'B'], + resp_format='json') + + mock_get.assert_called_with('http://fakeurl.com/data') + mock_join.assert_called_with('/fake/path', 'part_1.jsonl') + mock_to_json.assert_called_with('/fake/path/part_1.jsonl', + orient='records', + lines=True) + + mock_get.assert_called_once_with('http://fakeurl.com/data') + + @patch('scripts.data_prep.convert_delta_to_json.sql.connect') + @patch('scripts.data_prep.convert_delta_to_json.DatabricksSession') + @patch('scripts.data_prep.convert_delta_to_json.WorkspaceClient') + @patch('scripts.data_prep.convert_delta_to_json.os.makedirs') + @patch('scripts.data_prep.convert_delta_to_json.iterative_combine_jsons') + @patch('scripts.data_prep.convert_delta_to_json.fetch') + def test_dbconnect_called(self, mock_fetch: Any, mock_combine_jsons: Any, + mock_makedirs: Any, mock_workspace_client: Any, + mock_databricks_session: Any, + mock_sql_connect: Any): + + args = MagicMock() + + args.delta_table_name = 'test_table' + args.json_output_path = '/path/to/jsonl' + # Execute function with http_path=None (should use dbconnect) + args.http_path = None + args.cluster_id = '1234' + args.DATABRICKS_HOST = 'host' + args.DATABRICKS_TOKEN = 'token' + args.use_serverless = False + + mock_cluster_response = Namespace(spark_version='14.1.0-scala2.12') + mock_workspace_client.return_value.clusters.get.return_value = mock_cluster_response + + mock_remote = MagicMock() + mock_remote.getOrCreate.return_value = MagicMock( + ) # Mock return value for getOrCreate + mock_databricks_session.builder.remote.return_value = mock_remote + + fetch_DT(args) + mock_databricks_session.builder.remote.assert_called_once_with( + host=args.DATABRICKS_HOST, + token=args.DATABRICKS_TOKEN, + cluster_id=args.cluster_id) + + @patch('scripts.data_prep.convert_delta_to_json.sql.connect') + @patch('scripts.data_prep.convert_delta_to_json.DatabricksSession') + @patch('scripts.data_prep.convert_delta_to_json.WorkspaceClient') + @patch('scripts.data_prep.convert_delta_to_json.os.makedirs') + @patch('scripts.data_prep.convert_delta_to_json.iterative_combine_jsons') + @patch('scripts.data_prep.convert_delta_to_json.fetch') + def test_sqlconnect_called_dbr13(self, mock_fetch: Any, + mock_combine_jsons: Any, + mock_makedirs: Any, + mock_workspace_client: Any, + mock_databricks_session: Any, + mock_sql_connect: Any): + + args = MagicMock() + + args.delta_table_name = 'test_table' + args.json_output_path = '/path/to/jsonl' + # Execute function with http_path=None (should use dbconnect) + args.http_path = 'test_path' + args.cluster_id = '1234' + args.DATABRICKS_HOST = 'host' + args.DATABRICKS_TOKEN = 'token' + args.use_serverless = False + + mock_cluster_response = Namespace(spark_version='13.0.0-scala2.12') + mock_workspace_client.return_value.clusters.get.return_value = mock_cluster_response + + fetch_DT(args) + mock_sql_connect.assert_called_once_with( + server_hostname=args.DATABRICKS_HOST, + http_path=args.http_path, + access_token=args.DATABRICKS_TOKEN) + + @patch('scripts.data_prep.convert_delta_to_json.sql.connect') + @patch('scripts.data_prep.convert_delta_to_json.DatabricksSession') + @patch('scripts.data_prep.convert_delta_to_json.WorkspaceClient') + @patch('scripts.data_prep.convert_delta_to_json.os.makedirs') + @patch('scripts.data_prep.convert_delta_to_json.iterative_combine_jsons') + @patch('scripts.data_prep.convert_delta_to_json.fetch') + def test_sqlconnect_called_dbr14(self, mock_fetch: Any, + mock_combine_jsons: Any, + mock_makedirs: Any, + mock_workspace_client: Any, + mock_databricks_session: Any, + mock_sql_connect: Any): + + args = MagicMock() + + args.delta_table_name = 'test_table' + args.json_output_path = '/path/to/jsonl' + # Execute function with http_path=None (should use dbconnect) + args.http_path = 'test_path' + args.cluster_id = '1234' + args.DATABRICKS_HOST = 'host' + args.DATABRICKS_TOKEN = 'token' + args.use_serverless = False + + mock_cluster_response = Namespace(spark_version='14.2.0-scala2.12') + mock_workspace_client.return_value.clusters.get.return_value = mock_cluster_response + + fetch_DT(args) + mock_sql_connect.assert_called_once_with( + server_hostname=args.DATABRICKS_HOST, + http_path=args.http_path, + access_token=args.DATABRICKS_TOKEN) + + @patch('scripts.data_prep.convert_delta_to_json.sql.connect') + @patch('scripts.data_prep.convert_delta_to_json.DatabricksSession') + @patch('scripts.data_prep.convert_delta_to_json.WorkspaceClient') + @patch('scripts.data_prep.convert_delta_to_json.os.makedirs') + @patch('scripts.data_prep.convert_delta_to_json.iterative_combine_jsons') + @patch('scripts.data_prep.convert_delta_to_json.fetch') + def test_sqlconnect_called_https(self, mock_fetch: Any, + mock_combine_jsons: Any, + mock_makedirs: Any, + mock_workspace_client: Any, + mock_databricks_session: Any, + mock_sql_connect: Any): + + args = MagicMock() + + args.delta_table_name = 'test_table' + args.json_output_path = '/path/to/jsonl' + # Execute function with http_path=None (should use dbconnect) + args.http_path = 'test_path' + args.cluster_id = '1234' + args.DATABRICKS_HOST = 'https://test-host' + args.DATABRICKS_TOKEN = 'token' + args.use_serverless = False + + mock_cluster_response = Namespace(spark_version='14.2.0-scala2.12') + mock_workspace_client.return_value.clusters.get.return_value = mock_cluster_response + + fetch_DT(args) + mock_sql_connect.assert_called_once_with( + server_hostname='test-host', + http_path=args.http_path, + access_token=args.DATABRICKS_TOKEN) + + @patch('scripts.data_prep.convert_delta_to_json.sql.connect') + @patch('scripts.data_prep.convert_delta_to_json.DatabricksSession') + @patch('scripts.data_prep.convert_delta_to_json.WorkspaceClient') + @patch('scripts.data_prep.convert_delta_to_json.os.makedirs') + @patch('scripts.data_prep.convert_delta_to_json.iterative_combine_jsons') + @patch('scripts.data_prep.convert_delta_to_json.fetch') + def test_serverless(self, mock_fetch: Any, mock_combine_jsons: Any, + mock_makedirs: Any, mock_workspace_client: Any, + mock_databricks_session: Any, mock_sql_connect: Any): + + args = MagicMock() + + args.delta_table_name = 'test_table' + args.json_output_path = '/path/to/jsonl' + # Execute function with http_path=None (should use dbconnect) + args.http_path = 'test_path' + args.cluster_id = '1234' + args.DATABRICKS_HOST = 'https://test-host' + args.DATABRICKS_TOKEN = 'token' + args.use_serverless = True + + mock_cluster_response = Namespace(spark_version='14.2.0-scala2.12') + mock_workspace_client.return_value.clusters.get.return_value = mock_cluster_response + + fetch_DT(args) + assert not mock_sql_connect.called + assert not mock_databricks_session.builder.remote.called