DrishtiSharma commited on
Commit
76c69cb
·
verified ·
1 Parent(s): bf6b981

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +172 -0
app.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from theme4 import fast_rtc_theme
3
+ import torch
4
+ import json
5
+ import uuid
6
+ import os
7
+ import time
8
+ import pytz
9
+ from datetime import datetime
10
+ from transformers import (
11
+ AutoModelForCausalLM,
12
+ AutoTokenizer,
13
+ TextIteratorStreamer,
14
+ )
15
+ from threading import Thread
16
+ from huggingface_hub import CommitScheduler
17
+ from pathlib import Path
18
+ import spaces
19
+
20
+ os.system("apt-get update && apt-get install -y libstdc++6")
21
+
22
+ # Load HF token from the environment
23
+ token = os.environ["HF_TOKEN"]
24
+
25
+ # Load Model and Tokenizer
26
+ model_id = "large-traversaal/Phi-4-Hindi"
27
+ model = AutoModelForCausalLM.from_pretrained(
28
+ model_id,
29
+ token=token,
30
+ trust_remote_code=True,
31
+ torch_dtype=torch.bfloat16
32
+ )
33
+ tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
34
+ terminators = [tokenizer.eos_token_id]
35
+
36
+ # Move model to GPU if available
37
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
+ model = model.to(device)
39
+
40
+ # Set up logging folder and CommitScheduler to push logs to Hugging Face dataset repo
41
+ log_folder = Path("logs")
42
+ log_folder.mkdir(parents=True, exist_ok=True)
43
+ log_file = log_folder / f"chat_log_{uuid.uuid4()}.json"
44
+
45
+ scheduler = CommitScheduler(
46
+ repo_id="DrishtiSharma/phi-gradio-logs",
47
+ repo_type="dataset",
48
+ folder_path=log_folder,
49
+ path_in_repo="data",
50
+ every=0.01,
51
+ token=token
52
+ )
53
+
54
+ # Set timezone for logging timestamps
55
+ timezone = pytz.timezone("UTC")
56
+
57
+ @spaces.GPU(duration=60)
58
+ def chat(message, history, temperature, do_sample, max_tokens, top_p):
59
+ start_time = time.time()
60
+ timestamp = datetime.now(timezone).strftime("%Y-%m-%d %H:%M:%S %Z")
61
+
62
+ conversation_history = []
63
+ for item in history:
64
+ conversation_history.append({"role": "user", "content": item[0]})
65
+ if item[1] is not None:
66
+ conversation_history.append({"role": "assistant", "content": item[1]})
67
+ conversation_history.append({"role": "user", "content": message})
68
+
69
+ messages = tokenizer.apply_chat_template(conversation_history, tokenize=False, add_generation_prompt=True)
70
+ model_inputs = tokenizer([messages], return_tensors="pt").to(device)
71
+ streamer = TextIteratorStreamer(
72
+ tokenizer, timeout=70.0, skip_prompt=True, skip_special_tokens=True
73
+ )
74
+
75
+ # Define generation parameters
76
+ generate_kwargs = dict(
77
+ model_inputs,
78
+ streamer=streamer,
79
+ max_new_tokens=max_tokens,
80
+ do_sample=do_sample,
81
+ temperature=temperature,
82
+ top_p=top_p,
83
+ eos_token_id=terminators,
84
+ )
85
+
86
+ #Disable sampling if temperature is zero (deterministic generation)
87
+ if temperature == 0:
88
+ generate_kwargs["do_sample"] = False
89
+
90
+ generation_thread = Thread(target=model.generate, kwargs=generate_kwargs)
91
+ generation_thread.start()
92
+
93
+ partial_text = ""
94
+ for new_text in streamer:
95
+ partial_text += new_text
96
+ yield partial_text
97
+
98
+ # Calculate total response time
99
+ response_time = round(time.time() - start_time, 2)
100
+
101
+ # Prepare log entry for the interaction
102
+ log_data = {
103
+ "timestamp": timestamp,
104
+ "input": message,
105
+ "output": partial_text,
106
+ "response_time": response_time,
107
+ "temperature": temperature,
108
+ "do_sample": do_sample,
109
+ "max_tokens": max_tokens,
110
+ "top_p": top_p
111
+ }
112
+
113
+ with scheduler.lock:
114
+ with log_file.open("a", encoding="utf-8") as f:
115
+ f.write(json.dumps(log_data, ensure_ascii=False) + "\n")
116
+
117
+ # Function to clear chat history
118
+ def clear_chat():
119
+ return [], []
120
+
121
+ # Function to export chat history as a downloadable file
122
+ def export_chat(history):
123
+ if not history:
124
+ return None # No chat history to export
125
+
126
+ file_path = "chat_history.txt"
127
+ with open(file_path, "w", encoding="utf-8") as f:
128
+ for msg in history:
129
+ f.write(f"User: {msg[0]}\nBot: {msg[1]}\n")
130
+ return file_path
131
+
132
+
133
+ # Gradio UI
134
+ with gr.Blocks(theme=fast_rtc_theme) as demo:
135
+ with gr.Row():
136
+ with gr.Column(scale=1):
137
+ gr.Markdown("#### ⚙️🛠 Configure Settings")
138
+ temperature = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.1, label="Temperature", interactive=True)
139
+ do_sample = gr.Checkbox(label="Sampling", value=True, interactive=True)
140
+ max_tokens = gr.Slider(minimum=128, maximum=4096, step=1, value=1024, label="max_new_tokens", interactive=True)
141
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.1, step=0.2, label="top_p", interactive=True)
142
+
143
+
144
+ with gr.Column(scale=3):
145
+ gr.Markdown("# **Chat With Phi-4-Hindi** 💬 ")
146
+
147
+ chat_interface = gr.ChatInterface(
148
+ fn=chat,
149
+ examples=[
150
+ ["What is the English translation of: 'इस मॉडल को हिंदी और अंग्रेजी डेटा पर प्रशिक्षित किया गया था'?"],
151
+ ["टि�� अपने 3 बच्चों को ट्रिक या ट्रीटिंग के लिए ले जाता है। वे 4 घंटे बाहर रहते हैं। हर घंटे वे x घरों में जाते हैं। हर घर में हर बच्चे को 3 ट्रीट मिलते हैं। उसके बच्चों को कुल 180 ट्रीट मिलते हैं। अज्ञात चर x का मान क्या है?"],
152
+ ["how do you play fetch? A) throw the object for the dog to get and bring back to you. B) get the object and bring it back to the dog."]
153
+ ],
154
+ additional_inputs=[temperature, do_sample, max_tokens, top_p],
155
+ stop_btn="⏹ Stop",
156
+ description="Phi-4-Hindi is a bilingual instruction-tuned LLM for Hindi and English, trained on a mixed datasets composed of 485K Hindi-English samples.",
157
+ #theme="default"
158
+ )
159
+
160
+ with gr.Row():
161
+ clear_btn = gr.Button("🧹 Clear Chat", variant="primary")
162
+ export_btn = gr.Button("📥 Export Chat", variant="primary")
163
+
164
+ # Connect buttons to their functions (Clear and Export Chat)
165
+ clear_btn.click(
166
+ fn=clear_chat,
167
+ outputs=[chat_interface.chatbot, chat_interface.chatbot_value]
168
+ )
169
+
170
+ export_btn.click(fn=export_chat, inputs=[chat_interface.chatbot], outputs=[gr.File()])
171
+
172
+ demo.launch()