Skip to content

Commit

Permalink
fix: removed record type properties
Browse files Browse the repository at this point in the history
  • Loading branch information
bpiwowar committed Mar 6, 2024
1 parent 2eaaec5 commit 80de710
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 11 deletions.
2 changes: 1 addition & 1 deletion src/xpmir/conversation/learning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def conversations(self):
def __post_init__(self):
super().__post_init__()

self._recordtypes = RecordTypesCache()
self._recordtypes = RecordTypesCache("Conversation", ConversationHistoryItem)

def __iter__(self) -> RandomSerializableIterator[TopicConversationRecord]:
def generator(random: np.random.RandomState):
Expand Down
6 changes: 2 additions & 4 deletions src/xpmir/conversation/models/cosplade.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,13 @@ def forward(self, records: List[TopicConversationRecord]):
)

# List of query/answer couples
answer: Optional[AnswerConversationRecord] = None
answer: Optional[AnswerEntry] = None
for item, _ in zip(
c_record[ConversationHistoryItem].history,
range(self.history_size or sys.maxsize),
):
if isinstance(item, TopicRecord) and answer is not None:
query_answer_pairs.append(
(item[TextItem].text, answer[AnswerEntry].answer)
)
query_answer_pairs.append((item[TextItem].text, answer.answer))
pair_origins.append(ix)
elif isinstance(item, AnswerConversationRecord):
if (answer := item.get(AnswerEntry)) is None:
Expand Down
14 changes: 10 additions & 4 deletions src/xpmir/index/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ def index(
# Get next range
next_range = queues[current.rank].get() # type: DocumentRange
if next_range:
logger.debug("Got next range: %s", next_range)
heapq.heappushpop(heap, next_range)
else:
logger.info("Iterator %d is over", current.rank)
Expand Down Expand Up @@ -358,11 +359,16 @@ def device_execute(
with torch.no_grad():
for batch in iter_batches:
# Signals the output range
queue.put(
DocumentRange(
device_information.rank, batch[0][0], batch[-1][0]
)
document_range = DocumentRange(
device_information.rank, batch[0][0], batch[-1][0]
)
logger.debug(
"Starting range [%d] %s",
device_information.rank,
document_range,
)
queue.put(document_range)

# Outputs the documents
batcher.process(
batch,
Expand Down
15 changes: 13 additions & 2 deletions src/xpmir/utils/iter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
from abc import ABC, abstractmethod
from queue import Full
from queue import Full, Empty
import torch.multiprocessing as mp
from typing import (
Generic,
Expand Down Expand Up @@ -330,6 +330,7 @@ def mp_iterate(iterator, queue: mp.Queue, event: mp.Event):
break

except StopIteration:
logger.info("End of multi-process iterator")
queue.put(STOP_ITERATION)
except Exception as e:
logger.exception("Exception while iterating")
Expand All @@ -340,19 +341,29 @@ class QueueBasedMultiprocessIterator(Iterator[T]):
def __init__(self, queue: "mp.Queue[T]", stop_process: mp.Event):
self.queue = queue
self.stop_process = stop_process
self.stop_iteration = mp.Event()

def __next__(self):
# Get the next element
element = self.queue.get()
while True:
try:
element = self.queue.get(timeout=1)
break
except Empty:
if self.stop_iteration.is_set():
self.stop_process.set()
raise StopIteration()

# Last element
if isinstance(element, StopIterationClass):
# Just in case
self.stop_process.set()
self.stop_iteration.set()
raise StopIteration()

# An exception occurred
elif isinstance(element, Exception):
self.stop_iteration.set()
self.stop_process.set()
raise RuntimeError("Error in iterator process") from element

Expand Down

0 comments on commit 80de710

Please sign in to comment.