SVFT_PEFT / SVFT-main /MetaMath /eval_gsm8k.py
sakshi7502's picture
Upload 64 files
6376749 verified
import argparse
import json
import re
import jsonlines
from fraction import Fraction
from vllm import LLM, SamplingParams
import sys
MAX_INT = sys.maxsize
def is_number(s):
try:
float(s)
return True
except ValueError:
pass
try:
import unicodedata
unicodedata.numeric(s)
return True
except (TypeError, ValueError):
pass
return False
def extract_answer_number(completion):
text = completion.split('The answer is: ')
if len(text) > 1:
extract_ans = text[-1].strip()
match = re.search(r'[\-+]?\d*[\.,/]?\d+', extract_ans)
if match:
if '/' in match.group():
denominator = match.group().split('/')[1]
numerator = match.group().split('/')[0]
if is_number(denominator) == True and is_number(numerator) == True:
if denominator == '0':
return round(float(numerator.replace(',', '')))
else:
frac = Fraction(match.group().replace(',', ''))
num_numerator = frac.numerator
num_denominator = frac.denominator
return round(float(num_numerator / num_denominator))
else:
return None
else:
if float(match.group().replace(',', '')) == float('inf'):
return None
return round(float(match.group().replace(',', '')))
else:
return None
else:
return None
def batch_data(data_list, batch_size=1):
n = len(data_list) // batch_size
batch_data = []
for i in range(n-1):
start = i * batch_size
end = (i+1)*batch_size
batch_data.append(data_list[start:end])
last_start = (n-1) * batch_size
last_end = MAX_INT
batch_data.append(data_list[last_start:last_end])
return batch_data
def gsm8k_test(model, data_path, start=0, end=MAX_INT, batch_size=1, tensor_parallel_size=1):
INVALID_ANS = "[invalid]"
gsm8k_ins = []
gsm8k_answers = []
problem_prompt = (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response: Let's think step by step."
)
print('promt =====', problem_prompt)
with open(data_path,"r+", encoding="utf8") as f:
for idx, item in enumerate(jsonlines.Reader(f)):
temp_instr = problem_prompt.format(instruction=item["query"])
gsm8k_ins.append(temp_instr)
temp_ans = item['response'].split('#### ')[1]
temp_ans = int(temp_ans.replace(',', ''))
gsm8k_answers.append(temp_ans)
gsm8k_ins = gsm8k_ins[start:end]
gsm8k_answers = gsm8k_answers[start:end]
print('lenght ====', len(gsm8k_ins))
batch_gsm8k_ins = batch_data(gsm8k_ins, batch_size=batch_size)
stop_tokens = ["Question:", "Question", "USER:", "USER", "ASSISTANT:", "ASSISTANT", "Instruction:", "Instruction", "Response:", "Response"]
sampling_params = SamplingParams(temperature=0.0, top_p=1, max_tokens=512, stop=stop_tokens)
print('sampleing =====', sampling_params)
llm = LLM(model=model,tensor_parallel_size=tensor_parallel_size)
result = []
res_completions = []
for idx, (prompt, prompt_answer) in enumerate(zip(batch_gsm8k_ins, gsm8k_answers)):
if isinstance(prompt, list):
pass
else:
prompt = [prompt]
completions = llm.generate(prompt, sampling_params)
for output in completions:
prompt = output.prompt
generated_text = output.outputs[0].text
res_completions.append(generated_text)
invalid_outputs = []
for idx, (prompt, completion, prompt_answer) in enumerate(zip(gsm8k_ins, res_completions, gsm8k_answers)):
doc = {'question': prompt}
y_pred = extract_answer_number(completion)
if y_pred != None:
result.append(float(y_pred) == float(prompt_answer))
else:
result.append(False)
temp = {'question': prompt, 'output': completion, 'answer': prompt_answer}
invalid_outputs.append(temp)
acc = sum(result) / len(result)
print('len invalid outputs ====', len(invalid_outputs), ', valid_outputs===', invalid_outputs)
print('start===', start, ', end====', end)
print('gsm8k length====', len(result), ', gsm8k acc====', acc)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str) # model path
parser.add_argument("--data_file", type=str, default='') # data path
parser.add_argument("--start", type=int, default=0) #start index
parser.add_argument("--end", type=int, default=MAX_INT) # end index
parser.add_argument("--batch_size", type=int, default=400) # batch_size
parser.add_argument("--tensor_parallel_size", type=int, default=8) # tensor_parallel_size
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
gsm8k_test(model=args.model, data_path=args.data_file, start=args.start, end=args.end, batch_size=args.batch_size, tensor_parallel_size=args.tensor_parallel_size)