dejanseo commited on
Commit
a794f66
·
verified ·
1 Parent(s): cd92d8b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +158 -0
app.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import streamlit as st
5
+ import pandas as pd
6
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
7
+
8
+ # ==============================
9
+ # ⚙️ CONFIGURABLE PARAMETERS
10
+ # ==============================
11
+ MODEL_PATH = "dejanseo/bulgarian-search-query-intent-alpha" # HF model repository
12
+ LABEL_MAP_PATH = "label_map.json" # Ensure this file is in the same directory as app.py
13
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
14
+
15
+ # ==============================
16
+ # 📌 Load Model and Tokenizer
17
+ # ==============================
18
+ @st.cache_resource
19
+ def load_inference_resources():
20
+ # Load the label mapping from local file
21
+ with open(LABEL_MAP_PATH, "r") as f:
22
+ label_map = json.load(f)
23
+
24
+ # Convert ID keys from string to int for id_to_label mapping
25
+ id_to_label = {int(k): v for k, v in label_map["id_to_label"].items()}
26
+
27
+ # Load the tokenizer and model from Hugging Face
28
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
29
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
30
+ model.to(DEVICE)
31
+ model.eval() # Set model to evaluation mode
32
+
33
+ return model, tokenizer, label_map["label_to_id"], id_to_label
34
+
35
+ # ==============================
36
+ # 📌 Inference Function
37
+ # ==============================
38
+ def predict_intent(query, model, tokenizer, id_to_label):
39
+ """
40
+ Predict the intent of a Bulgarian search query.
41
+ """
42
+ # Tokenize input text
43
+ inputs = tokenizer(
44
+ query,
45
+ padding="max_length",
46
+ truncation=True,
47
+ max_length=128,
48
+ return_tensors="pt"
49
+ )
50
+
51
+ # Move inputs to the same device as the model
52
+ inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
53
+
54
+ # Inference without gradient tracking
55
+ with torch.no_grad():
56
+ outputs = model(**inputs)
57
+
58
+ # Compute probabilities with softmax
59
+ probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)[0]
60
+
61
+ # Identify the predicted class and confidence
62
+ predicted_class_id = torch.argmax(probabilities).item()
63
+ predicted_intent = id_to_label[predicted_class_id]
64
+ confidence = probabilities[predicted_class_id].item()
65
+
66
+ # Build a dictionary with all intent scores
67
+ all_intents = {id_to_label[i]: prob.item() for i, prob in enumerate(probabilities)}
68
+ sorted_intents = sorted(all_intents.items(), key=lambda x: x[1], reverse=True)
69
+
70
+ return {
71
+ "query": query,
72
+ "predicted_intent": predicted_intent,
73
+ "confidence": confidence,
74
+ "all_scores": sorted_intents
75
+ }
76
+
77
+ # ==============================
78
+ # 🌟 Streamlit UI for Inference
79
+ # ==============================
80
+ def inference_ui():
81
+ st.title("🔍 Bulgarian Search Intent Classification")
82
+
83
+ try:
84
+ # Load resources
85
+ model, tokenizer, label_to_id, id_to_label = load_inference_resources()
86
+ st.success(f"✅ Model loaded successfully! Found {len(id_to_label)} intent classes.")
87
+
88
+ # Show available intents
89
+ with st.expander("Available Intent Classes"):
90
+ st.write(", ".join(id_to_label.values()))
91
+
92
+ # Single query inference
93
+ query = st.text_input("Enter a Bulgarian search query:", "Как да направя резервация за ресторант?")
94
+
95
+ if st.button("Predict Intent"):
96
+ with st.spinner("Analyzing query..."):
97
+ prediction = predict_intent(query, model, tokenizer, id_to_label)
98
+
99
+ st.subheader("Prediction Results")
100
+ st.metric(
101
+ label="Predicted Intent",
102
+ value=prediction["predicted_intent"],
103
+ delta=f"{prediction['confidence']*100:.2f}% confidence"
104
+ )
105
+
106
+ st.subheader("Intent Probabilities")
107
+ df_probs = pd.DataFrame(prediction["all_scores"], columns=["Intent", "Probability"])
108
+ df_top5 = df_probs.head(5)
109
+ st.bar_chart(df_top5.set_index("Intent"))
110
+
111
+ with st.expander("View All Intent Probabilities"):
112
+ st.dataframe(df_probs)
113
+
114
+ # Batch inference section
115
+ st.subheader("Batch Inference")
116
+ uploaded_file = st.file_uploader("Upload a CSV/Excel file with queries", type=["csv", "xlsx", "parquet"])
117
+
118
+ if uploaded_file is not None:
119
+ if uploaded_file.name.endswith(".csv"):
120
+ df = pd.read_csv(uploaded_file)
121
+ elif uploaded_file.name.endswith(".xlsx"):
122
+ df = pd.read_excel(uploaded_file)
123
+ elif uploaded_file.name.endswith(".parquet"):
124
+ df = pd.read_parquet(uploaded_file)
125
+
126
+ query_column = "query" if "query" in df.columns else st.selectbox("Select the column containing queries:", df.columns)
127
+
128
+ if query_column and st.button("Run Batch Inference"):
129
+ progress_bar = st.progress(0)
130
+ results = []
131
+
132
+ for i, row in enumerate(df[query_column]):
133
+ progress_bar.progress((i + 1) / len(df))
134
+ prediction = predict_intent(row, model, tokenizer, id_to_label)
135
+ results.append({
136
+ "query": row,
137
+ "predicted_intent": prediction["predicted_intent"],
138
+ "confidence": prediction["confidence"]
139
+ })
140
+
141
+ results_df = pd.DataFrame(results)
142
+ st.subheader("Batch Inference Results")
143
+ st.dataframe(results_df)
144
+
145
+ csv = results_df.to_csv(index=False)
146
+ st.download_button(
147
+ label="Download Results as CSV",
148
+ data=csv,
149
+ file_name="batch_inference_results.csv",
150
+ mime="text/csv"
151
+ )
152
+
153
+ except Exception as e:
154
+ st.error(f"❌ Error loading model: {str(e)}")
155
+ st.error("Please ensure the model and label map files are available.")
156
+
157
+ if __name__ == "__main__":
158
+ inference_ui()