Skip to content

Commit

Permalink
Fix gen_ai only build (#2603)
Browse files Browse the repository at this point in the history
Summary:
with-proxy CUDA_HOME=/usr/local/cuda-12.1 python setup.py install -DTORCH_CUDA_ARCH_LIST="8.0;9.0;9.0a" -DFBGEMM_GENAI_ONLY=ON -j 60  --record files.txt 2>&1 | tee out.log

python FBGEMM/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py

Pull Request resolved: #2603

Reviewed By: jiawenliu64

Differential Revision: D57515269

Pulled By: jianyuh

fbshipit-source-id: febbaf5b0adc6f7d69630b17bdff4626ba47d65c
  • Loading branch information
jianyuh authored and facebook-github-bot committed May 19, 2024
1 parent 06e78e1 commit 37c283c
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 5 deletions.
4 changes: 4 additions & 0 deletions fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,3 +259,7 @@ def test_quantize_fp8_per_tensor_with_ub(

zq_ref = (x @ w.T).to(torch.bfloat16)
torch.testing.assert_close(zq, zq_ref, atol=1.0e-3, rtol=1.0e-3)


if __name__ == "__main__":
unittest.main()
21 changes: 17 additions & 4 deletions fbgemm_gpu/fbgemm_gpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,18 @@

try:
torch.ops.load_library(os.path.join(os.path.dirname(__file__), "fbgemm_gpu_py.so"))
except Exception as e:
print(e)
except Exception as error_ranking:
try:
torch.ops.load_library(
os.path.join(
os.path.dirname(__file__),
"experimental/gen_ai/fbgemm_gpu_experimental_gen_ai_py.so",
)
)
except Exception as error_gen_ai:
# When both ranking/gen_ai so files are not available, print the error logs
print(error_ranking)
print(error_gen_ai)

# Since __init__.py is only used in OSS context, we define `open_source` here
# and use its existence to determine whether or not we are in OSS context
Expand All @@ -24,5 +34,8 @@
# Export the version string from the version file auto-generated by setup.py
from fbgemm_gpu.docs.version import __version__ # noqa: F401, E402

# Trigger meta operator registrations
from . import sparse_ops # noqa: F401, E402
try:
# Trigger meta operator registrations
from . import sparse_ops # noqa: F401, E402
except Exception:
pass
5 changes: 4 additions & 1 deletion fbgemm_gpu/fbgemm_gpu/docs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,7 @@
# LICENSE file in the root directory of this source tree.

# Trigger the manual addition of docstrings to pybind11-generated operators
from . import jagged_tensor_ops, table_batched_embedding_ops # noqa: F401
try:
from . import jagged_tensor_ops, table_batched_embedding_ops # noqa: F401
except Exception:
pass

0 comments on commit 37c283c

Please sign in to comment.