Skip to content

Commit

Permalink
Add rich HTML representation to estimators (#5630)
Browse files Browse the repository at this point in the history
This adds a Jupyter (and other notebook) rich display hook that produces a HTML widget to represent an estimator in notebooks.

This adds the basics of having estimators displayed as HTML widgets in notebooks and other editors that use the Jupyter notebook "rich display" system.

<img width="304" alt="Screenshot 2023-10-26 at 15 33 37" src="https://github.com/rapidsai/cuml/assets/1448859/9eccb547-d37f-44b9-b284-fe284707764c">

This doesn't yet contain the cool feature of changing colour depending on fit status or the link to the documentation. For that we'd have to depend on a newer version of scikit-learn (or vendor the logic). In this case "newer" actually means "the next version to be released".

WDYT?

Authors:
  - Tim Head (https://github.com/betatim)
  - Simon Adorf (https://github.com/csadorf)
  - Dante Gama Dessavre (https://github.com/dantegd)

Approvers:
  - Simon Adorf (https://github.com/csadorf)

URL: #5630
  • Loading branch information
betatim authored Nov 9, 2023
1 parent b79c09f commit 9fed69e
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions python/cuml/internals/base.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ from cuml.internals.safe_imports import (
np = cpu_only_import('numpy')
nvtx_annotate = gpu_only_import_from("nvtx", "annotate", alt=null_decorator)

from sklearn.utils import estimator_html_repr

import cuml
import cuml.common
import cuml.internals.logger as logger
Expand Down Expand Up @@ -443,6 +445,12 @@ class Base(TagsMixin,
return {'preserves_dtype': [self.dtype]}
return {}

def _repr_mimebundle_(self, **kwargs):
"""Prepare representations used by jupyter kernels to display estimator"""
output = {"text/plain": repr(self)}
output["text/html"] = estimator_html_repr(self)
return output

def set_nvtx_annotations(self):
for func_name in ['fit', 'transform', 'predict', 'fit_transform',
'fit_predict']:
Expand Down

0 comments on commit 9fed69e

Please sign in to comment.