Skip to content

Commit

Permalink
Add skeleton for a multi-pass source mapper for Jaxprs/HLO to jax.exp…
Browse files Browse the repository at this point in the history
…erimental.

PiperOrigin-RevId: 721119935
  • Loading branch information
justinjfu authored and Google-ML-Automation committed Jan 29, 2025
1 parent 152099e commit b01111d
Show file tree
Hide file tree
Showing 9 changed files with 642 additions and 0 deletions.
14 changes: 14 additions & 0 deletions jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,20 @@ pytype_strict_library(
] + py_deps("numpy"),
)

pytype_strict_library(
name = "source_mapper",
srcs = glob(include = ["experimental/source_mapper/**/*.py"]),
visibility = [
"//visibility:public",
],
deps = [
":config",
":core",
":jax",
":source_info_util",
] + py_deps("absl/flags"),
)

pytype_strict_library(
name = "pallas",
srcs = glob(
Expand Down
29 changes: 29 additions & 0 deletions jax/experimental/source_mapper/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from jax._src.sourcemap import SourceMap as SourceMap
from jax._src.sourcemap import MappingsGenerator as MappingsGenerator
from jax.experimental.source_mapper.common import Pass as Pass
from jax.experimental.source_mapper.common import register_pass as register_pass
from jax.experimental.source_mapper.common import all_passes as all_passes
from jax.experimental.source_mapper.common import filter_passes as filter_passes
from jax.experimental.source_mapper.common import compile_with_env as compile_with_env
from jax.experimental.source_mapper.common import SourceMapDump as SourceMapDump
from jax.experimental.source_mapper.generate_map import generate_sourcemaps as generate_sourcemaps
from jax.experimental.source_mapper.mlir import create_mlir_sourcemap as create_mlir_sourcemap

# We import the jaxpr and hlo passes to register them.
import jax.experimental.source_mapper.jaxpr # pylint: disable=unused-import # noqa: F401
from jax.experimental.source_mapper.jaxpr import canonicalize_filename as canonicalize_filename
import jax.experimental.source_mapper.hlo # pylint: disable=unused-import # noqa: F401
91 changes: 91 additions & 0 deletions jax/experimental/source_mapper/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common utilities for generating source maps."""
import contextlib
import dataclasses
import re
from typing import Any, Protocol, Sequence

from absl import flags
import jax
from jax._src import sourcemap


@dataclasses.dataclass(frozen=True)
class SourceMapDump:
"""A container for a source map and the paired generated code."""
source_map: sourcemap.SourceMap
generated_code: str
pass_name: str


class CompileFn(Protocol):

def __call__(self, work_dir, fn, f_args, f_kwargs) -> Any:
...


class GenerateDumpFn(Protocol):

def __call__(self, compile_result: Any, **kwargs) -> SourceMapDump:
...


@dataclasses.dataclass(frozen=True)
class Pass:
name: str
compile_fn: CompileFn
generate_dump: GenerateDumpFn


_pass_registry = {}


def register_pass(pass_: Pass):
if pass_.name in _pass_registry:
raise ValueError(f"Pass {pass_.name} already registered")
_pass_registry[pass_.name] = pass_


def all_passes() -> Sequence[Pass]:
return list(_pass_registry.values())


def filter_passes(regex: str) -> Sequence[Pass]:
"""Gets all registered passes whose display name matches the given regex."""
return [
pass_
for pass_ in _pass_registry.values()
if re.match(regex, pass_.name)
]


@contextlib.contextmanager
def flag_env(**kwargs):
"""A context manager for setting and restoring flags."""
old_flags = {kwarg: getattr(flags.FLAGS, kwarg) for kwarg in kwargs}
for kwarg, new_value in kwargs.items():
setattr(flags.FLAGS, kwarg, new_value)
try:
yield
finally:
for kwarg, old_value in old_flags.items():
setattr(flags.FLAGS, kwarg, old_value)


def compile_with_env(f, f_args, f_kwargs, env_flags, compiler_flags):
with flag_env(**env_flags):
jax.jit(lambda *args, **kwargs: f(*args, **kwargs)).lower( # pylint: disable=unnecessary-lambda
*f_args, **f_kwargs
).compile(compiler_flags)
55 changes: 55 additions & 0 deletions jax/experimental/source_mapper/generate_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generates source maps for JAX functions."""
import os
import tempfile
from typing import Sequence, Protocol

from jax.experimental.source_mapper import common


class SourceMapGeneratorFn(Protocol):
def __call__(self, *args, **kwargs) -> Sequence[common.SourceMapDump]:
...


def generate_sourcemaps(
f,
passes: Sequence[common.Pass],
**kwargs
) -> SourceMapGeneratorFn:
"""Generates a SourceMapBundle for the specified compiler passes.
Args:
f: The function to compile.
passes: Which compiler passes to generate sourcemaps for.
**kwargs: Keyword arguments for generate_dump passes.
"""
def wrapper(*args, **kwargs) -> Sequence[common.SourceMapDump]:
pass_results: list[common.SourceMapDump] = []
compile_cache = {}
with tempfile.TemporaryDirectory() as work_dir:
for pass_to_eval in passes:
if pass_to_eval.compile_fn not in compile_cache:
pass_work_dir = os.path.join(work_dir, pass_to_eval.name)
os.makedirs(pass_work_dir, exist_ok=False)
compile_result = pass_to_eval.compile_fn(
pass_work_dir, f, args, kwargs
)
compile_cache[pass_to_eval.compile_fn] = compile_result
compile_result = compile_cache[pass_to_eval.compile_fn]
pass_results.append(pass_to_eval.generate_dump(compile_result,
**kwargs))
return pass_results
return wrapper
134 changes: 134 additions & 0 deletions jax/experimental/source_mapper/hlo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Source mapping generator for HLO dialects."""
import enum
import re
from typing import Any

import jax
from jax._src import sourcemap

from jax.experimental.source_mapper import common
from jax.experimental.source_mapper import mlir


class HloPass(enum.Enum):
STABLE_HLO = "hlo:stable-hlo"
ORIGINAL = "hlo:original"
OPTIMIZED = "hlo:optimized"


METADATA_REGEX = re.compile(
r"metadata={op_name=\"(?P<scope>.*)\" source_file=\"(?P<src_file>.*)\""
r" source_line=(?P<src_line>[0-9]+)\}"
)


def parse_hlo_dump(text: str) -> sourcemap.SourceMap:
mappings = sourcemap.MappingsGenerator()
used_source_files = []
for line in text.split("\n"):
mappings.new_group()
match = METADATA_REGEX.search(line)
if match:
match_dict = match.groupdict()
_ = match_dict["scope"] # Unused
src_file = match_dict["src_file"]
src_line = int(match_dict["src_line"])
if src_file not in used_source_files:
used_source_files.append(src_file)
src_file_idx = used_source_files.index(src_file)
src_line -= 1 # Segments are zero-indexed
first_col = line.index(line.strip()[0])
mappings.new_segment(first_col, src_file_idx, src_line, 0)
mappings.new_group()

return sourcemap.SourceMap(
version=3,
sources=used_source_files,
sources_content=[],
mappings=mappings.mappings(),
names=[],
)


def trace_and_lower(work_dir, f, f_args, f_kwargs):
lowered = jax.jit(lambda *args: f(*args, **f_kwargs)).lower(*f_args)
return (lowered, work_dir)


def stable_hlo_generate_dump(args: tuple[Any, str],
**_) -> common.SourceMapDump:
lowered, work_dir = args
del work_dir
hlo_text = lowered.as_text(debug_info=True)
source_map = mlir.create_mlir_sourcemap(hlo_text)
return common.SourceMapDump(
source_map=source_map,
generated_code=hlo_text,
pass_name=HloPass.STABLE_HLO.value,
)


common.register_pass(
common.Pass(
name=HloPass.STABLE_HLO.value,
compile_fn=trace_and_lower,
generate_dump=stable_hlo_generate_dump, # type: ignore[arg-type]
)
)


def original_hlo_generate_dump(args: tuple[Any, str],
**_) -> common.SourceMapDump:
lowered, work_dir = args
del work_dir
hlo_text = lowered.as_text(dialect="hlo", debug_info=True)
source_map = parse_hlo_dump(hlo_text)
return common.SourceMapDump(
source_map=source_map,
generated_code=hlo_text,
pass_name=HloPass.ORIGINAL.value,
)


common.register_pass(
common.Pass(
name=HloPass.ORIGINAL.value,
compile_fn=trace_and_lower,
generate_dump=original_hlo_generate_dump, # type: ignore[arg-type]
)
)


def optimized_generate_dump(args: tuple[Any, str],
**_) -> common.SourceMapDump:
lowered, work_dir = args
compilation_args = {"xla_dump_to": work_dir}
hlo_text = lowered.compile(compilation_args).as_text()
source_map = parse_hlo_dump(hlo_text)
return common.SourceMapDump(
source_map=source_map,
generated_code=hlo_text,
pass_name=HloPass.OPTIMIZED.value,
)


common.register_pass(
common.Pass(
name=HloPass.OPTIMIZED.value,
compile_fn=trace_and_lower,
generate_dump=optimized_generate_dump, # type: ignore[arg-type]
)
)
Loading

0 comments on commit b01111d

Please sign in to comment.