File size: 3,357 Bytes
009f8e2
d10fd10
afb1063
d10fd10
 
 
 
95f281d
d10fd10
 
95f281d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d10fd10
 
 
 
 
 
 
 
95f281d
 
d10fd10
 
 
7066d45
d10fd10
afb1063
5aafc6e
 
afb1063
 
 
fa07e23
 
 
 
afb1063
fa07e23
 
afb1063
90fd9d9
 
21b62df
 
 
90fd9d9
21b62df
 
 
90fd9d9
 
fa07e23
 
 
95f281d
 
fa07e23
 
 
 
90fd9d9
95f281d
fa07e23
 
afb1063
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import gradio as gr
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

# Load the model and tokenizer
model = AutoModelForSeq2SeqLM.from_pretrained("vennify/t5-base-grammar-correction")
tokenizer = AutoTokenizer.from_pretrained("vennify/t5-base-grammar-correction")

def correct_text(text, max_length, min_length, max_new_tokens, min_new_tokens, num_beams, temperature, top_p):
    inputs = tokenizer.encode("grammar: " + text, return_tensors="pt")
    
    if max_new_tokens > 0 or min_new_tokens > 0:
        if max_new_tokens > 0 and min_new_tokens > 0:
            outputs = model.generate(
                inputs,
                max_new_tokens=max_new_tokens,
                min_new_tokens=min_new_tokens,
                num_beams=num_beams,
                temperature=temperature,
                top_p=top_p,
                early_stopping=True,
                do_sample=True
            )
        elif max_new_tokens > 0:
            outputs = model.generate(inputs, max_new_tokens=max_new_tokens, min_length=min_length, num_beams=num_beams, temperature=temperature, top_p=top_p, early_stopping=True, do_sample=True)
        else:
            outputs = model.generate(inputs, max_length=max_length, min_new_tokens=min_new_tokens, num_beams=num_beams, temperature=temperature, top_p=top_p, early_stopping=True, do_sample=True)
    else:
        outputs = model.generate(
            inputs,
            max_length=max_length,
            min_length=min_length,
            num_beams=num_beams,
            temperature=temperature,
            top_p=top_p,
            early_stopping=True,
            do_sample=True
        )
    
    corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    yield corrected_text


def update_prompt(prompt):
    return prompt

# Create the Gradio interface
with gr.Blocks() as demo:
    gr.Markdown(
    """
    # Grammar Correction App
    """)
    prompt_box = gr.Textbox(lines=2, placeholder="Enter your prompt here...")
    output_box = gr.Textbox()
    submitBtn = gr.Button("Submit")

    # Sample prompts
    with gr.Row():
        samp1 = gr.Button("we shood buy an car")
        samp2 = gr.Button("she is more taller")
        samp3 = gr.Button("John and i saw a sheep over their.")
        
        samp1.click(update_prompt, samp1, prompt_box)
        samp2.click(update_prompt, samp2, prompt_box)
        samp3.click(update_prompt, samp3, prompt_box)
        
    
    with gr.Accordion("Generation Parameters:", open=False):
        max_length = gr.Slider(minimum=1, maximum=256, value=80, step=1, label="Max Length")
        min_length = gr.Slider(minimum=1, maximum=256, value=0,  step=1, label="Min Length")
        max_tokens = gr.Slider(minimum=0, maximum=256, value=0, step=1, label="Max New Tokens")
        min_tokens = gr.Slider(minimum=0, maximum=256, value=0, step=1, label="Min New Tokens")
        num_beams = gr.Slider(minimum=1, maximum=10, value=5, step=1, label="Num Beams")
        temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature")
        top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")

    
    submitBtn.click(correct_text, [prompt_box, max_length, min_length, max_tokens, min_tokens, num_beams, temperature, top_p], output_box)

    

demo.launch()