Spaces:
Sleeping
Sleeping
File size: 3,123 Bytes
009f8e2 23c0953 afb1063 d10fd10 f4b9a92 23c0953 6213036 23c0953 5df676a 23c0953 5df676a 9b8cca7 5df676a 23c0953 f4b9a92 23c0953 5aafc6e afb1063 23c0953 fa07e23 afb1063 90fd9d9 21b62df 90fd9d9 21b62df 23c0953 90fd9d9 fa07e23 5df676a 9b8cca7 fa07e23 90fd9d9 5df676a afb1063 23c0953 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import gradio as gr
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig
# Load the model and tokenizer
model = AutoModelForSeq2SeqLM.from_pretrained("vennify/t5-base-grammar-correction")
tokenizer = AutoTokenizer.from_pretrained("vennify/t5-base-grammar-correction")
def correct_text(text, genConfig):
inputs = tokenizer.encode("grammar: " + text, return_tensors="pt")
outputs = model.generate(inputs, **genConfig.to_dict())
corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return corrected_text
def respond(text, max_length, min_length, max_new_tokens, min_new_tokens, num_beams, num_beam_groups, temperature, top_k, top_p):
config = GenerationConfig(
max_length=max_length,
min_length=min_length,
num_beams=num_beams,
num_beam_groups=num_beam_groups,
temperature=float(temperature),
top_k=top_k,
top_p=float(top_p),
early_stopping=True,
do_sample=True
)
# Add max/min new tokens if they are there
if max_new_tokens > 0:
config.max_new_tokens = max_new_tokens
if min_new_tokens > 0:
config.min_new_tokens = min_new_tokens
corrected = correct_text(text, config)
yield corrected
def update_prompt(prompt):
return prompt
# Create the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("""# Grammar Correction App""")
prompt_box = gr.Textbox(placeholder="Enter your prompt here...")
output_box = gr.Textbox()
# Sample prompts
with gr.Row():
samp1 = gr.Button("we shood buy an car")
samp2 = gr.Button("she is more taller")
samp3 = gr.Button("John and i saw a sheep over their.")
samp1.click(update_prompt, samp1, prompt_box)
samp2.click(update_prompt, samp2, prompt_box)
samp3.click(update_prompt, samp3, prompt_box)
submitBtn = gr.Button("Submit")
with gr.Accordion("Generation Parameters:", open=False):
max_length = gr.Slider(minimum=1, maximum=256, value=80, step=1, label="Max Length")
min_length = gr.Slider(minimum=1, maximum=256, value=0, step=1, label="Min Length")
max_tokens = gr.Slider(minimum=0, maximum=256, value=0, step=1, label="Max New Tokens")
min_tokens = gr.Slider(minimum=0, maximum=256, value=0, step=1, label="Min New Tokens")
num_beams = gr.Slider(minimum=1, maximum=20, value=5, step=1, label="Num Beams")
beam_groups = gr.Slider(minimum=1, maximum=20, value=1, step=1, label="Num Beams Groups")
temperature = gr.Slider(minimum=0.1, maximum=100.0, value=0.7, step=0.1, label="Temperature")
top_k = gr.Slider(minimum=0, maximum=200, value=50, step=1, label="Top-k")
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=1.0, step=0.05, label="Top-p (nucleus sampling)")
submitBtn.click(respond, [prompt_box, max_length, min_length, max_tokens, min_tokens, num_beams, beam_groups, temperature, top_k, top_p], output_box)
demo.launch() |