File size: 3,123 Bytes
009f8e2
23c0953
afb1063
d10fd10
 
 
 
 
f4b9a92
23c0953
 
 
 
6213036
23c0953
5df676a
23c0953
 
 
 
5df676a
 
9b8cca7
5df676a
23c0953
 
 
 
 
 
 
 
 
 
f4b9a92
23c0953
 
5aafc6e
 
afb1063
 
 
23c0953
 
fa07e23
afb1063
90fd9d9
 
21b62df
 
 
90fd9d9
21b62df
 
 
23c0953
90fd9d9
fa07e23
5df676a
 
 
 
 
 
 
 
 
9b8cca7
fa07e23
90fd9d9
5df676a
afb1063
23c0953
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import gradio as gr
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig

# Load the model and tokenizer
model = AutoModelForSeq2SeqLM.from_pretrained("vennify/t5-base-grammar-correction")
tokenizer = AutoTokenizer.from_pretrained("vennify/t5-base-grammar-correction")


def correct_text(text, genConfig):
    inputs = tokenizer.encode("grammar: " + text, return_tensors="pt")
    outputs = model.generate(inputs, **genConfig.to_dict())

    corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return corrected_text

def respond(text, max_length, min_length, max_new_tokens, min_new_tokens, num_beams, num_beam_groups, temperature, top_k, top_p):
    config = GenerationConfig(
        max_length=max_length,
        min_length=min_length,
        num_beams=num_beams,
        num_beam_groups=num_beam_groups,
        temperature=float(temperature),
        top_k=top_k,
        top_p=float(top_p),
        early_stopping=True,
        do_sample=True
    )

    # Add max/min new tokens if they are there
    if max_new_tokens > 0: 
        config.max_new_tokens = max_new_tokens
    if min_new_tokens > 0: 
        config.min_new_tokens = min_new_tokens
    
    corrected = correct_text(text, config)
    yield corrected

def update_prompt(prompt):
    return prompt

# Create the Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("""# Grammar Correction App""")
    prompt_box = gr.Textbox(placeholder="Enter your prompt here...")
    output_box = gr.Textbox()

    # Sample prompts
    with gr.Row():
        samp1 = gr.Button("we shood buy an car")
        samp2 = gr.Button("she is more taller")
        samp3 = gr.Button("John and i saw a sheep over their.")
        
        samp1.click(update_prompt, samp1, prompt_box)
        samp2.click(update_prompt, samp2, prompt_box)
        samp3.click(update_prompt, samp3, prompt_box)
    submitBtn = gr.Button("Submit")
    
    with gr.Accordion("Generation Parameters:", open=False):
        max_length  = gr.Slider(minimum=1,   maximum=256,   value=80,  step=1,    label="Max Length")
        min_length  = gr.Slider(minimum=1,   maximum=256,   value=0,   step=1,    label="Min Length")
        max_tokens  = gr.Slider(minimum=0,   maximum=256,   value=0,   step=1,    label="Max New Tokens")
        min_tokens  = gr.Slider(minimum=0,   maximum=256,   value=0,   step=1,    label="Min New Tokens")
        num_beams   = gr.Slider(minimum=1,   maximum=20,    value=5,   step=1,    label="Num Beams")
        beam_groups = gr.Slider(minimum=1,   maximum=20,    value=1,   step=1,    label="Num Beams Groups")
        temperature = gr.Slider(minimum=0.1, maximum=100.0, value=0.7, step=0.1,  label="Temperature")
        top_k       = gr.Slider(minimum=0,   maximum=200,   value=50,  step=1,    label="Top-k")
        top_p       = gr.Slider(minimum=0.1, maximum=1.0,   value=1.0, step=0.05, label="Top-p (nucleus sampling)")
        

    
    submitBtn.click(respond, [prompt_box, max_length, min_length, max_tokens, min_tokens, num_beams, beam_groups, temperature, top_k, top_p], output_box)

demo.launch()