From 3a7faf67a8ca434a127a6fe8542afcb595b09349 Mon Sep 17 00:00:00 2001 From: eric Date: Wed, 27 Nov 2024 18:09:04 -0500 Subject: [PATCH] elastic rewrite --- fiftyone/brain/internal/core/elasticsearch.py | 562 +++++++++++++----- 1 file changed, 404 insertions(+), 158 deletions(-) diff --git a/fiftyone/brain/internal/core/elasticsearch.py b/fiftyone/brain/internal/core/elasticsearch.py index 3b45621f..a484d74c 100644 --- a/fiftyone/brain/internal/core/elasticsearch.py +++ b/fiftyone/brain/internal/core/elasticsearch.py @@ -58,6 +58,19 @@ class ElasticsearchSimilarityConfig(SimilarityConfig): bearer_auth (None): a bearer token to use ssl_assert_fingerprint (None): a SHA256 fingerprint to use verify_certs (None): whether to verify SSL certificates + key_field ("filepath"): the name of the FiftyOne sample field used as + the unique identifier to match elastic documents + patch_key_field ("id"): the name of the FiftyOne patch attribute + field used as the unique identifier to match elastic documents, if + ``patches_field`` is provided + backend_key_field ("fiftyone_sample"): the name of the elastic + document source field used as the unique identifier to match + embeddings with FiftyOne samples + backend_patch_key_field ("fiftyone_patch"): the name of the elastic + document source field used to match to a patch, if + ``patches_field`` is provided + backend_vector_field ("vector"): the name of the elastic doc source field + storing the embedding vector **kwargs: keyword arguments for :class:`SimilarityConfig` """ @@ -78,6 +91,11 @@ def __init__( bearer_auth=None, ssl_assert_fingerprint=None, verify_certs=None, + key_field="filepath", + patch_key_field="id", + backend_key_field="fiftyone_sample", + backend_patch_key_field="fiftyone_patch", + backend_vector_field="vector", **kwargs, ): if metric not in _SUPPORTED_METRICS: @@ -86,6 +104,15 @@ def __init__( % (metric, tuple(_SUPPORTED_METRICS.keys())) ) + if ( + backend_key_field == backend_patch_key_field + and patches_field is not None + ): + raise ValueError( + "The backend_key_field and backend_patch_key_field cannot have" + " the same value '%s'" % backend_key_field + ) + super().__init__( embeddings_field=embeddings_field, model=model, @@ -96,6 +123,11 @@ def __init__( self.index_name = index_name self.metric = metric + self.key_field = key_field + self.patch_key_field = patch_key_field + self.backend_key_field = backend_key_field + self.backend_patch_key_field = backend_patch_key_field + self.backend_vector_field = backend_vector_field self._hosts = hosts self._cloud_id = cloud_id @@ -255,6 +287,10 @@ def __init__(self, samples, config, brain_key, backend=None): self._metric = None self._initialize() + @property + def is_patch_index(self): + return self.config.patches_field is not None + @property def total_index_size(self): try: @@ -309,40 +345,11 @@ def _initialize(self): def _get_index_names(self): return self._client.indices.get_alias().keys() - def _get_index_ids(self, batch_size=1000): - sample_ids = [] - label_ids = [] - for batch in range(0, self.total_index_size, batch_size): - response = self._client.search( - index=self.config.index_name, - body={ - "fields": ["sample_id"], - "from": batch, - "query": { - "bool": { - "must": [ - {"exists": {"field": "vector"}}, - {"exists": {"field": "sample_id"}}, - ] - } - }, - }, - source=False, - size=batch_size, - ) - for doc in response["hits"]["hits"]: - sample_id = doc["fields"]["sample_id"][0] - sample_or_label_id = doc["_id"] - sample_ids.append(sample_id) - label_ids.append(sample_or_label_id) - - return sample_ids, label_ids - def _get_dimension(self): if self.total_index_size == 0: return None - if self.config.patches_field is not None: + if self.is_patch_index: embeddings, _, _ = self.get_embeddings( label_ids=self._label_ids[:1] ) @@ -360,7 +367,9 @@ def _get_metric(self): # we may be working with a preexisting index self._metric = self._client.indices.get_mapping( index=self.config.index_name - )[self.config.index_name]["mappings"]["properties"]["vector"][ + )[self.config.index_name]["mappings"]["properties"][ + self.config.backend_vector_field + ][ "similarity" ] except: @@ -381,23 +390,80 @@ def _create_index(self, dimension): metric = _SUPPORTED_METRICS[self.config.metric] mappings = { "properties": { - "vector": { + self.config.backend_vector_field: { "type": "dense_vector", "dims": dimension, "index": "true", "similarity": metric, - } + }, } } + if self.config.backend_key_field != "_id": + mappings["properties"][self.config.backend_key_field] = { + "type": "keyword" + } + if ( + self.is_patch_index + and self.config.backend_patch_key_field != "_id" + ): + mappings["properties"][self.config.backend_patch_key_field] = { + "type": "keyword" + } + self._client.indices.create( index=self.config.index_name, mappings=mappings ) self._metric = metric - def _get_existing_ids(self, ids): - docs = [{"_index": self.config.index_name, "_id": i} for i in ids] - resp = self._client.mget(docs=docs) - return [d["_id"] for d in resp["docs"] if d["found"]] + def _get_existing_ids(self, ids, is_labels=None): + return_labels = False + if self.is_patch_index and is_labels: + key_field = self.config.backend_patch_key_field + return_labels = True + else: + key_field = self.config.backend_key_field + + sample_ids, label_ids, index_ids, _ = self._get_docs( + values=ids, key_field=key_field, include_embeddings=False + ) + if return_labels: + return label_ids, index_ids + + return sample_ids, index_ids + + def _remap_ids_to_keys(self, ids, key_field, incoming_key_field): + unwind = False + if key_field == self.config.patch_key_field: + unwind = True + if ids is not None and key_field != incoming_key_field: + value_map = dict( + zip( + *self.samples.values( + [incoming_key_field, key_field], unwind=unwind + ) + ) + ) + _ids = [value_map[i] for i in ids] + else: + _ids = ids + return _ids + + def _remap_keys_to_ids(self, keys, key_field, outgoing_key_field): + unwind = False + if key_field == self.config.patch_key_field: + unwind = True + if keys is not None and key_field != outgoing_key_field: + value_map = dict( + zip( + *self.samples.values( + [key_field, outgoing_key_field], unwind=unwind + ) + ) + ) + _ids = [value_map.get(i, None) for i in keys] + else: + _ids = keys + return _ids def add_to_index( self, @@ -409,7 +475,22 @@ def add_to_index( warn_existing=False, reload=True, batch_size=500, + _incoming_key_field="id", + _incoming_patch_key_field="id", ): + if self.is_patch_index and label_ids is None: + raise ValueError( + "Label IDs are required to add embeddings to a patch index but" + " none were provided. The patch field for this index is: %s" + % self.config.patches_field + ) + sample_ids = self._remap_ids_to_keys( + sample_ids, self.config.key_field, _incoming_key_field + ) + label_ids = self._remap_ids_to_keys( + label_ids, self.config.patch_key_field, _incoming_patch_key_field + ) + if not self._index_exists(): self._create_index(embeddings.shape[1]) @@ -468,10 +549,28 @@ def add_to_index( ): operations = [] for _e, _id, _sid in zip(_embeddings, _ids, _sample_ids): - operations.append( - {"index": {"_index": self.config.index_name, "_id": _id}} - ) - operations.append({"sample_id": _sid, "vector": _e}) + skip_key_field = False + skip_patch_key_field = False + + op1 = {"index": {"_index": self.config.index_name}} + if self.config.backend_key_field == "_id": + op1["index"]["_id"] = _sid + skip_key_field = True + if ( + self.is_patch_index + and self.config.backend_patch_key_field == "_id" + ): + op1["index"]["_id"] = _id + skip_patch_key_field = True + + op2 = {self.config.backend_vector_field: _e} + if not skip_key_field: + op2[self.config.backend_key_field] = _sid + if self.is_patch_index and not skip_patch_key_field: + op2[self.config.backend_patch_key_field] = _id + + operations.append(op1) + operations.append(op2) self._client.bulk( index=self.config.index_name, @@ -489,14 +588,26 @@ def remove_from_index( allow_missing=True, warn_missing=False, reload=True, + _incoming_key_field="id", + _incoming_patch_key_field="id", ): + sample_ids = self._remap_ids_to_keys( + sample_ids, self.config.key_field, _incoming_key_field + ) + label_ids = self._remap_ids_to_keys( + label_ids, self.config.patch_key_field, _incoming_patch_key_field + ) + is_labels = False if label_ids is not None: ids = label_ids + is_labels = True else: ids = sample_ids + existing_ids, index_ids = self._get_existing_ids( + ids, is_labels=is_labels + ) if not allow_missing or warn_missing: - existing_ids = self._get_existing_ids(ids) missing_ids = set(ids) - set(existing_ids) num_missing = len(missing_ids) @@ -513,11 +624,9 @@ def remove_from_index( num_missing, ) - ids = existing_ids - operations = [ {"delete": {"_index": self.config.index_name, "_id": i}} - for i in ids + for i in index_ids ] self._client.bulk(body=operations, refresh=True) @@ -530,37 +639,37 @@ def get_embeddings( label_ids=None, allow_missing=True, warn_missing=False, + _incoming_key_field="id", + _incoming_patch_key_field="id", ): + sample_ids = self._remap_ids_to_keys( + sample_ids, self.config.key_field, _incoming_key_field + ) + label_ids = self._remap_ids_to_keys( + label_ids, self.config.patch_key_field, _incoming_patch_key_field + ) + + ids = sample_ids + is_labels = False if label_ids is not None: - if self.config.patches_field is None: + ids = label_ids + is_labels = True + if not self.is_patch_index: + # This is an index initially created on full sample + # embeddings, but patches are attempting to be + # accessed raise ValueError("This index does not support label IDs") - if sample_ids is not None: logger.warning( "Ignoring sample IDs when label IDs are provided" ) - if sample_ids is not None and self.config.patches_field is not None: - ( - embeddings, - sample_ids, - label_ids, - missing_ids, - ) = self._get_patch_embeddings_from_sample_ids(sample_ids) - elif self.config.patches_field is not None: - ( - embeddings, - sample_ids, - label_ids, - missing_ids, - ) = self._get_patch_embeddings_from_label_ids(label_ids) - else: - ( - embeddings, - sample_ids, - label_ids, - missing_ids, - ) = self._get_sample_embeddings(sample_ids) + ( + sample_ids, + label_ids, + embeddings, + missing_ids, + ) = self._get_embeddings(ids=ids, is_labels=is_labels) num_missing_ids = len(missing_ids) if num_missing_ids > 0: @@ -576,108 +685,171 @@ def get_embeddings( num_missing_ids, ) + sample_ids = self._remap_keys_to_ids( + sample_ids, self.config.key_field, _incoming_key_field + ) embeddings = np.array(embeddings) sample_ids = np.array(sample_ids) if label_ids is not None: + label_ids = self._remap_keys_to_ids( + label_ids, + self.config.patch_key_field, + _incoming_patch_key_field, + ) label_ids = np.array(label_ids) return embeddings, sample_ids, label_ids - def _parse_embeddings_response(self, response, label_id=True): - found_embeddings = [] - found_sample_ids = [] - found_label_ids = [] - for r in response: - if r.get("found", True): - found_embeddings.append(r["_source"]["vector"]) - if label_id: - found_sample_ids.append(r["_source"]["sample_id"]) - found_label_ids.append(r["_id"]) - else: - found_sample_ids.append(r["_id"]) - - return found_embeddings, found_sample_ids, found_label_ids - - def _get_sample_embeddings(self, sample_ids, batch_size=1000): - found_embeddings = [] - found_sample_ids = [] - - if sample_ids is None: - sample_ids, label_ids = self._get_index_ids(batch_size=batch_size) - - for batch_ids in fou.iter_batches(sample_ids, batch_size): - response = self._client.mget( - index=self.config.index_name, ids=batch_ids, source=True + def _get_embeddings(self, ids=None, batch_size=1000, is_labels=False): + key_field = None + missing_ids = [] + if ids is not None: + key_field = ( + self.config.backend_patch_key_field + if is_labels + else self.config.backend_key_field ) + ( + found_sample_ids, + found_label_ids, + _, + found_embeddings, + ) = self._get_docs( + values=ids, key_field=key_field, batch_size=batch_size + ) - ( - _found_embeddings, - _found_sample_ids, - _, - ) = self._parse_embeddings_response( - response["docs"], label_id=False - ) - found_embeddings += _found_embeddings - found_sample_ids += _found_sample_ids - - missing_ids = list(set(sample_ids) - set(found_sample_ids)) - - return found_embeddings, found_sample_ids, None, missing_ids + if ids is not None: + found_ids = found_label_ids if is_labels else found_sample_ids + missing_ids = list(set(ids) - set(found_ids)) - def _get_patch_embeddings_from_label_ids(self, label_ids, batch_size=1000): - found_embeddings = [] - found_sample_ids = [] - found_label_ids = [] + return found_sample_ids, found_label_ids, found_embeddings, missing_ids - if label_ids is None: - sample_ids, label_ids = self._get_index_ids(batch_size=batch_size) + def _get_docs( + self, + values=None, + key_field=None, + batch_size=1000, + include_embeddings=True, + ): + must_filter = [{"exists": {"field": self.config.backend_vector_field}}] + if key_field: + must_filter.append({"exists": {"field": key_field}}) - for batch_ids in fou.iter_batches(label_ids, batch_size): - response = self._client.mget( - index=self.config.index_name, ids=batch_ids, source=True + if ( + self.is_patch_index + and key_field != self.config.backend_patch_key_field + ): + must_filter.append( + {"exists": {"field": self.config.backend_patch_key_field}} + ) + elif key_field != self.config.backend_key_field: + must_filter.append( + {"exists": {"field": self.config.backend_key_field}} ) - ( - _found_embeddings, - _found_sample_ids, - _found_label_ids, - ) = self._parse_embeddings_response(response["docs"]) - found_embeddings += _found_embeddings - found_sample_ids += _found_sample_ids - found_label_ids += _found_label_ids + fields = [ + self.config.backend_key_field, + self.config.backend_patch_key_field, + ] + if include_embeddings: + fields.append(self.config.backend_vector_field) - missing_ids = list(set(label_ids) - set(found_label_ids)) + if values is not None and key_field is not None: + hits = self._get_docs_query( + values, key_field, fields, must_filter, batch_size=batch_size + ) + else: + hits = self._get_docs_all( + fields, must_filter, batch_size=batch_size + ) - return found_embeddings, found_sample_ids, found_label_ids, missing_ids + return self._parse_hits(hits) - def _get_patch_embeddings_from_sample_ids( - self, sample_ids, batch_size=100 + def _get_docs_query( + self, values, key_field, fields, must_filter, batch_size=1000 ): - found_embeddings = [] - found_sample_ids = [] - found_label_ids = [] + query_field = self._parse_query_field(key_field) - if sample_ids is None: - sample_ids, label_ids = self._get_index_ids(batch_size=batch_size) + hits = [] + for batch_ids in fou.iter_batches(values, batch_size): + terms = {query_field: batch_ids} + response = self._client.search( + index=self.config.index_name, + body={ + "fields": fields, + "size": batch_size, + "_source": False, + "query": { + "bool": {"must": [*must_filter, {"terms": terms}]}, + }, + }, + ) + hits.extend(response["hits"]["hits"]) + return hits - for batch_ids in fou.iter_batches(sample_ids, batch_size): + def _get_docs_all(self, fields, must_filter, batch_size=1000): + hits = [] + for batch in range(0, self.total_index_size, batch_size): response = self._client.search( index=self.config.index_name, - body={"query": {"terms": {"sample_id": sample_ids}}}, + body={ + "fields": fields, + "from": batch, + "size": batch_size, + "_source": False, + "query": {"bool": {"must": must_filter}}, + }, ) + hits.extend(response["hits"]["hits"]) + return hits - ( - _found_embeddings, - _found_sample_ids, - _found_label_ids, - ) = self._parse_embeddings_response(response["hits"]["hits"]) - found_embeddings += _found_embeddings - found_sample_ids += _found_sample_ids - found_label_ids += _found_label_ids + def _parse_hits(self, hits): + sample_ids = [] + label_ids = [] + index_ids = [] + embeddings = [] - missing_ids = list(set(sample_ids) - set(found_sample_ids)) + for hit in hits: + if hit.get("found", True): + sample_id, label_id, vector_id, embedding = self._parse_hit( + hit + ) + sample_ids.append(sample_id) + label_ids.append(label_id) + index_ids.append(vector_id) + embeddings.append(embedding) - return found_embeddings, found_sample_ids, found_label_ids, missing_ids + return sample_ids, label_ids, index_ids, embeddings + + def _parse_hit(self, hit): + label_id = None + source_field = "_source" if "_source" in hit else "fields" + + if self.is_patch_index: + if self.config.backend_patch_key_field == "_id": + label_id = hit["_id"] + else: + label_id = hit[source_field].get( + self.config.backend_patch_key_field, None + ) + if isinstance(label_id, list) and len(label_id) > 0: + label_id = label_id[0] + + if self.config.backend_key_field == "_id": + sample_id = hit["_id"] + else: + sample_id = hit[source_field].get( + self.config.backend_key_field, None + ) + if isinstance(sample_id, list) and len(sample_id) > 0: + sample_id = sample_id[0] + + embedding = hit[source_field].get( + self.config.backend_vector_field, None + ) + vector_id = hit["_id"] + + return sample_id, label_id, vector_id, embedding def cleanup(self): self._client.indices.delete( @@ -716,23 +888,31 @@ def _kneighbors( query = [query] if self.has_view: - if self.config.patches_field is not None: - index_ids = self.current_label_ids + if self.is_patch_index: + key_field = self.config.patch_key_field + backend_key_field = self.config.backend_patch_key_field + current_ids = self.current_label_ids + else: - index_ids = self.current_sample_ids + key_field = self.config.key_field + backend_key_field = self.config.backend_key_field + current_ids = self.current_sample_ids - _filter = {"terms": {"_id": list(index_ids)}} + index_ids = self._remap_ids_to_keys(current_ids, key_field, "id") + filter_field = self._parse_query_field(backend_key_field) + _filter = {"terms": {filter_field: list(index_ids)}} else: _filter = None - ids = [] + sample_ids = [] + label_ids = [] if self.is_patch_index else None dists = [] for q in query: if self._get_metric() == _SUPPORTED_METRICS["dotproduct"]: q /= np.linalg.norm(q) knn = { - "field": "vector", + "field": self.config.backend_vector_field, "query_vector": q.tolist(), "k": k, "num_candidates": 10 * k, @@ -744,16 +924,56 @@ def _kneighbors( index=self.config.index_name, knn=knn, size=k, + fields=[ + self.config.backend_key_field, + self.config.backend_patch_key_field, + ], + ) + + _sample_ids, _label_ids, _, _ = self._parse_hits( + response["hits"]["hits"] ) - ids.append([r["_id"] for r in response["hits"]["hits"]]) + _dists = [r["_score"] for r in response["hits"]["hits"]] + + _sample_ids = self._remap_keys_to_ids( + _sample_ids, self.config.key_field, "id" + ) + missing_inds = [ + ind for ind, _id in enumerate(_sample_ids) if _id is None + ] + if _label_ids is not None and self.is_patch_index: + _label_ids = self._remap_keys_to_ids( + _label_ids, self.config.patch_key_field, "id" + ) + missing_inds.extend( + [ind for ind, _id in enumerate(_label_ids) if _id is None] + ) + + for i in sorted(set(missing_inds), reverse=True): + del _sample_ids[i] + if return_dists: + del _dists[i] + if label_ids is not None: + del _label_ids[i] + if return_dists: - dists.append([r["_score"] for r in response["hits"]["hits"]]) + dists.append(_dists) + sample_ids.append(_sample_ids) + if self.is_patch_index: + label_ids.append(_label_ids) if single_query: - ids = ids[0] + sample_ids = sample_ids[0] + if label_ids is not None: + label_ids = label_ids[0] if return_dists: dists = dists[0] + if self.is_patch_index: + ids = label_ids + else: + ids = sample_ids + if return_dists: return ids, dists @@ -774,18 +994,44 @@ def _parse_neighbors_query(self, query): single_query = False # Query by ID(s) - response = self._client.mget( - index=self.config.index_name, ids=query_ids, source=True - ) - query = np.array( - [r["_source"]["vector"] for r in response["docs"] if r["found"]] + if self.is_patch_index: + key_field = self.config.patch_key_field + backend_key_field = self.config.backend_patch_key_field + else: + key_field = self.config.key_field + backend_key_field = self.config.backend_key_field + + query_ids = self._remap_ids_to_keys(query_ids, key_field, "id") + _, _, _, embeddings = self._get_docs( + values=query_ids, + key_field=backend_key_field, + include_embeddings=True, ) + query = np.array(embeddings) if single_query: query = query[0, :] return query + def _parse_query_field(self, key_field): + # Text fields in elastic need to have `.keyword` appended to them to + # use `terms` search + mapping = self._client.indices.get_mapping( + index=self.config.index_name + ) + properties = mapping[self.config.index_name]["mappings"]["properties"] + if key_field not in properties: + raise ValueError( + "Field %s not found in elastic index %s" + % (key_field, self.config.index_name) + ) + field_type = properties[key_field]["type"] + if field_type == "text": + return key_field + ".keyword" + else: + return key_field + @classmethod def _from_dict(cls, d, samples, config, brain_key): return cls(samples, config, brain_key)