Ozgur Unlu commited on
Commit
761644c
·
1 Parent(s): b00d113

more error fixes for grammar check

Browse files
Files changed (1) hide show
  1. app.py +63 -53
app.py CHANGED
@@ -3,8 +3,7 @@ import torch
3
  from transformers import (
4
  AutoTokenizer,
5
  AutoModelForSequenceClassification,
6
- T5ForConditionalGeneration,
7
- T5Tokenizer
8
  )
9
  import os
10
  from pdf_generator import ReportGenerator
@@ -19,9 +18,9 @@ def load_models():
19
  hate_tokenizer = AutoTokenizer.from_pretrained("facebook/roberta-hate-speech-dynabench-r4-target")
20
  hate_model = AutoModelForSequenceClassification.from_pretrained("facebook/roberta-hate-speech-dynabench-r4-target")
21
 
22
- # Grammar check model (using T5)
23
- grammar_tokenizer = T5Tokenizer.from_pretrained("orthwand/t5-small-grammar-correction")
24
- grammar_model = T5ForConditionalGeneration.from_pretrained("orthwand/t5-small-grammar-correction")
25
 
26
  return {
27
  'hate_speech': (hate_model, hate_tokenizer),
@@ -71,13 +70,15 @@ def check_hate_speech(text, model, tokenizer):
71
 
72
  def check_grammar(text, model, tokenizer):
73
  try:
74
- input_ids = tokenizer(f"grammar: {text}", return_tensors="pt", max_length=512, truncation=True).input_ids
 
75
 
76
  outputs = model.generate(
77
- input_ids,
 
78
  max_length=512,
79
- num_beams=4,
80
- early_stopping=True
81
  )
82
 
83
  corrected = tokenizer.decode(outputs[0], skip_special_tokens=True)
@@ -98,51 +99,60 @@ def check_grammar(text, model, tokenizer):
98
  }
99
 
100
  def analyze_content(text):
101
- # Initialize report generator
102
- report_gen = ReportGenerator()
103
- report_gen.add_header()
104
- report_gen.add_input_text(text)
105
-
106
- # Load models
107
- models = load_models()
108
-
109
- # Run all checks
110
- results = {}
111
-
112
- # 1. Length Check
113
- length_result = check_text_length(text)
114
- results['Length Check'] = length_result
115
- report_gen.add_check_result("Length Check", length_result['status'], length_result['message'])
116
-
117
- if length_result['status'] == 'fail':
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  report_path = report_gen.save_report()
 
119
  return results, report_path
120
-
121
- # 2. Hate Speech Check
122
- hate_result = check_hate_speech(text, models['hate_speech'][0], models['hate_speech'][1])
123
- results['Hate Speech Check'] = hate_result
124
- report_gen.add_check_result("Hate Speech Check", hate_result['status'], hate_result['message'])
125
-
126
- # 3. Grammar Check
127
- grammar_result = check_grammar(text, models['grammar'][0], models['grammar'][1])
128
- results['Grammar Check'] = grammar_result
129
- report_gen.add_check_result("Grammar Check", grammar_result['status'], grammar_result['message'])
130
-
131
- # 4. News Context Check
132
- if os.getenv('NEWS_API_KEY'):
133
- news_result = news_checker.check_content_against_news(text)
134
- else:
135
- news_result = {
136
- 'status': 'warning',
137
- 'message': 'News API key not configured. Skipping current events check.'
138
- }
139
- results['Current Events Context'] = news_result
140
- report_gen.add_check_result("Current Events Context", news_result['status'], news_result['message'])
141
-
142
- # Generate and save report
143
- report_path = report_gen.save_report()
144
-
145
- return results, report_path
146
 
147
  def format_results(results):
148
  status_symbols = {
@@ -201,7 +211,7 @@ def create_interface():
201
  - Text length
202
  - Hate speech and bias
203
  - Grammar
204
- - Current events context (requires News API key)
205
  """)
206
 
207
  return interface
 
3
  from transformers import (
4
  AutoTokenizer,
5
  AutoModelForSequenceClassification,
6
+ AutoModelForSeq2SeqLM
 
7
  )
8
  import os
9
  from pdf_generator import ReportGenerator
 
18
  hate_tokenizer = AutoTokenizer.from_pretrained("facebook/roberta-hate-speech-dynabench-r4-target")
19
  hate_model = AutoModelForSequenceClassification.from_pretrained("facebook/roberta-hate-speech-dynabench-r4-target")
20
 
21
+ # Grammar check model
22
+ grammar_tokenizer = AutoTokenizer.from_pretrained("vennify/t5-base-grammar-correction")
23
+ grammar_model = AutoModelForSeq2SeqLM.from_pretrained("vennify/t5-base-grammar-correction")
24
 
25
  return {
26
  'hate_speech': (hate_model, hate_tokenizer),
 
70
 
71
  def check_grammar(text, model, tokenizer):
72
  try:
73
+ input_text = f"grammar: {text}"
74
+ encoding = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
75
 
76
  outputs = model.generate(
77
+ input_ids=encoding.input_ids,
78
+ attention_mask=encoding.attention_mask,
79
  max_length=512,
80
+ num_beams=5,
81
+ num_return_sequences=1
82
  )
83
 
84
  corrected = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
99
  }
100
 
101
  def analyze_content(text):
102
+ try:
103
+ # Initialize report generator
104
+ report_gen = ReportGenerator()
105
+ report_gen.add_header()
106
+ report_gen.add_input_text(text)
107
+
108
+ # Load models
109
+ models = load_models()
110
+
111
+ # Run all checks
112
+ results = {}
113
+
114
+ # 1. Length Check
115
+ length_result = check_text_length(text)
116
+ results['Length Check'] = length_result
117
+ report_gen.add_check_result("Length Check", length_result['status'], length_result['message'])
118
+
119
+ if length_result['status'] == 'fail':
120
+ report_path = report_gen.save_report()
121
+ return results, report_path
122
+
123
+ # 2. Hate Speech Check
124
+ hate_result = check_hate_speech(text, models['hate_speech'][0], models['hate_speech'][1])
125
+ results['Hate Speech Check'] = hate_result
126
+ report_gen.add_check_result("Hate Speech Check", hate_result['status'], hate_result['message'])
127
+
128
+ # 3. Grammar Check
129
+ grammar_result = check_grammar(text, models['grammar'][0], models['grammar'][1])
130
+ results['Grammar Check'] = grammar_result
131
+ report_gen.add_check_result("Grammar Check", grammar_result['status'], grammar_result['message'])
132
+
133
+ # 4. News Context Check
134
+ if os.getenv('NEWS_API_KEY'):
135
+ news_result = news_checker.check_content_against_news(text)
136
+ else:
137
+ news_result = {
138
+ 'status': 'warning',
139
+ 'message': 'News API key not configured. Skipping current events check.'
140
+ }
141
+ results['Current Events Context'] = news_result
142
+ report_gen.add_check_result("Current Events Context", news_result['status'], news_result['message'])
143
+
144
+ # Generate and save report
145
  report_path = report_gen.save_report()
146
+
147
  return results, report_path
148
+ except Exception as e:
149
+ print(f"Error in analyze_content: {str(e)}")
150
+ return {
151
+ 'Length Check': {'status': 'error', 'message': 'Analysis failed'},
152
+ 'Hate Speech Check': {'status': 'error', 'message': 'Analysis failed'},
153
+ 'Grammar Check': {'status': 'error', 'message': 'Analysis failed'},
154
+ 'Current Events Context': {'status': 'error', 'message': 'Analysis failed'}
155
+ }, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
  def format_results(results):
158
  status_symbols = {
 
211
  - Text length
212
  - Hate speech and bias
213
  - Grammar
214
+ - Current events context
215
  """)
216
 
217
  return interface