Yerzhxn commited on
Commit
35e4fae
·
verified ·
1 Parent(s): e517b01

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -52
app.py CHANGED
@@ -1,57 +1,53 @@
1
- import streamlit as st
2
- import pandas as pd
3
- import numpy as np
4
  import torch
 
 
5
  from transformers import BertTokenizer, BertForSequenceClassification
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- # Загрузка модели и токенизатора с Hugging Face Hub
8
- model = torch.load('bert_model.pkl', map_location=torch.device('cpu'))
9
- tokenizer = joblib.load('bert_tokenizer.pkl')
10
-
11
- # Устройство для использования модели
12
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
- model = model.to(device)
14
-
15
- # Загрузка данных для поиска сходства
16
- try:
17
- data = pd.read_excel('DATA_new.xlsx')
18
- data_texts = data['Tags'].tolist()
19
- except FileNotFoundError:
20
- st.error("Файл 'DATA_new.xlsx' не найден.")
21
- except Exception as e:
22
- st.error(f"Ошибка загрузки файла: {e}")
23
-
24
- # Функция для нахождения сходства
25
- def find_similar_texts(input_text, top_n=5):
26
- inputs = tokenizer(input_text, return_tensors='pt', padding=True, truncation=True).to(device)
27
  with torch.no_grad():
28
- input_vector = model(**inputs).logits
29
- data_vectors = []
30
- for text in data_texts:
31
- inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True).to(device)
32
- with torch.no_grad():
33
- data_vectors.append(model(**inputs).logits)
34
- data_vectors = torch.stack(data_vectors).squeeze()
35
- similarities = torch.nn.functional.cosine_similarity(input_vector, data_vectors)
36
- similar_indices = torch.argsort(similarities, descending=True)[:top_n]
37
- similar_texts = [data_texts[i] for i in similar_indices]
38
- return similar_texts
39
-
40
- # Streamlit интерфейс в файле app.py
41
- if __name__ == "__main__":
42
- st.title("Поиск сходства текстов")
43
- st.write("Введите текст для поиска сходства")
44
-
45
- input_text = st.text_area("Текст для поиска сходства")
46
-
47
- if st.button("Найти похожие тексты"):
48
- if input_text.strip():
49
- similar_texts = find_similar_texts(input_text)
50
- if similar_texts:
51
- st.write("Похожие тексты:")
52
- for text in similar_texts:
53
- st.write(f"- {text}")
54
- else:
55
- st.write("Нет похожих текстов для данного ввода.")
56
  else:
57
- st.error("Пожалуйста, введите текст для поиска сходства.")
 
 
 
 
 
 
 
1
  import torch
2
+ import torch.nn.functional as F
3
+ import streamlit as st
4
  from transformers import BertTokenizer, BertForSequenceClassification
5
+ import joblib
6
+
7
+ # Загрузка модели, токенизатора и label_encoder
8
+ model = torch.load("bert_model.pkl", map_location=torch.device('cpu'))
9
+ tokenizer = joblib.load("bert_tokenizer.pkl")
10
+ label_encoder = joblib.load("label_encoder.pkl")
11
+
12
+ def predict_class_with_probabilities(text, model, tokenizer, label_encoder, max_len=128):
13
+ model.eval()
14
+
15
+ encodings = tokenizer(
16
+ text,
17
+ truncation=True,
18
+ padding="max_length",
19
+ max_length=max_len,
20
+ return_tensors="pt"
21
+ )
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  with torch.no_grad():
24
+ outputs = model(**encodings)
25
+ logits = outputs.logits
26
+
27
+ probabilities = F.softmax(logits, dim=1).squeeze().cpu().numpy()
28
+ predicted_class = torch.argmax(logits, dim=1).item()
29
+ predicted_label = label_encoder.inverse_transform([predicted_class])[0]
30
+
31
+ return predicted_label, probabilities
32
+
33
+ def main():
34
+ st.title("Text Classification App with Hugging Face Space")
35
+ st.write("Введите текст, чтобы получить предсказание и вероятности классов.")
36
+
37
+ input_text = st.text_input("Введите текст для классификации:")
38
+
39
+ if st.button("Предсказать"):
40
+ if input_text:
41
+ predicted_class, probabilities = predict_class_with_probabilities(input_text, model, tokenizer, label_encoder)
42
+
43
+ st.write(f"**Предсказанный класс:** {predicted_class}")
44
+ st.write("**Вероятности для каждого класса:**")
45
+
46
+ for idx, prob in enumerate(probabilities):
47
+ class_label = label_encoder.inverse_transform([idx])[0]
48
+ st.write(f"{class_label}: {prob:.4f}")
 
 
 
49
  else:
50
+ st.write("Пожалуйста, введите текст для предсказания.")
51
+
52
+ if __name__ == "__main__":
53
+ main()