Spaces:
Sleeping
Sleeping
Commit
·
6d4a64c
1
Parent(s):
c8405c4
mend indexers
Browse files
app.py
CHANGED
@@ -25,54 +25,28 @@ class NewsProcessor:
|
|
25 |
self.similarity_threshold = similarity_threshold
|
26 |
self.time_threshold = time_threshold
|
27 |
|
28 |
-
def mean_pooling(self, model_output, attention_mask):
|
29 |
-
token_embeddings = model_output[0]
|
30 |
-
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
31 |
-
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
32 |
-
|
33 |
-
def encode_text(self, text):
|
34 |
-
# Convert text to string and handle NaN values
|
35 |
-
if pd.isna(text):
|
36 |
-
text = ""
|
37 |
-
else:
|
38 |
-
text = str(text)
|
39 |
-
|
40 |
-
encoded_input = self.tokenizer(text, padding=True, truncation=True, max_length=512, return_tensors='pt')
|
41 |
-
with torch.no_grad():
|
42 |
-
model_output = self.model(**encoded_input)
|
43 |
-
sentence_embeddings = self.mean_pooling(model_output, encoded_input['attention_mask'])
|
44 |
-
return F.normalize(sentence_embeddings[0], p=2, dim=0).numpy()
|
45 |
-
|
46 |
-
def is_company_main_subject(self, text: str, companies: List[str]) -> Tuple[bool, str]:
|
47 |
-
if pd.isna(text):
|
48 |
-
return False, ""
|
49 |
-
|
50 |
-
text_lower = str(text).lower()
|
51 |
-
|
52 |
-
for company in companies:
|
53 |
-
company_lower = str(company).lower()
|
54 |
-
if company_lower in text_lower.split('.')[0]:
|
55 |
-
return True, company
|
56 |
-
if text_lower.count(company_lower) >= 3:
|
57 |
-
return True, company
|
58 |
-
doc = self.nlp(text_lower)
|
59 |
-
for sent in doc.sents:
|
60 |
-
if company_lower in sent.text:
|
61 |
-
for token in sent:
|
62 |
-
if token.dep_ == 'nsubj' and company_lower in token.text:
|
63 |
-
return True, company
|
64 |
-
return False, ""
|
65 |
-
|
66 |
def process_news(self, df: pd.DataFrame, progress_bar=None):
|
67 |
-
|
|
|
|
|
|
|
|
|
|
|
68 |
df = df.sort_values('datetime')
|
|
|
69 |
clusters = []
|
70 |
processed = set()
|
71 |
|
72 |
-
for i
|
73 |
if i in processed:
|
74 |
continue
|
75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
cluster = [i]
|
77 |
processed.add(i)
|
78 |
text1_embedding = self.encode_text(row1['text'])
|
@@ -80,11 +54,16 @@ class NewsProcessor:
|
|
80 |
if progress_bar:
|
81 |
progress_bar.progress(len(processed) / len(df))
|
82 |
progress_bar.text(f'Processing item {len(processed)}/{len(df)}...')
|
83 |
-
|
84 |
-
|
|
|
85 |
if j in processed:
|
86 |
continue
|
87 |
|
|
|
|
|
|
|
|
|
88 |
time_diff = pd.to_datetime(row1['datetime']) - pd.to_datetime(row2['datetime'])
|
89 |
if abs(time_diff.total_seconds() / 3600) > self.time_threshold:
|
90 |
continue
|
@@ -95,6 +74,7 @@ class NewsProcessor:
|
|
95 |
is_main1, main_company1 = self.is_company_main_subject(row1['text'], row1['company_list'])
|
96 |
is_main2, main_company2 = self.is_company_main_subject(row2['text'], row2['company_list'])
|
97 |
|
|
|
98 |
companies_overlap = bool(set(row1['company_list']) & set(row2['company_list']))
|
99 |
|
100 |
if similarity >= self.similarity_threshold and companies_overlap:
|
@@ -105,24 +85,31 @@ class NewsProcessor:
|
|
105 |
|
106 |
result_data = []
|
107 |
for cluster_id, cluster in enumerate(clusters, 1):
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
|
127 |
return pd.DataFrame(result_data)
|
128 |
|
@@ -184,7 +171,7 @@ def create_download_link(df: pd.DataFrame, filename: str) -> str:
|
|
184 |
return f'<a href="data:application/vnd.openxmlformats-officedocument.spreadsheetml.sheet;base64,{b64}" download="{filename}">Download {filename}</a>'
|
185 |
|
186 |
def main():
|
187 |
-
st.title("
|
188 |
st.write("Upload Excel file with columns: company, datetime, text")
|
189 |
|
190 |
uploaded_file = st.file_uploader("Choose Excel file", type=['xlsx'])
|
|
|
25 |
self.similarity_threshold = similarity_threshold
|
26 |
self.time_threshold = time_threshold
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
def process_news(self, df: pd.DataFrame, progress_bar=None):
|
29 |
+
# Ensure the DataFrame is not empty
|
30 |
+
if df.empty:
|
31 |
+
return pd.DataFrame(columns=['cluster_id', 'datetime', 'company', 'main_company', 'text', 'cluster_size'])
|
32 |
+
|
33 |
+
# Create company_list safely
|
34 |
+
df['company_list'] = df['company'].fillna('').str.split(' | ')
|
35 |
df = df.sort_values('datetime')
|
36 |
+
|
37 |
clusters = []
|
38 |
processed = set()
|
39 |
|
40 |
+
for i in tqdm(range(len(df)), total=len(df)):
|
41 |
if i in processed:
|
42 |
continue
|
43 |
+
|
44 |
+
row1 = df.iloc[i]
|
45 |
+
if pd.isna(row1['text']) or not row1['company_list']:
|
46 |
+
processed.add(i)
|
47 |
+
clusters.append([i])
|
48 |
+
continue
|
49 |
+
|
50 |
cluster = [i]
|
51 |
processed.add(i)
|
52 |
text1_embedding = self.encode_text(row1['text'])
|
|
|
54 |
if progress_bar:
|
55 |
progress_bar.progress(len(processed) / len(df))
|
56 |
progress_bar.text(f'Processing item {len(processed)}/{len(df)}...')
|
57 |
+
|
58 |
+
# Use index-based iteration instead of iterrows
|
59 |
+
for j in range(len(df)):
|
60 |
if j in processed:
|
61 |
continue
|
62 |
|
63 |
+
row2 = df.iloc[j]
|
64 |
+
if pd.isna(row2['text']) or not row2['company_list']:
|
65 |
+
continue
|
66 |
+
|
67 |
time_diff = pd.to_datetime(row1['datetime']) - pd.to_datetime(row2['datetime'])
|
68 |
if abs(time_diff.total_seconds() / 3600) > self.time_threshold:
|
69 |
continue
|
|
|
74 |
is_main1, main_company1 = self.is_company_main_subject(row1['text'], row1['company_list'])
|
75 |
is_main2, main_company2 = self.is_company_main_subject(row2['text'], row2['company_list'])
|
76 |
|
77 |
+
# Safe set operation
|
78 |
companies_overlap = bool(set(row1['company_list']) & set(row2['company_list']))
|
79 |
|
80 |
if similarity >= self.similarity_threshold and companies_overlap:
|
|
|
85 |
|
86 |
result_data = []
|
87 |
for cluster_id, cluster in enumerate(clusters, 1):
|
88 |
+
try:
|
89 |
+
cluster_texts = df.iloc[cluster]
|
90 |
+
main_companies = []
|
91 |
+
|
92 |
+
for _, row in cluster_texts.iterrows():
|
93 |
+
if not pd.isna(row['text']) and isinstance(row['company_list'], list):
|
94 |
+
is_main, company = self.is_company_main_subject(row['text'], row['company_list'])
|
95 |
+
if is_main and company:
|
96 |
+
main_companies.append(company)
|
97 |
+
|
98 |
+
main_company = main_companies[0] if main_companies else "Multiple/Unclear"
|
99 |
+
|
100 |
+
for idx in cluster:
|
101 |
+
row_data = df.iloc[idx]
|
102 |
+
result_data.append({
|
103 |
+
'cluster_id': cluster_id,
|
104 |
+
'datetime': row_data['datetime'],
|
105 |
+
'company': ' | '.join(row_data['company_list']) if isinstance(row_data['company_list'], list) else '',
|
106 |
+
'main_company': main_company,
|
107 |
+
'text': row_data['text'],
|
108 |
+
'cluster_size': len(cluster)
|
109 |
+
})
|
110 |
+
except Exception as e:
|
111 |
+
print(f"Error processing cluster {cluster_id}: {str(e)}")
|
112 |
+
continue
|
113 |
|
114 |
return pd.DataFrame(result_data)
|
115 |
|
|
|
171 |
return f'<a href="data:application/vnd.openxmlformats-officedocument.spreadsheetml.sheet;base64,{b64}" download="{filename}">Download {filename}</a>'
|
172 |
|
173 |
def main():
|
174 |
+
st.title("кластеризуем новости v.1.2")
|
175 |
st.write("Upload Excel file with columns: company, datetime, text")
|
176 |
|
177 |
uploaded_file = st.file_uploader("Choose Excel file", type=['xlsx'])
|