Skip to content

Commit

Permalink
Re-organize TBE python files (#2583)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2583

- Re-organize TBE python files

Differential Revision: D57239244
  • Loading branch information
q10 authored and facebook-github-bot committed May 15, 2024
1 parent be752af commit 747661f
Show file tree
Hide file tree
Showing 9 changed files with 32 additions and 28 deletions.
5 changes: 1 addition & 4 deletions fbgemm_gpu/fbgemm_gpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,4 @@
from fbgemm_gpu.docs.version import __version__ # noqa: F401, E402

# Trigger meta operator registrations

from . import sparse_ops, split_embeddings_cache_ops # noqa: F401, E402

# from . import sparse_ops # noqa: F401, E402
from . import sparse_ops # noqa: F401, E402
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
import torch # usort:skip
from torch import nn, Tensor # usort:skip

# Load cache ops
import fbgemm_gpu # noqa: F401
import fbgemm_gpu.split_embedding_codegen_lookup_invokers as invokers
from fbgemm_gpu.runtime_monitor import (
AsyncSeriesTimer,
Expand Down
6 changes: 6 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/tbe/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
24 changes: 24 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/tbe/cache/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch

from .split_embeddings_cache_ops import get_unique_indices

lib = torch.library.Library("fbgemm", "FRAGMENT")
lib.define(
"""
get_unique_indices(
Tensor linear_indices,
int max_indices,
bool compute_count=False,
bool compute_inverse_indices=False
) -> (Tensor, Tensor, Tensor?, Tensor?)
"""
)

lib.impl("get_unique_indices", get_unique_indices, "CUDA")
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,7 @@

import torch

lib = torch.library.Library("fbgemm", "FRAGMENT")
lib.define(
"""
get_unique_indices(
Tensor linear_indices,
int max_indices,
bool compute_count=False,
bool compute_inverse_indices=False
) -> (Tensor, Tensor, Tensor?, Tensor?)
"""
)


@torch.library.impl(lib, "get_unique_indices", "CUDA")
def get_unique_indices(
linear_indices: torch.Tensor,
max_indices: int,
Expand Down
3 changes: 0 additions & 3 deletions fbgemm_gpu/test/tbe/cache/cache_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,10 @@
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union

import fbgemm_gpu.split_embeddings_cache_ops # noqa: F401

import numpy as np
import torch
from fbgemm_gpu.runtime_monitor import TBEStatsReporter, TBEStatsReporterConfig
from fbgemm_gpu.split_embedding_configs import SparseType

from fbgemm_gpu.split_embedding_utils import round_up
from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
CacheAlgorithm,
Expand Down
3 changes: 0 additions & 3 deletions fbgemm_gpu/test/tbe/cache/linearize_cache_indices_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,9 @@
import unittest
from typing import Optional

import fbgemm_gpu.split_embeddings_cache_ops # noqa: F401

import torch
from hypothesis import Verbosity

from .. import common # noqa E402
from ..common import open_source


Expand Down
3 changes: 0 additions & 3 deletions fbgemm_gpu/test/tbe/cache/lxu_cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
from itertools import accumulate
from typing import Tuple

import fbgemm_gpu.split_embeddings_cache_ops # noqa: F401

import hypothesis.strategies as st
import numpy as np
import torch
Expand All @@ -24,7 +22,6 @@
from hypothesis import given, settings, Verbosity
from torch import Tensor

from .. import common # noqa E402
from ..common import MAX_EXAMPLES, open_source

if open_source:
Expand Down
1 change: 1 addition & 0 deletions fbgemm_gpu/test/tbe/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import List, Tuple

import fbgemm_gpu
import fbgemm_gpu.tbe.cache # noqa: F401
import numpy as np
import torch
from hypothesis import settings, Verbosity
Expand Down

0 comments on commit 747661f

Please sign in to comment.