File size: 3,302 Bytes
c1786ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import streamlit as st
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import os
from threading import Thread
import time

# Load Model and Tokenizer
token = os.environ.get("HF_TOKEN")
model_name = "large-traversaal/Phi-4-Hindi"

@st.cache_resource()
def load_model():
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        token=token,
        trust_remote_code=True,
        torch_dtype=torch.bfloat16
    )
    tok = AutoTokenizer.from_pretrained(model_name, token=token)
    return model, tok

model, tok = load_model()
terminators = [tok.eos_token_id]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Initialize session state if not set
if "chat_history" not in st.session_state:
    st.session_state.chat_history = []

# Chat function
def chat(message, temperature, do_sample, max_tokens):
    chat_log = st.session_state.chat_history.copy()
    chat_log.append({"role": "user", "content": message})
    messages = tok.apply_chat_template(chat_log, tokenize=False, add_generation_prompt=True)
    
    model_inputs = tok([messages], return_tensors="pt").to(device)
    streamer = TextIteratorStreamer(tok, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
    
    generate_kwargs = {
        "inputs": model_inputs["input_ids"],
        "streamer": streamer,
        "max_new_tokens": max_tokens,
        "do_sample": do_sample,
        "temperature": temperature,
        "eos_token_id": terminators,
    }
    
    if temperature == 0:
        generate_kwargs["do_sample"] = False
    
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()
    
    partial_text = ""
    for new_text in streamer:
        partial_text += new_text
        yield partial_text
    
    st.session_state.chat_history.append({"role": "assistant", "content": partial_text})

# Streamlit UI
st.title("πŸ’¬ Chat With Phi-4-Hindi")
st.markdown("Chat with [large-traversaal/Phi-4-Hindi](https://huggingface.co/large-traversaal/Phi-4-Hindi)")

# Chat input
temperature = st.sidebar.slider("Temperature", 0.0, 1.0, 0.3, 0.1)
do_sample = st.sidebar.checkbox("Use Sampling", value=True)
max_tokens = st.sidebar.slider("Max Tokens", 128, 4096, 512, 1)
text_color = st.sidebar.selectbox("Text Color", ["Red", "Black", "Blue", "Green", "Purple"], index=0)
dark_mode = st.sidebar.checkbox("πŸŒ™ Dark Mode", value=False)

def get_html_text(text, color):
    return f'<p style="color: {color.lower()}; font-size: 16px;">{text}</p>'

for msg in st.session_state.chat_history:
    if msg["role"] == "user":
        st.markdown(get_html_text("πŸ‘€ " + msg["content"], "black"), unsafe_allow_html=True)
    else:
        st.markdown(get_html_text("πŸ€– " + msg["content"], text_color), unsafe_allow_html=True)

user_input = st.text_input("Type your message:", "")
if st.button("Send"):
    if user_input.strip():
        st.session_state.chat_history.append({"role": "user", "content": user_input})
        with st.spinner("Generating response..."):
            for output in chat(user_input, temperature, do_sample, max_tokens):
                pass
        st.experimental_rerun()

if st.button("🧹 Clear Chat"):
    st.session_state.chat_history = []
    st.experimental_rerun()