Spaces:
Running
Running
import streamlit as st | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
import torch | |
# β Page config must be first | |
st.set_page_config(page_title="Grammar Fixer & Coach", layout="centered") | |
# Load grammar correction model | |
def load_grammar_model(): | |
model_name = "vennify/t5-base-grammar-correction" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
return tokenizer, model | |
# Load explanation model | |
def load_explanation_model(): | |
return pipeline("text2text-generation", model="google/flan-t5-large", max_length=512) | |
grammar_tokenizer, grammar_model = load_grammar_model() | |
explanation_model = load_explanation_model() | |
# Grammar correction function | |
def correct_grammar(text): | |
inputs = grammar_tokenizer.encode(text, return_tensors="pt", truncation=True) | |
outputs = grammar_model.generate(inputs, max_length=512, num_beams=4, early_stopping=True) | |
corrected = grammar_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return corrected | |
# Explanation function | |
def get_detailed_feedback(original, corrected): | |
prompt = ( | |
f"Analyze and explain all grammar, spelling, and punctuation corrections made when changing the following sentence:\n\n" | |
f"Original: {original}\n" | |
f"Corrected: {corrected}\n\n" | |
f"Give a list of corrections with the reason for each, and also suggest how the user can improve their writing." | |
) | |
explanation = explanation_model(prompt)[0]['generated_text'] | |
return explanation | |
# Streamlit UI | |
st.title("π§ Grammar Fixer & Writing Coach") | |
st.write("Paste your sentence or paragraph. The AI will correct it and explain each fix to help you learn.") | |
user_input = st.text_area("βοΈ Enter your text below:", height=200, placeholder="e.g., I, want you! to please foucs on you work only!!") | |
if st.button("Correct & Explain"): | |
if user_input.strip(): | |
with st.spinner("Correcting grammar..."): | |
corrected = correct_grammar(user_input) | |
with st.spinner("Explaining corrections..."): | |
explanation = get_detailed_feedback(user_input, corrected) | |
st.subheader("β Corrected Text") | |
st.success(corrected) | |
st.subheader("π Detailed Explanation") | |
st.markdown(explanation) | |
else: | |
st.warning("Please enter some text.") | |