DrishtiSharma commited on
Commit
850643b
Β·
verified Β·
1 Parent(s): c1786ee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -34
app.py CHANGED
@@ -3,28 +3,40 @@ import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
4
  import os
5
  from threading import Thread
6
- import time
7
 
8
- # Load Model and Tokenizer
9
- token = os.environ.get("HF_TOKEN")
10
- model_name = "large-traversaal/Phi-4-Hindi"
 
11
 
 
12
  @st.cache_resource()
13
  def load_model():
14
- model = AutoModelForCausalLM.from_pretrained(
15
- model_name,
16
- token=token,
17
- trust_remote_code=True,
18
- torch_dtype=torch.bfloat16
19
- )
20
- tok = AutoTokenizer.from_pretrained(model_name, token=token)
21
- return model, tok
 
 
 
 
 
22
 
 
23
  model, tok = load_model()
24
- terminators = [tok.eos_token_id]
25
-
26
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
- model = model.to(device)
 
 
 
 
 
 
 
28
 
29
  # Initialize session state if not set
30
  if "chat_history" not in st.session_state:
@@ -32,13 +44,19 @@ if "chat_history" not in st.session_state:
32
 
33
  # Chat function
34
  def chat(message, temperature, do_sample, max_tokens):
35
- chat_log = st.session_state.chat_history.copy()
36
- chat_log.append({"role": "user", "content": message})
37
- messages = tok.apply_chat_template(chat_log, tokenize=False, add_generation_prompt=True)
38
 
 
 
 
 
 
39
  model_inputs = tok([messages], return_tensors="pt").to(device)
 
 
40
  streamer = TextIteratorStreamer(tok, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
41
 
 
42
  generate_kwargs = {
43
  "inputs": model_inputs["input_ids"],
44
  "streamer": streamer,
@@ -50,46 +68,61 @@ def chat(message, temperature, do_sample, max_tokens):
50
 
51
  if temperature == 0:
52
  generate_kwargs["do_sample"] = False
53
-
 
54
  t = Thread(target=model.generate, kwargs=generate_kwargs)
55
  t.start()
56
 
57
- partial_text = ""
 
58
  for new_text in streamer:
59
- partial_text += new_text
60
- yield partial_text
61
-
62
- st.session_state.chat_history.append({"role": "assistant", "content": partial_text})
63
 
64
- # Streamlit UI
 
 
 
65
  st.title("πŸ’¬ Chat With Phi-4-Hindi")
 
66
  st.markdown("Chat with [large-traversaal/Phi-4-Hindi](https://huggingface.co/large-traversaal/Phi-4-Hindi)")
67
 
68
- # Chat input
69
  temperature = st.sidebar.slider("Temperature", 0.0, 1.0, 0.3, 0.1)
70
  do_sample = st.sidebar.checkbox("Use Sampling", value=True)
71
  max_tokens = st.sidebar.slider("Max Tokens", 128, 4096, 512, 1)
72
  text_color = st.sidebar.selectbox("Text Color", ["Red", "Black", "Blue", "Green", "Purple"], index=0)
73
  dark_mode = st.sidebar.checkbox("πŸŒ™ Dark Mode", value=False)
74
 
 
75
  def get_html_text(text, color):
76
  return f'<p style="color: {color.lower()}; font-size: 16px;">{text}</p>'
77
 
 
78
  for msg in st.session_state.chat_history:
79
- if msg["role"] == "user":
80
- st.markdown(get_html_text("πŸ‘€ " + msg["content"], "black"), unsafe_allow_html=True)
81
- else:
82
- st.markdown(get_html_text("πŸ€– " + msg["content"], text_color), unsafe_allow_html=True)
83
 
 
84
  user_input = st.text_input("Type your message:", "")
 
85
  if st.button("Send"):
86
  if user_input.strip():
87
  st.session_state.chat_history.append({"role": "user", "content": user_input})
88
- with st.spinner("Generating response..."):
89
- for output in chat(user_input, temperature, do_sample, max_tokens):
90
- pass
 
 
 
 
 
 
 
91
  st.experimental_rerun()
92
 
93
  if st.button("🧹 Clear Chat"):
94
- st.session_state.chat_history = []
 
 
95
  st.experimental_rerun()
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
4
  import os
5
  from threading import Thread
 
6
 
7
+ # Define model path for caching (Avoids reloading every app restart)
8
+ MODEL_PATH = "/mnt/data/Phi-4-Hindi"
9
+ TOKEN = os.environ.get("HF_TOKEN")
10
+ MODEL_NAME = "large-traversaal/Phi-4-Hindi"
11
 
12
+ # Load Model & Tokenizer Once
13
  @st.cache_resource()
14
  def load_model():
15
+ with st.spinner("Loading model... Please wait ⏳"):
16
+ if not os.path.exists(MODEL_PATH):
17
+ model = AutoModelForCausalLM.from_pretrained(
18
+ MODEL_NAME, token=TOKEN, trust_remote_code=True, torch_dtype=torch.bfloat16
19
+ )
20
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=TOKEN)
21
+ model.save_pretrained(MODEL_PATH)
22
+ tokenizer.save_pretrained(MODEL_PATH)
23
+ else:
24
+ model = AutoModelForCausalLM.from_pretrained(MODEL_PATH)
25
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
26
+
27
+ return model, tokenizer
28
 
29
+ # Load and move model to appropriate device
30
  model, tok = load_model()
 
 
31
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
+ try:
33
+ model = model.to(device)
34
+ except torch.cuda.OutOfMemoryError:
35
+ st.error("⚠️ CUDA Out of Memory! Running on CPU instead.")
36
+ device = torch.device("cpu")
37
+ model = model.to(device)
38
+
39
+ terminators = [tok.eos_token_id]
40
 
41
  # Initialize session state if not set
42
  if "chat_history" not in st.session_state:
 
44
 
45
  # Chat function
46
  def chat(message, temperature, do_sample, max_tokens):
47
+ """Processes chat input and generates a response using the model."""
 
 
48
 
49
+ # Append new message to history
50
+ st.session_state.chat_history.append({"role": "user", "content": message})
51
+
52
+ # Convert chat history into model-friendly format
53
+ messages = tok.apply_chat_template(st.session_state.chat_history, tokenize=False, add_generation_prompt=True)
54
  model_inputs = tok([messages], return_tensors="pt").to(device)
55
+
56
+ # Initialize streamer for token-wise response
57
  streamer = TextIteratorStreamer(tok, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
58
 
59
+ # Define generation parameters
60
  generate_kwargs = {
61
  "inputs": model_inputs["input_ids"],
62
  "streamer": streamer,
 
68
 
69
  if temperature == 0:
70
  generate_kwargs["do_sample"] = False
71
+
72
+ # Generate response asynchronously
73
  t = Thread(target=model.generate, kwargs=generate_kwargs)
74
  t.start()
75
 
76
+ # Collect response as it streams
77
+ response_text = ""
78
  for new_text in streamer:
79
+ response_text += new_text
80
+ yield response_text
 
 
81
 
82
+ # Save the assistant's response to session history
83
+ st.session_state.chat_history.append({"role": "assistant", "content": response_text})
84
+
85
+ # UI Setup
86
  st.title("πŸ’¬ Chat With Phi-4-Hindi")
87
+ st.success("βœ… Model is READY to chat!")
88
  st.markdown("Chat with [large-traversaal/Phi-4-Hindi](https://huggingface.co/large-traversaal/Phi-4-Hindi)")
89
 
90
+ # Sidebar Chat Settings
91
  temperature = st.sidebar.slider("Temperature", 0.0, 1.0, 0.3, 0.1)
92
  do_sample = st.sidebar.checkbox("Use Sampling", value=True)
93
  max_tokens = st.sidebar.slider("Max Tokens", 128, 4096, 512, 1)
94
  text_color = st.sidebar.selectbox("Text Color", ["Red", "Black", "Blue", "Green", "Purple"], index=0)
95
  dark_mode = st.sidebar.checkbox("πŸŒ™ Dark Mode", value=False)
96
 
97
+ # Function to format chat messages
98
  def get_html_text(text, color):
99
  return f'<p style="color: {color.lower()}; font-size: 16px;">{text}</p>'
100
 
101
+ # Display chat history
102
  for msg in st.session_state.chat_history:
103
+ role = "πŸ‘€" if msg["role"] == "user" else "πŸ€–"
104
+ st.markdown(get_html_text(f"**{role}:** {msg['content']}", text_color if role == "πŸ€–" else "black"), unsafe_allow_html=True)
 
 
105
 
106
+ # User Input Handling
107
  user_input = st.text_input("Type your message:", "")
108
+
109
  if st.button("Send"):
110
  if user_input.strip():
111
  st.session_state.chat_history.append({"role": "user", "content": user_input})
112
+
113
+ # Display chatbot response
114
+ with st.spinner("Generating response... πŸ€–πŸ’­"):
115
+ response_generator = chat(user_input, temperature, do_sample, max_tokens)
116
+ final_response = ""
117
+ for output in response_generator:
118
+ final_response = output # Store latest output
119
+
120
+ #st.success("βœ… Response generated!")
121
+ # Add generated response to session state
122
  st.experimental_rerun()
123
 
124
  if st.button("🧹 Clear Chat"):
125
+ with st.spinner("Clearing chat history..."):
126
+ st.session_state.chat_history = []
127
+ st.success("βœ… Chat history cleared!")
128
  st.experimental_rerun()