Spaces:
Sleeping
Sleeping
Commit
·
706e736
1
Parent(s):
4f373f5
1.8
Browse files
app.py
CHANGED
@@ -27,6 +27,25 @@ class NewsProcessor:
|
|
27 |
self.model = AutoModel.from_pretrained('sentence-transformers/paraphrase-multilingual-mpnet-base-v2')
|
28 |
self.similarity_threshold = similarity_threshold
|
29 |
self.time_threshold = time_threshold
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
def get_company_variants(self, company_name: str) -> Set[str]:
|
32 |
"""Generate morphological variants of company name."""
|
@@ -302,7 +321,7 @@ def create_download_link(df: pd.DataFrame, filename: str) -> str:
|
|
302 |
return f'<a href="data:application/vnd.openxmlformats-officedocument.spreadsheetml.sheet;base64,{b64}" download="{filename}">Download {filename}</a>'
|
303 |
|
304 |
def main():
|
305 |
-
st.title("кластеризуем новости v.1.
|
306 |
st.write("Upload Excel file with columns: company, datetime, text")
|
307 |
|
308 |
uploaded_file = st.file_uploader("Choose Excel file", type=['xlsx'])
|
|
|
27 |
self.model = AutoModel.from_pretrained('sentence-transformers/paraphrase-multilingual-mpnet-base-v2')
|
28 |
self.similarity_threshold = similarity_threshold
|
29 |
self.time_threshold = time_threshold
|
30 |
+
|
31 |
+
def mean_pooling(self, model_output, attention_mask):
|
32 |
+
token_embeddings = model_output[0]
|
33 |
+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
34 |
+
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
35 |
+
|
36 |
+
def encode_text(self, text):
|
37 |
+
# Convert text to string and handle NaN values
|
38 |
+
if pd.isna(text):
|
39 |
+
text = ""
|
40 |
+
else:
|
41 |
+
text = str(text)
|
42 |
+
|
43 |
+
encoded_input = self.tokenizer(text, padding=True, truncation=True, max_length=512, return_tensors='pt')
|
44 |
+
with torch.no_grad():
|
45 |
+
model_output = self.model(**encoded_input)
|
46 |
+
sentence_embeddings = self.mean_pooling(model_output, encoded_input['attention_mask'])
|
47 |
+
return F.normalize(sentence_embeddings[0], p=2, dim=0).numpy()
|
48 |
+
|
49 |
|
50 |
def get_company_variants(self, company_name: str) -> Set[str]:
|
51 |
"""Generate morphological variants of company name."""
|
|
|
321 |
return f'<a href="data:application/vnd.openxmlformats-officedocument.spreadsheetml.sheet;base64,{b64}" download="{filename}">Download {filename}</a>'
|
322 |
|
323 |
def main():
|
324 |
+
st.title("кластеризуем новости v.1.8")
|
325 |
st.write("Upload Excel file with columns: company, datetime, text")
|
326 |
|
327 |
uploaded_file = st.file_uploader("Choose Excel file", type=['xlsx'])
|