Spaces:
Sleeping
Sleeping
File size: 5,298 Bytes
6376749 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import argparse
import json
import re
import jsonlines
from fraction import Fraction
from vllm import LLM, SamplingParams
import sys
MAX_INT = sys.maxsize
def is_number(s):
try:
float(s)
return True
except ValueError:
pass
try:
import unicodedata
unicodedata.numeric(s)
return True
except (TypeError, ValueError):
pass
return False
def extract_answer_number(completion):
text = completion.split('The answer is: ')
if len(text) > 1:
extract_ans = text[-1].strip()
match = re.search(r'[\-+]?\d*[\.,/]?\d+', extract_ans)
if match:
if '/' in match.group():
denominator = match.group().split('/')[1]
numerator = match.group().split('/')[0]
if is_number(denominator) == True and is_number(numerator) == True:
if denominator == '0':
return round(float(numerator.replace(',', '')))
else:
frac = Fraction(match.group().replace(',', ''))
num_numerator = frac.numerator
num_denominator = frac.denominator
return round(float(num_numerator / num_denominator))
else:
return None
else:
if float(match.group().replace(',', '')) == float('inf'):
return None
return round(float(match.group().replace(',', '')))
else:
return None
else:
return None
def batch_data(data_list, batch_size=1):
n = len(data_list) // batch_size
batch_data = []
for i in range(n-1):
start = i * batch_size
end = (i+1)*batch_size
batch_data.append(data_list[start:end])
last_start = (n-1) * batch_size
last_end = MAX_INT
batch_data.append(data_list[last_start:last_end])
return batch_data
def gsm8k_test(model, data_path, start=0, end=MAX_INT, batch_size=1, tensor_parallel_size=1):
INVALID_ANS = "[invalid]"
gsm8k_ins = []
gsm8k_answers = []
problem_prompt = (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response: Let's think step by step."
)
print('promt =====', problem_prompt)
with open(data_path,"r+", encoding="utf8") as f:
for idx, item in enumerate(jsonlines.Reader(f)):
temp_instr = problem_prompt.format(instruction=item["query"])
gsm8k_ins.append(temp_instr)
temp_ans = item['response'].split('#### ')[1]
temp_ans = int(temp_ans.replace(',', ''))
gsm8k_answers.append(temp_ans)
gsm8k_ins = gsm8k_ins[start:end]
gsm8k_answers = gsm8k_answers[start:end]
print('lenght ====', len(gsm8k_ins))
batch_gsm8k_ins = batch_data(gsm8k_ins, batch_size=batch_size)
stop_tokens = ["Question:", "Question", "USER:", "USER", "ASSISTANT:", "ASSISTANT", "Instruction:", "Instruction", "Response:", "Response"]
sampling_params = SamplingParams(temperature=0.0, top_p=1, max_tokens=512, stop=stop_tokens)
print('sampleing =====', sampling_params)
llm = LLM(model=model,tensor_parallel_size=tensor_parallel_size)
result = []
res_completions = []
for idx, (prompt, prompt_answer) in enumerate(zip(batch_gsm8k_ins, gsm8k_answers)):
if isinstance(prompt, list):
pass
else:
prompt = [prompt]
completions = llm.generate(prompt, sampling_params)
for output in completions:
prompt = output.prompt
generated_text = output.outputs[0].text
res_completions.append(generated_text)
invalid_outputs = []
for idx, (prompt, completion, prompt_answer) in enumerate(zip(gsm8k_ins, res_completions, gsm8k_answers)):
doc = {'question': prompt}
y_pred = extract_answer_number(completion)
if y_pred != None:
result.append(float(y_pred) == float(prompt_answer))
else:
result.append(False)
temp = {'question': prompt, 'output': completion, 'answer': prompt_answer}
invalid_outputs.append(temp)
acc = sum(result) / len(result)
print('len invalid outputs ====', len(invalid_outputs), ', valid_outputs===', invalid_outputs)
print('start===', start, ', end====', end)
print('gsm8k length====', len(result), ', gsm8k acc====', acc)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str) # model path
parser.add_argument("--data_file", type=str, default='') # data path
parser.add_argument("--start", type=int, default=0) #start index
parser.add_argument("--end", type=int, default=MAX_INT) # end index
parser.add_argument("--batch_size", type=int, default=400) # batch_size
parser.add_argument("--tensor_parallel_size", type=int, default=8) # tensor_parallel_size
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
gsm8k_test(model=args.model, data_path=args.data_file, start=args.start, end=args.end, batch_size=args.batch_size, tensor_parallel_size=args.tensor_parallel_size)
|