Spaces:
Sleeping
Sleeping
import json | |
import os | |
import random | |
import re | |
import time | |
from functools import lru_cache | |
import torch | |
import numpy as np | |
import openai | |
try: | |
import transformers | |
except ImportError: | |
import sys | |
from ditk import logging | |
logging.warning("not found transformer, please install it using: pip install transformers") | |
sys.exit(1) | |
def sample_logits(out: torch.Tensor, temperature: float = 1.0, top_p: float = 0.8) -> int: | |
# Sample an action given the logits. | |
probs = torch.softmax(out, dim=-1).cpu().numpy() | |
sorted_probs = np.sort(probs)[::-1] | |
cumulative_probs = np.cumsum(sorted_probs) | |
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)]) | |
probs[probs < cutoff] = 0 | |
if temperature != 1.0: | |
probs = probs.pow(1.0 / temperature) | |
probs = probs / np.sum(probs) | |
out = np.random.choice(a=len(probs), p=probs) | |
return out | |
def calc_rwkv( | |
model: transformers.RwkvForCausalLM, | |
tokenizer: transformers.AutoTokenizer, | |
prompt: str, | |
max_len: int = 10 | |
) -> str: | |
# Use RWKV to generate sentence. | |
orig_len = len(prompt) | |
inputs = tokenizer(prompt, return_tensors="pt").to('cuda') | |
outputs = model(**inputs, labels=inputs["input_ids"]) | |
out, state = outputs.logits, outputs.state | |
# Recurrent generation. | |
with torch.no_grad(): | |
for i in range(max_len): | |
token = sample_logits(out[0, -1]) | |
tmp = tokenizer.decode([token]) | |
prompt = prompt + tmp | |
inputs = tokenizer(prompt, return_tensors="pt").to('cuda') | |
outputs = model(**inputs, labels=inputs["input_ids"]) | |
out, state = outputs.logits, outputs.state | |
return prompt[orig_len:] | |
def calc_internlm(model, tokenizer, prompt: str, args): | |
inputs = tokenizer(prompt, return_tensors="pt") | |
for k, v in inputs.items(): | |
inputs[k] = v.cuda() | |
gen_kwargs = { | |
"max_length": args.max_tokens, | |
"top_p": args.top_p, | |
"temperature": args.temperature, | |
"do_sample": True, | |
"repetition_penalty": args.frequency_penalty | |
} | |
output = model.generate(**inputs, **gen_kwargs) | |
output = tokenizer.decode(output) | |
return output | |
def load_data(args: dict) -> tuple: | |
# Load tabmwp dataset. | |
random.seed(args.seed) | |
data_root = 'dizoo/tabmwp/data' | |
if not os.path.exists(data_root): | |
os.mkdir(data_root) | |
if not os.path.exists(os.path.join(data_root, f'problems_train.json')): | |
os.system( | |
f'wget https://opendilab.net/download/DI-zoo/tabmwp/problems_train.json -O ' + | |
os.path.join(data_root, f'problems_train.json') + ' --no-check-certificate' | |
) | |
problems = json.load(open(os.path.join(data_root, f'problems_train.json'))) | |
pids = list(problems.keys()) | |
samples = random.sample(pids, args.train_number + args.cand_number) # random sample | |
train_pids = samples[:args.train_number] | |
cand_pids = samples[args.train_number:] | |
return problems, cand_pids, train_pids | |
def get_gpt3_output(prompt: str, args: dict) -> str: | |
return call_gpt3( | |
args.engine, prompt, args.temperature, args.max_tokens, args.top_p, args.frequency_penalty, | |
args.presence_penalty | |
) | |
def call_gpt3( | |
engine: str, prompt: str, temperature: float, max_tokens: int, top_p: float, frequency_penalty: float, | |
presence_penalty: float | |
) -> str: | |
patience = 100 | |
while True: | |
try: | |
response = openai.Completion.create( | |
engine=engine, | |
prompt=prompt, | |
temperature=temperature, | |
max_tokens=max_tokens, | |
top_p=top_p, | |
frequency_penalty=frequency_penalty, | |
presence_penalty=presence_penalty, | |
stop=["\n"] | |
) | |
output = response["choices"][0]["text"].strip() | |
break | |
except Exception: | |
patience -= 1 | |
if not patience: | |
print("!!! running out of patience waiting for OpenAI") | |
else: | |
time.sleep(0.1) | |
return output | |
def get_table_text(problem: dict) -> str: | |
table = problem['table'] | |
title = problem['table_title'] | |
if title and len(title) > 0: | |
table = f"[TITLE]: {title}\n{table}" | |
return table | |
def get_question_text(problem: dict, option_inds: list) -> str: | |
question = problem['question'] | |
unit = problem['unit'] | |
if unit and len(unit) > 0: | |
question = f"{question} (Unit: {unit})" | |
choices = problem['choices'] | |
if choices and len(choices) > 0: | |
choice_list = [] | |
for i, c in enumerate(choices): | |
choice_list.append("({}) {}".format(option_inds[i], c)) | |
options = " ".join(choice_list) | |
question = f"{question}\nOptions: {options}" | |
return question | |
def get_answer(problem: dict) -> str: | |
return problem['answer'] | |
def get_solution_text(problem: dict) -> str: | |
# GPT-3 can generate the solution with more tokens | |
solution = problem['solution'].replace("\n", "\\n") | |
return solution | |
def create_one_example( | |
format: str, table: str, question: str, answer: str, solution: str, test_example: bool = True | |
) -> str: | |
# Using template to generate one prompt example. | |
input_format, output_format = format.split("-") # e.g., "TQ-A" | |
elements = { | |
"Q": f"Question: {question}", | |
"T": f"Table: {table}", | |
"S": f"Solution: {solution}", | |
"A": f"Answer: The answer is {answer}.", | |
"AS": f"Answer: The answer is {answer}. BECAUSE: {solution}", | |
"SA": f"Answer: {solution} The answer is {answer}." | |
} | |
# Input | |
input = "\n".join(elements[label] for label in input_format) | |
# Output | |
if test_example: | |
output = "Answer:" | |
else: | |
output = elements[output_format] | |
# Prompt text | |
text = input + "\n" + output | |
text = text.replace(" ", " ").strip() | |
return text | |
def build_prompt(problems: list, shot_pids: list, test_pid: int, args: dict) -> str: | |
# Given ids, generate the complete prompt. That is, the input to LM. | |
examples = [] | |
pids = shot_pids + [test_pid] | |
# n-shot training examples | |
for pid in pids: | |
problem = problems[pid] | |
table = get_table_text(problem) | |
question = get_question_text(problem, args.option_inds) | |
answer = get_answer(problem) | |
solution = get_solution_text(problems[pid]) | |
if pid == test_pid: | |
assert pid not in shot_pids | |
example = create_one_example(args.prompt_format, table, question, answer, solution, test_example=True) | |
else: | |
example = create_one_example(args.prompt_format, table, question, answer, solution, test_example=False) | |
examples.append(example) | |
# create the prompt input | |
prompt_input = '\n\n'.join(examples) | |
return prompt_input | |
def extract_prediction(output: str, options: list, option_inds: list) -> str: | |
idx = output.find('\n') | |
if idx > 0: | |
output = output[:idx] | |
idx = output.find('=') | |
if idx > 0: | |
output = output[idx + 1:].strip() | |
# $\\frac{16}{95}$ -> 16/95 | |
output = re.sub(r"\$?\\frac\{([\d\.\,\-]+)\}\{([\d\.\,]+)\}\$?", r"\1/\2", output) | |
output = re.sub(r"(?<![AP]\.M)\.$", "", output) | |
output = re.sub(r"(?<=\d)[\=](?=[\-\$\d])", " = ", output) | |
output = re.sub(r"\u2212", "-", output) | |
# Multi-choice questions | |
if options: | |
patterns = [ | |
r'^\(([A-Za-z])\)$', # "(b)", "(B)" | |
r'^([A-Za-z])$', # "b", "B" | |
r'^([A-Za-z]). ', # "b", "B" | |
r'[Th]he answer is ([A-Z])', # "The answer is B" | |
r'^\(([A-Za-z])\) [\s\S]+$', # "(A) XXXXX" | |
r'[Th]he answer is \(([A-Za-z])\) [\s\S]+$', # "The answer is (B) XXXXX." | |
] | |
# have "X" in the output | |
for p in patterns: | |
pattern = re.compile(p) | |
res = pattern.findall(output) | |
if len(res) > 0: | |
pred = res[0].upper() # e.g., "B" | |
if pred in option_inds: | |
ind = option_inds.index(pred) # 1 | |
if ind >= len(options): | |
ind = random.choice(range(len(options))) | |
predition = options[ind] | |
return predition | |
# find the most similar options | |
scores = [score_string_similarity(x, output) for x in options] | |
max_idx = int(np.argmax(scores)) # json does not recognize NumPy data types | |
predition = options[max_idx] | |
return predition | |
else: | |
# free_text QA problems, numeric answer | |
patterns = [ | |
# r'^\([A-Za-z]\) ([\s\S]+)$', # "(A) XXXXX" | |
# r'[Th]he answer is \([A-Za-z]\) ([\s\S]+)$', # "The answer is (B) XXXXX." | |
r'[Th]he answer is ([\s\S]+)$', # "The answer is XXXXX.", | |
r'[Th]he table shows that ([\d\$\.\,\/\:]+) ', | |
r' = ([\d\$\.\,\/\:]+)', # "= $1.40" | |
r'(?<= be| is) ([\-\d\$\.\,\/\:]{0,}[\d]+)', # "will be $1.40" | |
r'(?<= are| was) ([\-\d\$\.\,\/\:]{0,}[\d]+)', # "are $1.40" | |
r'(?<= were) ([\-\d\$\.\,\/\:]{0,}[\d]+)', # "are $1.40" | |
r' ([\d\$\.\,\/\:]+ [AP]\.M\.)', # 7:25 P.M. | |
r'([\-\d\$\.\,\/\:]{0,}[\d]+)', # 14.5 | |
] | |
for p in patterns: | |
pattern = re.compile(p) | |
res = pattern.findall(output) | |
if len(res) > 0: | |
predition = res[-1].strip() | |
if predition.endswith(".") and ".M." not in predition: | |
predition = predition[:-1] | |
return predition | |
return output | |
def normalize_answer(text: str, unit: str) -> str: | |
# ["1,000", "123", "3/4", "56.456", "$56.4", "-3", "-10.02", "-3/2"] | |
text = re.sub("^[\$]", "", text) | |
text = re.sub("[\,\.\,\/]$", "", text) | |
result = re.match("^[-+]?[\d,./]+$", text) | |
if result is not None: | |
# is number? | |
text = text.replace(",", "") | |
result = re.match("[-+]?\d+$", text) | |
try: | |
if result is not None: | |
number = int(text) | |
elif "/" in text: | |
nums = text.split("/") | |
number = round(float(nums[0]) / float(nums[1]), 3) | |
else: | |
number = round(float(text), 3) | |
number = str(number) | |
number = re.sub(r"\.[0]+$", "", number) | |
return number | |
except: | |
return text | |
else: | |
# is text | |
if unit: | |
text = text.replace(unit, "").strip() | |
return text | |
def score_string_similarity(str1: str, str2: str) -> float: | |
if str1 == str2: | |
return 2.0 | |
if " " in str1 or " " in str2: | |
str1_split = str1.split(" ") | |
str2_split = str2.split(" ") | |
overlap = list(set(str1_split) & set(str2_split)) | |
return len(overlap) / max(len(str1_split), len(str2_split)) | |
else: | |
if str1 == str2: | |
return 1.0 | |
else: | |
return 0.0 | |
def create_example_from_pid(pid: int, problems: list, args: dict, test: bool = False) -> str: | |
problem = problems[pid] | |
table = get_table_text(problem) | |
question = get_question_text(problem, args.option_inds) | |
answer = get_answer(problem) | |
solution = get_solution_text(problems[pid]) | |
if test: | |
example = create_one_example(args.prompt_format, table, question, answer, solution, test_example=True) | |
else: | |
example = create_one_example(args.prompt_format, table, question, answer, solution, test_example=False) | |
return example | |