|
import os |
|
import json |
|
import torch |
|
import streamlit as st |
|
import pandas as pd |
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer |
|
|
|
|
|
|
|
|
|
MODEL_PATH = "dejanseo/bulgarian-search-query-intent-alpha" |
|
LABEL_MAP_PATH = "label_map.json" |
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
|
@st.cache_resource |
|
def load_inference_resources(): |
|
|
|
with open(LABEL_MAP_PATH, "r") as f: |
|
label_map = json.load(f) |
|
|
|
|
|
id_to_label = {int(k): v for k, v in label_map["id_to_label"].items()} |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) |
|
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH) |
|
model.to(DEVICE) |
|
model.eval() |
|
|
|
return model, tokenizer, label_map["label_to_id"], id_to_label |
|
|
|
|
|
|
|
|
|
def predict_intent(query, model, tokenizer, id_to_label): |
|
""" |
|
Predict the intent of a Bulgarian search query. |
|
""" |
|
|
|
inputs = tokenizer( |
|
query, |
|
padding="max_length", |
|
truncation=True, |
|
max_length=128, |
|
return_tensors="pt" |
|
) |
|
|
|
|
|
inputs = {k: v.to(DEVICE) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
|
|
probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)[0] |
|
|
|
|
|
predicted_class_id = torch.argmax(probabilities).item() |
|
predicted_intent = id_to_label[predicted_class_id] |
|
confidence = probabilities[predicted_class_id].item() |
|
|
|
|
|
all_intents = {id_to_label[i]: prob.item() for i, prob in enumerate(probabilities)} |
|
sorted_intents = sorted(all_intents.items(), key=lambda x: x[1], reverse=True) |
|
|
|
return { |
|
"query": query, |
|
"predicted_intent": predicted_intent, |
|
"confidence": confidence, |
|
"all_scores": sorted_intents |
|
} |
|
|
|
|
|
|
|
|
|
def inference_ui(): |
|
st.title("🔍 Bulgarian Search Intent Classification") |
|
|
|
try: |
|
|
|
model, tokenizer, label_to_id, id_to_label = load_inference_resources() |
|
st.success(f"✅ Model loaded successfully! Found {len(id_to_label)} intent classes.") |
|
|
|
|
|
with st.expander("Available Intent Classes"): |
|
st.write(", ".join(id_to_label.values())) |
|
|
|
|
|
query = st.text_input("Enter a Bulgarian search query:", "Как да направя резервация за ресторант?") |
|
|
|
if st.button("Predict Intent"): |
|
with st.spinner("Analyzing query..."): |
|
prediction = predict_intent(query, model, tokenizer, id_to_label) |
|
|
|
st.subheader("Prediction Results") |
|
st.metric( |
|
label="Predicted Intent", |
|
value=prediction["predicted_intent"], |
|
delta=f"{prediction['confidence']*100:.2f}% confidence" |
|
) |
|
|
|
st.subheader("Intent Probabilities") |
|
df_probs = pd.DataFrame(prediction["all_scores"], columns=["Intent", "Probability"]) |
|
df_top5 = df_probs.head(5) |
|
st.bar_chart(df_top5.set_index("Intent")) |
|
|
|
with st.expander("View All Intent Probabilities"): |
|
st.dataframe(df_probs) |
|
|
|
|
|
st.subheader("Batch Inference") |
|
uploaded_file = st.file_uploader("Upload a CSV/Excel file with queries", type=["csv", "xlsx", "parquet"]) |
|
|
|
if uploaded_file is not None: |
|
if uploaded_file.name.endswith(".csv"): |
|
df = pd.read_csv(uploaded_file) |
|
elif uploaded_file.name.endswith(".xlsx"): |
|
df = pd.read_excel(uploaded_file) |
|
elif uploaded_file.name.endswith(".parquet"): |
|
df = pd.read_parquet(uploaded_file) |
|
|
|
query_column = "query" if "query" in df.columns else st.selectbox("Select the column containing queries:", df.columns) |
|
|
|
if query_column and st.button("Run Batch Inference"): |
|
progress_bar = st.progress(0) |
|
results = [] |
|
|
|
for i, row in enumerate(df[query_column]): |
|
progress_bar.progress((i + 1) / len(df)) |
|
prediction = predict_intent(row, model, tokenizer, id_to_label) |
|
results.append({ |
|
"query": row, |
|
"predicted_intent": prediction["predicted_intent"], |
|
"confidence": prediction["confidence"] |
|
}) |
|
|
|
results_df = pd.DataFrame(results) |
|
st.subheader("Batch Inference Results") |
|
st.dataframe(results_df) |
|
|
|
csv = results_df.to_csv(index=False) |
|
st.download_button( |
|
label="Download Results as CSV", |
|
data=csv, |
|
file_name="batch_inference_results.csv", |
|
mime="text/csv" |
|
) |
|
|
|
except Exception as e: |
|
st.error(f"❌ Error loading model: {str(e)}") |
|
st.error("Please ensure the model and label map files are available.") |
|
|
|
if __name__ == "__main__": |
|
inference_ui() |
|
|