Spaces:
Build error
Build error
import streamlit as st | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
import os | |
from threading import Thread | |
import time | |
# Load Model and Tokenizer | |
token = os.environ.get("HF_TOKEN") | |
model_name = "large-traversaal/Phi-4-Hindi" | |
def load_model(): | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
token=token, | |
trust_remote_code=True, | |
torch_dtype=torch.bfloat16 | |
) | |
tok = AutoTokenizer.from_pretrained(model_name, token=token) | |
return model, tok | |
model, tok = load_model() | |
terminators = [tok.eos_token_id] | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = model.to(device) | |
# Initialize session state if not set | |
if "chat_history" not in st.session_state: | |
st.session_state.chat_history = [] | |
# Chat function | |
def chat(message, temperature, do_sample, max_tokens): | |
chat_log = st.session_state.chat_history.copy() | |
chat_log.append({"role": "user", "content": message}) | |
messages = tok.apply_chat_template(chat_log, tokenize=False, add_generation_prompt=True) | |
model_inputs = tok([messages], return_tensors="pt").to(device) | |
streamer = TextIteratorStreamer(tok, timeout=20.0, skip_prompt=True, skip_special_tokens=True) | |
generate_kwargs = { | |
"inputs": model_inputs["input_ids"], | |
"streamer": streamer, | |
"max_new_tokens": max_tokens, | |
"do_sample": do_sample, | |
"temperature": temperature, | |
"eos_token_id": terminators, | |
} | |
if temperature == 0: | |
generate_kwargs["do_sample"] = False | |
t = Thread(target=model.generate, kwargs=generate_kwargs) | |
t.start() | |
partial_text = "" | |
for new_text in streamer: | |
partial_text += new_text | |
yield partial_text | |
st.session_state.chat_history.append({"role": "assistant", "content": partial_text}) | |
# Streamlit UI | |
st.title("π¬ Chat With Phi-4-Hindi") | |
st.markdown("Chat with [large-traversaal/Phi-4-Hindi](https://huggingface.co/large-traversaal/Phi-4-Hindi)") | |
# Chat input | |
temperature = st.sidebar.slider("Temperature", 0.0, 1.0, 0.3, 0.1) | |
do_sample = st.sidebar.checkbox("Use Sampling", value=True) | |
max_tokens = st.sidebar.slider("Max Tokens", 128, 4096, 512, 1) | |
text_color = st.sidebar.selectbox("Text Color", ["Red", "Black", "Blue", "Green", "Purple"], index=0) | |
dark_mode = st.sidebar.checkbox("π Dark Mode", value=False) | |
def get_html_text(text, color): | |
return f'<p style="color: {color.lower()}; font-size: 16px;">{text}</p>' | |
for msg in st.session_state.chat_history: | |
if msg["role"] == "user": | |
st.markdown(get_html_text("π€ " + msg["content"], "black"), unsafe_allow_html=True) | |
else: | |
st.markdown(get_html_text("π€ " + msg["content"], text_color), unsafe_allow_html=True) | |
user_input = st.text_input("Type your message:", "") | |
if st.button("Send"): | |
if user_input.strip(): | |
st.session_state.chat_history.append({"role": "user", "content": user_input}) | |
with st.spinner("Generating response..."): | |
for output in chat(user_input, temperature, do_sample, max_tokens): | |
pass | |
st.experimental_rerun() | |
if st.button("π§Ή Clear Chat"): | |
st.session_state.chat_history = [] | |
st.experimental_rerun() |