diff --git a/embedding.c b/embedding.c index 743df7f..929a43f 100644 --- a/embedding.c +++ b/embedding.c @@ -295,6 +295,7 @@ hnsw_gettuple(IndexScanDesc scan, ScanDirection dir) int n_items; size_t n_results; label_t* results; + bool search_succeeded; /* Safety check */ if (scan->orderByData == NULL) @@ -313,9 +314,12 @@ hnsw_gettuple(IndexScanDesc scan, ScanDirection dir) n_items, (int)so->hnsw->meta.dim); } - if (!hnsw_search(&so->hnsw->meta, (coord_t*)ARR_DATA_PTR(array), &n_results, &results)) - elog(ERROR, "HNSW index search failed"); + search_succeeded = hnsw_search(&so->hnsw->meta, (coord_t*)ARR_DATA_PTR(array), &n_results, &results); pfree(array); + + if (!search_succeeded) + elog(ERROR, "HNSW index search failed"); + so->results = (ItemPointer)palloc(n_results*sizeof(ItemPointerData)); so->n_results = n_results; for (size_t i = 0; i < n_results; i++) diff --git a/hnswalg.cpp b/hnswalg.cpp index 500b542..0bb2f58 100644 --- a/hnswalg.cpp +++ b/hnswalg.cpp @@ -260,6 +260,8 @@ bool hnsw_search(HnswMetadata* meta, const coord_t *point, size_t* n_results, la auto result = searchKnn(meta, point, meta->efSearch); size_t nResults = result.size(); *results = (label_t*)malloc(nResults*sizeof(label_t)); + if (*results == NULL) + return false; for (size_t i = nResults; i-- != 0;) { (*results)[i] = result.top().second;