|
import argparse |
|
import logging |
|
import os |
|
import re |
|
import shutil |
|
from collections import defaultdict |
|
from typing import Dict, List, Optional, Tuple |
|
|
|
import pandas as pd |
|
from pie_datasets import Dataset, IterableDataset, load_dataset |
|
from pie_datasets.builders.brat import BratDocumentWithMergedSpans |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def find_span_idx(raw_text: str, span_string: str) -> Optional[List]: |
|
""" |
|
Match span string to raw text (document). |
|
Return either |
|
1) Tuple, 2) List of Tuples (more than one span match), or 3) empty List (no span match). |
|
""" |
|
|
|
span_string.strip() |
|
|
|
safe = re.escape(span_string) |
|
pattern = rf"{safe}" |
|
|
|
out = [(s.start(), s.end()) for s in re.finditer(pattern, raw_text)] |
|
return out |
|
|
|
|
|
def append_spans_start_and_end( |
|
raw_text: str, |
|
pd_table: pd.DataFrame, |
|
input_cols: List[str], |
|
input_idx_cols: List[str], |
|
output_cols: List[str], |
|
doc_id_col: str = "doc ID", |
|
) -> pd.DataFrame: |
|
""" |
|
Create new column(s) for span indexes (i.e. start and end as Tuple) in pd.DataFrame from span strings. |
|
Warn if |
|
1) span string does not match anything in document -> None, |
|
2) span string is not unique in the document -> List[Tuple]. |
|
""" |
|
pd_table = pd_table.join(pd.DataFrame(columns=output_cols)) |
|
for idx, pd_row in pd_table.iterrows(): |
|
for in_col, idx_col, out_col in zip(input_cols, input_idx_cols, output_cols): |
|
span_indices = find_span_idx(raw_text, pd_row[in_col]) |
|
str_idx = pd_row[idx_col] |
|
span_idx = None |
|
if span_indices is None or len(span_indices) == 0: |
|
logger.warning( |
|
f'The span "{pd_row[in_col]}" in Column "{in_col}" does not exist in {pd_row[doc_id_col]}.' |
|
) |
|
elif len(span_indices) == 1: |
|
|
|
if str_idx == str_idx: |
|
logger.warning(f'Column "{idx_col}" is not empty. It has value: {str_idx}.') |
|
span_idx = span_indices.pop() |
|
else: |
|
|
|
if str_idx != str_idx: |
|
logger.warning( |
|
f'The span "{pd_row[in_col]}" in Column "{in_col}" is not unique,' |
|
f'but, column "{idx_col}" is empty. ' |
|
f"Need a string index to specify the non-unique span." |
|
) |
|
else: |
|
span_idx = span_indices.pop(int(str_idx)) |
|
|
|
if span_idx is not None: |
|
pd_table.at[idx, out_col] = span_idx |
|
|
|
|
|
search_string = pd_row[in_col] |
|
reconstructed_string = raw_text[span_idx[0] : span_idx[1]] |
|
if search_string != reconstructed_string: |
|
raise ValueError( |
|
f"Reconstructed string does not match the original string. " |
|
f"Original: {search_string}, Reconstructed: {reconstructed_string}" |
|
) |
|
return pd_table |
|
|
|
|
|
def get_texts_from_pie_dataset( |
|
doc_ids: List[str], **dataset_kwargs |
|
) -> Dict[str, BratDocumentWithMergedSpans]: |
|
"""Get texts from a PIE dataset for a list of document IDs. |
|
|
|
:param doc_ids: list of document IDs |
|
:param dataset_kwargs: keyword arguments to pass to load_dataset |
|
|
|
:return: a dictionary with document IDs as keys and texts as values |
|
""" |
|
|
|
text_based_dataset = load_dataset(**dataset_kwargs) |
|
if not isinstance(text_based_dataset, (Dataset, IterableDataset)): |
|
raise ValueError( |
|
f"Expected a PIE Dataset or PIE IterableDataset, but got a {type(text_based_dataset)} instead." |
|
) |
|
if not issubclass(text_based_dataset.document_type, BratDocumentWithMergedSpans): |
|
raise ValueError( |
|
f"Expected a PIE Dataset with BratDocumentWithMergedSpans as document type, " |
|
f"but got {text_based_dataset.document_type} instead." |
|
) |
|
doc_id2text = {doc.id: doc for doc in text_based_dataset} |
|
return {doc_id: doc_id2text[doc_id] for doc_id in doc_ids} |
|
|
|
|
|
def set_span_annotation_ids( |
|
table: pd.DataFrame, |
|
doc_id2doc: Dict[str, BratDocumentWithMergedSpans], |
|
doc_id_col: str, |
|
span_idx_cols: List[str], |
|
span_id_cols: List[str], |
|
) -> pd.DataFrame: |
|
""" |
|
Create new column(s) for span annotation IDs in pd.DataFrame from span indexes. The annotation IDs are |
|
retrieved from the TextBasedDocument object using the span indexes. |
|
|
|
:param table: pd.DataFrame with span indexes, document IDs, and other columns |
|
:param doc_id2doc: dictionary with document IDs as keys and BratDocumentWithMergedSpans objects as values |
|
:param doc_id_col: column name that contains document IDs |
|
:param span_idx_cols: column names that contain span indexes |
|
:param span_id_cols: column names for new span ID columns |
|
|
|
:return: pd.DataFrame with new columns for span annotation IDs |
|
""" |
|
table = table.join(pd.DataFrame(columns=span_id_cols)) |
|
span2id: Dict[str, Dict[Tuple[int, int], str]] = defaultdict(dict) |
|
for doc_id, doc in doc_id2doc.items(): |
|
for span_id, span in zip(doc.metadata["span_ids"], doc.spans): |
|
span2id[doc_id][(span.start, span.end)] = span_id |
|
|
|
for span_idx_col, span_id_col in zip(span_idx_cols, span_id_cols): |
|
table[span_id_col] = table.apply( |
|
lambda row: span2id[row[doc_id_col]][tuple(row[span_idx_col])], axis=1 |
|
) |
|
|
|
return table |
|
|
|
|
|
def set_relation_annotation_ids( |
|
table: pd.DataFrame, |
|
doc_id2doc: Dict[str, BratDocumentWithMergedSpans], |
|
doc_id_col: str, |
|
relation_id_col: str, |
|
) -> pd.DataFrame: |
|
"""create new column for relation annotation IDs in pd.DataFrame. They are simply new ids starting from the last |
|
relation annotation id in the document. |
|
|
|
:param table: pd.DataFrame with document IDs and other columns |
|
:param doc_id2doc: dictionary with document IDs as keys and BratDocumentWithMergedSpans objects as values |
|
:param doc_id_col: column name that contains document IDs |
|
:param relation_id_col: column name for new relation ID column |
|
|
|
:return: pd.DataFrame with new column for relation annotation IDs |
|
""" |
|
|
|
table = table.join(pd.DataFrame(columns=[relation_id_col])) |
|
doc_id2highest_relation_id = defaultdict(int) |
|
|
|
for doc_id, doc in doc_id2doc.items(): |
|
|
|
doc_id2highest_relation_id[doc_id] = max( |
|
[int(relation_id[1:]) for relation_id in doc.metadata["relation_ids"]] |
|
) |
|
|
|
for idx, row in table.iterrows(): |
|
doc_id = row[doc_id_col] |
|
doc_id2highest_relation_id[doc_id] += 1 |
|
table.at[idx, relation_id_col] = f"R{doc_id2highest_relation_id[doc_id]}" |
|
|
|
return table |
|
|
|
|
|
def main( |
|
input_path: str, |
|
output_path: str, |
|
brat_data_dir: str, |
|
input_encoding: str, |
|
include_unsure: bool = False, |
|
doc_id_col: str = "doc ID", |
|
unsure_col: str = "unsure", |
|
span_str_cols: List[str] = ["head argument string", "tail argument string"], |
|
str_idx_cols: List[str] = ["head string index", "tail string index"], |
|
span_idx_cols: List[str] = ["head_span_idx", "tail_span_idx"], |
|
span_id_cols: List[str] = ["head_span_id", "tail_span_id"], |
|
relation_id_col: str = "relation_id", |
|
set_annotation_ids: bool = False, |
|
relation_type: str = "relation", |
|
) -> None: |
|
""" |
|
Convert long dependency annotations from a CSV file to a JSON format. The input table should have |
|
columns for document IDs, argument span strings, and string indexes (required in the case that the |
|
span string occurs multiple times in the base text). The argument span strings are matched to the |
|
base text to get the start and end indexes of the span. The output JSON file will have the same |
|
columns as the input file, plus two additional columns for the start and end indexes of the spans. |
|
|
|
:param input_path: path to a CSV/Excel file that contains annotations |
|
:param output_path: path to save JSON output |
|
:param brat_data_dir: directory where the BRAT data (base texts and annotations) is located |
|
:param input_encoding: encoding of the input file. Only used for CSV files. Default: "cp1252" |
|
:param include_unsure: include annotations marked as unsure. Default: False |
|
:param doc_id_col: column name that contains document IDs. Default: "doc ID" |
|
:param unsure_col: column name that contains unsure annotations. Default: "unsure" |
|
:param span_str_cols: column names that contain span strings. Default: ["head argument string", "tail argument string"] |
|
:param str_idx_cols: column names that contain string indexes. Default: ["head string index", "tail string index"] |
|
:param span_idx_cols: column names for new span-index columns. Default: ["head_span_idx", "tail_span_idx"] |
|
:param span_id_cols: column names for new span-ID columns. Default: ["head_span_id", "tail_span_id"] |
|
:param relation_id_col: column name for new relation-ID column. Default: "relation_id" |
|
:param set_annotation_ids: set annotation IDs for the spans and relations. Default: False |
|
:param relation_type: specify the relation type for the BRAT output. Default: "relation" |
|
|
|
:return: None |
|
""" |
|
|
|
if input_path.lower().endswith(".csv"): |
|
input_df = pd.read_csv(input_path, encoding=input_encoding) |
|
elif input_path.lower().endswith(".xlsx"): |
|
logger.warning( |
|
f"encoding parameter (--input-encoding={input_encoding}) is ignored for Excel files." |
|
) |
|
input_df = pd.read_excel(input_path) |
|
else: |
|
raise ValueError("Input file has unexpected format. Please provide a CSV or Excel file.") |
|
|
|
|
|
if not include_unsure: |
|
input_df = input_df[input_df[unsure_col].isna()] |
|
|
|
input_df = input_df.dropna(axis=1, how="all") |
|
|
|
|
|
result_df = pd.DataFrame(columns=[*input_df.columns, *span_idx_cols]) |
|
|
|
|
|
doc_ids = list(input_df[doc_id_col].unique()) |
|
|
|
|
|
doc_id2doc = get_texts_from_pie_dataset( |
|
doc_ids=doc_ids, |
|
path="pie/brat", |
|
name="merge_fragmented_spans", |
|
split="train", |
|
revision="769a15e44e7d691148dd05e54ae2b058ceaed1f0", |
|
base_dataset_kwargs=dict(data_dir=brat_data_dir), |
|
) |
|
|
|
for doc_id in doc_ids: |
|
|
|
|
|
doc_df = input_df[input_df[doc_id_col] == doc_id] |
|
input_df = input_df.drop(doc_df.index) |
|
|
|
doc_with_span_indices_df = append_spans_start_and_end( |
|
raw_text=doc_id2doc[doc_id].text, |
|
pd_table=doc_df, |
|
input_cols=span_str_cols, |
|
input_idx_cols=str_idx_cols, |
|
output_cols=span_idx_cols, |
|
) |
|
|
|
result_df = pd.concat( |
|
[result_df if not result_df.empty else None, doc_with_span_indices_df] |
|
) |
|
|
|
out_ext = os.path.splitext(output_path)[1] |
|
save_as_brat = out_ext == "" |
|
|
|
if set_annotation_ids or save_as_brat: |
|
result_df = set_span_annotation_ids( |
|
table=result_df, |
|
doc_id2doc=doc_id2doc, |
|
doc_id_col=doc_id_col, |
|
span_idx_cols=span_idx_cols, |
|
span_id_cols=span_id_cols, |
|
) |
|
result_df = set_relation_annotation_ids( |
|
table=result_df, |
|
doc_id2doc=doc_id2doc, |
|
doc_id_col=doc_id_col, |
|
relation_id_col=relation_id_col, |
|
) |
|
|
|
base_dir = os.path.dirname(output_path) |
|
os.makedirs(base_dir, exist_ok=True) |
|
|
|
if out_ext.lower() == ".json": |
|
logger.info(f"Saving output in JSON format to {output_path} ...") |
|
result_df.to_json( |
|
path_or_buf=output_path, |
|
orient="records", |
|
lines=True, |
|
) |
|
elif save_as_brat: |
|
logger.info(f"Saving output in BRAT format to {output_path} ...") |
|
os.makedirs(output_path, exist_ok=True) |
|
for doc_id in doc_ids: |
|
|
|
shutil.copy( |
|
src=os.path.join(brat_data_dir, f"{doc_id}.txt"), |
|
dst=os.path.join(output_path, f"{doc_id}.txt"), |
|
) |
|
|
|
|
|
|
|
input_ann_path = os.path.join(brat_data_dir, f"{doc_id}.ann") |
|
with open(input_ann_path, "r") as f: |
|
ann_lines = f.readlines() |
|
|
|
|
|
|
|
doc_df = result_df[result_df[doc_id_col] == doc_id] |
|
logger.info(f"Adding {len(doc_df)} relation annotations to {doc_id}.ann ...") |
|
for idx, row in doc_df.iterrows(): |
|
head_span_id = row[span_id_cols[0]] |
|
tail_span_id = row[span_id_cols[1]] |
|
relation_id = row[relation_id_col] |
|
ann_line = ( |
|
f"{relation_id}\t{relation_type} Arg1:{head_span_id} Arg2:{tail_span_id}\n" |
|
) |
|
ann_lines.append(ann_line) |
|
|
|
output_ann_path = os.path.join(output_path, f"{doc_id}.ann") |
|
with open(output_ann_path, "w") as f: |
|
f.writelines(ann_lines) |
|
else: |
|
raise ValueError( |
|
"Output file has unexpected format. Please provide a JSON file or a directory." |
|
) |
|
|
|
logger.info("Done!") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
""" |
|
example call: |
|
python src/data/prepare_sciarg_crosssection_annotations.py |
|
// or // |
|
python src/data/prepare_sciarg_crosssection_annotations.py \ |
|
--input-path data/annotations/sciarg-cross-section/aligned_input.csv \ |
|
--output-path data/annotations/sciarg-with-abstracts-and-cross-section-rels \ |
|
--brat-data-dir data/annotations/sciarg-abstracts/v0.9.3/alisa |
|
""" |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
parser = argparse.ArgumentParser( |
|
description="Read text files in a directory and a CSV file that contains cross-section annotations. " |
|
"Transform the CSV file to a JSON format and save at a specified output directory." |
|
) |
|
parser.add_argument( |
|
"--input-path", |
|
type=str, |
|
default="data/annotations/sciarg-cross-section/aligned_input.csv", |
|
help="Locate a CSV/Excel file.", |
|
) |
|
parser.add_argument( |
|
"--output-path", |
|
type=str, |
|
default="data/annotations/sciarg-with-abstracts-and-cross-section-rels", |
|
help="Specify a path where output will be saved. Should be a JSON file or a directory for BRAT output.", |
|
) |
|
parser.add_argument( |
|
"--brat-data-dir", |
|
type=str, |
|
default="data/annotations/sciarg-abstracts/v0.9.3/alisa", |
|
help="Specify the directory where the BRAT data (base texts and annotations) is located.", |
|
) |
|
parser.add_argument( |
|
"--relation-type", |
|
type=str, |
|
default="semantically_same", |
|
help="Specify the relation type for the BRAT output.", |
|
) |
|
parser.add_argument( |
|
"--input-encoding", |
|
type=str, |
|
default="cp1252", |
|
help="Specify encoding for reading an input file.", |
|
) |
|
parser.add_argument( |
|
"--include-unsure", |
|
action="store_true", |
|
help="Include annotations marked as unsure.", |
|
) |
|
parser.add_argument( |
|
"--set-annotation-ids", |
|
action="store_true", |
|
help="Set BRAT annotation IDs for the spans and relations.", |
|
) |
|
args = parser.parse_args() |
|
kwargs = vars(args) |
|
|
|
main(**kwargs) |
|
|