UpendraAI's picture
Update app.py
913ae3d verified
import streamlit as st
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModel
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from imblearn.over_sampling import RandomOverSampler
@st.cache_resource
def load_model_and_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("ai4bharat/indic-bert")
model = AutoModel.from_pretrained("ai4bharat/indic-bert")
return tokenizer, model
def get_embeddings(texts, tokenizer, model):
inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
outputs = model(**inputs)
embeddings = outputs.last_hidden_state[:, 0, :] # CLS token
return embeddings
@st.cache_data
def load_data():
df = pd.read_csv("SushasanSampleData.csv", encoding="utf-8")
df['applicationDetail'] = df['applicationDetail'].fillna("")
df['applicationCategoryName'] = df['applicationCategoryName'].fillna("अन्य")
return df
@st.cache_resource
def preprocess_and_train(df):
tokenizer, model = load_model_and_tokenizer()
text_embeddings = get_embeddings(df['applicationDetail'].tolist(), tokenizer, model)
text_embeddings = text_embeddings.cpu().numpy()
label_encoder = LabelEncoder()
labels = label_encoder.fit_transform(df['applicationCategoryName'])
ros = RandomOverSampler(random_state=42)
X_resampled, y_resampled = ros.fit_resample(text_embeddings, labels)
'''X_train, X_test, y_train, y_test = train_test_split(X_resampled, y_resampled, test_size=0.2, random_state=42)'''
X_train, X_test, y_train, y_test = train_test_split(X_resampled, y_resampled, test_size=0.2, random_state=42, stratify=y_resampled)
clf = LogisticRegression(max_iter=1000)
clf.fit(X_train, y_train)
return clf, tokenizer, model, label_encoder
df = load_data()
clf, tokenizer, model, label_encoder = preprocess_and_train(df)
# Streamlit UI
st.title("🇮🇳 Hindi Category Classifier (IndicBERT Powered)")
user_input = st.text_area("✍️ Enter Application Detail", "")
if st.button("🔍 Predict"):
if user_input.strip() == "":
st.warning("Please write something.")
else:
user_emb = get_embeddings([user_input], tokenizer, model)
user_emb = user_emb.cpu().numpy()
prediction = clf.predict(user_emb)
label = label_encoder.inverse_transform(prediction)[0]
st.success(f"🧠 Predicted Category: **{label}**")