|
import streamlit as st |
|
import torch |
|
import torch.nn.functional as F |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
|
|
|
st.markdown(""" |
|
<style> |
|
/* Hide Streamlit's menu and footer */ |
|
#MainMenu {visibility: hidden;} |
|
footer {visibility: hidden;} |
|
header {visibility: hidden;} |
|
|
|
/* Center and size the logo */ |
|
.block-container { |
|
padding-top: 1rem; |
|
} |
|
</style> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
@st.cache_resource |
|
def load_model_and_tokenizer(): |
|
model_name = "dejanseo/bulgarian-search-query-intent" |
|
model = AutoModelForSequenceClassification.from_pretrained(model_name) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
return model, tokenizer |
|
|
|
|
|
model, tokenizer = load_model_and_tokenizer() |
|
|
|
|
|
st.markdown(""" |
|
<div style="display: flex; justify-content: space-between; align-items: center;"> |
|
<h1>Класификация на намерения за търсене</h1> |
|
<a href="https://dejan.ai" target="_blank"> |
|
<img src="https://huggingface.co/spaces/dejanseo/bulgarian-search-query-intent-classifier/resolve/main/dejan-300x103.png" width="300"> |
|
</a> |
|
</div> |
|
""", unsafe_allow_html=True) |
|
|
|
st.write( |
|
"Въведете една или повече заявки (всеки на нов ред) или качете `.txt` файл, в който " |
|
"всяка заявка е на отделен ред без допълнителни параметри. " |
|
"Моделът е създаден от [DEJAN AI](https://dejan.ai)." |
|
) |
|
|
|
|
|
queries_input = st.text_area("Въведете вашите заявки (по една на ред):") |
|
|
|
|
|
uploaded_file = st.file_uploader( |
|
"Качете `.txt` файл с заявки (всеки ред съдържа една заявка)", type=["txt"] |
|
) |
|
|
|
|
|
queries = [] |
|
if queries_input.strip(): |
|
queries.extend([line.strip() for line in queries_input.splitlines() if line.strip()]) |
|
if uploaded_file is not None: |
|
file_content = uploaded_file.read().decode("utf-8") |
|
queries.extend([line.strip() for line in file_content.splitlines() if line.strip()]) |
|
|
|
|
|
button_disabled = False |
|
if queries: |
|
button_disabled = False |
|
else: |
|
button_disabled = True |
|
|
|
if st.button("Класифицирай", disabled=button_disabled): |
|
if queries: |
|
with st.spinner("Обработване..."): |
|
|
|
inputs = tokenizer( |
|
queries, |
|
return_tensors="pt", |
|
truncation=True, |
|
padding=True, |
|
max_length=256 |
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
logits = outputs.logits |
|
predictions = logits.argmax(dim=-1).tolist() |
|
probabilities = F.softmax(logits, dim=-1) |
|
confidence_scores = probabilities.max(dim=-1).values.tolist() |
|
|
|
|
|
id2label = model.config.id2label |
|
|
|
results = [] |
|
for query, pred, conf in zip(queries, predictions, confidence_scores): |
|
predicted_label = id2label.get(str(pred), id2label.get(pred, "Неизвестно")) |
|
results.append({ |
|
"Заявка": query, |
|
"Предсказано намерение": predicted_label, |
|
"Доверие": f"{conf:.2f}" |
|
}) |
|
|
|
st.write("### Резултати:") |
|
st.dataframe(results, use_container_width=True) |
|
else: |
|
st.warning("Моля, въведете поне една заявка, преди да класифицирате.") |
|
|