import pyrootutils root = pyrootutils.setup_root( search_from=__file__, indicator=[".project-root"], pythonpath=True, dotenv=True, ) import argparse import logging import os from typing import Dict, List, Optional, Tuple import pandas as pd from pie_datasets import Dataset, DatasetDict from pytorch_ie import Annotation from pytorch_ie.annotations import BinaryRelation, MultiSpan, Span from document.types import ( RelatedRelation, TextDocumentWithLabeledMultiSpansBinaryRelationsLabeledPartitionsAndRelatedRelations, ) from src.demo.retriever_utils import ( retrieve_all_relevant_spans, retrieve_all_relevant_spans_for_all_documents, retrieve_relevant_spans, ) from src.langchain_modules import DocumentAwareSpanRetrieverWithRelations logger = logging.getLogger(__name__) def get_original_doc_id_and_offsets(doc_id: str) -> Tuple[str, int, Optional[int]]: original_doc_id, middle, start_end, ext = doc_id.split(".") if middle == "remaining": return original_doc_id, int(start_end), None elif middle == "abstract": start, end = start_end.split("_") return original_doc_id, int(start), int(end) else: raise ValueError(f"unexpected doc_id format: {doc_id}") def add_base_annotations( documents: Dict[ str, TextDocumentWithLabeledMultiSpansBinaryRelationsLabeledPartitionsAndRelatedRelations ], retrieved_doc_ids: List[str], retriever: DocumentAwareSpanRetrieverWithRelations, ) -> Dict[Tuple[str, Annotation], Tuple[str, Annotation]]: # (retrieved_doc_id, retrieved_annotation) -> (original_doc_id, original_annotation) annotation_mapping = {} for retrieved_doc_id in retrieved_doc_ids: pie_doc = retriever.get_document(retrieved_doc_id).metadata["pie_document"].copy() original_doc_id, offset, _ = get_original_doc_id_and_offsets(retrieved_doc_id) document = documents[original_doc_id] span_mapping = {} for span in pie_doc.labeled_multi_spans.predictions: if isinstance(span, MultiSpan): new_span = span.copy( slices=[(start + offset, end + offset) for start, end in span.slices] ) elif isinstance(span, Span): new_span = span.copy(start=span.start + offset, end=span.end + offset) else: raise ValueError(f"unexpected span type: {span}") span_mapping[span] = new_span document.labeled_multi_spans.predictions.extend(span_mapping.values()) for relation in pie_doc.binary_relations.predictions: new_relation = relation.copy( head=span_mapping[relation.head], tail=span_mapping[relation.tail] ) document.binary_relations.predictions.append(new_relation) for old_ann, new_ann in span_mapping.items(): annotation_mapping[(retrieved_doc_id, old_ann)] = (original_doc_id, new_ann) return annotation_mapping def get_doc_and_span_id2annotation_mapping( span_ids: pd.Series, doc_ids: pd.Series, retriever: DocumentAwareSpanRetrieverWithRelations, base_annotation_mapping: Dict[Tuple[str, Annotation], Tuple[str, Annotation]], ) -> Dict[Tuple[str, str], Tuple[str, Annotation]]: if len(doc_ids) != len(span_ids): raise ValueError("doc_ids and span_ids must have the same length") doc_and_span_ids = zip(doc_ids.tolist(), span_ids.tolist()) return { (doc_id, span_id): base_annotation_mapping[(doc_id, retriever.get_span_by_id(span_id))] for doc_id, span_id in set(doc_and_span_ids) } def add_result_to_gold_data( result: pd.DataFrame, gold_dataset_dir: str, dataset_out_dir: str, retriever: DocumentAwareSpanRetrieverWithRelations, split: Optional[str] = None, link_relation_label: str = "semantically_same", reversed_relation_suffix: str = "_reversed", ): if not os.path.exists(gold_dataset_dir): raise ValueError(f"gold dataset directory does not exist: {gold_dataset_dir}") dataset_dict = DatasetDict.from_json(data_dir=gold_dataset_dir) if split is None and len(dataset_dict) == 1: split = list(dataset_dict.keys())[0] if split is None: raise ValueError("need to provide split name to add results to gold dataset") dataset = dataset_dict[split] doc_id2doc = {doc.id: doc for doc in dataset} retriever_doc_ids = ( result["doc_id"].unique().tolist() + result["query_doc_id"].unique().tolist() ) base_annotation_mapping = add_base_annotations( documents=doc_id2doc, retrieved_doc_ids=retriever_doc_ids, retriever=retriever ) # (retriever_doc_id, retriever_span_id) -> (original_doc_id, original_span) doc_and_span_id2annotation = {} doc_and_span_id2annotation.update( get_doc_and_span_id2annotation_mapping( span_ids=result["span_id"], doc_ids=result["doc_id"], retriever=retriever, base_annotation_mapping=base_annotation_mapping, ) ) doc_and_span_id2annotation.update( get_doc_and_span_id2annotation_mapping( span_ids=result["ref_span_id"], doc_ids=result["doc_id"], retriever=retriever, base_annotation_mapping=base_annotation_mapping, ) ) doc_and_span_id2annotation.update( get_doc_and_span_id2annotation_mapping( span_ids=result["query_span_id"], doc_ids=result["query_doc_id"], retriever=retriever, base_annotation_mapping=base_annotation_mapping, ) ) doc_id2head_tail2relation = {} for doc_id, doc in doc_id2doc.items(): head_and_tail2relation = {} for relation in doc.binary_relations.predictions: head_and_tail2relation[(relation.head, relation.tail)] = relation doc_id2head_tail2relation[doc_id] = head_and_tail2relation for row in result.itertuples(): query_doc_id, query_span = doc_and_span_id2annotation[ (row.query_doc_id, row.query_span_id) ] doc_id, span = doc_and_span_id2annotation[(row.doc_id, row.span_id)] doc_id2, ref_span = doc_and_span_id2annotation[(row.doc_id, row.ref_span_id)] if doc_id != query_doc_id: raise ValueError("doc_id and query_doc_id must be the same") if doc_id != doc_id2: raise ValueError("doc_id and ref_doc_id must be the same") doc = doc_id2doc[doc_id] link_rel = BinaryRelation( head=query_span, tail=ref_span, label=link_relation_label, score=row.sim_score ) doc.binary_relations.predictions.append(link_rel) head_and_tail2relation = doc_id2head_tail2relation[doc_id] related_rel_label = row.type if related_rel_label.endswith(reversed_relation_suffix): base_rel = head_and_tail2relation[(span, ref_span)] else: base_rel = head_and_tail2relation[(ref_span, span)] related_rel = RelatedRelation( head=query_span, tail=span, link_relation=link_rel, relation=base_rel, label=related_rel_label, score=link_rel.score * base_rel.score, ) doc.related_relations.predictions.append(related_rel) dataset = Dataset.from_documents(list(doc_id2doc.values())) dataset_dict = DatasetDict({split: dataset}) if not os.path.exists(dataset_out_dir): os.makedirs(dataset_out_dir, exist_ok=True) dataset_dict.to_json(dataset_out_dir) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "-c", "--config_path", type=str, default="configs/retriever/related_span_retriever_with_relations_from_other_docs.yaml", ) parser.add_argument( "--data_path", type=str, required=True, help="Path to a zip or directory containing a retriever dump.", ) parser.add_argument("-k", "--top_k", type=int, default=10) parser.add_argument("-t", "--threshold", type=float, default=0.95) parser.add_argument( "-o", "--output_path", type=str, required=True, ) parser.add_argument( "--query_doc_id", type=str, default=None, help="If provided, retrieve all spans for only this query document.", ) parser.add_argument( "--query_span_id", type=str, default=None, help="If provided, retrieve all spans for only this query span.", ) parser.add_argument( "--doc_id_whitelist", type=str, nargs="+", default=None, help="If provided, only consider documents with these IDs.", ) parser.add_argument( "--doc_id_blacklist", type=str, nargs="+", default=None, help="If provided, ignore documents with these IDs.", ) parser.add_argument( "--query_target_doc_id_pairs", type=str, nargs="+", default=None, help="One or more pairs of query and target document IDs " '(each separated by ":") to retrieve spans for. If provided, ' "--query_doc_id and --query_span_id are ignored.", ) parser.add_argument( "--gold_dataset_dir", type=str, default=None, help="If provided, add the spans and base relations from the retriever data as well " "as the related relations to the gold dataset.", ) parser.add_argument( "--dataset_out_dir", type=str, default=None, help="If provided, save the enriched gold dataset to this directory.", ) args = parser.parse_args() logging.basicConfig( format="%(asctime)s %(levelname)-8s %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S", ) if not args.output_path.endswith(".json"): raise ValueError("only support json output") logger.info(f"instantiating retriever from {args.config_path}...") retriever = DocumentAwareSpanRetrieverWithRelations.instantiate_from_config_file( args.config_path ) logger.info(f"loading data from {args.data_path}...") retriever.load_from_disc(args.data_path) search_kwargs = {"k": args.top_k, "score_threshold": args.threshold} if args.doc_id_whitelist is not None: search_kwargs["doc_id_whitelist"] = args.doc_id_whitelist if args.doc_id_blacklist is not None: search_kwargs["doc_id_blacklist"] = args.doc_id_blacklist logger.info(f"use search_kwargs: {search_kwargs}") if args.query_target_doc_id_pairs is not None: all_spans_for_all_documents = None for doc_id_pair in args.query_target_doc_id_pairs: query_doc_id, target_doc_id = doc_id_pair.split(":") current_result = retrieve_all_relevant_spans( retriever=retriever, query_doc_id=query_doc_id, doc_id_whitelist=[target_doc_id], **search_kwargs, ) if current_result is None: logger.warning( f"no relevant spans found for query_doc_id={query_doc_id} and " f"target_doc_id={target_doc_id}" ) continue logger.info( f"retrieved {len(current_result)} spans for query_doc_id={query_doc_id} " f"and target_doc_id={target_doc_id}" ) current_result["query_doc_id"] = query_doc_id if all_spans_for_all_documents is None: all_spans_for_all_documents = current_result else: all_spans_for_all_documents = pd.concat( [all_spans_for_all_documents, current_result], ignore_index=True ) elif args.query_span_id is not None: logger.warning(f"retrieving results for single span: {args.query_span_id}") all_spans_for_all_documents = retrieve_relevant_spans( retriever=retriever, query_span_id=args.query_span_id, **search_kwargs ) elif args.query_doc_id is not None: logger.warning(f"retrieving results for single document: {args.query_doc_id}") all_spans_for_all_documents = retrieve_all_relevant_spans( retriever=retriever, query_doc_id=args.query_doc_id, **search_kwargs ) else: all_spans_for_all_documents = retrieve_all_relevant_spans_for_all_documents( retriever=retriever, **search_kwargs ) if all_spans_for_all_documents is None: logger.warning("no relevant spans found in any document") exit(0) logger.info(f"dumping results ({len(all_spans_for_all_documents)}) to {args.output_path}...") os.makedirs(os.path.dirname(args.output_path), exist_ok=True) all_spans_for_all_documents.to_json(args.output_path) if args.gold_dataset_dir is not None: logger.info( f"reading gold data from {args.gold_dataset_dir} and adding results as predictions ..." ) if args.dataset_out_dir is None: raise ValueError("need to provide --dataset_out_dir to save the enriched dataset") add_result_to_gold_data( all_spans_for_all_documents, gold_dataset_dir=args.gold_dataset_dir, dataset_out_dir=args.dataset_out_dir, retriever=retriever, ) logger.info("done")