ArneBinder's picture
upload https://github.com/ArneBinder/pie-document-level/pull/452
e7eaeed verified
from __future__ import annotations
import logging
from functools import partial
from typing import (
Callable,
Dict,
Iterable,
List,
Optional,
Sequence,
Type,
TypeVar,
Union,
overload,
)
from pie_datasets import Dataset
from pie_modules.utils import resolve_type
from pytorch_ie import AutoPipeline, WithDocumentTypeMixin
from pytorch_ie.core import Document
logger = logging.getLogger(__name__)
D = TypeVar("D", bound=Document)
def clear_annotation_layers(doc: D, layer_names: List[str], predictions: bool = False) -> None:
for layer_name in layer_names:
if predictions:
doc[layer_name].predictions.clear()
else:
doc[layer_name].clear()
def move_annotations_from_predictions(doc: D, layer_names: List[str]) -> None:
for layer_name in layer_names:
annotations = list(doc[layer_name].predictions)
# remove any previous annotations
doc[layer_name].clear()
# each annotation can be attached to just one annotation container, so we need to clear the predictions
doc[layer_name].predictions.clear()
doc[layer_name].extend(annotations)
def move_annotations_to_predictions(doc: D, layer_names: List[str]) -> None:
for layer_name in layer_names:
annotations = list(doc[layer_name])
# each annotation can be attached to just one annotation container, so we need to clear the layer
doc[layer_name].clear()
# remove any previous annotations
doc[layer_name].predictions.clear()
doc[layer_name].predictions.extend(annotations)
def _add_annotations_from_other_document(
doc: D,
from_predictions: bool,
to_predictions: bool,
clear_before: bool,
other_doc: Optional[D] = None,
other_docs_dict: Optional[Dict[str, D]] = None,
layer_names: Optional[List[str]] = None,
) -> D:
if other_doc is None:
if other_docs_dict is None:
raise ValueError("Either other_doc or other_docs_dict must be provided")
other_doc = other_docs_dict.get(doc.id)
if other_doc is None:
logger.warning(f"Document with ID {doc.id} not found in other_docs")
return doc
# copy to not modify the input
other_doc_copy = type(other_doc).fromdict(other_doc.asdict())
if layer_names is None:
layer_names = [field.name for field in doc.annotation_fields()]
for layer_name in layer_names:
layer = doc[layer_name]
if to_predictions:
layer = layer.predictions
if clear_before:
layer.clear()
other_layer = other_doc_copy[layer_name]
if from_predictions:
other_layer = other_layer.predictions
other_annotations = list(other_layer)
other_layer.clear()
layer.extend(other_annotations)
return doc
def add_annotations_from_other_documents(
docs: Iterable[D],
other_docs: Sequence[Document],
get_other_doc_by_id: bool = False,
**kwargs,
) -> Sequence[D]:
other_id2doc = None
if get_other_doc_by_id:
other_id2doc = {doc.id: doc for doc in other_docs}
if isinstance(docs, Dataset):
if other_id2doc is None:
raise ValueError("get_other_doc_by_id must be True when passing a Dataset")
result = docs.map(
_add_annotations_from_other_document,
fn_kwargs=dict(other_docs_dict=other_id2doc, **kwargs),
)
elif isinstance(docs, list):
result = []
for i, doc in enumerate(docs):
if other_id2doc is not None:
other_doc = other_id2doc.get(doc.id)
if other_doc is None:
logger.warning(f"Document with ID {doc.id} not found in other_docs")
continue
else:
other_doc = other_docs[i]
# check if the IDs of the documents match
doc_id = getattr(doc, "id", None)
other_doc_id = getattr(other_doc, "id", None)
if doc_id is not None and doc_id != other_doc_id:
raise ValueError(
f"IDs of the documents do not match: {doc_id} != {other_doc_id}"
)
current_result = _add_annotations_from_other_document(
doc, other_doc=other_doc, **kwargs
)
result.append(current_result)
else:
raise ValueError(f"Unsupported type: {type(docs)}")
return result
DM = TypeVar("DM", bound=Dict[str, Iterable[Document]])
def add_annotations_from_other_documents_dict(
docs: DM, other_docs: Dict[str, Sequence[Document]], **kwargs
) -> DM:
if set(docs.keys()) != set(other_docs.keys()):
raise ValueError("Keys of the documents do not match")
result_dict = {
key: add_annotations_from_other_documents(doc_list, other_docs[key], **kwargs)
for key, doc_list in docs.items()
}
return type(docs)(result_dict)
def process_pipeline_steps(
documents: Sequence[Document],
processors: Dict[str, Callable[[Sequence[Document]], Optional[Sequence[Document]]]],
verbose: bool = False,
) -> Sequence[Document]:
# call the processors in the order they are provided
for step_name, processor in processors.items():
if verbose:
logger.info(f"process {step_name} ...")
processed_documents = processor(documents)
if processed_documents is not None:
documents = processed_documents
return documents
def process_documents(
documents: List[Document], processor: Callable[..., Optional[Document]], **kwargs
) -> List[Document]:
result = []
for doc in documents:
processed_doc = processor(doc, **kwargs)
if processed_doc is not None:
result.append(processed_doc)
else:
result.append(doc)
return result
class DummyTaskmodule(WithDocumentTypeMixin):
def __init__(self, document_type: Optional[Union[Type[Document], str]]):
if isinstance(document_type, str):
self._document_type = resolve_type(document_type, expected_super_type=Document)
else:
self._document_type = document_type
@property
def document_type(self) -> Optional[Type[Document]]:
return self._document_type
class NerRePipeline:
def __init__(
self,
ner_model_path: str,
re_model_path: str,
entity_layer: str,
relation_layer: str,
device: Optional[int] = None,
batch_size: Optional[int] = None,
show_progress_bar: Optional[bool] = None,
document_type: Optional[Union[Type[Document], str]] = None,
verbose: bool = True,
**processor_kwargs,
):
self.taskmodule = DummyTaskmodule(document_type)
self.ner_model_path = ner_model_path
self.re_model_path = re_model_path
self.processor_kwargs = processor_kwargs or {}
self.entity_layer = entity_layer
self.relation_layer = relation_layer
self.verbose = verbose
# set some values for the inference processors, if provided
for inference_pipeline in ["ner_pipeline", "re_pipeline"]:
if inference_pipeline not in self.processor_kwargs:
self.processor_kwargs[inference_pipeline] = {}
if "device" not in self.processor_kwargs[inference_pipeline] and device is not None:
self.processor_kwargs[inference_pipeline]["device"] = device
if (
"batch_size" not in self.processor_kwargs[inference_pipeline]
and batch_size is not None
):
self.processor_kwargs[inference_pipeline]["batch_size"] = batch_size
if (
"show_progress_bar" not in self.processor_kwargs[inference_pipeline]
and show_progress_bar is not None
):
self.processor_kwargs[inference_pipeline]["show_progress_bar"] = show_progress_bar
self.ner_pipeline = AutoPipeline.from_pretrained(
self.ner_model_path, **self.processor_kwargs.get("ner_pipeline", {})
)
self.re_pipeline = AutoPipeline.from_pretrained(
self.re_model_path, **self.processor_kwargs.get("re_pipeline", {})
)
@overload
def __call__(
self, documents: Sequence[Document], inplace: bool = False
) -> Sequence[Document]: ...
@overload
def __call__(self, documents: Document, inplace: bool = False) -> Document: ...
def __call__(
self, documents: Union[Sequence[Document], Document], inplace: bool = False
) -> Union[Sequence[Document], Document]:
is_single_doc = False
if isinstance(documents, Document):
documents = [documents]
is_single_doc = True
input_docs: Sequence[Document]
# we need to keep the original documents to add the gold data back
original_docs: Sequence[Document]
if inplace:
input_docs = documents
original_docs = [doc.copy() for doc in documents]
else:
input_docs = [doc.copy() for doc in documents]
original_docs = documents
docs_with_predictions = process_pipeline_steps(
documents=input_docs,
processors={
"clear_annotations": partial(
process_documents,
processor=clear_annotation_layers,
layer_names=[self.entity_layer, self.relation_layer],
**self.processor_kwargs.get("clear_annotations", {}),
),
"ner_pipeline": self.ner_pipeline,
"use_predicted_entities": partial(
process_documents,
processor=move_annotations_from_predictions,
layer_names=[self.entity_layer],
**self.processor_kwargs.get("use_predicted_entities", {}),
),
"re_pipeline": self.re_pipeline,
# otherwise we can not move the entities back to predictions
"clear_candidate_relations": partial(
process_documents,
processor=clear_annotation_layers,
layer_names=[self.relation_layer],
**self.processor_kwargs.get("clear_candidate_relations", {}),
),
"move_entities_to_predictions": partial(
process_documents,
processor=move_annotations_to_predictions,
layer_names=[self.entity_layer],
**self.processor_kwargs.get("move_entities_to_predictions", {}),
),
"re_add_gold_data": partial(
add_annotations_from_other_documents,
other_docs=original_docs,
from_predictions=False,
to_predictions=False,
clear_before=False,
layer_names=[self.entity_layer, self.relation_layer],
**self.processor_kwargs.get("re_add_gold_data", {}),
),
},
verbose=self.verbose,
)
if is_single_doc:
return docs_with_predictions[0]
return docs_with_predictions