|
import timeit |
|
from collections.abc import Iterable, Sequence |
|
from typing import Any, Dict, Optional, Union |
|
|
|
from pytorch_ie import Document, Pipeline |
|
|
|
from src.serializer.interface import DocumentSerializer |
|
|
|
from .logging_utils import get_pylogger |
|
|
|
log = get_pylogger(__name__) |
|
|
|
|
|
def document_batch_iter( |
|
dataset: Iterable[Document], batch_size: int |
|
) -> Iterable[Sequence[Document]]: |
|
if isinstance(dataset, Sequence): |
|
for i in range(0, len(dataset), batch_size): |
|
yield dataset[i : i + batch_size] |
|
elif isinstance(dataset, Iterable): |
|
docs = [] |
|
for doc in dataset: |
|
docs.append(doc) |
|
if len(docs) == batch_size: |
|
yield docs |
|
docs = [] |
|
if docs: |
|
yield docs |
|
else: |
|
raise ValueError(f"Unsupported dataset type: {type(dataset)}") |
|
|
|
|
|
def predict_and_serialize( |
|
pipeline: Optional[Pipeline], |
|
serializer: Optional[DocumentSerializer], |
|
dataset: Iterable[Document], |
|
document_batch_size: Optional[int] = None, |
|
) -> Dict[str, Any]: |
|
result: Dict[str, Any] = {} |
|
if pipeline is not None: |
|
log.info("Starting inference!") |
|
prediction_time = 0.0 |
|
else: |
|
log.warning("No prediction pipeline is defined, skip inference!") |
|
prediction_time = None |
|
docs_batch: Union[Iterable[Document], Sequence[Document]] |
|
|
|
batch_iter: Union[Sequence[Iterable[Document]], Iterable[Sequence[Document]]] |
|
if document_batch_size is None: |
|
batch_iter = [dataset] |
|
else: |
|
batch_iter = document_batch_iter(dataset=dataset, batch_size=document_batch_size) |
|
for docs_batch in batch_iter: |
|
if pipeline is not None: |
|
t_start = timeit.default_timer() |
|
docs_batch = pipeline(docs_batch, inplace=False) |
|
prediction_time += timeit.default_timer() - t_start |
|
|
|
|
|
if serializer is not None: |
|
|
|
|
|
serializer_result = serializer(docs_batch) |
|
if "serializer" in result and result["serializer"] != serializer_result: |
|
log.warning( |
|
f"serializer result changed from {result['serializer']} to {serializer_result}" |
|
" during prediction. Only the last result is returned." |
|
) |
|
result["serializer"] = serializer_result |
|
|
|
if prediction_time is not None: |
|
result["prediction_time"] = prediction_time |
|
|
|
return result |
|
|