phi-4-streamlit / app.py
DrishtiSharma's picture
Update app.py
fe9ba40 verified
raw
history blame contribute delete
5.32 kB
import streamlit as st
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import os
from threading import Thread
import requests
import time
# Define model path for caching (Avoids reloading every app restart)
MODEL_PATH = "/mnt/data/Phi-4-Hindi"
TOKEN = os.environ.get("HF_TOKEN")
MODEL_NAME = "DrishtiSharma/Phi-4-Hindi-quantized"
# Load Model & Tokenizer Once
@st.cache_resource()
def load_model():
with st.spinner("Loading model... Please wait ⏳"):
try:
if not os.path.exists(MODEL_PATH):
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME, token=TOKEN, trust_remote_code=True, torch_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=TOKEN)
model.save_pretrained(MODEL_PATH)
tokenizer.save_pretrained(MODEL_PATH)
else:
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
except requests.exceptions.ConnectionError:
st.error("⚠️ Connection error! Unable to download the model. Please check your internet connection and try again.")
return None, None
except requests.exceptions.ReadTimeout:
st.error("⚠️ Read Timeout! The request took too long. Please try again later.")
return None, None
return model, tokenizer
# Load and move model to appropriate device
model, tok = load_model()
if model is None or tok is None:
st.stop()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
model = model.to(device)
except torch.cuda.OutOfMemoryError:
st.error("⚠️ CUDA Out of Memory! Running on CPU instead.")
device = torch.device("cpu")
model = model.to(device)
terminators = [tok.eos_token_id]
# Initialize session state if not set
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
# Chat function
def chat(message, temperature, do_sample, max_tokens):
"""Processes chat input and generates a response using the model."""
# Append new message to history
st.session_state.chat_history.append({"role": "user", "content": message})
# Convert chat history into model-friendly format
messages = tok.apply_chat_template(st.session_state.chat_history, tokenize=False, add_generation_prompt=True)
model_inputs = tok([messages], return_tensors="pt").to(device)
# Initialize streamer for token-wise response
streamer = TextIteratorStreamer(tok, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
# Define generation parameters
generate_kwargs = {
"inputs": model_inputs["input_ids"],
"streamer": streamer,
"max_new_tokens": max_tokens,
"do_sample": do_sample,
"temperature": temperature,
"eos_token_id": terminators,
}
if temperature == 0:
generate_kwargs["do_sample"] = False
# Generate response asynchronously
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
# Collect response as it streams
response_text = ""
for new_text in streamer:
response_text += new_text
yield response_text
# Save the assistant's response to session history
st.session_state.chat_history.append({"role": "assistant", "content": response_text})
# UI Setup
st.title("πŸ’¬ Chat With Phi-4-Hindi")
st.success("βœ… Model is READY to chat!")
st.markdown("Chat with [large-traversaal/Phi-4-Hindi](https://huggingface.co/large-traversaal/Phi-4-Hindi)")
# Sidebar Chat Settings
temperature = st.sidebar.slider("Temperature", 0.0, 1.0, 0.3, 0.1)
do_sample = st.sidebar.checkbox("Use Sampling", value=True)
max_tokens = st.sidebar.slider("Max Tokens", 128, 4096, 512, 1)
text_color = st.sidebar.selectbox("Text Color", ["Red", "Black", "Blue", "Green", "Purple"], index=0)
dark_mode = st.sidebar.checkbox("πŸŒ™ Dark Mode", value=False)
# Function to format chat messages
def get_html_text(text, color):
return f'<p style="color: {color.lower()}; font-size: 16px;">{text}</p>'
# Display chat history
for msg in st.session_state.chat_history:
role = "πŸ‘€" if msg["role"] == "user" else "πŸ€–"
st.markdown(get_html_text(f"**{role}:** {msg['content']}", text_color if role == "πŸ€–" else "black"), unsafe_allow_html=True)
# User Input Handling
user_input = st.text_input("Type your message:", "")
if st.button("Send"):
if user_input.strip():
st.session_state.chat_history.append({"role": "user", "content": user_input})
# Display chatbot response
with st.spinner("Generating response... πŸ€–πŸ’­"):
response_generator = chat(user_input, temperature, do_sample, max_tokens)
final_response = ""
for output in response_generator:
final_response = output # Store latest output
st.success("βœ… Response generated!")
# Add generated response to session state
st.rerun()
if st.button("🧹 Clear Chat"):
with st.spinner("Clearing chat history..."):
st.session_state.chat_history = []
st.success("βœ… Chat history cleared!")
st.rerun()