Spaces:
Sleeping
Sleeping
File size: 3,357 Bytes
009f8e2 d10fd10 afb1063 d10fd10 95f281d d10fd10 95f281d d10fd10 95f281d d10fd10 7066d45 d10fd10 afb1063 5aafc6e afb1063 fa07e23 afb1063 fa07e23 afb1063 90fd9d9 21b62df 90fd9d9 21b62df 90fd9d9 fa07e23 95f281d fa07e23 90fd9d9 95f281d fa07e23 afb1063 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import gradio as gr
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
# Load the model and tokenizer
model = AutoModelForSeq2SeqLM.from_pretrained("vennify/t5-base-grammar-correction")
tokenizer = AutoTokenizer.from_pretrained("vennify/t5-base-grammar-correction")
def correct_text(text, max_length, min_length, max_new_tokens, min_new_tokens, num_beams, temperature, top_p):
inputs = tokenizer.encode("grammar: " + text, return_tensors="pt")
if max_new_tokens > 0 or min_new_tokens > 0:
if max_new_tokens > 0 and min_new_tokens > 0:
outputs = model.generate(
inputs,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
num_beams=num_beams,
temperature=temperature,
top_p=top_p,
early_stopping=True,
do_sample=True
)
elif max_new_tokens > 0:
outputs = model.generate(inputs, max_new_tokens=max_new_tokens, min_length=min_length, num_beams=num_beams, temperature=temperature, top_p=top_p, early_stopping=True, do_sample=True)
else:
outputs = model.generate(inputs, max_length=max_length, min_new_tokens=min_new_tokens, num_beams=num_beams, temperature=temperature, top_p=top_p, early_stopping=True, do_sample=True)
else:
outputs = model.generate(
inputs,
max_length=max_length,
min_length=min_length,
num_beams=num_beams,
temperature=temperature,
top_p=top_p,
early_stopping=True,
do_sample=True
)
corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
yield corrected_text
def update_prompt(prompt):
return prompt
# Create the Gradio interface
with gr.Blocks() as demo:
gr.Markdown(
"""
# Grammar Correction App
""")
prompt_box = gr.Textbox(lines=2, placeholder="Enter your prompt here...")
output_box = gr.Textbox()
submitBtn = gr.Button("Submit")
# Sample prompts
with gr.Row():
samp1 = gr.Button("we shood buy an car")
samp2 = gr.Button("she is more taller")
samp3 = gr.Button("John and i saw a sheep over their.")
samp1.click(update_prompt, samp1, prompt_box)
samp2.click(update_prompt, samp2, prompt_box)
samp3.click(update_prompt, samp3, prompt_box)
with gr.Accordion("Generation Parameters:", open=False):
max_length = gr.Slider(minimum=1, maximum=256, value=80, step=1, label="Max Length")
min_length = gr.Slider(minimum=1, maximum=256, value=0, step=1, label="Min Length")
max_tokens = gr.Slider(minimum=0, maximum=256, value=0, step=1, label="Max New Tokens")
min_tokens = gr.Slider(minimum=0, maximum=256, value=0, step=1, label="Min New Tokens")
num_beams = gr.Slider(minimum=1, maximum=10, value=5, step=1, label="Num Beams")
temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature")
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")
submitBtn.click(correct_text, [prompt_box, max_length, min_length, max_tokens, min_tokens, num_beams, temperature, top_p], output_box)
demo.launch()
|