Spaces:
Running
Running
import gradio as gr | |
import torch | |
from transformers import T5ForConditionalGeneration, T5TokenizerFast | |
tokenizer = T5TokenizerFast.from_pretrained("t5-base") | |
# Define the quantized model architecture | |
quantized_model = T5ForConditionalGeneration.from_pretrained("t5-base") | |
# Load the state dictionary | |
state_dict = torch.load("quantized_model.pt") | |
# Filter out keys that are not present in the quantized model | |
filtered_state_dict = {k: v for k, v in state_dict.items() if k in quantized_model.state_dict()} | |
# Load the filtered state dictionary into the quantized model | |
quantized_model.load_state_dict(filtered_state_dict, strict=False) | |
def encode_text(text): | |
encoding = tokenizer.encode_plus( | |
text, | |
max_length=512, | |
padding="max_length", | |
truncation=True, | |
return_attention_mask=True, | |
return_tensors='pt' | |
) | |
return encoding["input_ids"], encoding["attention_mask"] | |
def generate_summary(input_ids, attention_mask, model): | |
model = model.to(input_ids.device) | |
generated_ids = model.generate( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
max_length=150, | |
num_beams=2, | |
repetition_penalty=2.5, | |
length_penalty=1.0, | |
early_stopping=True | |
) | |
return generated_ids | |
def decode_summary(generated_ids): | |
summary = [tokenizer.decode(gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True) | |
for gen_id in generated_ids] | |
return "".join(summary) | |
def summarize(text): | |
input_ids, attention_mask = encode_text(text) | |
generated_ids = generate_summary(input_ids, attention_mask, quantized_model) | |
summary = decode_summary(generated_ids) | |
return summary | |
# Create Gradio interface | |
input_text = gr.Textbox(lines=10, label="Input Text") | |
output_text = gr.Textbox(label="Summary") | |
gr.Interface( | |
fn=summarize, | |
inputs=input_text, | |
outputs=output_text, | |
title="Poem Pulse", | |
description="Enter a Poem and get its Jist." | |
).launch() | |