Skip to content

Commit

Permalink
Get title for related notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobtomlinson committed Jan 25, 2023
1 parent c8da79d commit 9e37293
Showing 1 changed file with 24 additions and 8 deletions.
32 changes: 24 additions & 8 deletions extensions/rapids_related_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from docutils import nodes
from docutils.parsers.rst.states import RSTState
from docutils.statemachine import ViewList
from markdown_it import MarkdownIt
from sphinx.application import Sphinx
from sphinx.environment import BuildEnvironment
from sphinx.util.docutils import SphinxDirective
Expand All @@ -21,7 +22,7 @@ def read_notebook_tags(path: str) -> list[str]:


def generate_notebook_grid_myst(
notebooks: list[str], env: BuildEnvironment
notebooks: list[str], env: BuildEnvironment, state: RSTState
) -> list[str]:
"""Generate sphinx-design grid of notebooks in MyST markdown.
Expand All @@ -38,13 +39,10 @@ def generate_notebook_grid_myst(
md.append("````{grid-item-card}")
md.append(":link: /" + notebook)
md.append(":link-type: doc")

# FIXME Would prefer to use titles but can't do this because not all titles have necessarily been read yet
# The following line works on rebuilds because titles are cached in the environment but fails on a clean build
# md.append(str(env.titles[notebook].children[0]))
# Using the notebook docname instead for now
md.append(notebook)

try:
md.append(get_title_for_notebook(env.doc2path(notebook), state=state))
except ValueError:
md.append(notebook)
md.append("^" * len(notebook))
md.append("")
for tag in read_notebook_tags(env.doc2path(notebook)):
Expand All @@ -66,6 +64,23 @@ def parse_markdown(markdown: list[str], state: RSTState) -> list[nodes.Node]:
return node.children


def get_title_for_notebook(path: str, state: RSTState) -> str:
"""Read a notebook file and find the top-level heading."""
notebook = nbformat.read(path, as_version=4)
for cell in notebook.cells:
if cell["cell_type"] == "markdown":
cell_source = MarkdownIt().parse(cell["source"])
for i, token in enumerate(cell_source):
next_token = cell_source[i + 1]
if (
token.type == "heading_open"
and token.tag == "h1"
and next_token.type == "inline"
):
return next_token.content
raise ValueError("No top-level heading found")


class RelatedExamples(SphinxDirective):
def run(self) -> list[nodes.Node]:
output = nodes.section(ids=["relatedexamples"])
Expand All @@ -75,6 +90,7 @@ def run(self) -> list[nodes.Node]:
grid_markdown = generate_notebook_grid_myst(
notebooks=self.env.notebook_tag_map[self.env.docname],
env=self.env,
state=self.state,
)
for node in parse_markdown(
markdown=grid_markdown,
Expand Down

0 comments on commit 9e37293

Please sign in to comment.