Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import torch | |
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
# Load models | |
implicit_cot_model_name = 'yuntian-deng/gpt2-implicit-cot-multiplication' | |
implicit_cot_model = AutoModelForCausalLM.from_pretrained(implicit_cot_model_name) | |
tokenizer = AutoTokenizer.from_pretrained(implicit_cot_model_name) | |
no_cot_model_name = 'yuntian-deng/gpt2-no-cot-multiplication' | |
no_cot_model = AutoModelForCausalLM.from_pretrained(no_cot_model_name) | |
explicit_cot_model_name = 'yuntian-deng/gpt2-explicit-cot-multiplication' | |
explicit_cot_model = AutoModelForCausalLM.from_pretrained(explicit_cot_model_name) | |
models = {'implicit': implicit_cot_model, 'no': no_cot_model, 'explicit': explicit_cot_model} | |
# Constants | |
MAX_PRODUCT_DIGITS_PER_MODEL = {'implicit': 100, 'no': 100, 'explicit': 900} | |
def preprocess(num): | |
num = str(num).strip().replace(' ', '') | |
reversed_num = ' '.join(num[::-1]) | |
return reversed_num | |
def postprocess(raw_output): | |
prediction = raw_output.replace(' ', '')[::-1] | |
return prediction | |
def predict_product(num1, num2): | |
input_text = f'{preprocess(num1)} * {preprocess(num2)} =' | |
inputs = tokenizer(input_text, return_tensors='pt').to('cuda' if torch.cuda.is_available() else 'cpu') | |
[model.to('cuda' if torch.cuda.is_available() else 'cpu') for model in models.values()] | |
input_ids = inputs['input_ids'] | |
input_len = input_ids.shape[-1] | |
prediction = "" | |
ground_truth_product = "" | |
valid_input = True | |
try: | |
num1_int = int(num1) | |
num2_int = int(num2) | |
ground_truth_product = str(num1_int * num2_int) | |
ground_truth_digits_reversed = list(ground_truth_product)[::-1] | |
except ValueError: | |
valid_input = False | |
generated_ids_per_model = {model_name: inputs['input_ids'].data.clone() for model_name in models} | |
finished_per_model = {model_name: False for model_name in models} | |
past_key_values_per_model = {model_name: None for model_name in models} | |
predicted_annotations_per_model = {} | |
for step in range(max(MAX_PRODUCT_DIGITS_PER_MODEL.values())): # Set a maximum limit to prevent infinite loops | |
# Ground Truth | |
ground_truth_annotations = [(ground_truth_digit, None) for ground_truth_digit in ground_truth_digits_reversed[:step+1]] | |
ground_truth_annotations = ground_truth_annotations[::-1] | |
# Predicted | |
for model_name in models: | |
model = models[model_name] | |
if finished_per_model[model_name]: | |
continue | |
if step >= MAX_PRODUCT_DIGITS_PER_MODEL[model_name]: | |
continue | |
generation_kwargs = { | |
'input_ids': generated_ids_per_model[model_name], | |
'max_new_tokens': 1, | |
'do_sample': False, | |
'past_key_values': past_key_values_per_model[model_name], | |
'return_dict_in_generate': True, | |
'use_cache': True | |
} | |
if step == 0: | |
del generation_kwargs['past_key_values'] | |
outputs = model.generate(**generation_kwargs) | |
generated_ids = outputs.sequences | |
next_token_id = generated_ids[0, -1] | |
print (next_token_id) | |
if next_token_id.item() == tokenizer.eos_token_id: | |
finished_per_model[model_name] = True | |
continue | |
generated_ids_per_model[model_name] = generated_ids | |
past_key_values_per_model[model_name] = outputs.past_key_values | |
output_text = tokenizer.decode(generated_ids[0, input_len:], skip_special_tokens=True) | |
predicted_digits_reversed = output_text.strip().split(' ') | |
predicted_annotations = [] | |
is_correct_sofar = True | |
if model_name == 'explicit': | |
if '=' not in predicted_digits_reversed: | |
predicted_annotations = [(predicted_digit, None) for predicted_digit in predicted_digits_reversed] | |
predicted_digits_reversed = [] | |
else: | |
equal_sign_position = predicted_digits_reversed.index('=') | |
predicted_annotations = [(predicted_digit, None) for predicted_digit in predicted_digits_reversed[:equal_sign_position+1]] | |
predicted_digits_reversed = predicted_digits_reversed[equal_sign_position+1:] | |
for i in range(len(predicted_digits_reversed)): | |
predicted_digit = predicted_digits_reversed[i] | |
if i >= len(ground_truth_digits_reversed): | |
if predicted_digit == '0' and is_correct_sofar: | |
is_correct_digit = True | |
else: | |
is_correct_digit = False | |
else: | |
ground_truth_digit = ground_truth_digits_reversed[i] | |
if predicted_digit == ground_truth_digit: | |
is_correct_digit = True | |
else: | |
is_correct_digit = False | |
if not is_correct_digit: | |
is_correct_sofar = False | |
if is_correct_digit: | |
predicted_annotations.append((predicted_digit, "correct")) | |
else: | |
predicted_annotations.append((predicted_digit, "wrong")) | |
predicted_annotations = predicted_annotations[::-1] | |
predicted_annotations_per_model[model_name] = predicted_annotations | |
predicted_annotations_implicit_cot = predicted_annotations_per_model['implicit'] | |
predicted_annotations_nocot = predicted_annotations_per_model['no'] | |
predicted_annotations_explicit_cot = predicted_annotations_per_model['explicit'] | |
yield ground_truth_annotations, predicted_annotations_implicit_cot, predicted_annotations_nocot, predicted_annotations_explicit_cot | |
color_map = {"correct": "green", "wrong": "red"} | |
demo = gr.Interface( | |
fn=predict_product, | |
inputs=[ | |
gr.Textbox(label='First Number (up to 12 digits)', value='123456789'), | |
gr.Textbox(label='Second Number (up to 12 digits)', value='987654321'), | |
], | |
outputs=[ | |
gr.HighlightedText(label='Ground Truth Product', combine_adjacent=False, show_legend=False, color_map=color_map), | |
gr.HighlightedText(label='Implicit CoT Predicted Product', combine_adjacent=False, show_legend=False, color_map=color_map, show_inline_category=False), | |
gr.HighlightedText(label='No CoT Predicted Product', combine_adjacent=False, show_legend=False, color_map=color_map, show_inline_category=False), | |
gr.HighlightedText(label='Explicit CoT Predicted Product', combine_adjacent=False, show_legend=False, color_map=color_map, show_inline_category=False), | |
], | |
title='Use GPT2 to Predict Multiplication of Two Numbers (Without Using Intermediate Steps)', | |
description='This demo shows it\'s possible to use GPT2 to directly predict the product of two large numbers without using any intermediate reasoning steps. The GPT2 model has been finetuned to internalize chain-of-thought (CoT) reasoning within its hidden states, following our stepwise internalization approach detailed in the paper linked at the bottom of this page.', | |
article=""" | |
- [Paper: From Explicit CoT to Implicit CoT: Learning to Internalize CoT Step by Step](https://arxiv.org/pdf/2405.14838) | |
- [Code Repository](https://github.com/da03/Internalize_CoT_Step_by_Step) | |
- [Tweet Announcement](https://twitter.com/yuntiandeng/status/1795854740879774036) | |
""", | |
clear_btn=None, | |
submit_btn="Multiply!", | |
live=False, | |
concurrency_limit=1 | |
) | |
demo.queue(max_size=20).launch() | |