Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Make a best effort attempt to initialise all Databricks globals
Browse files Browse the repository at this point in the history
kartikgupta-db committed Feb 22, 2024

Verified

This commit was signed with the committer’s verified signature.
djhi Gildas Garcia
1 parent 3a2798f commit 3bf489f
Showing 3 changed files with 92 additions and 13 deletions.
90 changes: 78 additions & 12 deletions databricks/sdk/runtime/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

import logging
from typing import Dict, Union
from types import FunctionType
from typing import Callable, Dict, Union

from databricks.sdk.service import sql

logger = logging.getLogger('databricks.sdk')
is_local_implementation = True
@@ -86,23 +89,86 @@ def inner() -> Dict[str, str]:
_globals[var] = userNamespaceGlobals[var]
is_local_implementation = False
except ImportError:
from typing import cast

# OSS implementation
is_local_implementation = True

from databricks.sdk.dbutils import RemoteDbUtils
try:
# We expect this to fail and only do this for providing types
from pyspark.sql.context import SQLContext
sqlContext: SQLContext = None # type: ignore
sql = sqlContext.sql
table = sqlContext.table
except Exception:
pass

# The next few try-except blocks are for initialising globals in a best effort
# mannaer. We separate them to try to get as many of them working as possible
try:
from pyspark.sql.functions import udf # type: ignore
except ImportError:
pass

try:
from databricks.connect import DatabricksSession # type: ignore
spark = DatabricksSession.builder.getOrCreate()
sc = spark.sparkContext
except Exception:
# We are ignoring all failures here because user might want to initialize
# spark session themselves and we don't want to interfere with that
pass

try:
from IPython import display as IPDisplay

def display(input=None, *args, **kwargs) -> None : # type: ignore
"""
Display plots or data.
Display plot:
- display() # no-op
- display(matplotlib.figure.Figure)
Display dataset:
- display(spark.DataFrame)
- display(list) # if list can be converted to DataFrame, e.g., list of named tuples
- display(pandas.DataFrame)
- display(koalas.DataFrame)
- display(pyspark.pandas.DataFrame)
Display any other value that has a _repr_html_() method
For Spark 2.0 and 2.1:
- display(DataFrame, streamName='optional', trigger=optional pyspark.sql.streaming.Trigger,
checkpointLocation='optional')
For Spark 2.2+:
- display(DataFrame, streamName='optional', trigger=optional interval like '1 second',
checkpointLocation='optional')
"""
return IPDisplay.display(input, *args, **kwargs) # type: ignore

def displayHTML(html) -> None: # type: ignore
"""
Display HTML data.
Parameters
----------
data : URL or HTML string
If data is a URL, display the resource at that URL, the resource is loaded dynamically by the browser.
Otherwise data should be the HTML to be displayed.
See also:
IPython.display.HTML
IPython.display.display_html
"""
return IPDisplay.display_html(html, raw=True) # type: ignore

except ImportError:
pass


# We want to propagate the error in initialising dbutils because this is a core
# functionality of the sdk
from databricks.sdk.dbutils import RemoteDbUtils
from . import dbutils_stub

from typing import cast
dbutils_type = Union[dbutils_stub.dbutils, RemoteDbUtils]

try:
from .stub import *
except (ImportError, NameError):
# this assumes that all environment variables are set
dbutils = RemoteDbUtils()

dbutils = RemoteDbUtils()
dbutils = cast(dbutils_type, dbutils)
getArgument = dbutils.widgets.getArgument

__all__ = ['dbutils'] if is_local_implementation else dbruntime_objects
__all__ = dbruntime_objects
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -16,7 +16,8 @@
install_requires=["requests>=2.28.1,<3", "google-auth~=2.0"],
extras_require={"dev": ["pytest", "pytest-cov", "pytest-xdist", "pytest-mock",
"yapf", "pycodestyle", "autoflake", "isort", "wheel",
"ipython", "ipywidgets", "requests-mock", "pyfakefs"],
"ipython", "ipywidgets", "requests-mock", "pyfakefs",
"databricks-connect", "ipython"],
"notebook": ["ipython>=8,<9", "ipywidgets>=8,<9"]},
author="Serge Smertin",
author_email="[email protected]",
12 changes: 12 additions & 0 deletions tests/integration/test_runtime_globals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
def test_runtime_spark(w, env_or_skip):
env_or_skip("SPARK_CONNECT_CLUSTER_ID")

from databricks.sdk.runtime import spark
assert spark.sql("SELECT 1").collect()[0][0] == 1

def test_runtime_display(w, env_or_skip):
from databricks.sdk.runtime import display, displayHTML

# assert no errors
display("test")
displayHTML("test")

0 comments on commit 3bf489f

Please sign in to comment.