Ozgur Unlu
commited on
Commit
·
761644c
1
Parent(s):
b00d113
more error fixes for grammar check
Browse files
app.py
CHANGED
@@ -3,8 +3,7 @@ import torch
|
|
3 |
from transformers import (
|
4 |
AutoTokenizer,
|
5 |
AutoModelForSequenceClassification,
|
6 |
-
|
7 |
-
T5Tokenizer
|
8 |
)
|
9 |
import os
|
10 |
from pdf_generator import ReportGenerator
|
@@ -19,9 +18,9 @@ def load_models():
|
|
19 |
hate_tokenizer = AutoTokenizer.from_pretrained("facebook/roberta-hate-speech-dynabench-r4-target")
|
20 |
hate_model = AutoModelForSequenceClassification.from_pretrained("facebook/roberta-hate-speech-dynabench-r4-target")
|
21 |
|
22 |
-
# Grammar check model
|
23 |
-
grammar_tokenizer =
|
24 |
-
grammar_model =
|
25 |
|
26 |
return {
|
27 |
'hate_speech': (hate_model, hate_tokenizer),
|
@@ -71,13 +70,15 @@ def check_hate_speech(text, model, tokenizer):
|
|
71 |
|
72 |
def check_grammar(text, model, tokenizer):
|
73 |
try:
|
74 |
-
|
|
|
75 |
|
76 |
outputs = model.generate(
|
77 |
-
input_ids,
|
|
|
78 |
max_length=512,
|
79 |
-
num_beams=
|
80 |
-
|
81 |
)
|
82 |
|
83 |
corrected = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
@@ -98,51 +99,60 @@ def check_grammar(text, model, tokenizer):
|
|
98 |
}
|
99 |
|
100 |
def analyze_content(text):
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
report_path = report_gen.save_report()
|
|
|
119 |
return results, report_path
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
results['Grammar Check'] = grammar_result
|
129 |
-
report_gen.add_check_result("Grammar Check", grammar_result['status'], grammar_result['message'])
|
130 |
-
|
131 |
-
# 4. News Context Check
|
132 |
-
if os.getenv('NEWS_API_KEY'):
|
133 |
-
news_result = news_checker.check_content_against_news(text)
|
134 |
-
else:
|
135 |
-
news_result = {
|
136 |
-
'status': 'warning',
|
137 |
-
'message': 'News API key not configured. Skipping current events check.'
|
138 |
-
}
|
139 |
-
results['Current Events Context'] = news_result
|
140 |
-
report_gen.add_check_result("Current Events Context", news_result['status'], news_result['message'])
|
141 |
-
|
142 |
-
# Generate and save report
|
143 |
-
report_path = report_gen.save_report()
|
144 |
-
|
145 |
-
return results, report_path
|
146 |
|
147 |
def format_results(results):
|
148 |
status_symbols = {
|
@@ -201,7 +211,7 @@ def create_interface():
|
|
201 |
- Text length
|
202 |
- Hate speech and bias
|
203 |
- Grammar
|
204 |
-
- Current events context
|
205 |
""")
|
206 |
|
207 |
return interface
|
|
|
3 |
from transformers import (
|
4 |
AutoTokenizer,
|
5 |
AutoModelForSequenceClassification,
|
6 |
+
AutoModelForSeq2SeqLM
|
|
|
7 |
)
|
8 |
import os
|
9 |
from pdf_generator import ReportGenerator
|
|
|
18 |
hate_tokenizer = AutoTokenizer.from_pretrained("facebook/roberta-hate-speech-dynabench-r4-target")
|
19 |
hate_model = AutoModelForSequenceClassification.from_pretrained("facebook/roberta-hate-speech-dynabench-r4-target")
|
20 |
|
21 |
+
# Grammar check model
|
22 |
+
grammar_tokenizer = AutoTokenizer.from_pretrained("vennify/t5-base-grammar-correction")
|
23 |
+
grammar_model = AutoModelForSeq2SeqLM.from_pretrained("vennify/t5-base-grammar-correction")
|
24 |
|
25 |
return {
|
26 |
'hate_speech': (hate_model, hate_tokenizer),
|
|
|
70 |
|
71 |
def check_grammar(text, model, tokenizer):
|
72 |
try:
|
73 |
+
input_text = f"grammar: {text}"
|
74 |
+
encoding = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
|
75 |
|
76 |
outputs = model.generate(
|
77 |
+
input_ids=encoding.input_ids,
|
78 |
+
attention_mask=encoding.attention_mask,
|
79 |
max_length=512,
|
80 |
+
num_beams=5,
|
81 |
+
num_return_sequences=1
|
82 |
)
|
83 |
|
84 |
corrected = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
99 |
}
|
100 |
|
101 |
def analyze_content(text):
|
102 |
+
try:
|
103 |
+
# Initialize report generator
|
104 |
+
report_gen = ReportGenerator()
|
105 |
+
report_gen.add_header()
|
106 |
+
report_gen.add_input_text(text)
|
107 |
+
|
108 |
+
# Load models
|
109 |
+
models = load_models()
|
110 |
+
|
111 |
+
# Run all checks
|
112 |
+
results = {}
|
113 |
+
|
114 |
+
# 1. Length Check
|
115 |
+
length_result = check_text_length(text)
|
116 |
+
results['Length Check'] = length_result
|
117 |
+
report_gen.add_check_result("Length Check", length_result['status'], length_result['message'])
|
118 |
+
|
119 |
+
if length_result['status'] == 'fail':
|
120 |
+
report_path = report_gen.save_report()
|
121 |
+
return results, report_path
|
122 |
+
|
123 |
+
# 2. Hate Speech Check
|
124 |
+
hate_result = check_hate_speech(text, models['hate_speech'][0], models['hate_speech'][1])
|
125 |
+
results['Hate Speech Check'] = hate_result
|
126 |
+
report_gen.add_check_result("Hate Speech Check", hate_result['status'], hate_result['message'])
|
127 |
+
|
128 |
+
# 3. Grammar Check
|
129 |
+
grammar_result = check_grammar(text, models['grammar'][0], models['grammar'][1])
|
130 |
+
results['Grammar Check'] = grammar_result
|
131 |
+
report_gen.add_check_result("Grammar Check", grammar_result['status'], grammar_result['message'])
|
132 |
+
|
133 |
+
# 4. News Context Check
|
134 |
+
if os.getenv('NEWS_API_KEY'):
|
135 |
+
news_result = news_checker.check_content_against_news(text)
|
136 |
+
else:
|
137 |
+
news_result = {
|
138 |
+
'status': 'warning',
|
139 |
+
'message': 'News API key not configured. Skipping current events check.'
|
140 |
+
}
|
141 |
+
results['Current Events Context'] = news_result
|
142 |
+
report_gen.add_check_result("Current Events Context", news_result['status'], news_result['message'])
|
143 |
+
|
144 |
+
# Generate and save report
|
145 |
report_path = report_gen.save_report()
|
146 |
+
|
147 |
return results, report_path
|
148 |
+
except Exception as e:
|
149 |
+
print(f"Error in analyze_content: {str(e)}")
|
150 |
+
return {
|
151 |
+
'Length Check': {'status': 'error', 'message': 'Analysis failed'},
|
152 |
+
'Hate Speech Check': {'status': 'error', 'message': 'Analysis failed'},
|
153 |
+
'Grammar Check': {'status': 'error', 'message': 'Analysis failed'},
|
154 |
+
'Current Events Context': {'status': 'error', 'message': 'Analysis failed'}
|
155 |
+
}, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
|
157 |
def format_results(results):
|
158 |
status_symbols = {
|
|
|
211 |
- Text length
|
212 |
- Hate speech and bias
|
213 |
- Grammar
|
214 |
+
- Current events context
|
215 |
""")
|
216 |
|
217 |
return interface
|