Skip to content

Commit

Permalink
Adding support for index builder (#3800)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3800

In this diff,
1. codec can be referred both using desc name or remote path in IndexFromCodec
2. expose serialization of full index through BuildOperator
3. Rename get_local_filename to get_local_filepath.

Reviewed By: satymish

Differential Revision: D61813717
  • Loading branch information
kuarora authored and facebook-github-bot committed Aug 27, 2024
1 parent 084496a commit e708c22
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 15 deletions.
2 changes: 2 additions & 0 deletions benchs/bench_fw/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ def train(
@dataclass
class BuildOperator(IndexOperator):
index_descs: List[IndexDescriptor] = field(default_factory=lambda: [])
serialize_index: bool = False

def get_desc(self, name: str) -> Optional[IndexDescriptor]:
for desc in self.index_descs:
Expand Down Expand Up @@ -312,6 +313,7 @@ def build_index_wrapper(self, index_desc: IndexDescriptor):
path=index_desc.codec_desc.path,
index_name=index_desc.get_name(),
codec_name=index_desc.codec_desc.get_name(),
serialize_full_index=self.serialize_index,
)
index.set_io(self.io)
index_desc.index = index
Expand Down
17 changes: 8 additions & 9 deletions benchs/bench_fw/benchmark_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def merge_rcq_itq(

@dataclass
class BenchmarkIO:
path: str
path: str # local path

def __init__(self, path: str):
self.path = path
Expand All @@ -54,8 +54,7 @@ def __init__(self, path: str):
def clone(self):
return BenchmarkIO(path=self.path)

# TODO(kuarora): rename it as get_local_file
def get_local_filename(self, filename):
def get_local_filepath(self, filename):
if len(filename) > 184:
fn, ext = os.path.splitext(filename)
filename = (
Expand All @@ -72,7 +71,7 @@ def download_file_from_blobstore(
bucket: Optional[str] = None,
path: Optional[str] = None,
):
return self.get_local_filename(filename)
return self.get_local_filepath(filename)

def upload_file_to_blobstore(
self,
Expand All @@ -84,7 +83,7 @@ def upload_file_to_blobstore(
pass

def file_exist(self, filename: str):
fn = self.get_local_filename(filename)
fn = self.get_local_filepath(filename)
exists = os.path.exists(fn)
logger.info(f"{filename} {exists=}")
return exists
Expand Down Expand Up @@ -112,7 +111,7 @@ def write_file(
values: List[Any],
overwrite: bool = False,
):
fn = self.get_local_filename(filename)
fn = self.get_local_filepath(filename)
with ZipFile(fn, "w") as zip_file:
for key, value in zip(keys, values, strict=True):
with zip_file.open(key, "w", force_zip64=True) as f:
Expand Down Expand Up @@ -187,7 +186,7 @@ def write_nparray(
nparray: np.ndarray,
filename: str,
):
fn = self.get_local_filename(filename)
fn = self.get_local_filepath(filename)
logger.info(f"Saving nparray {nparray.shape} to {fn}")
np.save(fn, nparray)
self.upload_file_to_blobstore(filename)
Expand All @@ -209,7 +208,7 @@ def write_json(
filename: str,
overwrite: bool = False,
):
fn = self.get_local_filename(filename)
fn = self.get_local_filepath(filename)
logger.info(f"Saving json {json_dict} to {fn}")
with open(fn, "w") as fp:
json.dump(json_dict, fp)
Expand Down Expand Up @@ -239,7 +238,7 @@ def write_index(
index: faiss.Index,
filename: str,
):
fn = self.get_local_filename(filename)
fn = self.get_local_filepath(filename)
logger.info(f"Saving index to {fn}")
faiss.write_index(index, fn)
self.upload_file_to_blobstore(filename)
Expand Down
2 changes: 1 addition & 1 deletion benchs/bench_fw/descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def name_from_path(self):
name = filename
return name

def alias(self, benchmark_io : BenchmarkIO):
def alias(self, benchmark_io: BenchmarkIO):
if hasattr(benchmark_io, "bucket"):
return CodecDescriptor(desc_name=self.get_name(), bucket=benchmark_io.bucket, path=self.get_path(benchmark_io), d=self.d, metric=self.metric)
return CodecDescriptor(desc_name=self.get_name(), d=self.d, metric=self.metric)
Expand Down
17 changes: 12 additions & 5 deletions benchs/bench_fw/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,12 +786,12 @@ def is_flat_index(self):
# are used to wrap pre-trained Faiss indices (codecs)
@dataclass
class IndexFromCodec(Index):
path: Optional[str] = None
path: Optional[str] = None # remote or local path to the codec

def __post_init__(self):
super().__post_init__()
if self.path is None:
raise ValueError("path is not set")
if self.path is None and self.codec_name is None:
raise ValueError("path or desc_name is not set")

def get_quantizer(self):
if not self.is_ivf():
Expand All @@ -814,10 +814,17 @@ def fetch_meta(self, dry_run=False):
return None, None

def fetch_codec(self):
if self.path is not None:
codec_filename = os.path.basename(self.path)
remote_path = os.path.dirname(self.path)
else:
codec_filename = self.get_codec_name() + "codec"
remote_path = None

codec = self.io.read_index(
os.path.basename(self.path),
codec_filename,
self.bucket,
os.path.dirname(self.path),
remote_path,
)
assert self.d == codec.d
assert self.metric_type == codec.metric_type
Expand Down

0 comments on commit e708c22

Please sign in to comment.