sashdev commited on
Commit
10a975c
·
verified ·
1 Parent(s): 9908a7a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -26
app.py CHANGED
@@ -1,40 +1,35 @@
1
  import gradio as gr
2
  import torch
3
- import asyncio
4
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
 
6
- # Load model and tokenizer
7
- model_name = "hassaanik/grammar-correction-model"
8
- tokenizer = AutoTokenizer.from_pretrained(model_name)
 
9
 
10
- # Use GPU if available, otherwise fallback to CPU
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
13
 
14
- # Use FP16 for faster inference on GPU
15
- if torch.cuda.is_available():
16
- model.half()
17
-
18
- # Async grammar correction function with batch processing
19
- async def correct_grammar_async(texts):
20
- # Tokenize the batch of inputs and move it to the correct device (CPU/GPU)
21
- inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
22
-
23
- # Asynchronous generation process
24
- outputs = await asyncio.to_thread(model.generate, inputs["input_ids"], max_length=512, num_beams=5, early_stopping=True)
25
 
26
- # Decode outputs in parallel
27
- corrected_texts = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
28
- return corrected_texts
 
 
29
 
30
- # Gradio interface function to handle input and output
31
  def correct_grammar_interface(text):
32
- corrected_text = asyncio.run(correct_grammar_async([text]))[0] # Single input for now
33
  return corrected_text
34
 
35
- # Gradio Interface with async capabilities and batch input/output
36
  with gr.Blocks() as grammar_app:
37
- gr.Markdown("<h1>Fast Async Grammar Correction</h1>")
38
 
39
  with gr.Row():
40
  input_box = gr.Textbox(label="Input Text", placeholder="Enter text to be corrected", lines=4)
@@ -42,7 +37,7 @@ with gr.Blocks() as grammar_app:
42
 
43
  submit_button = gr.Button("Correct Grammar")
44
 
45
- # When the button is clicked, run the correction process asynchronously
46
  submit_button.click(fn=correct_grammar_interface, inputs=input_box, outputs=output_box)
47
 
48
  # Launch the app
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
 
4
 
5
+ # Load T5 model and tokenizer
6
+ model_name = "t5-base" # Use a smaller model for faster inference
7
+ tokenizer = T5Tokenizer.from_pretrained(model_name)
8
+ model = T5ForConditionalGeneration.from_pretrained(model_name)
9
 
10
+ # Use GPU if available
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ model.to(device)
13
 
14
+ # Grammar correction function
15
+ def correct_grammar(text):
16
+ input_text = f"correct: {text}"
17
+ input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
 
 
 
 
 
 
 
18
 
19
+ # Generate corrected text
20
+ output_ids = model.generate(input_ids, max_length=512, num_beams=5, early_stopping=True)
21
+ corrected_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
22
+
23
+ return corrected_text
24
 
25
+ # Gradio interface function
26
  def correct_grammar_interface(text):
27
+ corrected_text = correct_grammar(text)
28
  return corrected_text
29
 
30
+ # Gradio interface
31
  with gr.Blocks() as grammar_app:
32
+ gr.Markdown("<h1>Fast Grammar Correction with T5</h1>")
33
 
34
  with gr.Row():
35
  input_box = gr.Textbox(label="Input Text", placeholder="Enter text to be corrected", lines=4)
 
37
 
38
  submit_button = gr.Button("Correct Grammar")
39
 
40
+ # Button click event
41
  submit_button.click(fn=correct_grammar_interface, inputs=input_box, outputs=output_box)
42
 
43
  # Launch the app