|
import pyrootutils |
|
|
|
root = pyrootutils.setup_root( |
|
search_from=__file__, |
|
indicator=[".project-root"], |
|
pythonpath=True, |
|
dotenv=True, |
|
) |
|
|
|
import argparse |
|
import logging |
|
import os |
|
from typing import Dict, List, Optional, Tuple |
|
|
|
import pandas as pd |
|
from pie_datasets import Dataset, DatasetDict |
|
from pytorch_ie import Annotation |
|
from pytorch_ie.annotations import BinaryRelation, MultiSpan, Span |
|
|
|
from document.types import ( |
|
RelatedRelation, |
|
TextDocumentWithLabeledMultiSpansBinaryRelationsLabeledPartitionsAndRelatedRelations, |
|
) |
|
from src.demo.retriever_utils import ( |
|
retrieve_all_relevant_spans, |
|
retrieve_all_relevant_spans_for_all_documents, |
|
retrieve_relevant_spans, |
|
) |
|
from src.langchain_modules import DocumentAwareSpanRetrieverWithRelations |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def get_original_doc_id_and_offsets(doc_id: str) -> Tuple[str, int, Optional[int]]: |
|
original_doc_id, middle, start_end, ext = doc_id.split(".") |
|
if middle == "remaining": |
|
return original_doc_id, int(start_end), None |
|
elif middle == "abstract": |
|
start, end = start_end.split("_") |
|
return original_doc_id, int(start), int(end) |
|
else: |
|
raise ValueError(f"unexpected doc_id format: {doc_id}") |
|
|
|
|
|
def add_base_annotations( |
|
documents: Dict[ |
|
str, TextDocumentWithLabeledMultiSpansBinaryRelationsLabeledPartitionsAndRelatedRelations |
|
], |
|
retrieved_doc_ids: List[str], |
|
retriever: DocumentAwareSpanRetrieverWithRelations, |
|
) -> Dict[Tuple[str, Annotation], Tuple[str, Annotation]]: |
|
|
|
annotation_mapping = {} |
|
for retrieved_doc_id in retrieved_doc_ids: |
|
pie_doc = retriever.get_document(retrieved_doc_id).metadata["pie_document"].copy() |
|
original_doc_id, offset, _ = get_original_doc_id_and_offsets(retrieved_doc_id) |
|
document = documents[original_doc_id] |
|
span_mapping = {} |
|
for span in pie_doc.labeled_multi_spans.predictions: |
|
if isinstance(span, MultiSpan): |
|
new_span = span.copy( |
|
slices=[(start + offset, end + offset) for start, end in span.slices] |
|
) |
|
elif isinstance(span, Span): |
|
new_span = span.copy(start=span.start + offset, end=span.end + offset) |
|
else: |
|
raise ValueError(f"unexpected span type: {span}") |
|
span_mapping[span] = new_span |
|
document.labeled_multi_spans.predictions.extend(span_mapping.values()) |
|
for relation in pie_doc.binary_relations.predictions: |
|
new_relation = relation.copy( |
|
head=span_mapping[relation.head], tail=span_mapping[relation.tail] |
|
) |
|
document.binary_relations.predictions.append(new_relation) |
|
for old_ann, new_ann in span_mapping.items(): |
|
annotation_mapping[(retrieved_doc_id, old_ann)] = (original_doc_id, new_ann) |
|
|
|
return annotation_mapping |
|
|
|
|
|
def get_doc_and_span_id2annotation_mapping( |
|
span_ids: pd.Series, |
|
doc_ids: pd.Series, |
|
retriever: DocumentAwareSpanRetrieverWithRelations, |
|
base_annotation_mapping: Dict[Tuple[str, Annotation], Tuple[str, Annotation]], |
|
) -> Dict[Tuple[str, str], Tuple[str, Annotation]]: |
|
if len(doc_ids) != len(span_ids): |
|
raise ValueError("doc_ids and span_ids must have the same length") |
|
doc_and_span_ids = zip(doc_ids.tolist(), span_ids.tolist()) |
|
return { |
|
(doc_id, span_id): base_annotation_mapping[(doc_id, retriever.get_span_by_id(span_id))] |
|
for doc_id, span_id in set(doc_and_span_ids) |
|
} |
|
|
|
|
|
def add_result_to_gold_data( |
|
result: pd.DataFrame, |
|
gold_dataset_dir: str, |
|
dataset_out_dir: str, |
|
retriever: DocumentAwareSpanRetrieverWithRelations, |
|
split: Optional[str] = None, |
|
link_relation_label: str = "semantically_same", |
|
reversed_relation_suffix: str = "_reversed", |
|
): |
|
|
|
if not os.path.exists(gold_dataset_dir): |
|
raise ValueError(f"gold dataset directory does not exist: {gold_dataset_dir}") |
|
|
|
dataset_dict = DatasetDict.from_json(data_dir=gold_dataset_dir) |
|
if split is None and len(dataset_dict) == 1: |
|
split = list(dataset_dict.keys())[0] |
|
if split is None: |
|
raise ValueError("need to provide split name to add results to gold dataset") |
|
|
|
dataset = dataset_dict[split] |
|
|
|
doc_id2doc = {doc.id: doc for doc in dataset} |
|
retriever_doc_ids = ( |
|
result["doc_id"].unique().tolist() + result["query_doc_id"].unique().tolist() |
|
) |
|
base_annotation_mapping = add_base_annotations( |
|
documents=doc_id2doc, retrieved_doc_ids=retriever_doc_ids, retriever=retriever |
|
) |
|
|
|
doc_and_span_id2annotation = {} |
|
doc_and_span_id2annotation.update( |
|
get_doc_and_span_id2annotation_mapping( |
|
span_ids=result["span_id"], |
|
doc_ids=result["doc_id"], |
|
retriever=retriever, |
|
base_annotation_mapping=base_annotation_mapping, |
|
) |
|
) |
|
doc_and_span_id2annotation.update( |
|
get_doc_and_span_id2annotation_mapping( |
|
span_ids=result["ref_span_id"], |
|
doc_ids=result["doc_id"], |
|
retriever=retriever, |
|
base_annotation_mapping=base_annotation_mapping, |
|
) |
|
) |
|
doc_and_span_id2annotation.update( |
|
get_doc_and_span_id2annotation_mapping( |
|
span_ids=result["query_span_id"], |
|
doc_ids=result["query_doc_id"], |
|
retriever=retriever, |
|
base_annotation_mapping=base_annotation_mapping, |
|
) |
|
) |
|
doc_id2head_tail2relation = {} |
|
for doc_id, doc in doc_id2doc.items(): |
|
head_and_tail2relation = {} |
|
for relation in doc.binary_relations.predictions: |
|
head_and_tail2relation[(relation.head, relation.tail)] = relation |
|
doc_id2head_tail2relation[doc_id] = head_and_tail2relation |
|
|
|
for row in result.itertuples(): |
|
query_doc_id, query_span = doc_and_span_id2annotation[ |
|
(row.query_doc_id, row.query_span_id) |
|
] |
|
doc_id, span = doc_and_span_id2annotation[(row.doc_id, row.span_id)] |
|
doc_id2, ref_span = doc_and_span_id2annotation[(row.doc_id, row.ref_span_id)] |
|
if doc_id != query_doc_id: |
|
raise ValueError("doc_id and query_doc_id must be the same") |
|
if doc_id != doc_id2: |
|
raise ValueError("doc_id and ref_doc_id must be the same") |
|
doc = doc_id2doc[doc_id] |
|
link_rel = BinaryRelation( |
|
head=query_span, tail=ref_span, label=link_relation_label, score=row.sim_score |
|
) |
|
doc.binary_relations.predictions.append(link_rel) |
|
head_and_tail2relation = doc_id2head_tail2relation[doc_id] |
|
related_rel_label = row.type |
|
if related_rel_label.endswith(reversed_relation_suffix): |
|
base_rel = head_and_tail2relation[(span, ref_span)] |
|
else: |
|
base_rel = head_and_tail2relation[(ref_span, span)] |
|
related_rel = RelatedRelation( |
|
head=query_span, |
|
tail=span, |
|
link_relation=link_rel, |
|
relation=base_rel, |
|
label=related_rel_label, |
|
score=link_rel.score * base_rel.score, |
|
) |
|
doc.related_relations.predictions.append(related_rel) |
|
|
|
dataset = Dataset.from_documents(list(doc_id2doc.values())) |
|
dataset_dict = DatasetDict({split: dataset}) |
|
if not os.path.exists(dataset_out_dir): |
|
os.makedirs(dataset_out_dir, exist_ok=True) |
|
|
|
dataset_dict.to_json(dataset_out_dir) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"-c", |
|
"--config_path", |
|
type=str, |
|
default="configs/retriever/related_span_retriever_with_relations_from_other_docs.yaml", |
|
) |
|
parser.add_argument( |
|
"--data_path", |
|
type=str, |
|
required=True, |
|
help="Path to a zip or directory containing a retriever dump.", |
|
) |
|
parser.add_argument("-k", "--top_k", type=int, default=10) |
|
parser.add_argument("-t", "--threshold", type=float, default=0.95) |
|
parser.add_argument( |
|
"-o", |
|
"--output_path", |
|
type=str, |
|
required=True, |
|
) |
|
parser.add_argument( |
|
"--query_doc_id", |
|
type=str, |
|
default=None, |
|
help="If provided, retrieve all spans for only this query document.", |
|
) |
|
parser.add_argument( |
|
"--query_span_id", |
|
type=str, |
|
default=None, |
|
help="If provided, retrieve all spans for only this query span.", |
|
) |
|
parser.add_argument( |
|
"--doc_id_whitelist", |
|
type=str, |
|
nargs="+", |
|
default=None, |
|
help="If provided, only consider documents with these IDs.", |
|
) |
|
parser.add_argument( |
|
"--doc_id_blacklist", |
|
type=str, |
|
nargs="+", |
|
default=None, |
|
help="If provided, ignore documents with these IDs.", |
|
) |
|
parser.add_argument( |
|
"--query_target_doc_id_pairs", |
|
type=str, |
|
nargs="+", |
|
default=None, |
|
help="One or more pairs of query and target document IDs " |
|
'(each separated by ":") to retrieve spans for. If provided, ' |
|
"--query_doc_id and --query_span_id are ignored.", |
|
) |
|
parser.add_argument( |
|
"--gold_dataset_dir", |
|
type=str, |
|
default=None, |
|
help="If provided, add the spans and base relations from the retriever data as well " |
|
"as the related relations to the gold dataset.", |
|
) |
|
parser.add_argument( |
|
"--dataset_out_dir", |
|
type=str, |
|
default=None, |
|
help="If provided, save the enriched gold dataset to this directory.", |
|
) |
|
args = parser.parse_args() |
|
|
|
logging.basicConfig( |
|
format="%(asctime)s %(levelname)-8s %(message)s", |
|
level=logging.INFO, |
|
datefmt="%Y-%m-%d %H:%M:%S", |
|
) |
|
|
|
if not args.output_path.endswith(".json"): |
|
raise ValueError("only support json output") |
|
|
|
logger.info(f"instantiating retriever from {args.config_path}...") |
|
retriever = DocumentAwareSpanRetrieverWithRelations.instantiate_from_config_file( |
|
args.config_path |
|
) |
|
logger.info(f"loading data from {args.data_path}...") |
|
retriever.load_from_disc(args.data_path) |
|
|
|
search_kwargs = {"k": args.top_k, "score_threshold": args.threshold} |
|
if args.doc_id_whitelist is not None: |
|
search_kwargs["doc_id_whitelist"] = args.doc_id_whitelist |
|
if args.doc_id_blacklist is not None: |
|
search_kwargs["doc_id_blacklist"] = args.doc_id_blacklist |
|
logger.info(f"use search_kwargs: {search_kwargs}") |
|
|
|
if args.query_target_doc_id_pairs is not None: |
|
all_spans_for_all_documents = None |
|
for doc_id_pair in args.query_target_doc_id_pairs: |
|
query_doc_id, target_doc_id = doc_id_pair.split(":") |
|
current_result = retrieve_all_relevant_spans( |
|
retriever=retriever, |
|
query_doc_id=query_doc_id, |
|
doc_id_whitelist=[target_doc_id], |
|
**search_kwargs, |
|
) |
|
if current_result is None: |
|
logger.warning( |
|
f"no relevant spans found for query_doc_id={query_doc_id} and " |
|
f"target_doc_id={target_doc_id}" |
|
) |
|
continue |
|
logger.info( |
|
f"retrieved {len(current_result)} spans for query_doc_id={query_doc_id} " |
|
f"and target_doc_id={target_doc_id}" |
|
) |
|
current_result["query_doc_id"] = query_doc_id |
|
if all_spans_for_all_documents is None: |
|
all_spans_for_all_documents = current_result |
|
else: |
|
all_spans_for_all_documents = pd.concat( |
|
[all_spans_for_all_documents, current_result], ignore_index=True |
|
) |
|
|
|
elif args.query_span_id is not None: |
|
logger.warning(f"retrieving results for single span: {args.query_span_id}") |
|
all_spans_for_all_documents = retrieve_relevant_spans( |
|
retriever=retriever, query_span_id=args.query_span_id, **search_kwargs |
|
) |
|
elif args.query_doc_id is not None: |
|
logger.warning(f"retrieving results for single document: {args.query_doc_id}") |
|
all_spans_for_all_documents = retrieve_all_relevant_spans( |
|
retriever=retriever, query_doc_id=args.query_doc_id, **search_kwargs |
|
) |
|
else: |
|
all_spans_for_all_documents = retrieve_all_relevant_spans_for_all_documents( |
|
retriever=retriever, **search_kwargs |
|
) |
|
|
|
if all_spans_for_all_documents is None: |
|
logger.warning("no relevant spans found in any document") |
|
exit(0) |
|
|
|
logger.info(f"dumping results ({len(all_spans_for_all_documents)}) to {args.output_path}...") |
|
os.makedirs(os.path.dirname(args.output_path), exist_ok=True) |
|
all_spans_for_all_documents.to_json(args.output_path) |
|
|
|
if args.gold_dataset_dir is not None: |
|
logger.info( |
|
f"reading gold data from {args.gold_dataset_dir} and adding results as predictions ..." |
|
) |
|
if args.dataset_out_dir is None: |
|
raise ValueError("need to provide --dataset_out_dir to save the enriched dataset") |
|
add_result_to_gold_data( |
|
all_spans_for_all_documents, |
|
gold_dataset_dir=args.gold_dataset_dir, |
|
dataset_out_dir=args.dataset_out_dir, |
|
retriever=retriever, |
|
) |
|
|
|
logger.info("done") |
|
|