Yerzhxn commited on
Commit
e426245
·
verified ·
1 Parent(s): dcdf73a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -47
app.py CHANGED
@@ -1,53 +1,59 @@
1
- import torch
2
- import torch.nn.functional as F
3
  import joblib
4
- from transformers import BertTokenizer, BertForSequenceClassification
5
-
6
- # Загрузка предобученной модели и токенизатора
7
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
8
 
9
- model = torch.load('bert_model.pkl', map_location=device)
10
- model = model.to(device)
 
 
 
11
 
 
12
  tokenizer = joblib.load('bert_tokenizer.pkl')
13
 
14
- # Функция для предсказания класса с вероятностями
15
- def predict_class_with_probabilities(text, model, tokenizer, label_encoder, max_len=128):
16
- model.eval()
17
-
18
- # Токенизация текста
19
- encodings = tokenizer(
20
- text,
21
- truncation=True,
22
- padding="max_length",
23
- max_length=max_len,
24
- return_tensors="pt"
25
- )
26
-
27
- # Перенос токенизированных данных на устройство
28
- encodings = {key: val.to(device) for key, val in encodings.items()}
29
-
30
  with torch.no_grad():
31
- # Предсказание модели
32
- outputs = model(**encodings)
33
- logits = outputs.logits
34
-
35
- # Вычисление вероятностей классов
36
- probabilities = F.softmax(logits, dim=1).squeeze().cpu().numpy()
37
-
38
- # Определение предсказанного класса
39
- predicted_class = torch.argmax(logits, dim=1).item()
40
- predicted_label = label_encoder.inverse_transform([predicted_class])[0]
41
-
42
- return predicted_label, probabilities
43
-
44
- # Пример использования функции
45
- input_text = "художественно разработка дизайн проекта"
46
- predicted_class, probabilities = predict_class_with_probabilities(input_text, model, tokenizer, label_encoder)
47
-
48
- # Вывод результатов
49
- print(f"Предсказанный класс: {predicted_class}")
50
- print("Вероятности для каждого класса:")
51
- for idx, prob in enumerate(probabilities):
52
- class_label = label_encoder.inverse_transform([idx])[0]
53
- print(f"{class_label}: {prob:.4f}")
 
 
 
 
 
 
 
 
 
 
1
  import joblib
2
+ import streamlit as st
3
+ import pandas as pd
4
+ import numpy as np
5
+ import torch
6
+ from transformers import BertTokenizer
7
 
8
+ # Загрузка модели и токенизатора с обработкой ошибки CUDA
9
+ try:
10
+ model = torch.load('bert_model.pkl', map_location=torch.device('cpu'))
11
+ except RuntimeError as e:
12
+ st.error(f"Ошибка загрузки модели: {e}")
13
 
14
+ # Загрузка токенизатора BERT
15
  tokenizer = joblib.load('bert_tokenizer.pkl')
16
 
17
+ # Загрузка данных для поиска сходства
18
+ try:
19
+ data = pd.read_excel('DATA_new.xlsx')
20
+ data_texts = data['Text'].tolist()
21
+ except FileNotFoundError:
22
+ st.error("Файл 'DATA_new.xlsx' не найден.")
23
+ except Exception as e:
24
+ st.error(f"Ошибка загрузки файла: {e}")
25
+
26
+ # Функция для нахождения сходства
27
+ def find_similar_texts(input_text, top_n=5):
28
+ inputs = tokenizer(input_text, return_tensors='pt', padding=True, truncation=True)
 
 
 
 
29
  with torch.no_grad():
30
+ input_vector = model(**inputs).logits
31
+ data_vectors = []
32
+ for text in data_texts:
33
+ inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True)
34
+ with torch.no_grad():
35
+ data_vectors.append(model(**inputs).logits)
36
+ data_vectors = torch.stack(data_vectors).squeeze()
37
+ similarities = torch.nn.functional.cosine_similarity(input_vector, data_vectors)
38
+ similar_indices = torch.argsort(similarities, descending=True)[:top_n]
39
+ similar_texts = [data_texts[i] for i in similar_indices]
40
+ return similar_texts
41
+
42
+ # Streamlit интерфейс в файле app.py
43
+ if __name__ == "__main__":
44
+ st.title("Поиск сходства текстов")
45
+ st.write("Введите текст для поиска сходства")
46
+
47
+ input_text = st.text_area("Текст для поиска сходства")
48
+
49
+ if st.button("Найти похожие тексты"):
50
+ if input_text.strip():
51
+ similar_texts = find_similar_texts(input_text)
52
+ if similar_texts:
53
+ st.write("Похожие тексты:")
54
+ for text in similar_texts:
55
+ st.write(f"- {text}")
56
+ else:
57
+ st.write("Нет похожих текстов для данного ввода.")
58
+ else:
59
+ st.error("Пожалуйста, введите текст для поиска сходства.")