Skip to content

Commit

Permalink
Move train, build and search to their respective operators (facebookr…
Browse files Browse the repository at this point in the history
…esearch#3934)

Summary:
Pull Request resolved: facebookresearch#3934

Initial thought was to be able to call individual operations on execution operator but it make sense to keep single interface 'execute' and move all these implementations to respective operators.

Reviewed By: satymish

Differential Revision: D63290104

fbshipit-source-id: d1f0b1391c38552c5cdb0a8ea935e23d0d0cb75b
  • Loading branch information
kuarora authored and facebook-github-bot committed Oct 10, 2024
1 parent d243e62 commit 61eaf19
Showing 1 changed file with 114 additions and 110 deletions.
224 changes: 114 additions & 110 deletions benchs/bench_fw/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,10 @@ def build_index_wrapper(self, codec_desc: CodecDescriptor):
else:
assert codec_desc.is_trained()

def train(
def train_one(
self, codec_desc: CodecDescriptor, results: Dict[str, Any], dry_run=False
):
faiss.omp_set_num_threads(codec_desc.num_threads)
self.build_index_wrapper(codec_desc)
if codec_desc.is_trained():
return results, None
Expand All @@ -274,6 +275,16 @@ def train(
results["indices"][codec_desc.get_name()] = meta
return results, requires

def train(self, results, dry_run=False):
for desc in self.codec_descs:
results, requires = self.train_one(desc, results, dry_run=dry_run)
if dry_run:
if requires is None:
continue
return results, requires
assert requires is None
return results, None


@dataclass
class BuildOperator(IndexOperator):
Expand Down Expand Up @@ -322,17 +333,25 @@ def build_index_wrapper(self, index_desc: IndexDescriptor):
else:
assert index_desc.is_built()

def build(self, index_desc: IndexDescriptor, results: Dict[str, Any]):
def build_one(self, index_desc: IndexDescriptor, results: Dict[str, Any]):
faiss.omp_set_num_threads(index_desc.num_threads)
self.build_index_wrapper(index_desc)
if index_desc.is_built():
return
index_desc.index.get_index()

def build(self, results: Dict[str, Any]):
# TODO: add support for dry_run
for index_desc in self.index_descs:
self.build_one(index_desc, results)
return results, None


@dataclass
class SearchOperator(IndexOperator):
knn_descs: List[KnnDescriptor] = field(default_factory=lambda: [])
range: bool = False
compute_gt: bool = True

def get_desc(self, name: str) -> Optional[KnnDescriptor]:
for desc in self.knn_descs:
Expand Down Expand Up @@ -655,85 +674,16 @@ def range_search_benchmark(
index=index,
)


@dataclass
class ExecutionOperator:
distance_metric: str = "L2"
num_threads: int = 1
train_op: Optional[TrainOperator] = None
build_op: Optional[BuildOperator] = None
search_op: Optional[SearchOperator] = None
compute_gt: bool = True

def __post_init__(self):
if self.distance_metric == "IP":
self.distance_metric_type = faiss.METRIC_INNER_PRODUCT
elif self.distance_metric == "L2":
self.distance_metric_type = faiss.METRIC_L2
else:
raise ValueError

def set_io(self, io: BenchmarkIO):
self.io = io
self.io.distance_metric = self.distance_metric
self.io.distance_metric_type = self.distance_metric_type
if self.train_op:
self.train_op.set_io(io)
if self.build_op:
self.build_op.set_io(io)
if self.search_op:
self.search_op.set_io(io)

def train_one(self, codec_desc: CodecDescriptor, results: Dict[str, Any], dry_run):
faiss.omp_set_num_threads(self.num_threads)
assert self.train_op is not None
self.train_op.train(codec_desc, results, dry_run)

def train(self, results, dry_run=False):
faiss.omp_set_num_threads(self.num_threads)
if self.train_op is None:
return

for codec_desc in self.train_op.codec_descs:
self.train_one(codec_desc, results, dry_run)

def build_one(self, results: Dict[str, Any], index_desc: IndexDescriptor):
faiss.omp_set_num_threads(self.num_threads)
assert self.build_op is not None
self.build_op.build(index_desc, results)

def build(self, results: Dict[str, Any]):
faiss.omp_set_num_threads(self.num_threads)
if self.build_op is None:
return

for index_desc in self.build_op.index_descs:
self.build_one(index_desc, results)

def search(self):
faiss.omp_set_num_threads(self.num_threads)
if self.search_op is None:
return

for index_desc in self.search_op.knn_descs:
self.search_one(index_desc)

def search_one(
self,
knn_desc: KnnDescriptor,
results: Dict[str, Any],
dry_run=False,
range=False,
):
faiss.omp_set_num_threads(self.num_threads)
assert self.search_op is not None

if not dry_run and self.compute_gt:
self.create_gt_knn(knn_desc)
self.create_range_ref_knn(knn_desc)

self.search_op.build_index_wrapper(knn_desc)
faiss.omp_set_num_threads(knn_desc.num_threads)

self.build_index_wrapper(knn_desc)
# results, requires = self.reconstruct_benchmark(
# dry_run=True,
# results=results,
Expand All @@ -749,7 +699,7 @@ def search_one(
# index=index_desc.index,
# )
# assert requires is None
results, requires = self.search_op.knn_search_benchmark(
results, requires = self.knn_search_benchmark(
dry_run=True,
results=results,
knn_desc=knn_desc,
Expand All @@ -758,7 +708,7 @@ def search_one(
if dry_run:
return results, requires
else:
results, requires = self.search_op.knn_search_benchmark(
results, requires = self.knn_search_benchmark(
dry_run=False,
results=results,
knn_desc=knn_desc,
Expand All @@ -771,7 +721,7 @@ def search_one(
):
return results, None

ref_index_desc = self.search_op.get_desc(knn_desc.range_ref_index_desc)
ref_index_desc = self.get_desc(knn_desc.range_ref_index_desc)
if ref_index_desc is None:
raise ValueError(
f"{knn_desc.get_name()}: Unknown range index {knn_desc.range_ref_index_desc}"
Expand All @@ -786,17 +736,18 @@ def search_one(
range_search_metric_function,
coefficients,
coefficients_training_data,
) = self.search_op.range_search_reference(
) = self.range_search_reference(
ref_index_desc.index,
ref_index_desc.search_params,
range_metric,
query_dataset=knn_desc.query_dataset,
)
gt_rsm = None
if self.compute_gt:
gt_rsm = self.search_op.range_ground_truth(
gt_rsm = self.range_ground_truth(
gt_radius, range_search_metric_function
)
results, requires = self.search_op.range_search_benchmark(
results, requires = self.range_search_benchmark(
dry_run=True,
results=results,
index=knn_desc.index,
Expand All @@ -805,13 +756,13 @@ def search_one(
gt_radius=gt_radius,
range_search_metric_function=range_search_metric_function,
gt_rsm=gt_rsm,
query_vectors=knn_desc.query_dataset,
query_dataset=knn_desc.query_dataset,
)
if range and requires is not None:
if dry_run:
return results, requires
else:
results, requires = self.search_op.range_search_benchmark(
results, requires = self.range_search_benchmark(
dry_run=False,
results=results,
index=knn_desc.index,
Expand All @@ -820,12 +771,62 @@ def search_one(
gt_radius=gt_radius,
range_search_metric_function=range_search_metric_function,
gt_rsm=gt_rsm,
query_vectors=knn_desc.query_dataset,
query_dataset=knn_desc.query_dataset,
)
assert requires is None

return results, None

def search(
self,
results: Dict[str, Any],
dry_run: bool = False,):
for knn_desc in self.knn_descs:
results, requires = self.search_one(
knn_desc=knn_desc,
results=results,
dry_run=dry_run,
range=self.range)
if dry_run:
if requires is None:
continue
return results, requires

assert requires is None
return results, None


@dataclass
class ExecutionOperator:
distance_metric: str = "L2"
num_threads: int = 1
train_op: Optional[TrainOperator] = None
build_op: Optional[BuildOperator] = None
search_op: Optional[SearchOperator] = None
compute_gt: bool = True

def __post_init__(self):
if self.distance_metric == "IP":
self.distance_metric_type = faiss.METRIC_INNER_PRODUCT
elif self.distance_metric == "L2":
self.distance_metric_type = faiss.METRIC_L2
else:
raise ValueError

if self.search_op is not None:
self.search_op.compute_gt = self.compute_gt

def set_io(self, io: BenchmarkIO):
self.io = io
self.io.distance_metric = self.distance_metric
self.io.distance_metric_type = self.distance_metric_type
if self.train_op:
self.train_op.set_io(io)
if self.build_op:
self.build_op.set_io(io)
if self.search_op:
self.search_op.set_io(io)

def create_gt_codec(
self, codec_desc, results, train=True
) -> Optional[CodecDescriptor]:
Expand All @@ -841,7 +842,7 @@ def create_gt_codec(
)
self.train_op.codec_descs.insert(0, gt_codec_desc)
if train:
self.train_op.train(gt_codec_desc, results, dry_run=False)
self.train_op.train_one(gt_codec_desc, results, dry_run=False)

return gt_codec_desc

Expand All @@ -865,7 +866,7 @@ def create_gt_index(
)
self.build_op.index_descs.insert(0, gt_index_desc)
if build:
self.build_op.build(gt_index_desc, results)
self.build_op.build_one(gt_index_desc, results)

return gt_index_desc

Expand Down Expand Up @@ -906,7 +907,9 @@ def create_range_ref_knn(self, knn_desc):
return

if knn_desc.range_ref_index_desc is not None:
ref_index_desc = self.get_desc(knn_desc.range_ref_index_desc)
ref_index_desc = (
self.search_op.get_desc(knn_desc.range_ref_index_desc)
)
if ref_index_desc is None:
raise ValueError(f"Unknown range index {knn_desc.range_ref_index_desc}")
if ref_index_desc.range_metrics is None:
Expand All @@ -921,19 +924,20 @@ def create_range_ref_knn(self, knn_desc):
range_search_metric_function,
coefficients,
coefficients_training_data,
) = self.range_search_reference(
) = self.search_op.range_search_reference(
knn_desc.index, knn_desc.search_params, range_metric
)
results["metrics"][metric_key] = {
"coefficients": coefficients,
"training_data": coefficients_training_data,
}
knn_desc.gt_rsm = self.range_ground_truth(
knn_desc.gt_rsm = self.search_op.range_ground_truth(
knn_desc.gt_radius, range_search_metric_function
)

def create_ground_truths(self, results: Dict[str, Any]):
# TODO: Create all ground truth descriptors and put them in index descriptor as reference
# TODO: Create all ground truth descriptors and
# put them in index descriptor as reference
if self.train_op is not None:
for codec_desc in self.train_op.codec_descs:
self.create_gt_codec(codec_desc, results)
Expand All @@ -949,33 +953,33 @@ def create_ground_truths(self, results: Dict[str, Any]):
self.create_gt_knn(knn_desc, results)
self.create_range_ref_knn(knn_desc)

def execute(self, results: Dict[str, Any], dry_run: False):
def prepare_gt_or_range_knn(self, results: Dict[str, Any]):
if self.search_op is not None:
for knn_desc in self.search_op.knn_descs:
self.create_gt_knn(knn_desc, results)
self.create_range_ref_knn(knn_desc)

def execute(self, results: Dict[str, Any], dry_run: bool = False):
faiss.omp_set_num_threads(self.num_threads)
if self.train_op is not None:
for desc in self.train_op.codec_descs:
results, requires = self.train_op.train(desc, results, dry_run=dry_run)
if dry_run:
if requires is None:
continue
return results, requires
assert requires is None
results, requires = (
self.train_op.train(results=results, dry_run=dry_run)
)
if dry_run and requires:
return results, requires

if self.build_op is not None:
for desc in self.build_op.index_descs:
self.build_op.build(desc, results)
self.build_op.build(results)

if self.search_op is not None:
for desc in self.search_op.knn_descs:
results, requires = self.search_one(
knn_desc=desc,
results=results,
dry_run=dry_run,
range=self.search_op.range,
)
if dry_run:
if requires is None:
continue
return results, requires
if not dry_run and self.compute_gt:
self.prepare_gt_or_range_knn(results)

assert requires is None
results, requires = (
self.search_op.search(results=results, dry_run=dry_run)
)
if dry_run and requires:
return results, requires
return results, None

def execute_2(self, result_file=None):
Expand Down

0 comments on commit 61eaf19

Please sign in to comment.