pentarosarium commited on
Commit
cbb8180
·
1 Parent(s): 6d4a64c
Files changed (1) hide show
  1. app.py +39 -2
app.py CHANGED
@@ -25,6 +25,44 @@ class NewsProcessor:
25
  self.similarity_threshold = similarity_threshold
26
  self.time_threshold = time_threshold
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  def process_news(self, df: pd.DataFrame, progress_bar=None):
29
  # Ensure the DataFrame is not empty
30
  if df.empty:
@@ -55,7 +93,6 @@ class NewsProcessor:
55
  progress_bar.progress(len(processed) / len(df))
56
  progress_bar.text(f'Processing item {len(processed)}/{len(df)}...')
57
 
58
- # Use index-based iteration instead of iterrows
59
  for j in range(len(df)):
60
  if j in processed:
61
  continue
@@ -171,7 +208,7 @@ def create_download_link(df: pd.DataFrame, filename: str) -> str:
171
  return f'<a href="data:application/vnd.openxmlformats-officedocument.spreadsheetml.sheet;base64,{b64}" download="{filename}">Download {filename}</a>'
172
 
173
  def main():
174
- st.title("кластеризуем новости v.1.2")
175
  st.write("Upload Excel file with columns: company, datetime, text")
176
 
177
  uploaded_file = st.file_uploader("Choose Excel file", type=['xlsx'])
 
25
  self.similarity_threshold = similarity_threshold
26
  self.time_threshold = time_threshold
27
 
28
+ def mean_pooling(self, model_output, attention_mask):
29
+ token_embeddings = model_output[0]
30
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
31
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
32
+
33
+ def encode_text(self, text):
34
+ # Convert text to string and handle NaN values
35
+ if pd.isna(text):
36
+ text = ""
37
+ else:
38
+ text = str(text)
39
+
40
+ encoded_input = self.tokenizer(text, padding=True, truncation=True, max_length=512, return_tensors='pt')
41
+ with torch.no_grad():
42
+ model_output = self.model(**encoded_input)
43
+ sentence_embeddings = self.mean_pooling(model_output, encoded_input['attention_mask'])
44
+ return F.normalize(sentence_embeddings[0], p=2, dim=0).numpy()
45
+
46
+ def is_company_main_subject(self, text: str, companies: List[str]) -> Tuple[bool, str]:
47
+ if pd.isna(text):
48
+ return False, ""
49
+
50
+ text_lower = str(text).lower()
51
+
52
+ for company in companies:
53
+ company_lower = str(company).lower()
54
+ if company_lower in text_lower.split('.')[0]:
55
+ return True, company
56
+ if text_lower.count(company_lower) >= 3:
57
+ return True, company
58
+ doc = self.nlp(text_lower)
59
+ for sent in doc.sents:
60
+ if company_lower in sent.text:
61
+ for token in sent:
62
+ if token.dep_ == 'nsubj' and company_lower in token.text:
63
+ return True, company
64
+ return False, ""
65
+
66
  def process_news(self, df: pd.DataFrame, progress_bar=None):
67
  # Ensure the DataFrame is not empty
68
  if df.empty:
 
93
  progress_bar.progress(len(processed) / len(df))
94
  progress_bar.text(f'Processing item {len(processed)}/{len(df)}...')
95
 
 
96
  for j in range(len(df)):
97
  if j in processed:
98
  continue
 
208
  return f'<a href="data:application/vnd.openxmlformats-officedocument.spreadsheetml.sheet;base64,{b64}" download="{filename}">Download {filename}</a>'
209
 
210
  def main():
211
+ st.title("кластеризуем новости v.1.3+")
212
  st.write("Upload Excel file with columns: company, datetime, text")
213
 
214
  uploaded_file = st.file_uploader("Choose Excel file", type=['xlsx'])