pentarosarium commited on
Commit
d32e04e
·
1 Parent(s): cbb8180
Files changed (1) hide show
  1. app.py +153 -99
app.py CHANGED
@@ -25,129 +25,183 @@ class NewsProcessor:
25
  self.similarity_threshold = similarity_threshold
26
  self.time_threshold = time_threshold
27
 
28
- def mean_pooling(self, model_output, attention_mask):
29
- token_embeddings = model_output[0]
30
- input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
31
- return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
32
-
33
- def encode_text(self, text):
34
- # Convert text to string and handle NaN values
35
- if pd.isna(text):
36
- text = ""
37
- else:
38
- text = str(text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- encoded_input = self.tokenizer(text, padding=True, truncation=True, max_length=512, return_tensors='pt')
41
- with torch.no_grad():
42
- model_output = self.model(**encoded_input)
43
- sentence_embeddings = self.mean_pooling(model_output, encoded_input['attention_mask'])
44
- return F.normalize(sentence_embeddings[0], p=2, dim=0).numpy()
45
 
46
- def is_company_main_subject(self, text: str, companies: List[str]) -> Tuple[bool, str]:
47
- if pd.isna(text):
48
- return False, ""
 
 
 
 
49
 
50
- text_lower = str(text).lower()
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- for company in companies:
53
- company_lower = str(company).lower()
54
- if company_lower in text_lower.split('.')[0]:
55
- return True, company
56
- if text_lower.count(company_lower) >= 3:
57
- return True, company
58
- doc = self.nlp(text_lower)
59
- for sent in doc.sents:
60
- if company_lower in sent.text:
 
 
 
 
 
 
 
 
 
 
61
  for token in sent:
62
- if token.dep_ == 'nsubj' and company_lower in token.text:
63
- return True, company
64
- return False, ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  def process_news(self, df: pd.DataFrame, progress_bar=None):
67
  # Ensure the DataFrame is not empty
68
  if df.empty:
69
- return pd.DataFrame(columns=['cluster_id', 'datetime', 'company', 'main_company', 'text', 'cluster_size'])
70
 
71
- # Create company_list safely
72
- df['company_list'] = df['company'].fillna('').str.split(' | ')
73
  df = df.sort_values('datetime')
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  clusters = []
76
  processed = set()
77
 
78
- for i in tqdm(range(len(df)), total=len(df)):
79
  if i in processed:
80
  continue
81
 
82
- row1 = df.iloc[i]
83
- if pd.isna(row1['text']) or not row1['company_list']:
84
- processed.add(i)
85
- clusters.append([i])
86
- continue
87
-
88
- cluster = [i]
89
  processed.add(i)
90
- text1_embedding = self.encode_text(row1['text'])
91
-
92
- if progress_bar:
93
- progress_bar.progress(len(processed) / len(df))
94
- progress_bar.text(f'Processing item {len(processed)}/{len(df)}...')
95
 
96
- for j in range(len(df)):
97
- if j in processed:
98
- continue
99
-
100
- row2 = df.iloc[j]
101
- if pd.isna(row2['text']) or not row2['company_list']:
102
- continue
103
-
104
- time_diff = pd.to_datetime(row1['datetime']) - pd.to_datetime(row2['datetime'])
105
- if abs(time_diff.total_seconds() / 3600) > self.time_threshold:
106
- continue
107
-
108
- text2_embedding = self.encode_text(row2['text'])
109
- similarity = np.dot(text1_embedding, text2_embedding)
110
-
111
- is_main1, main_company1 = self.is_company_main_subject(row1['text'], row1['company_list'])
112
- is_main2, main_company2 = self.is_company_main_subject(row2['text'], row2['company_list'])
113
 
114
- # Safe set operation
115
- companies_overlap = bool(set(row1['company_list']) & set(row2['company_list']))
116
 
117
- if similarity >= self.similarity_threshold and companies_overlap:
118
- cluster.append(j)
119
- processed.add(j)
 
 
 
 
 
 
 
 
120
 
121
- clusters.append(cluster)
 
 
 
 
 
122
 
 
 
 
123
  result_data = []
124
- for cluster_id, cluster in enumerate(clusters, 1):
125
- try:
126
- cluster_texts = df.iloc[cluster]
127
- main_companies = []
128
-
129
- for _, row in cluster_texts.iterrows():
130
- if not pd.isna(row['text']) and isinstance(row['company_list'], list):
131
- is_main, company = self.is_company_main_subject(row['text'], row['company_list'])
132
- if is_main and company:
133
- main_companies.append(company)
134
-
135
- main_company = main_companies[0] if main_companies else "Multiple/Unclear"
136
-
137
- for idx in cluster:
138
- row_data = df.iloc[idx]
139
- result_data.append({
140
- 'cluster_id': cluster_id,
141
- 'datetime': row_data['datetime'],
142
- 'company': ' | '.join(row_data['company_list']) if isinstance(row_data['company_list'], list) else '',
143
- 'main_company': main_company,
144
- 'text': row_data['text'],
145
- 'cluster_size': len(cluster)
146
- })
147
- except Exception as e:
148
- print(f"Error processing cluster {cluster_id}: {str(e)}")
149
- continue
150
-
151
  return pd.DataFrame(result_data)
152
 
153
  class NewsDeduplicator:
@@ -208,7 +262,7 @@ def create_download_link(df: pd.DataFrame, filename: str) -> str:
208
  return f'<a href="data:application/vnd.openxmlformats-officedocument.spreadsheetml.sheet;base64,{b64}" download="{filename}">Download {filename}</a>'
209
 
210
  def main():
211
- st.title("кластеризуем новости v.1.3+")
212
  st.write("Upload Excel file with columns: company, datetime, text")
213
 
214
  uploaded_file = st.file_uploader("Choose Excel file", type=['xlsx'])
 
25
  self.similarity_threshold = similarity_threshold
26
  self.time_threshold = time_threshold
27
 
28
+ def preprocess_company_name(self, company_name: str) -> List[str]:
29
+ """
30
+ Preprocesses company name to create search patterns.
31
+ Handles cases with commas, quotes, and multiple words.
32
+ Returns key identifiable parts of the company name.
33
+ """
34
+ if pd.isna(company_name):
35
+ return []
36
+
37
+ # Remove quotes and extra spaces
38
+ name = str(company_name).strip('"\'').strip()
39
+
40
+ # Split by comma and take the first part (usually the main name)
41
+ main_name = name.split(',')[0].strip()
42
+
43
+ # Create patterns from significant parts of the name
44
+ patterns = []
45
+
46
+ # Add full name
47
+ patterns.append(main_name.lower())
48
+
49
+ # Add significant words (usually 3+ characters)
50
+ words = [w for w in main_name.split() if len(w) >= 3]
51
+ if len(words) > 1:
52
+ # Add the first significant word if it's not a common word
53
+ patterns.append(words[0].lower())
54
 
55
+ # Add combinations of consecutive words
56
+ for i in range(len(words)-1):
57
+ patterns.append(f"{words[i]} {words[i+1]}".lower())
58
+
59
+ return list(set(patterns))
60
 
61
+ def is_company_main_subject(self, text: str, company_name: str) -> Tuple[bool, float]:
62
+ """
63
+ Determines if the company is the main subject of the news.
64
+ Returns (is_main_subject, relevance_score).
65
+ """
66
+ if pd.isna(text) or pd.isna(company_name):
67
+ return False, 0.0
68
 
69
+ text = str(text).lower()
70
+
71
+ # Get company name patterns
72
+ company_patterns = self.preprocess_company_name(company_name)
73
+ if not company_patterns:
74
+ return False, 0.0
75
+
76
+ doc = self.nlp(text)
77
+
78
+ # Initialize metrics
79
+ mentions_count = 0
80
+ is_in_first_sentence = False
81
+ is_subject = False
82
+ other_companies_count = 0
83
 
84
+ # Check first sentence
85
+ first_sent = next(doc.sents)
86
+ first_sent_text = first_sent.text.lower()
87
+
88
+ for pattern in company_patterns:
89
+ if pattern in first_sent_text:
90
+ is_in_first_sentence = True
91
+ break
92
+
93
+ # Analyze each sentence
94
+ for sent in doc.sents:
95
+ sent_text = sent.text.lower()
96
+
97
+ # Count company mentions
98
+ for pattern in company_patterns:
99
+ if pattern in sent_text:
100
+ mentions_count += 1
101
+
102
+ # Check if company is subject
103
  for token in sent:
104
+ if pattern in token.text.lower() and token.dep_ in ['nsubj', 'nsubjpass']:
105
+ is_subject = True
106
+
107
+ # Count potential other company mentions
108
+ # This is a simplified approach - could be improved with named entity recognition
109
+ company_indicators = ['компания', 'корпорация', 'фирма', 'банк', 'group', 'inc', 'ltd', 'llc', 'corporation']
110
+ for indicator in company_indicators:
111
+ if indicator in sent_text:
112
+ other_companies_count += 1
113
+
114
+ # Calculate relevance score
115
+ relevance_score = 0.0
116
+ relevance_score += 0.4 if is_in_first_sentence else 0.0
117
+ relevance_score += 0.3 if is_subject else 0.0
118
+ relevance_score += min(0.3, mentions_count * 0.1) # Cap at 0.3
119
+
120
+ # Reduce score if many other companies are mentioned
121
+ relevance_score *= max(0.2, 1 - (other_companies_count * 0.1))
122
+
123
+ # Company is considered main subject if score is above threshold
124
+ return relevance_score >= 0.5, relevance_score
125
 
126
  def process_news(self, df: pd.DataFrame, progress_bar=None):
127
  # Ensure the DataFrame is not empty
128
  if df.empty:
129
+ return pd.DataFrame(columns=['cluster_id', 'datetime', 'company', 'relevance_score', 'text', 'cluster_size'])
130
 
 
 
131
  df = df.sort_values('datetime')
132
 
133
+ # First, filter out news where the company isn't the main subject
134
+ relevance_results = []
135
+ for idx, row in df.iterrows():
136
+ is_main, score = self.is_company_main_subject(row['text'], row['company'])
137
+ if is_main:
138
+ relevance_results.append({
139
+ 'idx': idx,
140
+ 'relevance_score': score
141
+ })
142
+
143
+ if not relevance_results:
144
+ return pd.DataFrame(columns=['cluster_id', 'datetime', 'company', 'relevance_score', 'text', 'cluster_size'])
145
+
146
+ relevant_indices = [r['idx'] for r in relevance_results]
147
+ relevance_scores = {r['idx']: r['relevance_score'] for r in relevance_results}
148
+
149
+ df_filtered = df.loc[relevant_indices].copy()
150
+ df_filtered['relevance_score'] = df_filtered.index.map(relevance_scores)
151
+
152
+ # Continue with clustering logic...
153
  clusters = []
154
  processed = set()
155
 
156
+ for i in tqdm(range(len(df_filtered)), total=len(df_filtered)):
157
  if i in processed:
158
  continue
159
 
160
+ row1 = df_filtered.iloc[i]
161
+ cluster = [df_filtered.index[i]]
 
 
 
 
 
162
  processed.add(i)
 
 
 
 
 
163
 
164
+ if not pd.isna(row1['text']):
165
+ text1_embedding = self.encode_text(row1['text'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
+ if progress_bar:
168
+ progress_bar.progress(len(processed) / len(df_filtered))
169
 
170
+ for j in range(len(df_filtered)):
171
+ if j in processed:
172
+ continue
173
+
174
+ row2 = df_filtered.iloc[j]
175
+ if pd.isna(row2['text']):
176
+ continue
177
+
178
+ time_diff = pd.to_datetime(row1['datetime']) - pd.to_datetime(row2['datetime'])
179
+ if abs(time_diff.total_seconds() / 3600) > self.time_threshold:
180
+ continue
181
 
182
+ text2_embedding = self.encode_text(row2['text'])
183
+ similarity = np.dot(text1_embedding, text2_embedding)
184
+
185
+ if similarity >= self.similarity_threshold:
186
+ cluster.append(df_filtered.index[j])
187
+ processed.add(j)
188
 
189
+ clusters.append(cluster)
190
+
191
+ # Create result DataFrame
192
  result_data = []
193
+ for cluster_id, cluster_indices in enumerate(clusters, 1):
194
+ cluster_rows = df.loc[cluster_indices]
195
+ for idx in cluster_indices:
196
+ result_data.append({
197
+ 'cluster_id': cluster_id,
198
+ 'datetime': df.loc[idx, 'datetime'],
199
+ 'company': df.loc[idx, 'company'],
200
+ 'relevance_score': relevance_scores[idx],
201
+ 'text': df.loc[idx, 'text'],
202
+ 'cluster_size': len(cluster_indices)
203
+ })
204
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  return pd.DataFrame(result_data)
206
 
207
  class NewsDeduplicator:
 
262
  return f'<a href="data:application/vnd.openxmlformats-officedocument.spreadsheetml.sheet;base64,{b64}" download="{filename}">Download {filename}</a>'
263
 
264
  def main():
265
+ st.title("кластеризуем новости v.1.4")
266
  st.write("Upload Excel file with columns: company, datetime, text")
267
 
268
  uploaded_file = st.file_uploader("Choose Excel file", type=['xlsx'])