OpenFactCheck-Prerelease / src /openfactcheck /solvers /tutorial /search_engine_evidence_retriever.py
Hasan Iqbal
Made the imports more user friendly
ec53a03 unverified
raw
history blame contribute delete
8.91 kB
import time
import json
import openai
from typing import List, Dict, Any
from .utils.prompt_base import QGEN_PROMPT
from .utils.api import chatgpt, search_google, search_bing
from .utils.web_util import scrape_url, select_doc_by_keyword_coverage, select_passages_by_semantic_similarity
from openfactcheck import FactCheckerState, StandardTaskSolver, Solver
@Solver.register("search_engine_evidence_retriever", "claims", "evidences")
class SearchEngineEvidenceRetriever(StandardTaskSolver):
def __init__(self, args):
super().__init__(args)
self.search_engine = args.get("search_engine", "google")
self.search_engine_func = {
"google": search_google,
"bing": search_bing
}.get(self.search_engine, "google")
self.url_merge_method = args.get("url_merge_method", "union")
def __call__(self, state: FactCheckerState, *args, **kwargs):
claims = state.get(self.input_name)
queries = self.generate_questions_as_query(claims)
evidences = self.search_evidence(claims, queries)
state.set(self.output_name, evidences)
return True, state
# generate questions and queries based on a claim
def generate_questions_as_query(self, claims,
num_retries: int = 3) -> List[list]:
"""
num_retries: the number of retries when error occurs during openai api calling
"""
query_list = []
for i, claim in enumerate(claims):
for _ in range(num_retries):
try:
response = chatgpt(QGEN_PROMPT + claim)
break
except openai.OpenAIError as exception:
print(f"{exception}. Retrying...")
time.sleep(1)
query_list.append(response)
# print(response)
# print("\n")
# convert openai output: a string into a list of questions/queries
# not check-worthy claims: query response is set as "", accordingly return a []
# other responses are split into a list of questions/queries
automatic_query_list = []
for query in query_list:
if query == "":
automatic_query_list.append([])
else:
new_tmp = []
tmp = query.split("\n")
for q in tmp:
q = q.strip()
if q == "" or q == "Output:":
continue
elif q[:6] == "Output":
q = q[7:].strip()
new_tmp.append(q)
automatic_query_list.append(new_tmp)
return automatic_query_list
# ----------------------------------------------------------
# Evidence Retrieval
# ----------------------------------------------------------
def collect_claim_url_list(self, queries: List[str]) -> List[str]:
"""
collect urls for a claim given the query list:
queries: a list of queries or questions for a claim
search_engine: use which search engine to retrieve evidence, google or bing
url_union_or_intersection: url operation, to merge all -> 'union' or obtain intersection
intersection urls tend to be what is not expected, less relevant
"""
if len(queries) == 0:
print("Invalid queries: []")
return None
urls_list: List[list] = [] # initial list of urls for all queries
url_query_dict: Dict[str, list] = {} # url as key, and list of queries corresponding to this url as value.
url_union, url_intersection = [], []
for query in queries:
urls = self.search_engine_func(query)
urls_list.append(urls)
for i, urls in enumerate(urls_list):
for url in urls:
if url_query_dict.get(url) is None:
url_query_dict[url] = [queries[i]]
else:
url_query_dict[url] = url_query_dict[url] + [queries[i]]
if self.url_merge_method == "union":
for urls in urls_list:
url_union += urls
url_union = list(set(url_union))
assert (len(url_union) == len(url_query_dict.keys()))
return list(url_query_dict.keys()), url_query_dict
elif self.url_merge_method == "intersection":
url_intersection = urls_list[0]
for urls in urls_list[1:]:
url_intersection = list(set(url_intersection).intersection(set(urls)))
return url_intersection, url_query_dict
else:
print("Invalid url operation, please choose from 'union' and 'intersection'.")
return None, url_query_dict
def search_evidence(self,
decontextualised_claims: List[str],
automatic_query_list: List[list],
path_save_evidence: str = "evidence.json",
save_web_text: bool = False) -> Dict[str, Dict[str, Any]]:
assert (len(decontextualised_claims) == len(automatic_query_list))
claim_info: Dict[str, Dict[str, Any]] = {}
for i, claim in enumerate(decontextualised_claims):
queries = automatic_query_list[i]
if len(queries) == 0:
claim_info[claim] = {"claim": claim, "automatic_queries": queries, "evidence_list": []}
print("Claim: {} This is an opinion, not check-worthy.".format(claim))
continue
# for each checkworthy claim, first gather urls of related web pages
urls, url_query_dict = self.collect_claim_url_list(queries)
docs: List[dict] = []
for j, url in enumerate(urls):
web_text, _ = scrape_url(url)
if not web_text is None:
docs.append({"query": url_query_dict[url], "url": url, "web_text": web_text})
else:
continue
print("Claim: {}\nWe retrieved {} urls, {} web pages are accessible.".format(claim, len(urls), len(docs)))
# we can directly use the first k of url_query_dict, as it is the list of google returned.
# Here, we select the most relevent top-k docs against the claim by keyword coverage
# return index of selected documents as the order in docs
if len(docs) != 0:
docs_text = [d['web_text'] for d in docs]
selected_docs_index = select_doc_by_keyword_coverage(claim, docs_text)
print(selected_docs_index)
else:
# no related web articles collected for this claim, continue to next claim
claim_info[claim] = {"claim": claim, "automatic_queries": queries, "evidence_list": []}
continue
selected_docs = [docs_text[i] for i in selected_docs_index]
# score corresponding passages and select the top-5 passages
# return the text of passages; and a list of doc ids for each passage.
# ids here is as the total number and order in selected_docs_index such as in [4, 25, 28, 32, 33]
topk_passages, passage_doc_id = select_passages_by_semantic_similarity(claim, selected_docs)
# recover doc_id to original index in docs which records detailed information of a doc
passage_doc_index = []
for ids in passage_doc_id:
passage_doc_index.append([selected_docs_index[id] for id in ids])
# evidence list
evidence_list: List[dict] = []
for pid, p in enumerate(topk_passages):
doc_ids = passage_doc_index[pid]
if save_web_text:
evidence_list.append({"evidence_id": pid, "web_page_snippet_manual": p,
"query": [docs[doc_id]["query"] for doc_id in doc_ids],
"url": [docs[doc_id]["url"] for doc_id in doc_ids],
"web_text": [docs[doc_id]["web_text"] for doc_id in doc_ids], })
else:
evidence_list.append({"evidence_id": pid, "web_page_snippet_manual": p,
"query": [docs[doc_id]["query"] for doc_id in doc_ids],
"url": [docs[doc_id]["url"] for doc_id in doc_ids],
"web_text": [], })
claim_info[claim] = {"claim": claim, "automatic_queries": queries, "evidence_list": evidence_list}
# write to json file
# Serializing json
json_object = json.dumps(claim_info, indent=4)
# Writing to sample.json
with open(path_save_evidence, "w") as outfile:
outfile.write(json_object)
return claim_info