-
Notifications
You must be signed in to change notification settings - Fork 2.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add skeleton for a multi-pass source mapper for Jaxprs/HLO to jax.exp…
…erimental. PiperOrigin-RevId: 721119935
- Loading branch information
1 parent
152099e
commit b01111d
Showing
9 changed files
with
642 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# Copyright 2025 The JAX Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from jax._src.sourcemap import SourceMap as SourceMap | ||
from jax._src.sourcemap import MappingsGenerator as MappingsGenerator | ||
from jax.experimental.source_mapper.common import Pass as Pass | ||
from jax.experimental.source_mapper.common import register_pass as register_pass | ||
from jax.experimental.source_mapper.common import all_passes as all_passes | ||
from jax.experimental.source_mapper.common import filter_passes as filter_passes | ||
from jax.experimental.source_mapper.common import compile_with_env as compile_with_env | ||
from jax.experimental.source_mapper.common import SourceMapDump as SourceMapDump | ||
from jax.experimental.source_mapper.generate_map import generate_sourcemaps as generate_sourcemaps | ||
from jax.experimental.source_mapper.mlir import create_mlir_sourcemap as create_mlir_sourcemap | ||
|
||
# We import the jaxpr and hlo passes to register them. | ||
import jax.experimental.source_mapper.jaxpr # pylint: disable=unused-import # noqa: F401 | ||
from jax.experimental.source_mapper.jaxpr import canonicalize_filename as canonicalize_filename | ||
import jax.experimental.source_mapper.hlo # pylint: disable=unused-import # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
# Copyright 2025 The JAX Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Common utilities for generating source maps.""" | ||
import contextlib | ||
import dataclasses | ||
import re | ||
from typing import Any, Protocol, Sequence | ||
|
||
from absl import flags | ||
import jax | ||
from jax._src import sourcemap | ||
|
||
|
||
@dataclasses.dataclass(frozen=True) | ||
class SourceMapDump: | ||
"""A container for a source map and the paired generated code.""" | ||
source_map: sourcemap.SourceMap | ||
generated_code: str | ||
pass_name: str | ||
|
||
|
||
class CompileFn(Protocol): | ||
|
||
def __call__(self, work_dir, fn, f_args, f_kwargs) -> Any: | ||
... | ||
|
||
|
||
class GenerateDumpFn(Protocol): | ||
|
||
def __call__(self, compile_result: Any, **kwargs) -> SourceMapDump: | ||
... | ||
|
||
|
||
@dataclasses.dataclass(frozen=True) | ||
class Pass: | ||
name: str | ||
compile_fn: CompileFn | ||
generate_dump: GenerateDumpFn | ||
|
||
|
||
_pass_registry = {} | ||
|
||
|
||
def register_pass(pass_: Pass): | ||
if pass_.name in _pass_registry: | ||
raise ValueError(f"Pass {pass_.name} already registered") | ||
_pass_registry[pass_.name] = pass_ | ||
|
||
|
||
def all_passes() -> Sequence[Pass]: | ||
return list(_pass_registry.values()) | ||
|
||
|
||
def filter_passes(regex: str) -> Sequence[Pass]: | ||
"""Gets all registered passes whose display name matches the given regex.""" | ||
return [ | ||
pass_ | ||
for pass_ in _pass_registry.values() | ||
if re.match(regex, pass_.name) | ||
] | ||
|
||
|
||
@contextlib.contextmanager | ||
def flag_env(**kwargs): | ||
"""A context manager for setting and restoring flags.""" | ||
old_flags = {kwarg: getattr(flags.FLAGS, kwarg) for kwarg in kwargs} | ||
for kwarg, new_value in kwargs.items(): | ||
setattr(flags.FLAGS, kwarg, new_value) | ||
try: | ||
yield | ||
finally: | ||
for kwarg, old_value in old_flags.items(): | ||
setattr(flags.FLAGS, kwarg, old_value) | ||
|
||
|
||
def compile_with_env(f, f_args, f_kwargs, env_flags, compiler_flags): | ||
with flag_env(**env_flags): | ||
jax.jit(lambda *args, **kwargs: f(*args, **kwargs)).lower( # pylint: disable=unnecessary-lambda | ||
*f_args, **f_kwargs | ||
).compile(compiler_flags) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
# Copyright 2025 The JAX Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Generates source maps for JAX functions.""" | ||
import os | ||
import tempfile | ||
from typing import Sequence, Protocol | ||
|
||
from jax.experimental.source_mapper import common | ||
|
||
|
||
class SourceMapGeneratorFn(Protocol): | ||
def __call__(self, *args, **kwargs) -> Sequence[common.SourceMapDump]: | ||
... | ||
|
||
|
||
def generate_sourcemaps( | ||
f, | ||
passes: Sequence[common.Pass], | ||
**kwargs | ||
) -> SourceMapGeneratorFn: | ||
"""Generates a SourceMapBundle for the specified compiler passes. | ||
Args: | ||
f: The function to compile. | ||
passes: Which compiler passes to generate sourcemaps for. | ||
**kwargs: Keyword arguments for generate_dump passes. | ||
""" | ||
def wrapper(*args, **kwargs) -> Sequence[common.SourceMapDump]: | ||
pass_results: list[common.SourceMapDump] = [] | ||
compile_cache = {} | ||
with tempfile.TemporaryDirectory() as work_dir: | ||
for pass_to_eval in passes: | ||
if pass_to_eval.compile_fn not in compile_cache: | ||
pass_work_dir = os.path.join(work_dir, pass_to_eval.name) | ||
os.makedirs(pass_work_dir, exist_ok=False) | ||
compile_result = pass_to_eval.compile_fn( | ||
pass_work_dir, f, args, kwargs | ||
) | ||
compile_cache[pass_to_eval.compile_fn] = compile_result | ||
compile_result = compile_cache[pass_to_eval.compile_fn] | ||
pass_results.append(pass_to_eval.generate_dump(compile_result, | ||
**kwargs)) | ||
return pass_results | ||
return wrapper |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
# Copyright 2025 The JAX Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Source mapping generator for HLO dialects.""" | ||
import enum | ||
import re | ||
from typing import Any | ||
|
||
import jax | ||
from jax._src import sourcemap | ||
|
||
from jax.experimental.source_mapper import common | ||
from jax.experimental.source_mapper import mlir | ||
|
||
|
||
class HloPass(enum.Enum): | ||
STABLE_HLO = "hlo:stable-hlo" | ||
ORIGINAL = "hlo:original" | ||
OPTIMIZED = "hlo:optimized" | ||
|
||
|
||
METADATA_REGEX = re.compile( | ||
r"metadata={op_name=\"(?P<scope>.*)\" source_file=\"(?P<src_file>.*)\"" | ||
r" source_line=(?P<src_line>[0-9]+)\}" | ||
) | ||
|
||
|
||
def parse_hlo_dump(text: str) -> sourcemap.SourceMap: | ||
mappings = sourcemap.MappingsGenerator() | ||
used_source_files = [] | ||
for line in text.split("\n"): | ||
mappings.new_group() | ||
match = METADATA_REGEX.search(line) | ||
if match: | ||
match_dict = match.groupdict() | ||
_ = match_dict["scope"] # Unused | ||
src_file = match_dict["src_file"] | ||
src_line = int(match_dict["src_line"]) | ||
if src_file not in used_source_files: | ||
used_source_files.append(src_file) | ||
src_file_idx = used_source_files.index(src_file) | ||
src_line -= 1 # Segments are zero-indexed | ||
first_col = line.index(line.strip()[0]) | ||
mappings.new_segment(first_col, src_file_idx, src_line, 0) | ||
mappings.new_group() | ||
|
||
return sourcemap.SourceMap( | ||
version=3, | ||
sources=used_source_files, | ||
sources_content=[], | ||
mappings=mappings.mappings(), | ||
names=[], | ||
) | ||
|
||
|
||
def trace_and_lower(work_dir, f, f_args, f_kwargs): | ||
lowered = jax.jit(lambda *args: f(*args, **f_kwargs)).lower(*f_args) | ||
return (lowered, work_dir) | ||
|
||
|
||
def stable_hlo_generate_dump(args: tuple[Any, str], | ||
**_) -> common.SourceMapDump: | ||
lowered, work_dir = args | ||
del work_dir | ||
hlo_text = lowered.as_text(debug_info=True) | ||
source_map = mlir.create_mlir_sourcemap(hlo_text) | ||
return common.SourceMapDump( | ||
source_map=source_map, | ||
generated_code=hlo_text, | ||
pass_name=HloPass.STABLE_HLO.value, | ||
) | ||
|
||
|
||
common.register_pass( | ||
common.Pass( | ||
name=HloPass.STABLE_HLO.value, | ||
compile_fn=trace_and_lower, | ||
generate_dump=stable_hlo_generate_dump, # type: ignore[arg-type] | ||
) | ||
) | ||
|
||
|
||
def original_hlo_generate_dump(args: tuple[Any, str], | ||
**_) -> common.SourceMapDump: | ||
lowered, work_dir = args | ||
del work_dir | ||
hlo_text = lowered.as_text(dialect="hlo", debug_info=True) | ||
source_map = parse_hlo_dump(hlo_text) | ||
return common.SourceMapDump( | ||
source_map=source_map, | ||
generated_code=hlo_text, | ||
pass_name=HloPass.ORIGINAL.value, | ||
) | ||
|
||
|
||
common.register_pass( | ||
common.Pass( | ||
name=HloPass.ORIGINAL.value, | ||
compile_fn=trace_and_lower, | ||
generate_dump=original_hlo_generate_dump, # type: ignore[arg-type] | ||
) | ||
) | ||
|
||
|
||
def optimized_generate_dump(args: tuple[Any, str], | ||
**_) -> common.SourceMapDump: | ||
lowered, work_dir = args | ||
compilation_args = {"xla_dump_to": work_dir} | ||
hlo_text = lowered.compile(compilation_args).as_text() | ||
source_map = parse_hlo_dump(hlo_text) | ||
return common.SourceMapDump( | ||
source_map=source_map, | ||
generated_code=hlo_text, | ||
pass_name=HloPass.OPTIMIZED.value, | ||
) | ||
|
||
|
||
common.register_pass( | ||
common.Pass( | ||
name=HloPass.OPTIMIZED.value, | ||
compile_fn=trace_and_lower, | ||
generate_dump=optimized_generate_dump, # type: ignore[arg-type] | ||
) | ||
) |
Oops, something went wrong.