#!/usr/bin/env python # -*- coding: utf-8 -*- # Created by zd302 at 08/07/2024 import gradio as gr import tqdm import torch import numpy as np from time import sleep import threading import gc import os import json import pytorch_lightning as pl from urllib.parse import urlparse from accelerate import Accelerator from transformers import BartTokenizer, BartForConditionalGeneration from transformers import BloomTokenizerFast, BloomForCausalLM, BertTokenizer, BertForSequenceClassification from transformers import RobertaTokenizer, RobertaForSequenceClassification from rank_bm25 import BM25Okapi # import bm25s # import Stemmer # optional: for stemming from html2lines import url2lines from googleapiclient.discovery import build from averitec.models.DualEncoderModule import DualEncoderModule from averitec.models.SequenceClassificationModule import SequenceClassificationModule from averitec.models.JustificationGenerationModule import JustificationGenerationModule from averitec.data.sample_claims import CLAIMS_Type # --------------------------------------------------------------------------- # load .env from utils import create_user_id user_id = create_user_id() from datetime import datetime from azure.storage.fileshare import ShareServiceClient try: from dotenv import load_dotenv load_dotenv() except Exception as e: pass account_url = os.environ["AZURE_ACCOUNT_URL"] credential = { "account_key": os.environ['AZURE_ACCOUNT_KEY'], "account_name": os.environ['AZURE_ACCOUNT_NAME'] } file_share_name = "averitec" azure_service = ShareServiceClient(account_url=account_url, credential=credential) azure_share_client = azure_service.get_share_client(file_share_name) # ---------- Setting ---------- import requests from bs4 import BeautifulSoup import wikipediaapi wiki_wiki = wikipediaapi.Wikipedia('AVeriTeC (zd302@cam.ac.uk)', 'en') import nltk nltk.download('punkt') from nltk import pos_tag, word_tokenize, sent_tokenize import spacy os.system("python -m spacy download en_core_web_sm") nlp = spacy.load("en_core_web_sm") # --------------------------------------------------------------------------- # Load sample dict for AVeriTeC search # all_samples_dict = json.load(open('averitec/data/all_samples.json', 'r')) # --------------------------------------------------------------------------- # ---------- Load pretrained models ---------- # ---------- load Evidence retrieval model ---------- # from drqa import retriever # db_class = retriever.get_class('sqlite') # doc_db = db_class("averitec/data/wikipedia_dumps/enwiki.db") # ranker = retriever.get_class('tfidf')(tfidf_path="averitec/data/wikipedia_dumps/enwiki-tfidf-with-id-title.npz") # ---------- Load Veracity and Justification prediction model ---------- print("Loading models ...") LABEL = [ "Supported", "Refuted", "Not Enough Evidence", "Conflicting Evidence/Cherrypicking", ] # Veracity device = "cuda:0" if torch.cuda.is_available() else "cpu" veracity_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") bert_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=4, problem_type="single_label_classification") veracity_model = SequenceClassificationModule.load_from_checkpoint("averitec/pretrained_models/bert_veracity.ckpt", tokenizer=veracity_tokenizer, model=bert_model).to(device) # Justification justification_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large', add_prefix_space=True) bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large") best_checkpoint = 'averitec/pretrained_models/bart_justifications_verdict-epoch=13-val_loss=2.03-val_meteor=0.28.ckpt' justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model).to(device) # --------------------------------------------------------------------------- # Set up Gradio Theme theme = gr.themes.Base( primary_hue="blue", secondary_hue="red", font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"], ) # ---------- Setting ---------- class Docs: def __init__(self, metadata=dict(), page_content=""): self.metadata = metadata self.page_content = page_content def make_html_source(source, i): meta = source.metadata content = source.page_content.strip() card = f"""

Doc {i} - URL: {meta['url']}

{content}

""" return card # ----- veracity_prediction ----- class SequenceClassificationDataLoader(pl.LightningDataModule): def __init__(self, tokenizer, data_file, batch_size, add_extra_nee=False): super().__init__() self.tokenizer = tokenizer self.data_file = data_file self.batch_size = batch_size self.add_extra_nee = add_extra_nee def tokenize_strings( self, source_sentences, max_length=400, pad_to_max_length=False, return_tensors="pt", ): encoded_dict = self.tokenizer( source_sentences, max_length=max_length, padding="max_length" if pad_to_max_length else "longest", truncation=True, return_tensors=return_tensors, ) input_ids = encoded_dict["input_ids"] attention_masks = encoded_dict["attention_mask"] return input_ids, attention_masks def quadruple_to_string(self, claim, question, answer, bool_explanation=""): if bool_explanation is not None and len(bool_explanation) > 0: bool_explanation = ", because " + bool_explanation.lower().strip() else: bool_explanation = "" return ( "[CLAIM] " + claim.strip() + " [QUESTION] " + question.strip() + " " + answer.strip() + bool_explanation ) def averitec_veracity_prediction(claim, qa_evidence): bert_model_name = "bert-base-uncased" tokenizer = BertTokenizer.from_pretrained(bert_model_name) bert_model = BertForSequenceClassification.from_pretrained(bert_model_name, num_labels=4, problem_type="single_label_classification") device = "cuda:0" if torch.cuda.is_available() else "cpu" trained_model = SequenceClassificationModule.load_from_checkpoint("averitec/pretrained_models/bert_veracity.ckpt", tokenizer=tokenizer, model=bert_model).to(device) dataLoader = SequenceClassificationDataLoader( tokenizer=tokenizer, data_file="this_is_discontinued", batch_size=32, add_extra_nee=False, ) evidence_strings = [] for evidence in qa_evidence: evidence_strings.append( dataLoader.quadruple_to_string(claim, evidence.metadata["query"], evidence.metadata["answer"], "")) if len(evidence_strings) == 0: # If we found no evidence e.g. because google returned 0 pages, just output NEI. pred_label = "Not Enough Evidence" return pred_label tokenized_strings, attention_mask = dataLoader.tokenize_strings(evidence_strings) example_support = torch.argmax( trained_model(tokenized_strings.to(device), attention_mask=attention_mask.to(device)).logits, axis=1) has_unanswerable = False has_true = False has_false = False for v in example_support: if v == 0: has_true = True if v == 1: has_false = True if v in (2, 3,): # TODO another hack -- we cant have different labels for train and test so we do this has_unanswerable = True if has_unanswerable: answer = 2 elif has_true and not has_false: answer = 0 elif not has_true and has_false: answer = 1 else: answer = 3 pred_label = LABEL[answer] return pred_label def fever_veracity_prediction(claim, evidence): tokenizer = RobertaTokenizer.from_pretrained('Dzeniks/roberta-fact-check') model = RobertaForSequenceClassification.from_pretrained('Dzeniks/roberta-fact-check') evidence_string = "" for evi in evidence: evidence_string += evi.metadata['title'] + evi.metadata['evidence'] + ' ' input_sequence = tokenizer.encode_plus(claim, evidence_string, return_tensors="pt") with torch.no_grad(): prediction = model(**input_sequence) label = torch.argmax(prediction[0]).item() pred_label = LABEL[label] return pred_label def veracity_prediction(claim, qa_evidence): # bert_model_name = "bert-base-uncased" # tokenizer = BertTokenizer.from_pretrained(bert_model_name) # bert_model = BertForSequenceClassification.from_pretrained(bert_model_name, num_labels=4, # problem_type="single_label_classification") # # device = "cuda:0" if torch.cuda.is_available() else "cpu" # trained_model = SequenceClassificationModule.load_from_checkpoint("averitec/pretrained_models/bert_veracity.ckpt", # tokenizer=tokenizer, model=bert_model).to(device) dataLoader = SequenceClassificationDataLoader( tokenizer=veracity_tokenizer, data_file="this_is_discontinued", batch_size=32, add_extra_nee=False, ) evidence_strings = [] for evidence in qa_evidence: evidence_strings.append( dataLoader.quadruple_to_string(claim, evidence.metadata["query"], evidence.metadata["answer"], "")) if len(evidence_strings) == 0: # If we found no evidence e.g. because google returned 0 pages, just output NEI. pred_label = "Not Enough Evidence" return pred_label tokenized_strings, attention_mask = dataLoader.tokenize_strings(evidence_strings) example_support = torch.argmax(veracity_model(tokenized_strings.to(device), attention_mask=attention_mask.to(device)).logits, axis=1) has_unanswerable = False has_true = False has_false = False for v in example_support: if v == 0: has_true = True if v == 1: has_false = True if v in (2, 3,): # TODO another hack -- we cant have different labels for train and test so we do this has_unanswerable = True if has_unanswerable: answer = 2 elif has_true and not has_false: answer = 0 elif not has_true and has_false: answer = 1 else: answer = 3 pred_label = LABEL[answer] return pred_label def extract_claim_str(claim, qa_evidence, verdict_label): claim_str = "[CLAIM] " + claim + " [EVIDENCE] " for evidence in qa_evidence: q_text = evidence.metadata['query'].strip() if len(q_text) == 0: continue if not q_text[-1] == "?": q_text += "?" answer_strings = [] answer_strings.append(evidence.metadata['answer']) claim_str += q_text for a_text in answer_strings: if a_text: if not a_text[-1] == ".": a_text += "." claim_str += " " + a_text.strip() claim_str += " " claim_str += " [VERDICT] " + verdict_label return claim_str def averitec_justification_generation(claim, qa_evidence, verdict_label): # claim_str = extract_claim_str(claim, qa_evidence, verdict_label) claim_str.strip() device = "cuda:0" if torch.cuda.is_available() else "cpu" tokenizer = BartTokenizer.from_pretrained('facebook/bart-large', add_prefix_space=True) bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large") best_checkpoint = 'averitec/pretrained_models/bart_justifications_verdict-epoch=13-val_loss=2.03-val_meteor=0.28.ckpt' trained_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=tokenizer, model=bart_model).to(device) pred_justification = trained_model.generate(claim_str, device=device) return pred_justification.strip() def justification_generation(claim, qa_evidence, verdict_label): # claim_str = extract_claim_str(claim, qa_evidence, verdict_label) claim_str.strip() # device = "cuda:0" if torch.cuda.is_available() else "cpu" # tokenizer = BartTokenizer.from_pretrained('facebook/bart-large', add_prefix_space=True) # bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large") # # best_checkpoint = 'averitec/pretrained_models/bart_justifications_verdict-epoch=13-val_loss=2.03-val_meteor=0.28.ckpt' # trained_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=tokenizer, # model=bart_model).to(device) pred_justification = justification_model.generate(claim_str, device=device) return pred_justification.strip() def QAprediction(claim, evidence, sources): parts = [] # evidence_title = f"""
Retrieved Evidence:
""" for i, evi in enumerate(evidence, 1): part = f"""Doc {i}""" subpart = f"""{i}""" # subpart = f"""{i}""" subparts = "".join([part, subpart]) parts.append(subparts) evidence_part = ", ".join(parts) prediction_title = f"""
Prediction:
""" # if 'Google' in sources or 'AVeriTeC' in sources: # verdict_label = averitec_veracity_prediction(claim, evidence) # justification_label = averitec_justification_generation(claim, evidence, verdict_label) # # justification_label = "See retrieved docs." # justification_part = f"""Justification: {justification_label}""" # if 'WikiPedia' in sources: # # verdict_label = fever_veracity_prediction(claim, evidence) # justification_label = averitec_justification_generation(claim, evidence, verdict_label) # # justification_label = "See retrieved docs." # justification_part = f"""Justification: {justification_label}""" verdict_label = veracity_prediction(claim, evidence) justification_label = justification_generation(claim, evidence, verdict_label) # justification_label = "See retrieved docs." justification_part = f"""Justification: {justification_label}""" verdict_part = f"""Verdict: {verdict_label}.
""" content_parts = "".join([evidence_title, evidence_part, prediction_title, verdict_part, justification_part]) # content_parts = "".join([evidence_title, evidence_part, verdict_title, verdict_part, justification_title, justification_part]) return content_parts, [verdict_label, justification_label] # ----------GoogleAPIretriever--------- def generate_reference_corpus(reference_file): with open(reference_file) as f: j = json.load(f) train_examples = j all_data_corpus = [] tokenized_corpus = [] for train_example in train_examples: train_claim = train_example["claim"] speaker = train_example["speaker"].strip() if train_example["speaker"] is not None and len( train_example["speaker"]) > 1 else "they" questions = [q["question"] for q in train_example["questions"]] claim_dict_builder = {} claim_dict_builder["claim"] = train_claim claim_dict_builder["speaker"] = speaker claim_dict_builder["questions"] = questions tokenized_corpus.append(nltk.word_tokenize(claim_dict_builder["claim"])) all_data_corpus.append(claim_dict_builder) return tokenized_corpus, all_data_corpus def doc2prompt(doc): prompt_parts = "Outrageously, " + doc["speaker"] + " claimed that \"" + doc[ "claim"].strip() + "\". Criticism includes questions like: " questions = [q.strip() for q in doc["questions"]] return prompt_parts + " ".join(questions) def docs2prompt(top_docs): return "\n\n".join([doc2prompt(d) for d in top_docs]) def prompt_question_generation(test_claim, speaker="they", topk=10): # reference_file = "averitec_code/data/train.json" tokenized_corpus, all_data_corpus = generate_reference_corpus(reference_file) bm25 = BM25Okapi(tokenized_corpus) # Define the bloom model: accelerator = Accelerator() accel_device = accelerator.device device = "cuda:0" if torch.cuda.is_available() else "cpu" tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-7b1") model = BloomForCausalLM.from_pretrained("bigscience/bloom-7b1", torch_dtype=torch.bfloat16).to(device) # -------------------------------------------------- # test claim s = bm25.get_scores(nltk.word_tokenize(test_claim)) top_n = np.argsort(s)[::-1][:topk] docs = [all_data_corpus[i] for i in top_n] # -------------------------------------------------- prompt = docs2prompt(docs) + "\n\n" + "Outrageously, " + speaker + " claimed that \"" + test_claim.strip() + \ "\". Criticism includes questions like: " sentences = [prompt] inputs = tokenizer(sentences, padding=True, return_tensors="pt").to(device) outputs = model.generate(inputs["input_ids"], max_length=2000, num_beams=2, no_repeat_ngram_size=2, early_stopping=True) tgt_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] in_len = len(sentences[0]) questions_str = tgt_text[in_len:].split("\n")[0] qs = questions_str.split("?") qs = [q.strip() + "?" for q in qs if q.strip() and len(q.strip()) < 300] # generate_question = [{"question": q, "answers": []} for q in qs] return generate_question def check_claim_date(check_date): try: year, month, date = check_date.split("-") except: month, date, year = "01", "01", "2022" if len(year) == 2 and int(year) <= 30: year = "20" + year elif len(year) == 2: year = "19" + year elif len(year) == 1: year = "200" + year if len(month) == 1: month = "0" + month if len(date) == 1: date = "0" + date sort_date = year + month + date return sort_date def string_to_search_query(text, author): parts = word_tokenize(text.strip()) tags = pos_tag(parts) keep_tags = ["CD", "JJ", "NN", "VB"] if author is not None: search_string = author.split() else: search_string = [] for token, tag in zip(parts, tags): for keep_tag in keep_tags: if tag[1].startswith(keep_tag): search_string.append(token) search_string = " ".join(search_string) return search_string def google_search(search_term, api_key, cse_id, **kwargs): service = build("customsearch", "v1", developerKey=api_key) res = service.cse().list(q=search_term, cx=cse_id, **kwargs).execute() if "items" in res: return res['items'] else: return [] def get_domain_name(url): if '://' not in url: url = 'http://' + url domain = urlparse(url).netloc if domain.startswith("www."): return domain[4:] else: return domain def get_and_store(url_link, fp, worker, worker_stack): page_lines = url2lines(url_link) with open(fp, "w") as out_f: print("\n".join([url_link] + page_lines), file=out_f) worker_stack.append(worker) gc.collect() def get_google_search_results(api_key, search_engine_id, google_search, sort_date, search_string, page=0): search_results = [] for i in range(3): try: search_results += google_search( search_string, api_key, search_engine_id, num=10, start=0 + 10 * page, sort="date:r:19000101:" + sort_date, dateRestrict=None, gl="US" ) break except: sleep(3) return search_results def averitec_search(claim, generate_question, speaker="they", check_date="2024-01-01", n_pages=1): # n_pages=3 # default config api_key = os.environ["GOOGLE_API_KEY"] search_engine_id = os.environ["GOOGLE_SEARCH_ENGINE_ID"] blacklist = [ "jstor.org", # Blacklisted because their pdfs are not labelled as such, and clog up the download "facebook.com", # Blacklisted because only post titles can be scraped, but the scraper doesn't know this, "ftp.cs.princeton.edu", # Blacklisted because it hosts many large NLP corpora that keep showing up "nlp.cs.princeton.edu", "huggingface.co" ] blacklist_files = [ # Blacklisted some NLP nonsense that crashes my machine with OOM errors "/glove.", "ftp://ftp.cs.princeton.edu/pub/cs226/autocomplete/words-333333.txt", "https://web.mit.edu/adamrose/Public/googlelist", ] # save to folder store_folder = "averitec_code/store/retrieved_docs" # index = 0 questions = [q["question"] for q in generate_question] # check the date of the claim sort_date = check_claim_date(check_date) # check_date="2022-01-01" # search_strings = [] search_types = [] search_string_2 = string_to_search_query(claim, None) search_strings += [search_string_2, claim, ] search_types += ["claim", "claim-noformat", ] search_strings += questions search_types += ["question" for _ in questions] # start to search search_results = [] visited = {} store_counter = 0 worker_stack = list(range(10)) retrieve_evidence = [] for this_search_string, this_search_type in zip(search_strings, search_types): for page_num in range(n_pages): search_results = get_google_search_results(api_key, search_engine_id, google_search, sort_date, this_search_string, page=page_num) for result in search_results: link = str(result["link"]) domain = get_domain_name(link) if domain in blacklist: continue broken = False for b_file in blacklist_files: if b_file in link: broken = True if broken: continue if link.endswith(".pdf") or link.endswith(".doc"): continue store_file_path = "" if link in visited: store_file_path = visited[link] else: store_counter += 1 store_file_path = store_folder + "/search_result_" + str(index) + "_" + str( store_counter) + ".store" visited[link] = store_file_path while len(worker_stack) == 0: # Wait for a wrrker to become available. Check every second. sleep(1) worker = worker_stack.pop() t = threading.Thread(target=get_and_store, args=(link, store_file_path, worker, worker_stack)) t.start() line = [str(index), claim, link, str(page_num), this_search_string, this_search_type, store_file_path] retrieve_evidence.append(line) return retrieve_evidence def claim2prompts(example): claim = example["claim"] # claim_str = "Claim: " + claim + "||Evidence: " claim_str = "Evidence: " for question in example["questions"]: q_text = question["question"].strip() if len(q_text) == 0: continue if not q_text[-1] == "?": q_text += "?" answer_strings = [] for a in question["answers"]: if a["answer_type"] in ["Extractive", "Abstractive"]: answer_strings.append(a["answer"]) if a["answer_type"] == "Boolean": answer_strings.append(a["answer"] + ", because " + a["boolean_explanation"].lower().strip()) for a_text in answer_strings: if not a_text[-1] in [".", "!", ":", "?"]: a_text += "." # prompt_lookup_str = claim + " " + a_text prompt_lookup_str = a_text this_q_claim_str = claim_str + " " + a_text.strip() + "||Question answered: " + q_text yield (prompt_lookup_str, this_q_claim_str.replace("\n", " ").replace("||", "\n")) def generate_step2_reference_corpus(reference_file): with open(reference_file) as f: train_examples = json.load(f) prompt_corpus = [] tokenized_corpus = [] for example in train_examples: for lookup_str, prompt in claim2prompts(example): entry = nltk.word_tokenize(lookup_str) tokenized_corpus.append(entry) prompt_corpus.append(prompt) return tokenized_corpus, prompt_corpus def decorate_with_questions(claim, retrieve_evidence, top_k=10): # top_k=100 # reference_file = "averitec_code/data/train.json" tokenized_corpus, prompt_corpus = generate_step2_reference_corpus(reference_file) prompt_bm25 = BM25Okapi(tokenized_corpus) # Define the bloom model: accelerator = Accelerator() accel_device = accelerator.device device = "cuda:0" if torch.cuda.is_available() else "cpu" tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-7b1") model = BloomForCausalLM.from_pretrained( "bigscience/bloom-7b1", device_map="auto", torch_dtype=torch.bfloat16, offload_folder="./offload" ) # tokenized_corpus = [] all_data_corpus = [] for retri_evi in tqdm.tqdm(retrieve_evidence): store_file = retri_evi[-1] with open(store_file, 'r') as f: first = True for line in f: line = line.strip() if first: first = False location_url = line continue if len(line) > 3: entry = nltk.word_tokenize(line) if (location_url, line) not in all_data_corpus: tokenized_corpus.append(entry) all_data_corpus.append((location_url, line)) if len(tokenized_corpus) == 0: print("") bm25 = BM25Okapi(tokenized_corpus) s = bm25.get_scores(nltk.word_tokenize(claim)) top_n = np.argsort(s)[::-1][:top_k] docs = [all_data_corpus[i] for i in top_n] generate_qa_pairs = [] # Then, generate questions for those top 50: for doc in tqdm.tqdm(docs): # prompt_lookup_str = example["claim"] + " " + doc[1] prompt_lookup_str = doc[1] prompt_s = prompt_bm25.get_scores(nltk.word_tokenize(prompt_lookup_str)) prompt_n = 10 prompt_top_n = np.argsort(prompt_s)[::-1][:prompt_n] prompt_docs = [prompt_corpus[i] for i in prompt_top_n] claim_prompt = "Evidence: " + doc[1].replace("\n", " ") + "\nQuestion answered: " prompt = "\n\n".join(prompt_docs + [claim_prompt]) sentences = [prompt] inputs = tokenizer(sentences, padding=True, return_tensors="pt").to(device) outputs = model.generate(inputs["input_ids"], max_length=5000, num_beams=2, no_repeat_ngram_size=2, early_stopping=True) tgt_text = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:], skip_special_tokens=True)[0] # We are not allowed to generate more than 250 characters: tgt_text = tgt_text[:250] qa_pair = [tgt_text.strip().split("?")[0].replace("\n", " ") + "?", doc[1].replace("\n", " "), doc[0]] generate_qa_pairs.append(qa_pair) return generate_qa_pairs def triple_to_string(x): return " ".join([item.strip() for item in x]) def rerank_questions(claim, bm25_qas, topk=3): # tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") bert_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2, problem_type="single_label_classification") # Must specify single_label for some reason best_checkpoint = "averitec_code/pretrained_models/bert_dual_encoder.ckpt" device = "cuda:0" if torch.cuda.is_available() else "cpu" trained_model = DualEncoderModule.load_from_checkpoint(best_checkpoint, tokenizer=tokenizer, model=bert_model).to( device) # strs_to_score = [] values = [] for question, answer, source in bm25_qas: str_to_score = triple_to_string([claim, question, answer]) strs_to_score.append(str_to_score) values.append([question, answer, source]) if len(bm25_qas) > 0: encoded_dict = tokenizer(strs_to_score, max_length=512, padding="longest", truncation=True, return_tensors="pt").to(device) input_ids = encoded_dict['input_ids'] attention_masks = encoded_dict['attention_mask'] scores = torch.softmax(trained_model(input_ids, attention_mask=attention_masks).logits, axis=-1)[:, 1] top_n = torch.argsort(scores, descending=True)[:topk] pass_through = [{"question": values[i][0], "answers": values[i][1], "source_url": values[i][2]} for i in top_n] else: pass_through = [] top3_qa_pairs = pass_through return top3_qa_pairs def GoogleAPIretriever(query): # ----- Generate QA pairs using AVeriTeC top3_qa_pairs_path = "averitec_code/top3_qa_pairs1.json" if not os.path.exists(top3_qa_pairs_path): # step 1: generate questions for the query/claim using Bloom generate_question = prompt_question_generation(query) # step 2: retrieve evidence for the generated questions using Google API retrieve_evidence = averitec_search(query, generate_question) # step 3: generate QA pairs for each retrieved document bm25_qa_pairs = decorate_with_questions(query, retrieve_evidence) # step 4: rerank QA pairs top3_qa_pairs = rerank_questions(query, bm25_qa_pairs) else: top3_qa_pairs = json.load(open(top3_qa_pairs_path, 'r')) # Add score to metadata results = [] for i, qa in enumerate(top3_qa_pairs): metadata = dict() metadata['name'] = qa['question'] metadata['url'] = qa['source_url'] metadata['cached_source_url'] = qa['source_url'] metadata['short_name'] = "Evidence {}".format(i + 1) metadata['page_number'] = "" metadata['query'] = qa['question'] metadata['answer'] = qa['answers'] metadata['page_content'] = "Question: " + qa['question'] + "
" + "Answer: " + qa['answers'] page_content = f"""{metadata['page_content']}""" results.append((metadata, page_content)) return results # ----------GoogleAPIretriever--------- # ----------Wikipediaretriever--------- def bm25_retriever(query, corpus, topk=3): bm25 = BM25Okapi(corpus) # query_tokens = word_tokenize(query) scores = bm25.get_scores(query_tokens) top_n = np.argsort(scores)[::-1][:topk] top_n_scores = [scores[i] for i in top_n] return top_n, top_n_scores def bm25s_retriever(query, corpus, topk=3): # optional: create a stemmer stemmer = Stemmer.Stemmer("english") # Tokenize the corpus and only keep the ids (faster and saves memory) corpus_tokens = bm25s.tokenize(corpus, stopwords="en", stemmer=stemmer) # Create the BM25 model and index the corpus retriever = bm25s.BM25() retriever.index(corpus_tokens) # Query the corpus query_tokens = bm25s.tokenize(query, stemmer=stemmer) # Get top-k results as a tuple of (doc ids, scores). Both are arrays of shape (n_queries, k) results, scores = retriever.retrieve(query_tokens, corpus=corpus, k=topk) top_n = [corpus.index(res) for res in results[0]] return top_n, scores def find_evidence_from_wikipedia_dumps(claim): # doc = nlp(claim) entities_in_claim = [str(ent).lower() for ent in doc.ents] title2id = ranker.doc_dict[0] wiki_intro, ent_list = [], [] for ent in entities_in_claim: if ent in title2id.keys(): ids = title2id[ent] introduction = doc_db.get_doc_intro(ids) wiki_intro.append([ent, introduction]) # fulltext = doc_db.get_doc_text(ids) # evidence.append([ent, fulltext]) ent_list.append(ent) if len(wiki_intro) < 5: evidence_tfidf = process_topk(claim, title2id, ent_list, k=5) wiki_intro.extend(evidence_tfidf) return wiki_intro, doc def relevant_sentence_retrieval(query, wiki_intro, k): # 1. Create corpus here corpus, sentences = [], [] titles = [] for i, (title, intro) in enumerate(wiki_intro): sents_in_intro = sent_tokenize(intro) for sent in sents_in_intro: corpus.append(word_tokenize(sent)) sentences.append(sent) titles.append(title) # # ----- BM25 bm25_top_n, bm25_top_n_scores = bm25_retriever(query, corpus, topk=k) bm25_top_n_sents = [sentences[i] for i in bm25_top_n] bm25_top_n_titles = [titles[i] for i in bm25_top_n] # ----- BM25s # bm25s_top_n, bm25s_top_n_scores = bm25s_retriever(query, sentences, topk=k) # corpus->sentences # bm25s_top_n_sents = [sentences[i] for i in bm25s_top_n] # bm25s_top_n_titles = [titles[i] for i in bm25s_top_n] return bm25_top_n_sents, bm25_top_n_titles def process_topk(query, title2id, ent_list, k=1): doc_names, doc_scores = ranker.closest_docs(query, k) evidence_tfidf = [] for _name in doc_names: if _name not in ent_list and len(ent_list) < 5: ent_list.append(_name) idx = title2id[_name] introduction = doc_db.get_doc_intro(idx) evidence_tfidf.append([_name, introduction]) # fulltext = doc_db.get_doc_text(idx) # evidence_tfidf.append([_name,fulltext]) return evidence_tfidf def WikipediaDumpsretriever(claim): # # 1. extract relevant wikipedia pages from wikipedia dumps wiki_intro, doc = find_evidence_from_wikipedia_dumps(claim) # wiki_intro = [['trump', "'''Trump''' most commonly refers to:\n* Donald Trump (born 1946), President of the United States from 2017 to 2021 \n* Trump (card games), any playing card given an ad-hoc high rank\n\n'''Trump''' may also refer to:"]] # 2. extract relevant sentences from extracted wikipedia pages sents, titles = relevant_sentence_retrieval(claim, wiki_intro, k=3) # results = [] for i, (sent, title) in enumerate(zip(sents, titles)): metadata = dict() metadata['name'] = claim metadata['url'] = "https://en.wikipedia.org/wiki/" + "_".join(title.split()) metadata['cached_source_url'] = "https://en.wikipedia.org/wiki/" + "_".join(title.split()) metadata['short_name'] = "Evidence {}".format(i + 1) metadata['page_number'] = "" metadata['query'] = sent metadata['title'] = title metadata['evidence'] = sent metadata['answer'] = "" metadata['page_content'] = "Title: " + str(metadata['title']) + "
" + "Evidence: " + metadata[ 'evidence'] page_content = f"""{metadata['page_content']}""" results.append(Docs(metadata, page_content)) return results # ----------WikipediaAPIretriever--------- def clean_str(p): return p.encode().decode("unicode-escape").encode("latin1").decode("utf-8") def get_page_obs(page): # find all paragraphs paragraphs = page.split("\n") paragraphs = [p.strip() for p in paragraphs if p.strip()] # # find all sentence # sentences = [] # for p in paragraphs: # sentences += p.split('. ') # sentences = [s.strip() + '.' for s in sentences if s.strip()] # # return ' '.join(sentences[:5]) # return ' '.join(sentences) return ' '.join(paragraphs[:5]) def search_entity_wikipeida(entity): find_evidence = [] page_py = wiki_wiki.page(entity) if page_py.exists(): introduction = page_py.summary find_evidence.append([str(entity), introduction]) return find_evidence def search_step(entity): ent_ = entity.replace(" ", "+") search_url = f"https://en.wikipedia.org/w/index.php?search={ent_}" response_text = requests.get(search_url).text soup = BeautifulSoup(response_text, features="html.parser") result_divs = soup.find_all("div", {"class": "mw-search-result-heading"}) find_evidence = [] if result_divs: # mismatch # If the wikipeida page of the entity is not exist, find similar wikipedia pages. result_titles = [clean_str(div.get_text().strip()) for div in result_divs] similar_titles = result_titles[:5] for _t in similar_titles: if len(find_evidence) < 5: _evi = search_step(_t) find_evidence.extend(_evi) else: page = [p.get_text().strip() for p in soup.find_all("p") + soup.find_all("ul")] if any("may refer to:" in p for p in page): _evi = search_step("[" + entity + "]") find_evidence.extend(_evi) else: # page_py = wiki_wiki.page(entity) # # if page_py.exists(): # introduction = page_py.summary # else: page_text = "" for p in page: if len(p.split(" ")) > 2: page_text += clean_str(p) if not p.endswith("\n"): page_text += "\n" introduction = get_page_obs(page_text) find_evidence.append([entity, introduction]) return find_evidence def find_similar_wikipedia(entity, relevant_wikipages): # If the relevant wikipeida page of the entity is less than 5, find similar wikipedia pages. ent_ = entity.replace(" ", "+") search_url = f"https://en.wikipedia.org/w/index.php?search={ent_}&title=Special:Search&profile=advanced&fulltext=1&ns0=1" response_text = requests.get(search_url).text soup = BeautifulSoup(response_text, features="html.parser") result_divs = soup.find_all("div", {"class": "mw-search-result-heading"}) if result_divs: result_titles = [clean_str(div.get_text().strip()) for div in result_divs] similar_titles = result_titles[:5] saved_titles = [ent[0] for ent in relevant_wikipages] if relevant_wikipages else relevant_wikipages for _t in similar_titles: if _t not in saved_titles and len(relevant_wikipages) < 5: _evi = search_entity_wikipeida(_t) # _evi = search_step(_t) relevant_wikipages.extend(_evi) return relevant_wikipages def find_evidence_from_wikipedia(claim): # doc = nlp(claim) # wikipedia_page = [] for ent in doc.ents: relevant_wikipages = search_entity_wikipeida(ent) if len(relevant_wikipages) < 5: relevant_wikipages = find_similar_wikipedia(str(ent), relevant_wikipages) wikipedia_page.extend(relevant_wikipages) return wikipedia_page def relevant_wikipedia_API_retriever(claim): # doc = nlp(claim) wiki_intro = [] for ent in doc.ents: page_py = wiki_wiki.page(ent) if page_py.exists(): introduction = page_py.summary else: introduction = "No documents found." wiki_intro.append([str(ent), introduction]) return wiki_intro, doc def Wikipediaretriever(claim, sources): # # 1. extract relevant wikipedia pages from wikipedia dumps if "Dump" in sources: wikipedia_page = find_evidence_from_wikipedia_dumps(claim) else: wikipedia_page = find_evidence_from_wikipedia(claim) # wiki_intro, doc = relevant_wikipedia_API_retriever(claim) # 2. extract relevant sentences from extracted wikipedia pages sents, titles = relevant_sentence_retrieval(claim, wikipedia_page, k=3) # results = [] for i, (sent, title) in enumerate(zip(sents, titles)): metadata = dict() metadata['name'] = claim metadata['url'] = "https://en.wikipedia.org/wiki/" + "_".join(title.split()) metadata['cached_source_url'] = "https://en.wikipedia.org/wiki/" + "_".join(title) metadata['short_name'] = "Evidence {}".format(i + 1) metadata['page_number'] = "" metadata['query'] = sent metadata['title'] = title metadata['evidence'] = sent metadata['answer'] = "" metadata['page_content'] = "Title: " + str(metadata['title']) + "
" + "Evidence: " + metadata['evidence'] page_content = f"""{metadata['page_content']}""" results.append(Docs(metadata, page_content)) return results def log_on_azure(file, logs, azure_share_client): logs = json.dumps(logs) file_client = azure_share_client.get_file_client(file) file_client.upload_file(logs) def chat(claim, history, sources): evidence = [] # if 'Google' in sources: # evidence = GoogleAPIretriever(query) # if 'WikiPediaDumps' in sources: # evidence = WikipediaDumpsretriever(query) if 'WikiPedia' in sources: evidence = Wikipediaretriever(claim, sources) answer_set, answer_output = QAprediction(claim, evidence, sources) docs_html = "" if len(evidence) > 0: docs_html = [] for i, evi in enumerate(evidence, 1): docs_html.append(make_html_source(evi, i)) docs_html = "".join(docs_html) else: print("No documents found") url_of_evidence = "" output_language = "English" output_query = claim history[-1] = (claim, answer_set) history = [tuple(x) for x in history] ############################################################ evi_list = [] for evi in evidence: title_str = evi.metadata['title'] evi_str = evi.metadata['evidence'] evi_list.append([title_str, evi_str]) try: # Log answer on Azure Blob Storage # IF AZURE_ISSAVE=TRUE, save the logs into the Azure share client. if bool(os.environ["AZURE_ISSAVE"]): timestamp = str(datetime.now().timestamp()) # timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") file = timestamp + ".json" logs = { "user_id": str(user_id), "claim": claim, "sources": sources, "evidence": evi_list, "url": url_of_evidence, "answer": answer_output, "time": timestamp, } log_on_azure(file, logs, azure_share_client) except Exception as e: print(f"Error logging on Azure Blob Storage: {e}") raise gr.Error( f"AVeriTeC Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)") ########## return history, docs_html, output_query, output_language def main(): init_prompt = """ Hello, I am a fact-checking assistant designed to help you find appropriate evidence to predict the veracity of claims. What do you want to fact-check? """ with gr.Blocks(title="AVeriTeC fact-checker", css="style.css", theme=theme, elem_id="main-component") as demo: with gr.Tab("AVeriTeC"): with gr.Row(elem_id="chatbot-row"): with gr.Column(scale=2): chatbot = gr.Chatbot( value=[(None, init_prompt)], show_copy_button=True, show_label=False, elem_id="chatbot", layout="panel", avatar_images=(None, "assets/averitec.png") ) # avatar_images=(None, "https://i.ibb.co/YNyd5W2/logo4.png"), with gr.Row(elem_id="input-message"): textbox = gr.Textbox(placeholder="Ask me what claim do you want to check!", show_label=False, scale=7, lines=1, interactive=True, elem_id="input-textbox") # submit = gr.Button("",elem_id = "submit-button",scale = 1,interactive = True,icon = "https://static-00.iconduck.com/assets.00/settings-icon-2048x2046-cw28eevx.png") with gr.Column(scale=1, variant="panel", elem_id="right-panel"): with gr.Tabs() as tabs: with gr.TabItem("Examples", elem_id="tab-examples", id=0): examples_hidden = gr.Textbox(visible=False) first_key = list(CLAIMS_Type.keys())[0] dropdown_samples = gr.Dropdown(CLAIMS_Type.keys(), value=first_key, interactive=True, show_label=True, label="Select claim type", elem_id="dropdown-samples") samples = [] for i, key in enumerate(CLAIMS_Type.keys()): examples_visible = True if i == 0 else False with gr.Row(visible=examples_visible) as group_examples: examples_questions = gr.Examples( CLAIMS_Type[key], [examples_hidden], examples_per_page=8, run_on_click=False, elem_id=f"examples{i}", api_name=f"examples{i}", # label = "Click on the example question or enter your own", # cache_examples=True, ) samples.append(group_examples) with gr.Tab("Sources", elem_id="tab-citations", id=1): sources_textbox = gr.HTML(show_label=False, elem_id="sources-textbox") docs_textbox = gr.State("") with gr.Tab("Configuration", elem_id="tab-config", id=2): gr.Markdown("Reminder: We currently only support fact-checking in English!") # dropdown_sources = gr.Radio( # ["AVeriTeC", "WikiPediaDumps", "Google", "WikiPediaAPI"], # label="Select source", # value="WikiPediaAPI", # interactive=True, # ) dropdown_sources = gr.Radio( ["Google", "WikiPedia"], label="Select source", value="WikiPedia", interactive=True, ) dropdown_retriever = gr.Dropdown( ["BM25", "BM25s"], label="Select evidence retriever", multiselect=False, value="BM25", interactive=True, ) output_query = gr.Textbox(label="Query used for retrieval", show_label=True, elem_id="reformulated-query", lines=2, interactive=False) output_language = gr.Textbox(label="Language", show_label=True, elem_id="language", lines=1, interactive=False) with gr.Tab("About", elem_classes="max-height other-tabs"): with gr.Row(): with gr.Column(scale=1): gr.Markdown("See more info at [https://fever.ai/task.html](https://fever.ai/task.html)") def start_chat(query, history): history = history + [(query, None)] history = [tuple(x) for x in history] return (gr.update(interactive=False), gr.update(selected=1), history) def finish_chat(): return (gr.update(interactive=True, value="")) (textbox .submit(start_chat, [textbox, chatbot], [textbox, tabs, chatbot], queue=False, api_name="start_chat_textbox") .then(chat, [textbox, chatbot, dropdown_sources], [chatbot, sources_textbox, output_query, output_language], concurrency_limit=8, api_name="chat_textbox") .then(finish_chat, None, [textbox], api_name="finish_chat_textbox") ) (examples_hidden .change(start_chat, [examples_hidden, chatbot], [textbox, tabs, chatbot], queue=False, api_name="start_chat_examples") .then(chat, [examples_hidden, chatbot, dropdown_sources], [chatbot, sources_textbox, output_query, output_language], concurrency_limit=8, api_name="chat_examples") .then(finish_chat, None, [textbox], api_name="finish_chat_examples") ) def change_sample_questions(key): index = list(CLAIMS_Type.keys()).index(key) visible_bools = [False] * len(samples) visible_bools[index] = True return [gr.update(visible=visible_bools[i]) for i in range(len(samples))] dropdown_samples.change(change_sample_questions, dropdown_samples, samples) demo.queue() demo.launch(share=True) if __name__ == "__main__": main()