Spaces:
Sleeping
Sleeping
import pprint | |
def last_boxed_only(sample): | |
q, a = sample | |
a = last_boxed_only_string(a) | |
if a == None: | |
return None | |
return (q, a) | |
def last_boxed_only_string(string): | |
idx = string.rfind("\\boxed") | |
if idx < 0: | |
idx = string.rfind("\\fbox") | |
if idx < 0: | |
return None | |
i = idx | |
right_brace_idx = None | |
num_left_braces_open = 0 | |
while i < len(string): | |
if string[i] == "{": | |
num_left_braces_open += 1 | |
if string[i] == "}": | |
num_left_braces_open -= 1 | |
if num_left_braces_open == 0: | |
right_brace_idx = i | |
break | |
i += 1 | |
if right_brace_idx == None: | |
retval = None | |
else: | |
retval = string[idx:right_brace_idx + 1] | |
return retval | |
def only_until_first_boxed_from_tokens(string, tokens): | |
idx = string.find("\\boxed") | |
if idx < 0: | |
idx = string.find("\\fbox") | |
if idx < 0: | |
return None | |
cum_length = 0 | |
for i, t in enumerate(tokens): | |
cum_length += len(t) | |
if cum_length >= idx: | |
break | |
return tokens[:i] | |
def clean_numbers(sample): | |
if not sample: | |
return None | |
new_sample = list() | |
for s in sample: | |
new_sample.append(_clean_numbers(s)) | |
return tuple(new_sample) | |
def _clean_numbers(string): | |
""" | |
Clean Numbers in the given string | |
>>> _clean_numbers(None, "Hello 123") | |
'Hello 123' | |
>>> _clean_numbers(None, "Hello 1234") | |
'Hello 1,234' | |
>>> _clean_numbers(None, "Hello 1234324asdasd") | |
'Hello 1,234,324asdasd' | |
""" | |
num_prev_digits = 0 | |
new_string = "" | |
for i, c in enumerate(string): | |
# isdigit() doesnt work here because of weird unicode chars. | |
if c in {'1', '2', '3', '4', '5', '6', '7', '8', '9', '0'}: | |
num_prev_digits += 1 | |
else: | |
if num_prev_digits > 3: | |
# Some fixing | |
string_number = new_string[-num_prev_digits:] | |
new_string = new_string[:-num_prev_digits] + "{0:,}".format(int(string_number)) | |
num_prev_digits = 0 | |
new_string += c | |
if num_prev_digits > 3: | |
# Some fixing | |
string_number = new_string[-num_prev_digits:] | |
new_string = new_string[:-num_prev_digits] + "{0:,}".format(int(string_number)) | |
return new_string | |
def fix_fracs(string): | |
substrs = string.split("\\frac") | |
new_str = substrs[0] | |
if len(substrs) > 1: | |
substrs = substrs[1:] | |
for substr in substrs: | |
new_str += "\\frac" | |
if substr[0] == "{": | |
new_str += substr | |
else: | |
try: | |
assert len(substr) >= 2 | |
except AssertionError: | |
return string | |
a = substr[0] | |
b = substr[1] | |
if b != "{": | |
if len(substr) > 2: | |
post_substr = substr[2:] | |
new_str += "{" + a + "}{" + b + "}" + post_substr | |
else: | |
new_str += "{" + a + "}{" + b + "}" | |
else: | |
if len(substr) > 2: | |
post_substr = substr[2:] | |
new_str += "{" + a + "}" + b + post_substr | |
else: | |
new_str += "{" + a + "}" + b | |
string = new_str | |
return string | |
def fix_a_slash_b(string): | |
if len(string.split("/")) != 2: | |
return string | |
a = string.split("/")[0] | |
b = string.split("/")[1] | |
try: | |
a = int(a) | |
b = int(b) | |
assert string == "{}/{}".format(a, b) | |
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" | |
return new_string | |
except AssertionError: | |
return string | |
def remove_right_units(string): | |
# "\\text{ " only ever occurs (at least in the val set) when describing units | |
if "\\text{ " in string: | |
splits = string.split("\\text{ ") | |
assert len(splits) == 2 | |
return splits[0] | |
else: | |
return string | |
def fix_sqrt(string): | |
if "\\sqrt" not in string: | |
return string | |
splits = string.split("\\sqrt") | |
new_string = splits[0] | |
for split in splits[1:]: | |
if split[0] != "{": | |
a = split[0] | |
new_substr = "\\sqrt{" + a + "}" + split[1:] | |
else: | |
new_substr = "\\sqrt" + split | |
new_string += new_substr | |
return new_string | |
def strip_string(string): | |
# linebreaks | |
string = string.replace("\n", "") | |
# remove inverse spaces | |
string = string.replace("\\!", "") | |
# replace \\ with \ | |
string = string.replace("\\\\", "\\") | |
# replace tfrac and dfrac with frac | |
string = string.replace("tfrac", "frac") | |
string = string.replace("dfrac", "frac") | |
# remove \left and \right | |
string = string.replace("\\left", "") | |
string = string.replace("\\right", "") | |
# Remove circ (degrees) | |
string = string.replace("^{\\circ}", "") | |
string = string.replace("^\\circ", "") | |
# remove dollar signs | |
string = string.replace("\\$", "") | |
# remove units (on the right) | |
string = remove_right_units(string) | |
# remove percentage | |
string = string.replace("\\%", "") | |
string = string.replace("\%", "") # noqa: W605 | |
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string | |
string = string.replace(" .", " 0.") | |
string = string.replace("{.", "{0.") | |
# if empty, return empty string | |
if len(string) == 0: | |
return string | |
if string[0] == ".": | |
string = "0" + string | |
# to consider: get rid of e.g. "k = " or "q = " at beginning | |
if len(string.split("=")) == 2: | |
if len(string.split("=")[0]) <= 2: | |
string = string.split("=")[1] | |
# fix sqrt3 --> sqrt{3} | |
string = fix_sqrt(string) | |
# remove spaces | |
string = string.replace(" ", "") | |
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} | |
string = fix_fracs(string) | |
# manually change 0.5 --> \frac{1}{2} | |
if string == "0.5": | |
string = "\\frac{1}{2}" | |
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y | |
string = fix_a_slash_b(string) | |
return string | |
def is_equiv(str1, str2, verbose=False): | |
if str1 is None and str2 is None: | |
print("WARNING: Both None") | |
return True | |
if str1 is None or str2 is None: | |
return False | |
try: | |
ss1 = strip_string(str1) | |
ss2 = strip_string(str2) | |
#pdb.set_trace() | |
if verbose: | |
print(ss1, ss2) | |
return ss1 == ss2 | |
except Exception: | |
return str1 == str2 | |
class NotEqual: | |
def __eq__(self, other): | |
return False |