From a8fecacbe48e024a811213b81bf51efa298f38c3 Mon Sep 17 00:00:00 2001 From: Benjamin Piwowarski Date: Wed, 28 Feb 2024 15:56:33 +0100 Subject: [PATCH] feat: multi-gpu sparse index building --- src/xpmir/index/sparse.py | 198 +++++++++++++++++++++---- src/xpmir/learning/devices.py | 36 ++++- src/xpmir/neural/__init__.py | 9 +- src/xpmir/neural/interaction/common.py | 6 +- src/xpmir/utils/multiprocessing.py | 47 ++++++ 5 files changed, 249 insertions(+), 47 deletions(-) create mode 100644 src/xpmir/utils/multiprocessing.py diff --git a/src/xpmir/index/sparse.py b/src/xpmir/index/sparse.py index 5b3c62d..7a215cb 100644 --- a/src/xpmir/index/sparse.py +++ b/src/xpmir/index/sparse.py @@ -1,10 +1,14 @@ """Index for sparse models""" +import heapq import torch +from queue import Empty +import torch.multiprocessing as mp import numpy as np import sys from pathlib import Path -from typing import Dict, List, Tuple, Generic +from typing import Dict, List, Tuple, Generic, Iterator, Union +from attrs import define from experimaestro import ( Annotated, Config, @@ -19,10 +23,11 @@ from xpmir.learning import ModuleInitMode from xpmir.learning.batchers import Batcher from xpmir.utils.utils import batchiter, easylog -from xpmir.letor import Device, DEFAULT_DEVICE +from xpmir.letor import Device, DeviceInformation, DEFAULT_DEVICE from xpmir.text.encoders import TextEncoderBase, TextsRepresentationOutput, InputType from xpmir.rankers import Retriever, TopicRecord, ScoredDocument from xpmir.utils.iter import MultiprocessIterator +from xpmir.utils.multiprocessing import StoppableQueue import xpmir_rust logger = easylog() @@ -123,6 +128,22 @@ def retrieve(self, query: TopicRecord, top_k=None) -> List[ScoredDocument]: return self.index.retrieve(query, top_k or self.topk) +@define(frozen=True) +class EncodedDocument: + docid: int + value: torch.Tensor + + +@define(frozen=True) +class DocumentRange: + rank: int + start: int + end: int + + def __lt__(self, other: "DocumentRange"): + return self.start < other.start + + class SparseRetrieverIndexBuilder(Task, Generic[InputType]): """Builds an index from a sparse representation @@ -147,6 +168,7 @@ class SparseRetrieverIndexBuilder(Task, Generic[InputType]): fast top-k strategies""" device: Meta[Device] = DEFAULT_DEVICE + """The device for building the index""" max_postings: Meta[int] = 16384 """Maximum number of postings (per term) before flushing to disk""" @@ -172,52 +194,166 @@ def task_outputs(self, dep): ) def execute(self): - # Encode all documents - logger.info( - f"Load the encoder and transfer to the target device {self.device.value}" + max_docs = ( + self.documents.documentcount + if self.max_docs == 0 + else min(self.max_docs, self.documents.documentcount) ) - self.encoder.initialize(ModuleInitMode.DEFAULT.to_options(None)) - self.encoder.to(self.device.value).eval() - - batcher = self.batcher.initialize(self.batch_size) - - doc_iter = tqdm( - zip( - range(sys.maxsize if self.max_docs == 0 else self.max_docs), - MultiprocessIterator(self.documents.iter_documents()), + iter_batches = tqdm( + MultiprocessIterator( + batchiter( + self.batch_size, + zip( + range(sys.maxsize if self.max_docs == 0 else self.max_docs), + MultiprocessIterator(self.documents.iter_documents()).start(), + ), + ) ), - total=self.documents.documentcount - if self.max_docs == 0 - else min(self.max_docs, self.documents.documentcount), + total=max_docs // self.batch_size, + unit_scale=self.batch_size, + unit="documents", desc="Building the index", ) - # Create the index builder + self.encoder.initialize(ModuleInitMode.DEFAULT.to_options(None)) + + closed = mp.Event() + queues = [ + StoppableQueue(2 * self.batch_size + 1, closed) + for _ in range(self.device.n_processes) + ] + + # Cleanup the index before starting from shutil import rmtree - import xpmir_rust if self.index_path.is_dir(): rmtree(self.index_path) self.index_path.mkdir(parents=True) - self.indexer = xpmir_rust.index.SparseIndexer(str(self.index_path)) + # Start the index process + index_process = mp.Process( + target=self.index, + args=(queues,), + daemon=True, + ) + index_process.start() + + # Waiting for the encoder process to end + logger.info(f"Starting to index {max_docs} documents") + + try: + self.device.execute(self.device_execute, iter_batches, queues) + finally: + logger.info("Waiting for the index process to stop") + index_process.join() + if index_process.exitcode != 0: + logger.warning( + "Indexer process has finished with exit code %d", + index_process.exitcode, + ) + raise RuntimeError("Failure") - # Index - logger.info(f"Starting to index {self.documents.documentcount} documents") + def index( + self, queues: List[StoppableQueue[Union[DocumentRange, EncodedDocument]]] + ): + """Index encoded documents - with torch.no_grad(): - for batch in batchiter(self.batch_size, doc_iter): - batcher.process(batch, self.encode_documents) + :param queues: Queues are used to send tensors + """ + try: + # Get ranges + logger.info( + "Starting the indexing process (%d queues) in %s", + len(queues), + self.index_path, + ) + indexer = xpmir_rust.index.SparseIndexer(str(self.index_path)) + heap = [queue.get() for queue in queues] + heapq.heapify(queues) + + # Loop over them + while heap: + # Process current range + current = heap[0] + logger.debug("Handling range: %s", current) + for docid in range(current.start, current.end + 1): + encoded = queues[current.rank].get() + assert ( + encoded.docid == docid + ), f"Mismatch in document IDs ({encoded.docid} vs {docid})" + + (nonzero_ix,) = encoded.value.nonzero() + indexer.add( + docid, nonzero_ix.astype(np.uint64), encoded.value[nonzero_ix] + ) + + # Get next range + next_range = queues[current.rank].get() # type: DocumentRange + if next_range: + heapq.heappushpop(heap, next_range) + else: + logger.info("Iterator %d is over", current.rank) + heapq.heappop(heap) + + logger.info("Building the index") + indexer.build(self.in_memory) + except Empty: + logger.warning("One encoder got a problem... stopping") + raise + except Exception: + # Close all the queues + logger.exception( + "Got an exception in the indexing process, closing the queues" + ) + queues[0].stop() + raise + + def device_execute( + self, + device_information: DeviceInformation, + iter_batches: Iterator[List[Tuple[int, DocumentRecord]]], + queues: List[StoppableQueue], + ): + try: + # Encode all documents + logger.info( + "Load the encoder and " + f"transfer to the target device {self.device.value}" + ) - # Build the index - self.indexer.build(self.in_memory) + encoder = self.encoder.to(self.device.value).eval() + queue = queues[device_information.rank] + batcher = self.batcher.initialize(self.batch_size) - def encode_documents(self, batch: List[Tuple[int, DocumentRecord]]): + # Index + with torch.no_grad(): + for batch in iter_batches: + # Signals the output range + queue.put( + DocumentRange( + device_information.rank, batch[0][0], batch[-1][0] + ) + ) + # Outputs the documents + batcher.process(batch, self.encode_documents, encoder, queue) + + # Build the index + logger.info("Closing queue %d", device_information.rank) + queue.put(None) + except Exception: + queue.stop() + raise + + def encode_documents( + self, + batch: List[Tuple[int, DocumentRecord]], + encoder: TextEncoderBase[InputType, TextsRepresentationOutput], + queue: "mp.Queue[EncodedDocument]", + ): # Assumes for now dense vectors vectors = ( - self.encoder([d[TextItem].text for _, d in batch]).value.cpu().numpy() + encoder([d[TextItem].text for _, d in batch]).value.cpu().numpy() ) # bs * vocab for vector, (docid, _) in zip(vectors, batch): - (nonzero_ix,) = vector.nonzero() - self.indexer.add(docid, nonzero_ix.astype(np.uint64), vector[nonzero_ix]) + queue.put(EncodedDocument(docid, vector)) diff --git a/src/xpmir/learning/devices.py b/src/xpmir/learning/devices.py index da91585..2e5dd79 100644 --- a/src/xpmir/learning/devices.py +++ b/src/xpmir/learning/devices.py @@ -20,6 +20,12 @@ class DeviceInformation: main: bool """Flag for the main process (all other are slaves)""" + count: int = 1 + """Number of processes""" + + rank: int = 0 + """Rank when using multiple processes""" + class ComputationContext(Context): device_information: DeviceInformation @@ -40,11 +46,14 @@ def value(self): return torch.device("cpu") - def execute(self, callback): - return callback(DeviceInformation(self.value, True)) + n_processes = 1 + """Number of processes""" + + def execute(self, callback, *args, **kwargs): + callback(DeviceInformation(self.value, True), *args, **kwargs) -def mp_launcher(rank, path, world_size, device, callback, taskenv): +def mp_launcher(rank, path, world_size, callback, taskenv, args, kwargs): logger.warning("Launcher of rank %d [%s]", rank, path) TaskEnv._instance = taskenv taskenv.slave = rank == 0 @@ -52,7 +61,12 @@ def mp_launcher(rank, path, world_size, device, callback, taskenv): dist.init_process_group( "gloo", init_method=f"file://{path}", rank=rank, world_size=world_size ) - callback(DistributedDeviceInformation(device, rank == 0, rank)) + device = torch.device(f"cuda:{rank}") + callback( + DistributedDeviceInformation(device, rank == 0, rank, count=world_size), + *args, + **kwargs, + ) # Cleanup dist.destroy_process_group() @@ -94,12 +108,19 @@ def value(self): return torch.device("cuda") - def execute(self, callback): + @cached_property + def n_processes(self): + """Number of processes""" + if self.distributed: + return torch.cuda.device_count() + return 1 + + def execute(self, callback, *args, **kwargs): # Setup distributed computation # Seehttps://pytorch.org/tutorials/intermediate/ddp_tutorial.html n_gpus = torch.cuda.device_count() if n_gpus == 1 or not self.distributed: - callback(DeviceInformation(self.value, True)) + callback(DeviceInformation(self.value, True), *args, **kwargs) else: with tempfile.NamedTemporaryFile() as temporary: logger.info("Setting up distributed CUDA computing (%d GPUs)", n_gpus) @@ -108,9 +129,10 @@ def execute(self, callback): args=( temporary.name, n_gpus, - self.value, callback, TaskEnv.instance(), + args, + kwargs, ), nprocs=n_gpus, join=True, diff --git a/src/xpmir/neural/__init__.py b/src/xpmir/neural/__init__.py index e28b77f..4698006 100644 --- a/src/xpmir/neural/__init__.py +++ b/src/xpmir/neural/__init__.py @@ -1,15 +1,14 @@ from abc import abstractmethod import itertools -from typing import Iterable, Union, List, Optional, TypeVar, Generic +from typing import Iterable, Union, List, Optional, TypeVar, Generic, Sequence import torch from datamaestro_text.data.ir import TextItem -from xpmir.learning.batchers import Sliceable from xpmir.learning.context import TrainerContext from xpmir.letor.records import BaseRecords, ProductRecords, TopicRecord, DocumentRecord from xpmir.rankers import LearnableScorer -QueriesRep = TypeVar("QueriesRep", bound=Sliceable["QueriesRep"]) -DocsRep = TypeVar("DocsRep", bound=Sliceable["DocsRep"]) +QueriesRep = TypeVar("QueriesRep", bound=Sequence) +DocsRep = TypeVar("DocsRep", bound=Sequence) class DualRepresentationScorer(LearnableScorer, Generic[QueriesRep, DocsRep]): @@ -57,7 +56,7 @@ def encode_documents(self, records: Iterable[DocumentRecord]) -> DocsRep: def encode_queries(self, records: Iterable[TopicRecord]) -> QueriesRep: """Encode a list of texts (document or query) - The return value is model dependent, but should be sliceable + The return value is model dependent, but should be sequence By default, uses `merge` """ diff --git a/src/xpmir/neural/interaction/common.py b/src/xpmir/neural/interaction/common.py index 2daa163..7403865 100644 --- a/src/xpmir/neural/interaction/common.py +++ b/src/xpmir/neural/interaction/common.py @@ -1,16 +1,14 @@ from abc import ABC, abstractmethod -from typing import List, Union +from typing import List, Union, Sequence from attrs import evolve import torch from attrs import define from experimaestro import Config -from xpmir.learning.batchers import Sliceable - @define -class SimilarityInput(Sliceable["SimilarityInput"]): +class SimilarityInput(Sequence["SimilarityInput"]): value: torch.Tensor """A 3D tensor (batch x max_length x dim)""" diff --git a/src/xpmir/utils/multiprocessing.py b/src/xpmir/utils/multiprocessing.py new file mode 100644 index 0000000..4005c6f --- /dev/null +++ b/src/xpmir/utils/multiprocessing.py @@ -0,0 +1,47 @@ +"""Index for sparse models""" + +from queue import Full, Empty +import torch.multiprocessing as mp +from typing import Any, Generic, TypeVar +from xpmir.utils.logging import easylog + +logger = easylog() + +T = TypeVar("T") + + +class StoppableQueue(Generic[T]): + """Queue with a stop event flag""" + + def __init__(self, maxsize: int, stopping_event: mp.Event): + self.queue = mp.Queue(maxsize) + self._stopping_event = stopping_event + + def get(self, timeout=1.0): + item = None + while True: + try: + item = self.queue.get(timeout=timeout) + break + except Empty: + if self._stopping_event.is_set(): + raise + + return item + + def put(self, item: Any, timeout=1.0): + while True: + try: + self.queue.put(item, timeout=timeout) + break + except Full: + # Try again... + if self._stopping_event.is_set(): + raise + + def close(self): + self.queue.close() + + def stop(self): + logger.warning("Stopping the iterator") + self._stopping_event.set()