MikhailPugachev commited on
Commit
d5fbae4
·
1 Parent(s): c38b845

Исправлен путь к модели

Browse files
Files changed (1) hide show
  1. app.py +35 -9
app.py CHANGED
@@ -5,6 +5,7 @@ from transformers import AutoTokenizer
5
  from model_SingleLabelClassifier import SingleLabelClassifier
6
  from safetensors.torch import load_file
7
  import json
 
8
 
9
 
10
  MODEL_NAME = "allenai/scibert_scivocab_uncased"
@@ -15,7 +16,7 @@ MAX_LEN = 250
15
  # Загрузка меток
16
  with open("label_mappings.json", "r") as f:
17
  mappings = json.load(f)
18
- abel2id = mappings["label2id"]
19
  id2label = {int(k): v for k, v in mappings["id2label"].items()}
20
 
21
  # Загрузка модели и токенизатора
@@ -30,9 +31,13 @@ def load_model_and_tokenizer():
30
 
31
  model, tokenizer = load_model_and_tokenizer()
32
 
33
- # Функция предсказания
34
- def predict(title, summary, model, tokenizer, id2label, max_length=320, top_k=3):
35
  model.eval()
 
 
 
 
36
  text = title + ". " + summary
37
 
38
  inputs = tokenizer(
@@ -44,7 +49,11 @@ def predict(title, summary, model, tokenizer, id2label, max_length=320, top_k=3)
44
  )
45
 
46
  with torch.no_grad():
47
- outputs = model(**inputs)
 
 
 
 
48
  logits = outputs["logits"]
49
  probs = F.softmax(logits, dim=1).squeeze().numpy()
50
 
@@ -52,17 +61,34 @@ def predict(title, summary, model, tokenizer, id2label, max_length=320, top_k=3)
52
  return [(id2label[i], round(probs[i], 3)) for i in top_indices]
53
 
54
  # Интерфейс Streamlit
55
- st.title("ArXiv Tag Predictor")
56
- st.write("Вставьте заголовок и аннотацию статьи!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  title = st.text_input("**Title**")
59
  summary = st.text_area("**Summary**", height=200)
60
 
61
- if st.button("Предсказать тег"):
62
  if not title or not summary:
63
  st.warning("Пожалуйста, введите и заголовок, и аннотацию!")
64
  else:
65
  preds = predict(title, summary, model, tokenizer, id2label)
66
- st.subheader("Предсказанные теги:")
67
  for tag, prob in preds:
68
- st.write(f"**{tag}** — вероятность: {prob:.3f}")
 
5
  from model_SingleLabelClassifier import SingleLabelClassifier
6
  from safetensors.torch import load_file
7
  import json
8
+ import re
9
 
10
 
11
  MODEL_NAME = "allenai/scibert_scivocab_uncased"
 
16
  # Загрузка меток
17
  with open("label_mappings.json", "r") as f:
18
  mappings = json.load(f)
19
+ label2id = mappings["label2id"]
20
  id2label = {int(k): v for k, v in mappings["id2label"].items()}
21
 
22
  # Загрузка модели и токенизатора
 
31
 
32
  model, tokenizer = load_model_and_tokenizer()
33
 
34
+ # Обновлённая функция предсказания
35
+ def predict(title, summary, model, tokenizer, id2label, max_length=MAX_LEN, top_k=3):
36
  model.eval()
37
+
38
+ # Удаляем лишние точки, пробелы и объединяем текст
39
+ title = re.sub(r"\.+$", "", title.strip())
40
+ summary = re.sub(r"\.+$", "", summary.strip())
41
  text = title + ". " + summary
42
 
43
  inputs = tokenizer(
 
49
  )
50
 
51
  with torch.no_grad():
52
+ outputs = model(
53
+ input_ids=inputs["input_ids"],
54
+ attention_mask=inputs["attention_mask"],
55
+ token_type_ids=inputs.get("token_type_ids")
56
+ )
57
  logits = outputs["logits"]
58
  probs = F.softmax(logits, dim=1).squeeze().numpy()
59
 
 
61
  return [(id2label[i], round(probs[i], 3)) for i in top_indices]
62
 
63
  # Интерфейс Streamlit
64
+ st.title("🧠 ArXiv Tag Predictor")
65
+
66
+ with st.expander("ℹ️ Описание модели"):
67
+ st.markdown("""
68
+ Данная модель обучена на основе [SciBERT](https://huggingface.co/allenai/scibert_scivocab_uncased) для классификации научных статей с сайта [arXiv.org](https://arxiv.org).
69
+
70
+ - Использует **65 различных тегов** из тематик arXiv (например: `cs.CL`, `math.CO`, `stat.ML`, и т.д.)
71
+ - Модель обучена на **заголовках и аннотациях** научных публикаций
72
+ - На вход принимает **англоязычный текст**
73
+ - Предсказывает **топ-3 наиболее вероятных тега** для каждой статьи
74
+
75
+ Ниже вы можете посмотреть полный список возможных тегов 👇
76
+ """)
77
+
78
+ with st.expander("📄 Список всех тегов"):
79
+ tag_list = sorted(label2id.keys())
80
+ st.markdown("\n".join([f"- `{tag}`" for tag in tag_list]))
81
+
82
+ st.write("Введите заголовок и аннотацию научной статьи (на английском):")
83
 
84
  title = st.text_input("**Title**")
85
  summary = st.text_area("**Summary**", height=200)
86
 
87
+ if st.button("📌 Предсказать теги"):
88
  if not title or not summary:
89
  st.warning("Пожалуйста, введите и заголовок, и аннотацию!")
90
  else:
91
  preds = predict(title, summary, model, tokenizer, id2label)
92
+ st.subheader("📚 Предсказанные теги:")
93
  for tag, prob in preds:
94
+ st.write(f"**{tag}** — вероятность: `{prob:.3f}`")