ZeeAI1 commited on
Commit
aad4ca6
Β·
verified Β·
1 Parent(s): 018dcdf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -26
app.py CHANGED
@@ -1,51 +1,56 @@
1
  import streamlit as st
2
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
  import torch
4
- import difflib
5
 
6
- # Load the grammar correction model
7
  @st.cache_resource
8
- def load_model():
9
  model_name = "prithivida/grammar_error_correcter_v1"
10
  tokenizer = AutoTokenizer.from_pretrained(model_name)
11
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
12
  return tokenizer, model
13
 
14
- tokenizer, model = load_model()
 
 
 
 
 
 
 
15
 
16
- # Function to correct grammar
17
  def correct_grammar(text):
18
  input_text = "gec: " + text
19
- inputs = tokenizer.encode(input_text, return_tensors="pt", truncation=True)
20
- outputs = model.generate(inputs, max_length=512, num_beams=4, early_stopping=True)
21
- corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
22
  return corrected_text
23
 
24
- # Function to show differences
25
- def show_differences(original, corrected):
26
- diff = difflib.ndiff(original.split(), corrected.split())
27
- changes = []
28
- for d in diff:
29
- if d.startswith("- "):
30
- changes.append(f"❌ Removed: `{d[2:]}`")
31
- elif d.startswith("+ "):
32
- changes.append(f"βœ… Added: `{d[2:]}`")
33
- return "\n".join(changes) if changes else "No major changes found."
34
-
35
- # Streamlit UI
36
- st.title("πŸ“ Grammar Correction App")
37
- st.write("Enter a sentence or paragraph below, and the AI will correct grammatical errors and highlight the changes.")
38
 
39
  user_input = st.text_area("Your Text", height=200, placeholder="Type or paste your text here...")
40
 
41
- if st.button("Correct Grammar"):
42
  if user_input.strip():
43
  with st.spinner("Correcting grammar..."):
44
  corrected = correct_grammar(user_input)
 
 
 
 
45
  st.subheader("βœ… Corrected Text")
46
  st.success(corrected)
47
 
48
- st.subheader("πŸ•΅οΈ What Changed?")
49
- st.markdown(show_differences(user_input, corrected))
50
  else:
51
  st.warning("Please enter some text to correct.")
 
1
  import streamlit as st
2
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
3
  import torch
 
4
 
5
+ # Load grammar correction model
6
  @st.cache_resource
7
+ def load_grammar_model():
8
  model_name = "prithivida/grammar_error_correcter_v1"
9
  tokenizer = AutoTokenizer.from_pretrained(model_name)
10
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
11
  return tokenizer, model
12
 
13
+ # Load explanation model
14
+ @st.cache_resource
15
+ def load_explainer():
16
+ explainer = pipeline("text2text-generation", model="google/flan-t5-base", max_length=256)
17
+ return explainer
18
+
19
+ grammar_tokenizer, grammar_model = load_grammar_model()
20
+ explanation_model = load_explainer()
21
 
22
+ # Grammar correction function
23
  def correct_grammar(text):
24
  input_text = "gec: " + text
25
+ inputs = grammar_tokenizer.encode(input_text, return_tensors="pt", truncation=True)
26
+ outputs = grammar_model.generate(inputs, max_length=512, num_beams=4, early_stopping=True)
27
+ corrected_text = grammar_tokenizer.decode(outputs[0], skip_special_tokens=True)
28
  return corrected_text
29
 
30
+ # Explanation function using second model
31
+ def explain_correction(original, corrected):
32
+ prompt = f"Explain the grammar improvements made from: \"{original}\" to: \"{corrected}\""
33
+ result = explanation_model(prompt)[0]['generated_text']
34
+ return result
35
+
36
+ # Streamlit App UI
37
+ st.title("πŸ“ Smart Grammar Correction with Explanations")
38
+ st.write("Enter your text below. The AI will correct grammar **and explain why** the changes were made using grammar principles.")
 
 
 
 
 
39
 
40
  user_input = st.text_area("Your Text", height=200, placeholder="Type or paste your text here...")
41
 
42
+ if st.button("Correct and Explain"):
43
  if user_input.strip():
44
  with st.spinner("Correcting grammar..."):
45
  corrected = correct_grammar(user_input)
46
+
47
+ with st.spinner("Explaining corrections..."):
48
+ explanation = explain_correction(user_input, corrected)
49
+
50
  st.subheader("βœ… Corrected Text")
51
  st.success(corrected)
52
 
53
+ st.subheader("πŸ“˜ Explanation (Why it was changed)")
54
+ st.markdown(explanation)
55
  else:
56
  st.warning("Please enter some text to correct.")