ScientificArgumentRecommender / src /demo /retrieve_and_dump_all_relevant.py
ArneBinder's picture
upload https://github.com/ArneBinder/pie-document-level/pull/452
e7eaeed verified
import pyrootutils
root = pyrootutils.setup_root(
search_from=__file__,
indicator=[".project-root"],
pythonpath=True,
dotenv=True,
)
import argparse
import logging
import os
from typing import Dict, List, Optional, Tuple
import pandas as pd
from pie_datasets import Dataset, DatasetDict
from pytorch_ie import Annotation
from pytorch_ie.annotations import BinaryRelation, MultiSpan, Span
from document.types import (
RelatedRelation,
TextDocumentWithLabeledMultiSpansBinaryRelationsLabeledPartitionsAndRelatedRelations,
)
from src.demo.retriever_utils import (
retrieve_all_relevant_spans,
retrieve_all_relevant_spans_for_all_documents,
retrieve_relevant_spans,
)
from src.langchain_modules import DocumentAwareSpanRetrieverWithRelations
logger = logging.getLogger(__name__)
def get_original_doc_id_and_offsets(doc_id: str) -> Tuple[str, int, Optional[int]]:
original_doc_id, middle, start_end, ext = doc_id.split(".")
if middle == "remaining":
return original_doc_id, int(start_end), None
elif middle == "abstract":
start, end = start_end.split("_")
return original_doc_id, int(start), int(end)
else:
raise ValueError(f"unexpected doc_id format: {doc_id}")
def add_base_annotations(
documents: Dict[
str, TextDocumentWithLabeledMultiSpansBinaryRelationsLabeledPartitionsAndRelatedRelations
],
retrieved_doc_ids: List[str],
retriever: DocumentAwareSpanRetrieverWithRelations,
) -> Dict[Tuple[str, Annotation], Tuple[str, Annotation]]:
# (retrieved_doc_id, retrieved_annotation) -> (original_doc_id, original_annotation)
annotation_mapping = {}
for retrieved_doc_id in retrieved_doc_ids:
pie_doc = retriever.get_document(retrieved_doc_id).metadata["pie_document"].copy()
original_doc_id, offset, _ = get_original_doc_id_and_offsets(retrieved_doc_id)
document = documents[original_doc_id]
span_mapping = {}
for span in pie_doc.labeled_multi_spans.predictions:
if isinstance(span, MultiSpan):
new_span = span.copy(
slices=[(start + offset, end + offset) for start, end in span.slices]
)
elif isinstance(span, Span):
new_span = span.copy(start=span.start + offset, end=span.end + offset)
else:
raise ValueError(f"unexpected span type: {span}")
span_mapping[span] = new_span
document.labeled_multi_spans.predictions.extend(span_mapping.values())
for relation in pie_doc.binary_relations.predictions:
new_relation = relation.copy(
head=span_mapping[relation.head], tail=span_mapping[relation.tail]
)
document.binary_relations.predictions.append(new_relation)
for old_ann, new_ann in span_mapping.items():
annotation_mapping[(retrieved_doc_id, old_ann)] = (original_doc_id, new_ann)
return annotation_mapping
def get_doc_and_span_id2annotation_mapping(
span_ids: pd.Series,
doc_ids: pd.Series,
retriever: DocumentAwareSpanRetrieverWithRelations,
base_annotation_mapping: Dict[Tuple[str, Annotation], Tuple[str, Annotation]],
) -> Dict[Tuple[str, str], Tuple[str, Annotation]]:
if len(doc_ids) != len(span_ids):
raise ValueError("doc_ids and span_ids must have the same length")
doc_and_span_ids = zip(doc_ids.tolist(), span_ids.tolist())
return {
(doc_id, span_id): base_annotation_mapping[(doc_id, retriever.get_span_by_id(span_id))]
for doc_id, span_id in set(doc_and_span_ids)
}
def add_result_to_gold_data(
result: pd.DataFrame,
gold_dataset_dir: str,
dataset_out_dir: str,
retriever: DocumentAwareSpanRetrieverWithRelations,
split: Optional[str] = None,
link_relation_label: str = "semantically_same",
reversed_relation_suffix: str = "_reversed",
):
if not os.path.exists(gold_dataset_dir):
raise ValueError(f"gold dataset directory does not exist: {gold_dataset_dir}")
dataset_dict = DatasetDict.from_json(data_dir=gold_dataset_dir)
if split is None and len(dataset_dict) == 1:
split = list(dataset_dict.keys())[0]
if split is None:
raise ValueError("need to provide split name to add results to gold dataset")
dataset = dataset_dict[split]
doc_id2doc = {doc.id: doc for doc in dataset}
retriever_doc_ids = (
result["doc_id"].unique().tolist() + result["query_doc_id"].unique().tolist()
)
base_annotation_mapping = add_base_annotations(
documents=doc_id2doc, retrieved_doc_ids=retriever_doc_ids, retriever=retriever
)
# (retriever_doc_id, retriever_span_id) -> (original_doc_id, original_span)
doc_and_span_id2annotation = {}
doc_and_span_id2annotation.update(
get_doc_and_span_id2annotation_mapping(
span_ids=result["span_id"],
doc_ids=result["doc_id"],
retriever=retriever,
base_annotation_mapping=base_annotation_mapping,
)
)
doc_and_span_id2annotation.update(
get_doc_and_span_id2annotation_mapping(
span_ids=result["ref_span_id"],
doc_ids=result["doc_id"],
retriever=retriever,
base_annotation_mapping=base_annotation_mapping,
)
)
doc_and_span_id2annotation.update(
get_doc_and_span_id2annotation_mapping(
span_ids=result["query_span_id"],
doc_ids=result["query_doc_id"],
retriever=retriever,
base_annotation_mapping=base_annotation_mapping,
)
)
doc_id2head_tail2relation = {}
for doc_id, doc in doc_id2doc.items():
head_and_tail2relation = {}
for relation in doc.binary_relations.predictions:
head_and_tail2relation[(relation.head, relation.tail)] = relation
doc_id2head_tail2relation[doc_id] = head_and_tail2relation
for row in result.itertuples():
query_doc_id, query_span = doc_and_span_id2annotation[
(row.query_doc_id, row.query_span_id)
]
doc_id, span = doc_and_span_id2annotation[(row.doc_id, row.span_id)]
doc_id2, ref_span = doc_and_span_id2annotation[(row.doc_id, row.ref_span_id)]
if doc_id != query_doc_id:
raise ValueError("doc_id and query_doc_id must be the same")
if doc_id != doc_id2:
raise ValueError("doc_id and ref_doc_id must be the same")
doc = doc_id2doc[doc_id]
link_rel = BinaryRelation(
head=query_span, tail=ref_span, label=link_relation_label, score=row.sim_score
)
doc.binary_relations.predictions.append(link_rel)
head_and_tail2relation = doc_id2head_tail2relation[doc_id]
related_rel_label = row.type
if related_rel_label.endswith(reversed_relation_suffix):
base_rel = head_and_tail2relation[(span, ref_span)]
else:
base_rel = head_and_tail2relation[(ref_span, span)]
related_rel = RelatedRelation(
head=query_span,
tail=span,
link_relation=link_rel,
relation=base_rel,
label=related_rel_label,
score=link_rel.score * base_rel.score,
)
doc.related_relations.predictions.append(related_rel)
dataset = Dataset.from_documents(list(doc_id2doc.values()))
dataset_dict = DatasetDict({split: dataset})
if not os.path.exists(dataset_out_dir):
os.makedirs(dataset_out_dir, exist_ok=True)
dataset_dict.to_json(dataset_out_dir)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-c",
"--config_path",
type=str,
default="configs/retriever/related_span_retriever_with_relations_from_other_docs.yaml",
)
parser.add_argument(
"--data_path",
type=str,
required=True,
help="Path to a zip or directory containing a retriever dump.",
)
parser.add_argument("-k", "--top_k", type=int, default=10)
parser.add_argument("-t", "--threshold", type=float, default=0.95)
parser.add_argument(
"-o",
"--output_path",
type=str,
required=True,
)
parser.add_argument(
"--query_doc_id",
type=str,
default=None,
help="If provided, retrieve all spans for only this query document.",
)
parser.add_argument(
"--query_span_id",
type=str,
default=None,
help="If provided, retrieve all spans for only this query span.",
)
parser.add_argument(
"--doc_id_whitelist",
type=str,
nargs="+",
default=None,
help="If provided, only consider documents with these IDs.",
)
parser.add_argument(
"--doc_id_blacklist",
type=str,
nargs="+",
default=None,
help="If provided, ignore documents with these IDs.",
)
parser.add_argument(
"--query_target_doc_id_pairs",
type=str,
nargs="+",
default=None,
help="One or more pairs of query and target document IDs "
'(each separated by ":") to retrieve spans for. If provided, '
"--query_doc_id and --query_span_id are ignored.",
)
parser.add_argument(
"--gold_dataset_dir",
type=str,
default=None,
help="If provided, add the spans and base relations from the retriever data as well "
"as the related relations to the gold dataset.",
)
parser.add_argument(
"--dataset_out_dir",
type=str,
default=None,
help="If provided, save the enriched gold dataset to this directory.",
)
args = parser.parse_args()
logging.basicConfig(
format="%(asctime)s %(levelname)-8s %(message)s",
level=logging.INFO,
datefmt="%Y-%m-%d %H:%M:%S",
)
if not args.output_path.endswith(".json"):
raise ValueError("only support json output")
logger.info(f"instantiating retriever from {args.config_path}...")
retriever = DocumentAwareSpanRetrieverWithRelations.instantiate_from_config_file(
args.config_path
)
logger.info(f"loading data from {args.data_path}...")
retriever.load_from_disc(args.data_path)
search_kwargs = {"k": args.top_k, "score_threshold": args.threshold}
if args.doc_id_whitelist is not None:
search_kwargs["doc_id_whitelist"] = args.doc_id_whitelist
if args.doc_id_blacklist is not None:
search_kwargs["doc_id_blacklist"] = args.doc_id_blacklist
logger.info(f"use search_kwargs: {search_kwargs}")
if args.query_target_doc_id_pairs is not None:
all_spans_for_all_documents = None
for doc_id_pair in args.query_target_doc_id_pairs:
query_doc_id, target_doc_id = doc_id_pair.split(":")
current_result = retrieve_all_relevant_spans(
retriever=retriever,
query_doc_id=query_doc_id,
doc_id_whitelist=[target_doc_id],
**search_kwargs,
)
if current_result is None:
logger.warning(
f"no relevant spans found for query_doc_id={query_doc_id} and "
f"target_doc_id={target_doc_id}"
)
continue
logger.info(
f"retrieved {len(current_result)} spans for query_doc_id={query_doc_id} "
f"and target_doc_id={target_doc_id}"
)
current_result["query_doc_id"] = query_doc_id
if all_spans_for_all_documents is None:
all_spans_for_all_documents = current_result
else:
all_spans_for_all_documents = pd.concat(
[all_spans_for_all_documents, current_result], ignore_index=True
)
elif args.query_span_id is not None:
logger.warning(f"retrieving results for single span: {args.query_span_id}")
all_spans_for_all_documents = retrieve_relevant_spans(
retriever=retriever, query_span_id=args.query_span_id, **search_kwargs
)
elif args.query_doc_id is not None:
logger.warning(f"retrieving results for single document: {args.query_doc_id}")
all_spans_for_all_documents = retrieve_all_relevant_spans(
retriever=retriever, query_doc_id=args.query_doc_id, **search_kwargs
)
else:
all_spans_for_all_documents = retrieve_all_relevant_spans_for_all_documents(
retriever=retriever, **search_kwargs
)
if all_spans_for_all_documents is None:
logger.warning("no relevant spans found in any document")
exit(0)
logger.info(f"dumping results ({len(all_spans_for_all_documents)}) to {args.output_path}...")
os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
all_spans_for_all_documents.to_json(args.output_path)
if args.gold_dataset_dir is not None:
logger.info(
f"reading gold data from {args.gold_dataset_dir} and adding results as predictions ..."
)
if args.dataset_out_dir is None:
raise ValueError("need to provide --dataset_out_dir to save the enriched dataset")
add_result_to_gold_data(
all_spans_for_all_documents,
gold_dataset_dir=args.gold_dataset_dir,
dataset_out_dir=args.dataset_out_dir,
retriever=retriever,
)
logger.info("done")