pentarosarium commited on
Commit
2e0c24f
·
1 Parent(s): a0341a8

amend torch mistake

Browse files
Files changed (1) hide show
  1. app.py +17 -6
app.py CHANGED
@@ -1,7 +1,8 @@
1
  import streamlit as st
2
  import pandas as pd
3
  import numpy as np
4
- from transformers import pipeline
 
5
  from datetime import datetime
6
  import io
7
  import base64
@@ -10,25 +11,35 @@ from rapidfuzz import fuzz, process
10
  from collections import defaultdict
11
  from tqdm import tqdm
12
  import spacy
 
 
13
 
14
  # Download Russian model
15
  spacy.cli.download("ru_core_news_sm")
16
 
17
-
18
  class NewsProcessor:
19
  def __init__(self, similarity_threshold=0.75, time_threshold=24):
20
  try:
21
  self.nlp = spacy.load("ru_core_news_sm")
22
  except:
23
  self.nlp = spacy.load("en_core_web_sm")
24
- self.embeddings = pipeline("feature-extraction",
25
- model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
26
- token=st.secrets["hf_token"])
27
  self.similarity_threshold = similarity_threshold
28
  self.time_threshold = time_threshold
29
 
 
 
 
 
 
30
  def encode_text(self, text):
31
- return np.mean(self.embeddings(text)[0], axis=0)
 
 
 
 
32
 
33
  def is_company_main_subject(self, text: str, companies: List[str]) -> Tuple[bool, str]:
34
  text_lower = text.lower()
 
1
  import streamlit as st
2
  import pandas as pd
3
  import numpy as np
4
+ from transformers import AutoTokenizer, AutoModel
5
+ import torch
6
  from datetime import datetime
7
  import io
8
  import base64
 
11
  from collections import defaultdict
12
  from tqdm import tqdm
13
  import spacy
14
+ import torch.nn.functional as F
15
+
16
 
17
  # Download Russian model
18
  spacy.cli.download("ru_core_news_sm")
19
 
 
20
  class NewsProcessor:
21
  def __init__(self, similarity_threshold=0.75, time_threshold=24):
22
  try:
23
  self.nlp = spacy.load("ru_core_news_sm")
24
  except:
25
  self.nlp = spacy.load("en_core_web_sm")
26
+
27
+ self.tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/paraphrase-multilingual-mpnet-base-v2')
28
+ self.model = AutoModel.from_pretrained('sentence-transformers/paraphrase-multilingual-mpnet-base-v2')
29
  self.similarity_threshold = similarity_threshold
30
  self.time_threshold = time_threshold
31
 
32
+ def mean_pooling(self, model_output, attention_mask):
33
+ token_embeddings = model_output[0]
34
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
35
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
36
+
37
  def encode_text(self, text):
38
+ encoded_input = self.tokenizer(text, padding=True, truncation=True, max_length=512, return_tensors='pt')
39
+ with torch.no_grad():
40
+ model_output = self.model(**encoded_input)
41
+ sentence_embeddings = self.mean_pooling(model_output, encoded_input['attention_mask'])
42
+ return F.normalize(sentence_embeddings[0], p=2, dim=0).numpy()
43
 
44
  def is_company_main_subject(self, text: str, companies: List[str]) -> Tuple[bool, str]:
45
  text_lower = text.lower()