pentarosarium commited on
Commit
6d4a64c
·
1 Parent(s): c8405c4

mend indexers

Browse files
Files changed (1) hide show
  1. app.py +49 -62
app.py CHANGED
@@ -25,54 +25,28 @@ class NewsProcessor:
25
  self.similarity_threshold = similarity_threshold
26
  self.time_threshold = time_threshold
27
 
28
- def mean_pooling(self, model_output, attention_mask):
29
- token_embeddings = model_output[0]
30
- input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
31
- return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
32
-
33
- def encode_text(self, text):
34
- # Convert text to string and handle NaN values
35
- if pd.isna(text):
36
- text = ""
37
- else:
38
- text = str(text)
39
-
40
- encoded_input = self.tokenizer(text, padding=True, truncation=True, max_length=512, return_tensors='pt')
41
- with torch.no_grad():
42
- model_output = self.model(**encoded_input)
43
- sentence_embeddings = self.mean_pooling(model_output, encoded_input['attention_mask'])
44
- return F.normalize(sentence_embeddings[0], p=2, dim=0).numpy()
45
-
46
- def is_company_main_subject(self, text: str, companies: List[str]) -> Tuple[bool, str]:
47
- if pd.isna(text):
48
- return False, ""
49
-
50
- text_lower = str(text).lower()
51
-
52
- for company in companies:
53
- company_lower = str(company).lower()
54
- if company_lower in text_lower.split('.')[0]:
55
- return True, company
56
- if text_lower.count(company_lower) >= 3:
57
- return True, company
58
- doc = self.nlp(text_lower)
59
- for sent in doc.sents:
60
- if company_lower in sent.text:
61
- for token in sent:
62
- if token.dep_ == 'nsubj' and company_lower in token.text:
63
- return True, company
64
- return False, ""
65
-
66
  def process_news(self, df: pd.DataFrame, progress_bar=None):
67
- df['company_list'] = df['company'].str.split(' | ')
 
 
 
 
 
68
  df = df.sort_values('datetime')
 
69
  clusters = []
70
  processed = set()
71
 
72
- for i, row1 in tqdm(df.iterrows(), total=len(df)):
73
  if i in processed:
74
  continue
75
-
 
 
 
 
 
 
76
  cluster = [i]
77
  processed.add(i)
78
  text1_embedding = self.encode_text(row1['text'])
@@ -80,11 +54,16 @@ class NewsProcessor:
80
  if progress_bar:
81
  progress_bar.progress(len(processed) / len(df))
82
  progress_bar.text(f'Processing item {len(processed)}/{len(df)}...')
83
-
84
- for j, row2 in df.iterrows():
 
85
  if j in processed:
86
  continue
87
 
 
 
 
 
88
  time_diff = pd.to_datetime(row1['datetime']) - pd.to_datetime(row2['datetime'])
89
  if abs(time_diff.total_seconds() / 3600) > self.time_threshold:
90
  continue
@@ -95,6 +74,7 @@ class NewsProcessor:
95
  is_main1, main_company1 = self.is_company_main_subject(row1['text'], row1['company_list'])
96
  is_main2, main_company2 = self.is_company_main_subject(row2['text'], row2['company_list'])
97
 
 
98
  companies_overlap = bool(set(row1['company_list']) & set(row2['company_list']))
99
 
100
  if similarity >= self.similarity_threshold and companies_overlap:
@@ -105,24 +85,31 @@ class NewsProcessor:
105
 
106
  result_data = []
107
  for cluster_id, cluster in enumerate(clusters, 1):
108
- cluster_texts = df.iloc[cluster]
109
- main_companies = []
110
- for _, row in cluster_texts.iterrows():
111
- is_main, company = self.is_company_main_subject(row['text'], row['company_list'])
112
- if is_main and company:
113
- main_companies.append(company)
114
-
115
- main_company = main_companies[0] if main_companies else "Multiple/Unclear"
116
-
117
- for idx in cluster:
118
- result_data.append({
119
- 'cluster_id': cluster_id,
120
- 'datetime': df.iloc[idx]['datetime'],
121
- 'company': ' | '.join(df.iloc[idx]['company_list']),
122
- 'main_company': main_company,
123
- 'text': df.iloc[idx]['text'],
124
- 'cluster_size': len(cluster)
125
- })
 
 
 
 
 
 
 
126
 
127
  return pd.DataFrame(result_data)
128
 
@@ -184,7 +171,7 @@ def create_download_link(df: pd.DataFrame, filename: str) -> str:
184
  return f'<a href="data:application/vnd.openxmlformats-officedocument.spreadsheetml.sheet;base64,{b64}" download="{filename}">Download {filename}</a>'
185
 
186
  def main():
187
- st.title("News Clustering App")
188
  st.write("Upload Excel file with columns: company, datetime, text")
189
 
190
  uploaded_file = st.file_uploader("Choose Excel file", type=['xlsx'])
 
25
  self.similarity_threshold = similarity_threshold
26
  self.time_threshold = time_threshold
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  def process_news(self, df: pd.DataFrame, progress_bar=None):
29
+ # Ensure the DataFrame is not empty
30
+ if df.empty:
31
+ return pd.DataFrame(columns=['cluster_id', 'datetime', 'company', 'main_company', 'text', 'cluster_size'])
32
+
33
+ # Create company_list safely
34
+ df['company_list'] = df['company'].fillna('').str.split(' | ')
35
  df = df.sort_values('datetime')
36
+
37
  clusters = []
38
  processed = set()
39
 
40
+ for i in tqdm(range(len(df)), total=len(df)):
41
  if i in processed:
42
  continue
43
+
44
+ row1 = df.iloc[i]
45
+ if pd.isna(row1['text']) or not row1['company_list']:
46
+ processed.add(i)
47
+ clusters.append([i])
48
+ continue
49
+
50
  cluster = [i]
51
  processed.add(i)
52
  text1_embedding = self.encode_text(row1['text'])
 
54
  if progress_bar:
55
  progress_bar.progress(len(processed) / len(df))
56
  progress_bar.text(f'Processing item {len(processed)}/{len(df)}...')
57
+
58
+ # Use index-based iteration instead of iterrows
59
+ for j in range(len(df)):
60
  if j in processed:
61
  continue
62
 
63
+ row2 = df.iloc[j]
64
+ if pd.isna(row2['text']) or not row2['company_list']:
65
+ continue
66
+
67
  time_diff = pd.to_datetime(row1['datetime']) - pd.to_datetime(row2['datetime'])
68
  if abs(time_diff.total_seconds() / 3600) > self.time_threshold:
69
  continue
 
74
  is_main1, main_company1 = self.is_company_main_subject(row1['text'], row1['company_list'])
75
  is_main2, main_company2 = self.is_company_main_subject(row2['text'], row2['company_list'])
76
 
77
+ # Safe set operation
78
  companies_overlap = bool(set(row1['company_list']) & set(row2['company_list']))
79
 
80
  if similarity >= self.similarity_threshold and companies_overlap:
 
85
 
86
  result_data = []
87
  for cluster_id, cluster in enumerate(clusters, 1):
88
+ try:
89
+ cluster_texts = df.iloc[cluster]
90
+ main_companies = []
91
+
92
+ for _, row in cluster_texts.iterrows():
93
+ if not pd.isna(row['text']) and isinstance(row['company_list'], list):
94
+ is_main, company = self.is_company_main_subject(row['text'], row['company_list'])
95
+ if is_main and company:
96
+ main_companies.append(company)
97
+
98
+ main_company = main_companies[0] if main_companies else "Multiple/Unclear"
99
+
100
+ for idx in cluster:
101
+ row_data = df.iloc[idx]
102
+ result_data.append({
103
+ 'cluster_id': cluster_id,
104
+ 'datetime': row_data['datetime'],
105
+ 'company': ' | '.join(row_data['company_list']) if isinstance(row_data['company_list'], list) else '',
106
+ 'main_company': main_company,
107
+ 'text': row_data['text'],
108
+ 'cluster_size': len(cluster)
109
+ })
110
+ except Exception as e:
111
+ print(f"Error processing cluster {cluster_id}: {str(e)}")
112
+ continue
113
 
114
  return pd.DataFrame(result_data)
115
 
 
171
  return f'<a href="data:application/vnd.openxmlformats-officedocument.spreadsheetml.sheet;base64,{b64}" download="{filename}">Download {filename}</a>'
172
 
173
  def main():
174
+ st.title("кластеризуем новости v.1.2")
175
  st.write("Upload Excel file with columns: company, datetime, text")
176
 
177
  uploaded_file = st.file_uploader("Choose Excel file", type=['xlsx'])