|
from __future__ import annotations |
|
|
|
import logging |
|
from functools import partial |
|
from typing import ( |
|
Callable, |
|
Dict, |
|
Iterable, |
|
List, |
|
Optional, |
|
Sequence, |
|
Type, |
|
TypeVar, |
|
Union, |
|
overload, |
|
) |
|
|
|
from pie_datasets import Dataset |
|
from pie_modules.utils import resolve_type |
|
from pytorch_ie import AutoPipeline, WithDocumentTypeMixin |
|
from pytorch_ie.core import Document |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
D = TypeVar("D", bound=Document) |
|
|
|
|
|
def clear_annotation_layers(doc: D, layer_names: List[str], predictions: bool = False) -> None: |
|
for layer_name in layer_names: |
|
if predictions: |
|
doc[layer_name].predictions.clear() |
|
else: |
|
doc[layer_name].clear() |
|
|
|
|
|
def move_annotations_from_predictions(doc: D, layer_names: List[str]) -> None: |
|
for layer_name in layer_names: |
|
annotations = list(doc[layer_name].predictions) |
|
|
|
doc[layer_name].clear() |
|
|
|
doc[layer_name].predictions.clear() |
|
doc[layer_name].extend(annotations) |
|
|
|
|
|
def move_annotations_to_predictions(doc: D, layer_names: List[str]) -> None: |
|
for layer_name in layer_names: |
|
annotations = list(doc[layer_name]) |
|
|
|
doc[layer_name].clear() |
|
|
|
doc[layer_name].predictions.clear() |
|
doc[layer_name].predictions.extend(annotations) |
|
|
|
|
|
def _add_annotations_from_other_document( |
|
doc: D, |
|
from_predictions: bool, |
|
to_predictions: bool, |
|
clear_before: bool, |
|
other_doc: Optional[D] = None, |
|
other_docs_dict: Optional[Dict[str, D]] = None, |
|
layer_names: Optional[List[str]] = None, |
|
) -> D: |
|
if other_doc is None: |
|
if other_docs_dict is None: |
|
raise ValueError("Either other_doc or other_docs_dict must be provided") |
|
other_doc = other_docs_dict.get(doc.id) |
|
if other_doc is None: |
|
logger.warning(f"Document with ID {doc.id} not found in other_docs") |
|
return doc |
|
|
|
|
|
other_doc_copy = type(other_doc).fromdict(other_doc.asdict()) |
|
|
|
if layer_names is None: |
|
layer_names = [field.name for field in doc.annotation_fields()] |
|
|
|
for layer_name in layer_names: |
|
layer = doc[layer_name] |
|
if to_predictions: |
|
layer = layer.predictions |
|
if clear_before: |
|
layer.clear() |
|
other_layer = other_doc_copy[layer_name] |
|
if from_predictions: |
|
other_layer = other_layer.predictions |
|
other_annotations = list(other_layer) |
|
other_layer.clear() |
|
layer.extend(other_annotations) |
|
|
|
return doc |
|
|
|
|
|
def add_annotations_from_other_documents( |
|
docs: Iterable[D], |
|
other_docs: Sequence[Document], |
|
get_other_doc_by_id: bool = False, |
|
**kwargs, |
|
) -> Sequence[D]: |
|
other_id2doc = None |
|
if get_other_doc_by_id: |
|
other_id2doc = {doc.id: doc for doc in other_docs} |
|
|
|
if isinstance(docs, Dataset): |
|
if other_id2doc is None: |
|
raise ValueError("get_other_doc_by_id must be True when passing a Dataset") |
|
result = docs.map( |
|
_add_annotations_from_other_document, |
|
fn_kwargs=dict(other_docs_dict=other_id2doc, **kwargs), |
|
) |
|
elif isinstance(docs, list): |
|
result = [] |
|
for i, doc in enumerate(docs): |
|
if other_id2doc is not None: |
|
other_doc = other_id2doc.get(doc.id) |
|
if other_doc is None: |
|
logger.warning(f"Document with ID {doc.id} not found in other_docs") |
|
continue |
|
else: |
|
other_doc = other_docs[i] |
|
|
|
|
|
doc_id = getattr(doc, "id", None) |
|
other_doc_id = getattr(other_doc, "id", None) |
|
if doc_id is not None and doc_id != other_doc_id: |
|
raise ValueError( |
|
f"IDs of the documents do not match: {doc_id} != {other_doc_id}" |
|
) |
|
|
|
current_result = _add_annotations_from_other_document( |
|
doc, other_doc=other_doc, **kwargs |
|
) |
|
result.append(current_result) |
|
else: |
|
raise ValueError(f"Unsupported type: {type(docs)}") |
|
|
|
return result |
|
|
|
|
|
DM = TypeVar("DM", bound=Dict[str, Iterable[Document]]) |
|
|
|
|
|
def add_annotations_from_other_documents_dict( |
|
docs: DM, other_docs: Dict[str, Sequence[Document]], **kwargs |
|
) -> DM: |
|
if set(docs.keys()) != set(other_docs.keys()): |
|
raise ValueError("Keys of the documents do not match") |
|
|
|
result_dict = { |
|
key: add_annotations_from_other_documents(doc_list, other_docs[key], **kwargs) |
|
for key, doc_list in docs.items() |
|
} |
|
return type(docs)(result_dict) |
|
|
|
|
|
def process_pipeline_steps( |
|
documents: Sequence[Document], |
|
processors: Dict[str, Callable[[Sequence[Document]], Optional[Sequence[Document]]]], |
|
verbose: bool = False, |
|
) -> Sequence[Document]: |
|
|
|
|
|
for step_name, processor in processors.items(): |
|
if verbose: |
|
logger.info(f"process {step_name} ...") |
|
processed_documents = processor(documents) |
|
if processed_documents is not None: |
|
documents = processed_documents |
|
|
|
return documents |
|
|
|
|
|
def process_documents( |
|
documents: List[Document], processor: Callable[..., Optional[Document]], **kwargs |
|
) -> List[Document]: |
|
result = [] |
|
for doc in documents: |
|
processed_doc = processor(doc, **kwargs) |
|
if processed_doc is not None: |
|
result.append(processed_doc) |
|
else: |
|
result.append(doc) |
|
return result |
|
|
|
|
|
class DummyTaskmodule(WithDocumentTypeMixin): |
|
def __init__(self, document_type: Optional[Union[Type[Document], str]]): |
|
if isinstance(document_type, str): |
|
self._document_type = resolve_type(document_type, expected_super_type=Document) |
|
else: |
|
self._document_type = document_type |
|
|
|
@property |
|
def document_type(self) -> Optional[Type[Document]]: |
|
return self._document_type |
|
|
|
|
|
class NerRePipeline: |
|
def __init__( |
|
self, |
|
ner_model_path: str, |
|
re_model_path: str, |
|
entity_layer: str, |
|
relation_layer: str, |
|
device: Optional[int] = None, |
|
batch_size: Optional[int] = None, |
|
show_progress_bar: Optional[bool] = None, |
|
document_type: Optional[Union[Type[Document], str]] = None, |
|
verbose: bool = True, |
|
**processor_kwargs, |
|
): |
|
self.taskmodule = DummyTaskmodule(document_type) |
|
self.ner_model_path = ner_model_path |
|
self.re_model_path = re_model_path |
|
self.processor_kwargs = processor_kwargs or {} |
|
self.entity_layer = entity_layer |
|
self.relation_layer = relation_layer |
|
self.verbose = verbose |
|
|
|
for inference_pipeline in ["ner_pipeline", "re_pipeline"]: |
|
if inference_pipeline not in self.processor_kwargs: |
|
self.processor_kwargs[inference_pipeline] = {} |
|
if "device" not in self.processor_kwargs[inference_pipeline] and device is not None: |
|
self.processor_kwargs[inference_pipeline]["device"] = device |
|
if ( |
|
"batch_size" not in self.processor_kwargs[inference_pipeline] |
|
and batch_size is not None |
|
): |
|
self.processor_kwargs[inference_pipeline]["batch_size"] = batch_size |
|
if ( |
|
"show_progress_bar" not in self.processor_kwargs[inference_pipeline] |
|
and show_progress_bar is not None |
|
): |
|
self.processor_kwargs[inference_pipeline]["show_progress_bar"] = show_progress_bar |
|
|
|
self.ner_pipeline = AutoPipeline.from_pretrained( |
|
self.ner_model_path, **self.processor_kwargs.get("ner_pipeline", {}) |
|
) |
|
self.re_pipeline = AutoPipeline.from_pretrained( |
|
self.re_model_path, **self.processor_kwargs.get("re_pipeline", {}) |
|
) |
|
|
|
@overload |
|
def __call__( |
|
self, documents: Sequence[Document], inplace: bool = False |
|
) -> Sequence[Document]: ... |
|
|
|
@overload |
|
def __call__(self, documents: Document, inplace: bool = False) -> Document: ... |
|
|
|
def __call__( |
|
self, documents: Union[Sequence[Document], Document], inplace: bool = False |
|
) -> Union[Sequence[Document], Document]: |
|
|
|
is_single_doc = False |
|
if isinstance(documents, Document): |
|
documents = [documents] |
|
is_single_doc = True |
|
|
|
input_docs: Sequence[Document] |
|
|
|
original_docs: Sequence[Document] |
|
if inplace: |
|
input_docs = documents |
|
original_docs = [doc.copy() for doc in documents] |
|
else: |
|
input_docs = [doc.copy() for doc in documents] |
|
original_docs = documents |
|
|
|
docs_with_predictions = process_pipeline_steps( |
|
documents=input_docs, |
|
processors={ |
|
"clear_annotations": partial( |
|
process_documents, |
|
processor=clear_annotation_layers, |
|
layer_names=[self.entity_layer, self.relation_layer], |
|
**self.processor_kwargs.get("clear_annotations", {}), |
|
), |
|
"ner_pipeline": self.ner_pipeline, |
|
"use_predicted_entities": partial( |
|
process_documents, |
|
processor=move_annotations_from_predictions, |
|
layer_names=[self.entity_layer], |
|
**self.processor_kwargs.get("use_predicted_entities", {}), |
|
), |
|
"re_pipeline": self.re_pipeline, |
|
|
|
"clear_candidate_relations": partial( |
|
process_documents, |
|
processor=clear_annotation_layers, |
|
layer_names=[self.relation_layer], |
|
**self.processor_kwargs.get("clear_candidate_relations", {}), |
|
), |
|
"move_entities_to_predictions": partial( |
|
process_documents, |
|
processor=move_annotations_to_predictions, |
|
layer_names=[self.entity_layer], |
|
**self.processor_kwargs.get("move_entities_to_predictions", {}), |
|
), |
|
"re_add_gold_data": partial( |
|
add_annotations_from_other_documents, |
|
other_docs=original_docs, |
|
from_predictions=False, |
|
to_predictions=False, |
|
clear_before=False, |
|
layer_names=[self.entity_layer, self.relation_layer], |
|
**self.processor_kwargs.get("re_add_gold_data", {}), |
|
), |
|
}, |
|
verbose=self.verbose, |
|
) |
|
if is_single_doc: |
|
return docs_with_predictions[0] |
|
return docs_with_predictions |
|
|