Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add skeleton for a multi-pass source mapper for Jaxprs/HLO to jax.exp…
Browse files Browse the repository at this point in the history
…erimental.

PiperOrigin-RevId: 718108389
justinjfu authored and Google-ML-Automation committed Jan 29, 2025
1 parent 152099e commit f57097f
Showing 9 changed files with 639 additions and 0 deletions.
14 changes: 14 additions & 0 deletions jax/BUILD
Original file line number Diff line number Diff line change
@@ -598,6 +598,20 @@ pytype_strict_library(
] + py_deps("numpy"),
)

pytype_strict_library(
name = "source_mapper",
srcs = glob(include = ["experimental/source_mapper/**/*.py"]),
visibility = [
"//visibility:public",
],
deps = [
":config",
":core",
":jax",
":source_info_util",
] + py_deps("absl/flags"),
)

pytype_strict_library(
name = "pallas",
srcs = glob(
29 changes: 29 additions & 0 deletions jax/experimental/source_mapper/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from jax._src.sourcemap import SourceMap as SourceMap
from jax._src.sourcemap import MappingsGenerator as MappingsGenerator
from jax.experimental.source_mapper.common import Pass as Pass
from jax.experimental.source_mapper.common import register_pass as register_pass
from jax.experimental.source_mapper.common import all_passes as all_passes
from jax.experimental.source_mapper.common import filter_passes as filter_passes
from jax.experimental.source_mapper.common import compile_with_env as compile_with_env
from jax.experimental.source_mapper.common import SourceMapDump as SourceMapDump
from jax.experimental.source_mapper.generate_map import generate_sourcemaps as generate_sourcemaps
from jax.experimental.source_mapper.mlir import create_mlir_sourcemap as create_mlir_sourcemap

# We import the jaxpr and hlo passes to register them.
import jax.experimental.source_mapper.jaxpr # pylint: disable=unused-import # noqa: F401
from jax.experimental.source_mapper.jaxpr import canonicalize_filename as canonicalize_filename
import jax.experimental.source_mapper.hlo # pylint: disable=unused-import # noqa: F401
91 changes: 91 additions & 0 deletions jax/experimental/source_mapper/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common utilities for generating source maps."""
import contextlib
import dataclasses
import re
from typing import Any, Protocol, Sequence

from absl import flags
import jax
from jax._src import sourcemap


@dataclasses.dataclass(frozen=True)
class SourceMapDump:
"""A container for a source map and the paired generated code."""
source_map: sourcemap.SourceMap
generated_code: str
pass_name: str


class CompileFn(Protocol):

def __call__(self, work_dir, fn, f_args, f_kwargs) -> Any:
...


class GenerateDumpFn(Protocol):

def __call__(self, compile_result: Any, **kwargs) -> SourceMapDump:
...


@dataclasses.dataclass(frozen=True)
class Pass:
name: str
compile_fn: CompileFn
generate_dump: GenerateDumpFn


_pass_registry = {}


def register_pass(pass_: Pass):
if pass_.name in _pass_registry:
raise ValueError(f"Pass {pass_.name} already registered")
_pass_registry[pass_.name] = pass_


def all_passes() -> Sequence[Pass]:
return list(_pass_registry.values())


def filter_passes(regex: str) -> Sequence[Pass]:
"""Gets all registered passes whose display name matches the given regex."""
return [
pass_
for pass_ in _pass_registry.values()
if re.match(regex, pass_.name)
]


@contextlib.contextmanager
def flag_env(**kwargs):
"""A context manager for setting and restoring flags."""
old_flags = {kwarg: getattr(flags.FLAGS, kwarg) for kwarg in kwargs}
for kwarg, new_value in kwargs.items():
setattr(flags.FLAGS, kwarg, new_value)
try:
yield
finally:
for kwarg, old_value in old_flags.items():
setattr(flags.FLAGS, kwarg, old_value)


def compile_with_env(f, f_args, f_kwargs, env_flags, compiler_flags):
with flag_env(**env_flags):
jax.jit(lambda *args, **kwargs: f(*args, **kwargs)).lower( # pylint: disable=unnecessary-lambda
*f_args, **f_kwargs
).compile(compiler_flags)
55 changes: 55 additions & 0 deletions jax/experimental/source_mapper/generate_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generates source maps for JAX functions."""
import os
import tempfile
from typing import Sequence, Protocol

from jax.experimental.source_mapper import common


class SourceMapGeneratorFn(Protocol):
def __call__(self, *args, **kwargs) -> Sequence[common.SourceMapDump]:
...


def generate_sourcemaps(
f,
passes: Sequence[common.Pass],
**kwargs
) -> SourceMapGeneratorFn:
"""Generates a SourceMapBundle for the specified compiler passes.
Args:
f: The function to compile.
passes: Which compiler passes to generate sourcemaps for.
**kwargs: Keyword arguments for generate_dump passes.
"""
def wrapper(*args, **kwargs) -> Sequence[common.SourceMapDump]:
pass_results: list[common.SourceMapDump] = []
compile_cache = {}
with tempfile.TemporaryDirectory() as work_dir:
for pass_to_eval in passes:
if pass_to_eval.compile_fn not in compile_cache:
pass_work_dir = os.path.join(work_dir, pass_to_eval.name)
os.makedirs(pass_work_dir, exist_ok=False)
compile_result = pass_to_eval.compile_fn(
pass_work_dir, f, args, kwargs
)
compile_cache[pass_to_eval.compile_fn] = compile_result
compile_result = compile_cache[pass_to_eval.compile_fn]
pass_results.append(pass_to_eval.generate_dump(compile_result,
**kwargs))
return pass_results
return wrapper
134 changes: 134 additions & 0 deletions jax/experimental/source_mapper/hlo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Source mapping generator for HLO dialects."""
import enum
import re
from typing import Any

import jax
from jax._src import sourcemap

from jax.experimental.source_mapper import common
from jax.experimental.source_mapper import mlir


class HloPass(enum.Enum):
STABLE_HLO = "hlo:stable-hlo"
ORIGINAL = "hlo:original"
OPTIMIZED = "hlo:optimized"


METADATA_REGEX = re.compile(
r"metadata={op_name=\"(?P<scope>.*)\" source_file=\"(?P<src_file>.*)\""
r" source_line=(?P<src_line>[0-9]+)\}"
)


def parse_hlo_dump(text: str) -> sourcemap.SourceMap:
mappings = sourcemap.MappingsGenerator()
used_source_files = []
for line in text.split("\n"):
mappings.new_group()
match = METADATA_REGEX.search(line)
if match:
match_dict = match.groupdict()
_ = match_dict["scope"] # Unused
src_file = match_dict["src_file"]
src_line = int(match_dict["src_line"])
if src_file not in used_source_files:
used_source_files.append(src_file)
src_file_idx = used_source_files.index(src_file)
src_line -= 1 # Segments are zero-indexed
first_col = line.index(line.strip()[0])
mappings.new_segment(first_col, src_file_idx, src_line, 0)
mappings.new_group()

return sourcemap.SourceMap(
version=3,
sources=used_source_files,
sources_content=[],
mappings=mappings.mappings(),
names=[],
)


def trace_and_lower(work_dir, f, f_args, f_kwargs):
lowered = jax.jit(lambda *args: f(*args, **f_kwargs)).lower(*f_args)
return (lowered, work_dir)


def stable_hlo_generate_dump(args: tuple[Any, str],
**_) -> common.SourceMapDump:
lowered, work_dir = args
del work_dir
hlo_text = lowered.as_text(debug_info=True)
source_map = mlir.create_mlir_sourcemap(hlo_text)
return common.SourceMapDump(
source_map=source_map,
generated_code=hlo_text,
pass_name=HloPass.STABLE_HLO.value,
)


common.register_pass(
common.Pass(
name=HloPass.STABLE_HLO.value,
compile_fn=trace_and_lower,
generate_dump=stable_hlo_generate_dump, # type: ignore[arg-type]
)
)


def original_hlo_generate_dump(args: tuple[Any, str],
**_) -> common.SourceMapDump:
lowered, work_dir = args
del work_dir
hlo_text = lowered.as_text(dialect="hlo", debug_info=True)
source_map = parse_hlo_dump(hlo_text)
return common.SourceMapDump(
source_map=source_map,
generated_code=hlo_text,
pass_name=HloPass.ORIGINAL.value,
)


common.register_pass(
common.Pass(
name=HloPass.ORIGINAL.value,
compile_fn=trace_and_lower,
generate_dump=original_hlo_generate_dump, # type: ignore[arg-type]
)
)


def optimized_generate_dump(args: tuple[Any, str],
**_) -> common.SourceMapDump:
lowered, work_dir = args
compilation_args = {"xla_dump_to": work_dir}
hlo_text = lowered.compile(compilation_args).as_text()
source_map = parse_hlo_dump(hlo_text)
return common.SourceMapDump(
source_map=source_map,
generated_code=hlo_text,
pass_name=HloPass.OPTIMIZED.value,
)


common.register_pass(
common.Pass(
name=HloPass.OPTIMIZED.value,
compile_fn=trace_and_lower,
generate_dump=optimized_generate_dump, # type: ignore[arg-type]
)
)
80 changes: 80 additions & 0 deletions jax/experimental/source_mapper/jaxpr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Source mapping generator for Jaxprs."""
import re
from typing import Any

import jax
from jax._src import config
from jax._src import core
from jax._src import source_info_util
from jax._src import sourcemap
from jax.experimental.source_mapper import common

source_info_util.register_exclusion(__file__)


def compile_jaxpr(work_dir, f, f_args, f_kwargs):
del work_dir
return jax.make_jaxpr(f)(*f_args, **f_kwargs)


def canonicalize_filename(file_name: str):
pattern = config.hlo_source_file_canonicalization_regex.value
if pattern:
file_name = re.sub(pattern, '', file_name)
return file_name


def make_jaxpr_dump(jaxpr: core.Jaxpr, **_) -> common.SourceMapDump:
pprint_mappings: list[list[tuple[int, int, Any]]] = []
pprint_str = jaxpr.pretty_print(source_map=pprint_mappings)
used_source_files = []
mappings = sourcemap.MappingsGenerator()
for pprint_map_line in pprint_mappings:
mappings.new_group()
for pprint_segment in pprint_map_line:
start_col, end_col, frame = pprint_segment
del end_col
file_name = canonicalize_filename(frame.file_name)
if file_name not in used_source_files:
used_source_files.append(file_name)
file_idx = used_source_files.index(file_name)
src_line = frame.start_line - 1 # Zero-indexed
src_col = frame.start_column
# A segment is a tuple of the form:
# (generated_col, src_file_idx, src_line, src_col)
mappings.new_segment(start_col, file_idx, src_line, src_col)
mappings.new_group()
source_map = sourcemap.SourceMap(
version=3,
sources=used_source_files,
sources_content=[],
mappings=mappings.mappings(),
names=[],
)
return common.SourceMapDump(
source_map=source_map,
generated_code=pprint_str,
pass_name='jaxpr',
)


common.register_pass(
common.Pass(
name='jaxpr',
compile_fn=compile_jaxpr,
generate_dump=make_jaxpr_dump, # type: ignore[arg-type]
)
)
140 changes: 140 additions & 0 deletions jax/experimental/source_mapper/mlir.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for generating source mappings for MLIR dialects."""
import collections
import re
from typing import cast

from jax._src import sourcemap


# TODO(justinfu): Make a proper parser for MLIR dumps.
LOC_REGEX = re.compile(r"loc\(#loc(?P<id>[0-9]+)\)")

SRC_REGEX = re.compile(
r"#loc(?P<id>[0-9]+) ="
r" loc\(\"(?P<file>.*)\":(?P<line>[0-9]+):(?P<col>[0-9]+)\)"
)

SCOPED_REGEX = re.compile(
r"#loc(?P<id>[0-9]+) = loc\(\"(?P<scope>.*)\"\(#loc(?P<tgt_id>[0-9]+)\)\)"
)

CALLSITE_REGEX = re.compile(
r"#loc(?P<id>[0-9]+) = loc\(callsite\(#loc(?P<callee>[0-9]+) at"
r" #loc(?P<caller>[0-9]+)\)\)"
)

Location = collections.namedtuple("Location", ["file", "line", "col"])
Redirect = collections.namedtuple("Redirect", ["tgt_id"])


def create_mlir_sourcemap(mlir_dump: str) -> sourcemap.SourceMap:
mappings = sourcemap.MappingsGenerator()
dump_lines: list[str] = mlir_dump.split("\n")

segment_dict, sources = parse_mlir_locations(dump_lines)
used_sources = []
used_sources_filenames = []
for line in dump_lines:
mappings.new_group()
match = LOC_REGEX.search(line)
if match:
loc_id = int(match.group("id"))
if loc_id not in segment_dict:
# TODO(justinfu): This happens on fusion locations - need to implement.
continue
segment = list(segment_dict[loc_id])
first_col = line.index(line.strip()[0])
segment[0] = first_col
# Remap the sourcefile index to only sourcefiles that are used.
# This is optional but makes the mapping file smaller by pruning
# unused sourcefiles.
source_idx = segment[1]
if source_idx not in used_sources:
used_sources.append(source_idx)
used_sources_filenames.append(sources[source_idx])
segment[1] = used_sources.index(source_idx)
mappings.new_segment(*segment)
mappings.new_group()

return sourcemap.SourceMap(
version=3,
sources=used_sources_filenames,
sources_content=[''] * len(used_sources_filenames),
mappings=mappings.mappings(),
names=[],
)


def parse_mlir_locations(
mlir_dump: list[str],
) -> tuple[dict[int, sourcemap.Segment], list[str]]:
locations: dict[int, Location | Redirect] = {}
source_files = []
for line in mlir_dump:
if line.startswith("#loc"):
src_match = SRC_REGEX.match(line)
if src_match:
match_dict = src_match.groupdict()
filename = match_dict["file"]
locations[int(match_dict["id"])] = Location(
file=filename,
line=int(match_dict["line"]),
col=int(match_dict["col"]),
)
if filename not in source_files:
source_files.append(filename)
continue
scoped_match = SCOPED_REGEX.match(line)
if scoped_match:
match_dict = scoped_match.groupdict()
locations[int(match_dict["id"])] = Redirect(
tgt_id=int(match_dict["tgt_id"])
)
continue
callsite_match = CALLSITE_REGEX.match(line)
if callsite_match:
match_dict = callsite_match.groupdict()
locations[int(match_dict["id"])] = Redirect(
tgt_id=int(match_dict["callee"])
)
continue
if "loc(unknown)" in line:
continue
# Resolve redirects
while True:
new_locations: dict[int, Location | Redirect] = {}
updated = False
for loc_id, loc in locations.items():
if isinstance(loc, Redirect):
new_locations[loc_id] = locations[loc.tgt_id]
updated = True
else:
new_locations[loc_id] = loc
locations = new_locations
if not updated:
break
segment_dict: dict[int, sourcemap.Segment] = {}
for id_, loc in locations.items():
# A segment is a tuple of the form:
# (generated_col, src_file_idx, src_line, src_col)
loc = cast(Location, loc)
segment_dict[id_] = (
0,
source_files.index(loc.file),
loc.line - 1, # Zero-indexed, so offset by 1.
loc.col,
)
return segment_dict, source_files
10 changes: 10 additions & 0 deletions tests/BUILD
Original file line number Diff line number Diff line change
@@ -1584,6 +1584,16 @@ jax_py_test(
],
)

jax_py_test(
name = "source_mapper_test",
srcs = ["source_mapper_test.py"],
deps = [
"//jax",
"//jax:source_mapper",
"//jax:test_util",
],
)

jax_py_test(
name = "sourcemap_test",
srcs = ["sourcemap_test.py"],
86 changes: 86 additions & 0 deletions tests/source_mapper_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import absltest
from absl.testing import parameterized
from jax import numpy as jnp
from jax._src import test_util as jtu
from jax.experimental import source_mapper


class SourceMapperTest(jtu.JaxTestCase):

def test_jaxpr_pass(self):
def jax_fn(x, y):
return x + y
test_x = jnp.array([1, 2, 3])
test_y = jnp.array([4, 5, 6])
source_maps = source_mapper.generate_sourcemaps(
jax_fn,
passes=source_mapper.filter_passes("jaxpr"))(test_x, test_y)
self.assertLen(source_maps, 1)
dump = source_maps[0]
self.assertEqual(dump.pass_name, "jaxpr")
self.assertIn("add a b", dump.generated_code)
source_map = dump.source_map
self.assertLen(source_map.sources, 1)
self.assertEqual(source_map.sources[0],
source_mapper.canonicalize_filename(__file__))
mappings = source_map.mappings
self.assertLen(mappings, len(dump.generated_code.split("\n")) + 1)
gen_col, file_idx, src_line, src_col = mappings[0][0]
# It's hard to guarantee at what column the add instruction will be
# generated in the dump. We just sanity-check that it's greater than 0.
self.assertGreater(gen_col, 0)
# There is only one file, so we should map to that
self.assertEqual(file_idx, 0)
# These should line up with the function definition of jax_fn above.
self.assertEqual(src_line, jax_fn.__code__.co_firstlineno)
self.assertEqual(src_col, 13)

@parameterized.parameters(
("hlo:stable-hlo", "stablehlo.add", 13),
("hlo:original", "add", 0),
("hlo:optimized", "add", 0),
)
def test_hlo_passes(self, pass_name, expected_hlo_op, expected_col):
def jax_fn(x, y):
return x + y
test_x = jnp.array([1, 2, 3])
test_y = jnp.array([4, 5, 6])
source_maps = source_mapper.generate_sourcemaps(
jax_fn,
passes=source_mapper.filter_passes(pass_name))(test_x, test_y)
self.assertLen(source_maps, 1)
dump = source_maps[0]
self.assertEqual(dump.pass_name, pass_name)
self.assertIn(expected_hlo_op, dump.generated_code)
source_map = dump.source_map
self.assertLen(source_map.sources, 1)
self.assertEqual(source_map.sources[0],
source_mapper.canonicalize_filename(__file__))
mappings = source_map.mappings
self.assertLen(mappings, len(dump.generated_code.split("\n")) + 1)
nonempty_mappings = [m for m in mappings if m]
self.assertLen(nonempty_mappings, 1)
gen_col, file_idx, src_line, src_col = nonempty_mappings[0][0]
self.assertGreater(gen_col, 0)
# There is only one file, so we should map to that
self.assertEqual(file_idx, 0)
# These should line up with the function definition of jax_fn above.
self.assertEqual(src_line, jax_fn.__code__.co_firstlineno)
self.assertEqual(src_col, expected_col)


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit f57097f

Please sign in to comment.