Ozgur Unlu commited on
Commit
b00d113
·
1 Parent(s): d1fd071

changed the grammar checking model

Browse files
Files changed (3) hide show
  1. app.py +27 -12
  2. news_checker.py +18 -4
  3. requirements.txt +3 -1
app.py CHANGED
@@ -3,7 +3,8 @@ import torch
3
  from transformers import (
4
  AutoTokenizer,
5
  AutoModelForSequenceClassification,
6
- pipeline
 
7
  )
8
  import os
9
  from pdf_generator import ReportGenerator
@@ -18,13 +19,13 @@ def load_models():
18
  hate_tokenizer = AutoTokenizer.from_pretrained("facebook/roberta-hate-speech-dynabench-r4-target")
19
  hate_model = AutoModelForSequenceClassification.from_pretrained("facebook/roberta-hate-speech-dynabench-r4-target")
20
 
21
- # Bias detection (using same model with different labels)
22
- bias_tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
23
- bias_model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased")
24
 
25
  return {
26
  'hate_speech': (hate_model, hate_tokenizer),
27
- 'bias': (bias_model, bias_tokenizer)
28
  }
29
 
30
  # Initialize news checker
@@ -68,12 +69,20 @@ def check_hate_speech(text, model, tokenizer):
68
  'message': f'Error in hate speech detection: {str(e)}'
69
  }
70
 
71
- def check_grammar(text):
72
  try:
73
- nlp = pipeline("text2text-generation", model="gramformer/gramformer", device=0 if torch.cuda.is_available() else -1)
74
- corrected = nlp(text, max_length=1000)[0]['generated_text']
75
 
76
- if corrected.lower() != text.lower():
 
 
 
 
 
 
 
 
 
77
  return {
78
  'status': 'warning',
79
  'message': f'Suggested corrections:\n{corrected}'
@@ -115,12 +124,18 @@ def analyze_content(text):
115
  report_gen.add_check_result("Hate Speech Check", hate_result['status'], hate_result['message'])
116
 
117
  # 3. Grammar Check
118
- grammar_result = check_grammar(text)
119
  results['Grammar Check'] = grammar_result
120
  report_gen.add_check_result("Grammar Check", grammar_result['status'], grammar_result['message'])
121
 
122
  # 4. News Context Check
123
- news_result = news_checker.check_content_against_news(text)
 
 
 
 
 
 
124
  results['Current Events Context'] = news_result
125
  report_gen.add_check_result("Current Events Context", news_result['status'], news_result['message'])
126
 
@@ -186,7 +201,7 @@ def create_interface():
186
  - Text length
187
  - Hate speech and bias
188
  - Grammar
189
- - Current events context
190
  """)
191
 
192
  return interface
 
3
  from transformers import (
4
  AutoTokenizer,
5
  AutoModelForSequenceClassification,
6
+ T5ForConditionalGeneration,
7
+ T5Tokenizer
8
  )
9
  import os
10
  from pdf_generator import ReportGenerator
 
19
  hate_tokenizer = AutoTokenizer.from_pretrained("facebook/roberta-hate-speech-dynabench-r4-target")
20
  hate_model = AutoModelForSequenceClassification.from_pretrained("facebook/roberta-hate-speech-dynabench-r4-target")
21
 
22
+ # Grammar check model (using T5)
23
+ grammar_tokenizer = T5Tokenizer.from_pretrained("orthwand/t5-small-grammar-correction")
24
+ grammar_model = T5ForConditionalGeneration.from_pretrained("orthwand/t5-small-grammar-correction")
25
 
26
  return {
27
  'hate_speech': (hate_model, hate_tokenizer),
28
+ 'grammar': (grammar_model, grammar_tokenizer)
29
  }
30
 
31
  # Initialize news checker
 
69
  'message': f'Error in hate speech detection: {str(e)}'
70
  }
71
 
72
+ def check_grammar(text, model, tokenizer):
73
  try:
74
+ input_ids = tokenizer(f"grammar: {text}", return_tensors="pt", max_length=512, truncation=True).input_ids
 
75
 
76
+ outputs = model.generate(
77
+ input_ids,
78
+ max_length=512,
79
+ num_beams=4,
80
+ early_stopping=True
81
+ )
82
+
83
+ corrected = tokenizer.decode(outputs[0], skip_special_tokens=True)
84
+
85
+ if corrected.lower().strip() != text.lower().strip():
86
  return {
87
  'status': 'warning',
88
  'message': f'Suggested corrections:\n{corrected}'
 
124
  report_gen.add_check_result("Hate Speech Check", hate_result['status'], hate_result['message'])
125
 
126
  # 3. Grammar Check
127
+ grammar_result = check_grammar(text, models['grammar'][0], models['grammar'][1])
128
  results['Grammar Check'] = grammar_result
129
  report_gen.add_check_result("Grammar Check", grammar_result['status'], grammar_result['message'])
130
 
131
  # 4. News Context Check
132
+ if os.getenv('NEWS_API_KEY'):
133
+ news_result = news_checker.check_content_against_news(text)
134
+ else:
135
+ news_result = {
136
+ 'status': 'warning',
137
+ 'message': 'News API key not configured. Skipping current events check.'
138
+ }
139
  results['Current Events Context'] = news_result
140
  report_gen.add_check_result("Current Events Context", news_result['status'], news_result['message'])
141
 
 
201
  - Text length
202
  - Hate speech and bias
203
  - Grammar
204
+ - Current events context (requires News API key)
205
  """)
206
 
207
  return interface
news_checker.py CHANGED
@@ -9,9 +9,21 @@ load_dotenv()
9
  class NewsChecker:
10
  def __init__(self):
11
  self.api_key = os.getenv('NEWS_API_KEY')
12
- self.newsapi = NewsApiClient(api_key=self.api_key)
 
 
 
 
 
 
 
 
13
 
14
  def get_recent_news(self):
 
 
 
 
15
  try:
16
  # Get news from the last 7 days
17
  week_ago = (datetime.now() - timedelta(days=7)).strftime('%Y-%m-%d')
@@ -33,9 +45,12 @@ class NewsChecker:
33
  }
34
  for article in articles if article['description']
35
  ]
 
36
  return pd.DataFrame(news_data)
37
- return pd.DataFrame()
38
-
 
 
39
  except Exception as e:
40
  print(f"Error fetching news: {str(e)}")
41
  return pd.DataFrame()
@@ -49,7 +64,6 @@ class NewsChecker:
49
  }
50
 
51
  # Simple keyword matching for demo purposes
52
- # In a production environment, you'd want to use more sophisticated NLP techniques
53
  marketing_words = set(marketing_text.lower().split())
54
  potential_conflicts = []
55
 
 
9
  class NewsChecker:
10
  def __init__(self):
11
  self.api_key = os.getenv('NEWS_API_KEY')
12
+ if not self.api_key:
13
+ print("WARNING: NEWS_API_KEY not found in environment variables")
14
+ else:
15
+ print("NEWS_API_KEY found in environment variables")
16
+
17
+ try:
18
+ self.newsapi = NewsApiClient(api_key=self.api_key)
19
+ except Exception as e:
20
+ print(f"Error initializing NewsAPI client: {str(e)}")
21
 
22
  def get_recent_news(self):
23
+ if not self.api_key:
24
+ print("Cannot fetch news: No API key configured")
25
+ return pd.DataFrame()
26
+
27
  try:
28
  # Get news from the last 7 days
29
  week_ago = (datetime.now() - timedelta(days=7)).strftime('%Y-%m-%d')
 
45
  }
46
  for article in articles if article['description']
47
  ]
48
+ print(f"Successfully fetched {len(news_data)} articles")
49
  return pd.DataFrame(news_data)
50
+ else:
51
+ print(f"NewsAPI response status was not 'ok': {response.get('status')}")
52
+ return pd.DataFrame()
53
+
54
  except Exception as e:
55
  print(f"Error fetching news: {str(e)}")
56
  return pd.DataFrame()
 
64
  }
65
 
66
  # Simple keyword matching for demo purposes
 
67
  marketing_words = set(marketing_text.lower().split())
68
  potential_conflicts = []
69
 
requirements.txt CHANGED
@@ -6,4 +6,6 @@ fpdf2==2.7.8
6
  pandas==2.1.4
7
  numpy==1.24.3
8
  requests==2.31.0
9
- python-dotenv==1.0.0
 
 
 
6
  pandas==2.1.4
7
  numpy==1.24.3
8
  requests==2.31.0
9
+ python-dotenv==1.0.0
10
+ sentencepiece==0.2.0
11
+ sacremoses==0.1.1