import pprint def last_boxed_only(sample): q, a = sample a = last_boxed_only_string(a) if a == None: return None return (q, a) def last_boxed_only_string(string): idx = string.rfind("\\boxed") if idx < 0: idx = string.rfind("\\fbox") if idx < 0: return None i = idx right_brace_idx = None num_left_braces_open = 0 while i < len(string): if string[i] == "{": num_left_braces_open += 1 if string[i] == "}": num_left_braces_open -= 1 if num_left_braces_open == 0: right_brace_idx = i break i += 1 if right_brace_idx == None: retval = None else: retval = string[idx:right_brace_idx + 1] return retval def only_until_first_boxed_from_tokens(string, tokens): idx = string.find("\\boxed") if idx < 0: idx = string.find("\\fbox") if idx < 0: return None cum_length = 0 for i, t in enumerate(tokens): cum_length += len(t) if cum_length >= idx: break return tokens[:i] def clean_numbers(sample): if not sample: return None new_sample = list() for s in sample: new_sample.append(_clean_numbers(s)) return tuple(new_sample) def _clean_numbers(string): """ Clean Numbers in the given string >>> _clean_numbers(None, "Hello 123") 'Hello 123' >>> _clean_numbers(None, "Hello 1234") 'Hello 1,234' >>> _clean_numbers(None, "Hello 1234324asdasd") 'Hello 1,234,324asdasd' """ num_prev_digits = 0 new_string = "" for i, c in enumerate(string): # isdigit() doesnt work here because of weird unicode chars. if c in {'1', '2', '3', '4', '5', '6', '7', '8', '9', '0'}: num_prev_digits += 1 else: if num_prev_digits > 3: # Some fixing string_number = new_string[-num_prev_digits:] new_string = new_string[:-num_prev_digits] + "{0:,}".format(int(string_number)) num_prev_digits = 0 new_string += c if num_prev_digits > 3: # Some fixing string_number = new_string[-num_prev_digits:] new_string = new_string[:-num_prev_digits] + "{0:,}".format(int(string_number)) return new_string def fix_fracs(string): substrs = string.split("\\frac") new_str = substrs[0] if len(substrs) > 1: substrs = substrs[1:] for substr in substrs: new_str += "\\frac" if substr[0] == "{": new_str += substr else: try: assert len(substr) >= 2 except AssertionError: return string a = substr[0] b = substr[1] if b != "{": if len(substr) > 2: post_substr = substr[2:] new_str += "{" + a + "}{" + b + "}" + post_substr else: new_str += "{" + a + "}{" + b + "}" else: if len(substr) > 2: post_substr = substr[2:] new_str += "{" + a + "}" + b + post_substr else: new_str += "{" + a + "}" + b string = new_str return string def fix_a_slash_b(string): if len(string.split("/")) != 2: return string a = string.split("/")[0] b = string.split("/")[1] try: a = int(a) b = int(b) assert string == "{}/{}".format(a, b) new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" return new_string except AssertionError: return string def remove_right_units(string): # "\\text{ " only ever occurs (at least in the val set) when describing units if "\\text{ " in string: splits = string.split("\\text{ ") assert len(splits) == 2 return splits[0] else: return string def fix_sqrt(string): if "\\sqrt" not in string: return string splits = string.split("\\sqrt") new_string = splits[0] for split in splits[1:]: if split[0] != "{": a = split[0] new_substr = "\\sqrt{" + a + "}" + split[1:] else: new_substr = "\\sqrt" + split new_string += new_substr return new_string def strip_string(string): # linebreaks string = string.replace("\n", "") # remove inverse spaces string = string.replace("\\!", "") # replace \\ with \ string = string.replace("\\\\", "\\") # replace tfrac and dfrac with frac string = string.replace("tfrac", "frac") string = string.replace("dfrac", "frac") # remove \left and \right string = string.replace("\\left", "") string = string.replace("\\right", "") # Remove circ (degrees) string = string.replace("^{\\circ}", "") string = string.replace("^\\circ", "") # remove dollar signs string = string.replace("\\$", "") # remove units (on the right) string = remove_right_units(string) # remove percentage string = string.replace("\\%", "") string = string.replace("\%", "") # noqa: W605 # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string string = string.replace(" .", " 0.") string = string.replace("{.", "{0.") # if empty, return empty string if len(string) == 0: return string if string[0] == ".": string = "0" + string # to consider: get rid of e.g. "k = " or "q = " at beginning if len(string.split("=")) == 2: if len(string.split("=")[0]) <= 2: string = string.split("=")[1] # fix sqrt3 --> sqrt{3} string = fix_sqrt(string) # remove spaces string = string.replace(" ", "") # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} string = fix_fracs(string) # manually change 0.5 --> \frac{1}{2} if string == "0.5": string = "\\frac{1}{2}" # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y string = fix_a_slash_b(string) return string def is_equiv(str1, str2, verbose=False): if str1 is None and str2 is None: print("WARNING: Both None") return True if str1 is None or str2 is None: return False try: ss1 = strip_string(str1) ss2 = strip_string(str2) #pdb.set_trace() if verbose: print(ss1, ss2) return ss1 == ss2 except Exception: return str1 == str2 class NotEqual: def __eq__(self, other): return False