from __future__ import annotations import logging from functools import partial from typing import ( Callable, Dict, Iterable, List, Optional, Sequence, Type, TypeVar, Union, overload, ) from pie_datasets import Dataset from pie_modules.utils import resolve_type from pytorch_ie import AutoPipeline, WithDocumentTypeMixin from pytorch_ie.core import Document logger = logging.getLogger(__name__) D = TypeVar("D", bound=Document) def clear_annotation_layers(doc: D, layer_names: List[str], predictions: bool = False) -> None: for layer_name in layer_names: if predictions: doc[layer_name].predictions.clear() else: doc[layer_name].clear() def move_annotations_from_predictions(doc: D, layer_names: List[str]) -> None: for layer_name in layer_names: annotations = list(doc[layer_name].predictions) # remove any previous annotations doc[layer_name].clear() # each annotation can be attached to just one annotation container, so we need to clear the predictions doc[layer_name].predictions.clear() doc[layer_name].extend(annotations) def move_annotations_to_predictions(doc: D, layer_names: List[str]) -> None: for layer_name in layer_names: annotations = list(doc[layer_name]) # each annotation can be attached to just one annotation container, so we need to clear the layer doc[layer_name].clear() # remove any previous annotations doc[layer_name].predictions.clear() doc[layer_name].predictions.extend(annotations) def _add_annotations_from_other_document( doc: D, from_predictions: bool, to_predictions: bool, clear_before: bool, other_doc: Optional[D] = None, other_docs_dict: Optional[Dict[str, D]] = None, layer_names: Optional[List[str]] = None, ) -> D: if other_doc is None: if other_docs_dict is None: raise ValueError("Either other_doc or other_docs_dict must be provided") other_doc = other_docs_dict.get(doc.id) if other_doc is None: logger.warning(f"Document with ID {doc.id} not found in other_docs") return doc # copy to not modify the input other_doc_copy = type(other_doc).fromdict(other_doc.asdict()) if layer_names is None: layer_names = [field.name for field in doc.annotation_fields()] for layer_name in layer_names: layer = doc[layer_name] if to_predictions: layer = layer.predictions if clear_before: layer.clear() other_layer = other_doc_copy[layer_name] if from_predictions: other_layer = other_layer.predictions other_annotations = list(other_layer) other_layer.clear() layer.extend(other_annotations) return doc def add_annotations_from_other_documents( docs: Iterable[D], other_docs: Sequence[Document], get_other_doc_by_id: bool = False, **kwargs, ) -> Sequence[D]: other_id2doc = None if get_other_doc_by_id: other_id2doc = {doc.id: doc for doc in other_docs} if isinstance(docs, Dataset): if other_id2doc is None: raise ValueError("get_other_doc_by_id must be True when passing a Dataset") result = docs.map( _add_annotations_from_other_document, fn_kwargs=dict(other_docs_dict=other_id2doc, **kwargs), ) elif isinstance(docs, list): result = [] for i, doc in enumerate(docs): if other_id2doc is not None: other_doc = other_id2doc.get(doc.id) if other_doc is None: logger.warning(f"Document with ID {doc.id} not found in other_docs") continue else: other_doc = other_docs[i] # check if the IDs of the documents match doc_id = getattr(doc, "id", None) other_doc_id = getattr(other_doc, "id", None) if doc_id is not None and doc_id != other_doc_id: raise ValueError( f"IDs of the documents do not match: {doc_id} != {other_doc_id}" ) current_result = _add_annotations_from_other_document( doc, other_doc=other_doc, **kwargs ) result.append(current_result) else: raise ValueError(f"Unsupported type: {type(docs)}") return result DM = TypeVar("DM", bound=Dict[str, Iterable[Document]]) def add_annotations_from_other_documents_dict( docs: DM, other_docs: Dict[str, Sequence[Document]], **kwargs ) -> DM: if set(docs.keys()) != set(other_docs.keys()): raise ValueError("Keys of the documents do not match") result_dict = { key: add_annotations_from_other_documents(doc_list, other_docs[key], **kwargs) for key, doc_list in docs.items() } return type(docs)(result_dict) def process_pipeline_steps( documents: Sequence[Document], processors: Dict[str, Callable[[Sequence[Document]], Optional[Sequence[Document]]]], verbose: bool = False, ) -> Sequence[Document]: # call the processors in the order they are provided for step_name, processor in processors.items(): if verbose: logger.info(f"process {step_name} ...") processed_documents = processor(documents) if processed_documents is not None: documents = processed_documents return documents def process_documents( documents: List[Document], processor: Callable[..., Optional[Document]], **kwargs ) -> List[Document]: result = [] for doc in documents: processed_doc = processor(doc, **kwargs) if processed_doc is not None: result.append(processed_doc) else: result.append(doc) return result class DummyTaskmodule(WithDocumentTypeMixin): def __init__(self, document_type: Optional[Union[Type[Document], str]]): if isinstance(document_type, str): self._document_type = resolve_type(document_type, expected_super_type=Document) else: self._document_type = document_type @property def document_type(self) -> Optional[Type[Document]]: return self._document_type class NerRePipeline: def __init__( self, ner_model_path: str, re_model_path: str, entity_layer: str, relation_layer: str, device: Optional[int] = None, batch_size: Optional[int] = None, show_progress_bar: Optional[bool] = None, document_type: Optional[Union[Type[Document], str]] = None, verbose: bool = True, **processor_kwargs, ): self.taskmodule = DummyTaskmodule(document_type) self.ner_model_path = ner_model_path self.re_model_path = re_model_path self.processor_kwargs = processor_kwargs or {} self.entity_layer = entity_layer self.relation_layer = relation_layer self.verbose = verbose # set some values for the inference processors, if provided for inference_pipeline in ["ner_pipeline", "re_pipeline"]: if inference_pipeline not in self.processor_kwargs: self.processor_kwargs[inference_pipeline] = {} if "device" not in self.processor_kwargs[inference_pipeline] and device is not None: self.processor_kwargs[inference_pipeline]["device"] = device if ( "batch_size" not in self.processor_kwargs[inference_pipeline] and batch_size is not None ): self.processor_kwargs[inference_pipeline]["batch_size"] = batch_size if ( "show_progress_bar" not in self.processor_kwargs[inference_pipeline] and show_progress_bar is not None ): self.processor_kwargs[inference_pipeline]["show_progress_bar"] = show_progress_bar self.ner_pipeline = AutoPipeline.from_pretrained( self.ner_model_path, **self.processor_kwargs.get("ner_pipeline", {}) ) self.re_pipeline = AutoPipeline.from_pretrained( self.re_model_path, **self.processor_kwargs.get("re_pipeline", {}) ) @overload def __call__( self, documents: Sequence[Document], inplace: bool = False ) -> Sequence[Document]: ... @overload def __call__(self, documents: Document, inplace: bool = False) -> Document: ... def __call__( self, documents: Union[Sequence[Document], Document], inplace: bool = False ) -> Union[Sequence[Document], Document]: is_single_doc = False if isinstance(documents, Document): documents = [documents] is_single_doc = True input_docs: Sequence[Document] # we need to keep the original documents to add the gold data back original_docs: Sequence[Document] if inplace: input_docs = documents original_docs = [doc.copy() for doc in documents] else: input_docs = [doc.copy() for doc in documents] original_docs = documents docs_with_predictions = process_pipeline_steps( documents=input_docs, processors={ "clear_annotations": partial( process_documents, processor=clear_annotation_layers, layer_names=[self.entity_layer, self.relation_layer], **self.processor_kwargs.get("clear_annotations", {}), ), "ner_pipeline": self.ner_pipeline, "use_predicted_entities": partial( process_documents, processor=move_annotations_from_predictions, layer_names=[self.entity_layer], **self.processor_kwargs.get("use_predicted_entities", {}), ), "re_pipeline": self.re_pipeline, # otherwise we can not move the entities back to predictions "clear_candidate_relations": partial( process_documents, processor=clear_annotation_layers, layer_names=[self.relation_layer], **self.processor_kwargs.get("clear_candidate_relations", {}), ), "move_entities_to_predictions": partial( process_documents, processor=move_annotations_to_predictions, layer_names=[self.entity_layer], **self.processor_kwargs.get("move_entities_to_predictions", {}), ), "re_add_gold_data": partial( add_annotations_from_other_documents, other_docs=original_docs, from_predictions=False, to_predictions=False, clear_before=False, layer_names=[self.entity_layer, self.relation_layer], **self.processor_kwargs.get("re_add_gold_data", {}), ), }, verbose=self.verbose, ) if is_single_doc: return docs_with_predictions[0] return docs_with_predictions