pentarosarium commited on
Commit
706e736
·
1 Parent(s): 4f373f5
Files changed (1) hide show
  1. app.py +20 -1
app.py CHANGED
@@ -27,6 +27,25 @@ class NewsProcessor:
27
  self.model = AutoModel.from_pretrained('sentence-transformers/paraphrase-multilingual-mpnet-base-v2')
28
  self.similarity_threshold = similarity_threshold
29
  self.time_threshold = time_threshold
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  def get_company_variants(self, company_name: str) -> Set[str]:
32
  """Generate morphological variants of company name."""
@@ -302,7 +321,7 @@ def create_download_link(df: pd.DataFrame, filename: str) -> str:
302
  return f'<a href="data:application/vnd.openxmlformats-officedocument.spreadsheetml.sheet;base64,{b64}" download="{filename}">Download {filename}</a>'
303
 
304
  def main():
305
- st.title("кластеризуем новости v.1.7")
306
  st.write("Upload Excel file with columns: company, datetime, text")
307
 
308
  uploaded_file = st.file_uploader("Choose Excel file", type=['xlsx'])
 
27
  self.model = AutoModel.from_pretrained('sentence-transformers/paraphrase-multilingual-mpnet-base-v2')
28
  self.similarity_threshold = similarity_threshold
29
  self.time_threshold = time_threshold
30
+
31
+ def mean_pooling(self, model_output, attention_mask):
32
+ token_embeddings = model_output[0]
33
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
34
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
35
+
36
+ def encode_text(self, text):
37
+ # Convert text to string and handle NaN values
38
+ if pd.isna(text):
39
+ text = ""
40
+ else:
41
+ text = str(text)
42
+
43
+ encoded_input = self.tokenizer(text, padding=True, truncation=True, max_length=512, return_tensors='pt')
44
+ with torch.no_grad():
45
+ model_output = self.model(**encoded_input)
46
+ sentence_embeddings = self.mean_pooling(model_output, encoded_input['attention_mask'])
47
+ return F.normalize(sentence_embeddings[0], p=2, dim=0).numpy()
48
+
49
 
50
  def get_company_variants(self, company_name: str) -> Set[str]:
51
  """Generate morphological variants of company name."""
 
321
  return f'<a href="data:application/vnd.openxmlformats-officedocument.spreadsheetml.sheet;base64,{b64}" download="{filename}">Download {filename}</a>'
322
 
323
  def main():
324
+ st.title("кластеризуем новости v.1.8")
325
  st.write("Upload Excel file with columns: company, datetime, text")
326
 
327
  uploaded_file = st.file_uploader("Choose Excel file", type=['xlsx'])