dejanseo's picture
Update app.py
f56ff9e verified
import streamlit as st
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# Force light theme globally
st.markdown("""
<style>
/* Hide Streamlit's menu and footer */
#MainMenu {visibility: hidden;}
footer {visibility: hidden;}
header {visibility: hidden;}
/* Center and size the logo */
.block-container {
padding-top: 1rem;
}
</style>
""", unsafe_allow_html=True)
# Load model and tokenizer from Hugging Face Hub
@st.cache_resource
def load_model_and_tokenizer():
model_name = "dejanseo/bulgarian-search-query-intent"
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
return model, tokenizer
# Load resources
model, tokenizer = load_model_and_tokenizer()
# Page layout with clickable logo
st.markdown("""
<div style="display: flex; justify-content: space-between; align-items: center;">
<h1>Класификация на намерения за търсене</h1>
<a href="https://dejan.ai" target="_blank">
<img src="https://huggingface.co/spaces/dejanseo/bulgarian-search-query-intent-classifier/resolve/main/dejan-300x103.png" width="300">
</a>
</div>
""", unsafe_allow_html=True)
st.write(
"Въведете една или повече заявки (всеки на нов ред) или качете `.txt` файл, в който "
"всяка заявка е на отделен ред без допълнителни параметри. "
"Моделът е създаден от [DEJAN AI](https://dejan.ai)."
)
# Текстово поле за въвеждане на заявки
queries_input = st.text_area("Въведете вашите заявки (по една на ред):")
# Качване на `.txt` файл
uploaded_file = st.file_uploader(
"Качете `.txt` файл с заявки (всеки ред съдържа една заявка)", type=["txt"]
)
# Събиране на заявките от текстовото поле и/или файла
queries = []
if queries_input.strip():
queries.extend([line.strip() for line in queries_input.splitlines() if line.strip()])
if uploaded_file is not None:
file_content = uploaded_file.read().decode("utf-8")
queries.extend([line.strip() for line in file_content.splitlines() if line.strip()])
# UI for button with spinner
button_disabled = False
if queries:
button_disabled = False
else:
button_disabled = True
if st.button("Класифицирай", disabled=button_disabled):
if queries:
with st.spinner("Обработване..."):
# Tokenize in batch
inputs = tokenizer(
queries,
return_tensors="pt",
truncation=True,
padding=True,
max_length=256
)
# Run inference
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predictions = logits.argmax(dim=-1).tolist()
probabilities = F.softmax(logits, dim=-1)
confidence_scores = probabilities.max(dim=-1).values.tolist()
# Използване на наличната label mapping от модела
id2label = model.config.id2label
results = []
for query, pred, conf in zip(queries, predictions, confidence_scores):
predicted_label = id2label.get(str(pred), id2label.get(pred, "Неизвестно"))
results.append({
"Заявка": query,
"Предсказано намерение": predicted_label,
"Доверие": f"{conf:.2f}"
})
st.write("### Резултати:")
st.dataframe(results, use_container_width=True)
else:
st.warning("Моля, въведете поне една заявка, преди да класифицирате.")