Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,4 +1,7 @@
|
|
1 |
import os
|
|
|
|
|
|
|
2 |
import threading
|
3 |
import torch
|
4 |
import torch._dynamo
|
@@ -12,55 +15,17 @@ from transformers import (
|
|
12 |
import gradio as gr
|
13 |
import spaces
|
14 |
|
15 |
-
# ํ์ํ ๊ฒฝ์ฐ Bitnet ์ง์์ ์ํ transformers ์ค์น
|
16 |
-
# Hugging Face Spaces์์๋ Dockerfile ๋ฑ์ ํตํด ๋ฏธ๋ฆฌ ์ค์นํ๋ ๊ฒ์ด ๋ ์ผ๋ฐ์ ์
๋๋ค.
|
17 |
-
# ๋ก์ปฌ์์ ํ
์คํธ ์์๋ ํ์ํ ์ ์์ต๋๋ค.
|
18 |
-
# print("Installing required transformers branch...")
|
19 |
-
# try:
|
20 |
-
# os.system("pip install git+https://github.com/shumingma/transformers.git -q")
|
21 |
-
# print("transformers branch installed.")
|
22 |
-
# except Exception as e:
|
23 |
-
# print(f"Error installing transformers branch: {e}")
|
24 |
-
# print("Proceeding with potentially default transformers version.")
|
25 |
-
|
26 |
-
# os.system("pip install accelerate bitsandbytes -q") # bitsandbytes, accelerate๋ ํ์ํ ์ ์์ต๋๋ค.
|
27 |
-
|
28 |
-
|
29 |
model_id = "microsoft/bitnet-b1.58-2B-4T"
|
30 |
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
model_id,
|
39 |
-
torch_dtype=torch.bfloat16,
|
40 |
-
device_map="auto",
|
41 |
-
# load_in_8bit=True # Bitnet์ 1.58bit์ด๋ฏ๋ก 8bit ๋ก๋ฉ์ด ์๋ฏธ ์์ ์ ์์ต๋๋ค.
|
42 |
-
)
|
43 |
-
print(f"Model loaded successfully on device: {model.device}")
|
44 |
-
except Exception as e:
|
45 |
-
print(f"Error loading model: {e}")
|
46 |
-
# ๋ชจ๋ธ ๋ก๋ฉ ์คํจ ์ ๋๋ฏธ ๋ชจ๋ธ ์ฌ์ฉ ๋๋ ์ค๋ฅ ์ฒ๋ฆฌ
|
47 |
-
class DummyModel:
|
48 |
-
def generate(self, **kwargs):
|
49 |
-
# ๋๋ฏธ ์๋ต ์์ฑ
|
50 |
-
input_ids = kwargs.get('input_ids')
|
51 |
-
streamer = kwargs.get('streamer')
|
52 |
-
if streamer:
|
53 |
-
# ๊ฐ๋จํ ๋๋ฏธ ์๋ต ์คํธ๋ฆฌ๋ฐ
|
54 |
-
dummy_response = "๋ชจ๋ธ ๋ก๋ฉ์ ์คํจํ์ฌ ๋๋ฏธ ์๋ต์ ์ ๊ณตํฉ๋๋ค. ์ค์ /๊ฒฝ๋ก๋ฅผ ํ์ธํ์ธ์."
|
55 |
-
for char in dummy_response:
|
56 |
-
streamer.put(char)
|
57 |
-
streamer.end()
|
58 |
-
model = DummyModel()
|
59 |
-
tokenizer = AutoTokenizer.from_pretrained("gpt2") # ๋๋ฏธ ํ ํฌ๋์ด์
|
60 |
-
print("Using dummy model due to loading failure.")
|
61 |
-
|
62 |
|
63 |
-
@spaces.GPU
|
64 |
def respond(
|
65 |
message: str,
|
66 |
history: list[tuple[str, str]],
|
@@ -81,11 +46,6 @@ def respond(
|
|
81 |
Yields:
|
82 |
The growing response text as new tokens are generated.
|
83 |
"""
|
84 |
-
# ๋๋ฏธ ๋ชจ๋ธ ์ฌ์ฉ ์ ์คํธ๋ฆฌ๋ฐ ์ค๋ฅ ๋ฐฉ์ง
|
85 |
-
if isinstance(model, DummyModel):
|
86 |
-
yield "๋ชจ๋ธ ๋ก๋ฉ์ ์คํจํ์ฌ ์๋ต์ ์์ฑํ ์ ์์ต๋๋ค."
|
87 |
-
return
|
88 |
-
|
89 |
messages = [{"role": "system", "content": system_message}]
|
90 |
for user_msg, bot_msg in history:
|
91 |
if user_msg:
|
@@ -94,236 +54,45 @@ def respond(
|
|
94 |
messages.append({"role": "assistant", "content": bot_msg})
|
95 |
messages.append({"role": "user", "content": message})
|
96 |
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
102 |
-
|
103 |
-
streamer = TextIteratorStreamer(
|
104 |
-
tokenizer, skip_prompt=True, skip_special_tokens=True
|
105 |
-
)
|
106 |
-
generate_kwargs = dict(
|
107 |
-
**inputs,
|
108 |
-
streamer=streamer,
|
109 |
-
max_new_tokens=max_tokens,
|
110 |
-
temperature=temperature,
|
111 |
-
top_p=top_p,
|
112 |
-
do_sample=True,
|
113 |
-
# Bitnet ๋ชจ๋ธ์ ํ์ํ ์ถ๊ฐ ์ธ์ ์ค์ (๋ชจ๋ธ ๋ฌธ์ ํ์ธ ํ์)
|
114 |
-
# ์๋ฅผ ๋ค์ด, quantize_config ๋ฑ
|
115 |
-
)
|
116 |
-
|
117 |
-
# ์ฐ๋ ๋์์ ๋ชจ๋ธ ์์ฑ ์คํ
|
118 |
-
thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
|
119 |
-
thread.start()
|
120 |
-
|
121 |
-
# ์คํธ๋ฆฌ๋จธ๋ก๋ถํฐ ํ
์คํธ๋ฅผ ์ฝ์ด์ yield
|
122 |
-
response = ""
|
123 |
-
for new_text in streamer:
|
124 |
-
# yield ํ๊ธฐ ์ ์ ๋ถํ์ํ ๊ณต๋ฐฑ/ํ ํฐ ์ ๊ฑฐ ๋๋ ์ฒ๋ฆฌ ๊ฐ๋ฅ
|
125 |
-
response += new_text
|
126 |
-
yield response
|
127 |
-
|
128 |
-
except Exception as e:
|
129 |
-
print(f"Error during response generation: {e}")
|
130 |
-
yield f"์๋ต ์์ฑ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค: {e}"
|
131 |
-
|
132 |
-
|
133 |
-
# --- ๋์์ธ ๊ฐ์ ์ ์ํ CSS ์ฝ๋ ---
|
134 |
-
css_styles = """
|
135 |
-
/* ์ ์ฒด ํ์ด์ง ๋ฐฐ๊ฒฝ ๋ฐ ๊ธฐ๋ณธ ํฐํธ ์ค์ */
|
136 |
-
body {
|
137 |
-
font-family: 'Segoe UI', 'Roboto', 'Arial', sans-serif;
|
138 |
-
line-height: 1.6;
|
139 |
-
margin: 0;
|
140 |
-
padding: 20px; /* ์ฑ ์ฃผ๋ณ ์ฌ๋ฐฑ ์ถ๊ฐ */
|
141 |
-
background-color: #f4f7f6; /* ๋ถ๋๋ฌ์ด ๋ฐฐ๊ฒฝ์ */
|
142 |
-
}
|
143 |
-
|
144 |
-
/* ๋ฉ์ธ ์ฑ ์ปจํ
์ด๋ ์คํ์ผ */
|
145 |
-
.gradio-container {
|
146 |
-
max-width: 900px; /* ์ค์ ์ ๋ ฌ ๋ฐ ์ต๋ ๋๋น ์ ํ */
|
147 |
-
margin: 20px auto;
|
148 |
-
border-radius: 12px; /* ๋ฅ๊ทผ ๋ชจ์๋ฆฌ */
|
149 |
-
overflow: hidden; /* ์์ ์์๋ค์ด ๋ชจ์๋ฆฌ๋ฅผ ๋์ง ์๋๋ก */
|
150 |
-
box-shadow: 0 8px 16px rgba(0, 0, 0, 0.1); /* ๊ทธ๋ฆผ์ ํจ๊ณผ */
|
151 |
-
background-color: #ffffff; /* ์ฑ ๋ด์ฉ ์์ญ ๋ฐฐ๊ฒฝ์ */
|
152 |
-
}
|
153 |
-
|
154 |
-
/* ํ์ดํ ๋ฐ ์ค๋ช
์์ญ (ChatInterface์ ๊ธฐ๋ณธ ํ์ดํ/์ค๋ช
) */
|
155 |
-
/* ์ด ์์ญ์ ChatInterface ๊ตฌ์กฐ์ ๋ฐ๋ผ ์ ํํ ํด๋์ค ์ด๋ฆ์ด ๋ค๋ฅผ ์ ์์ผ๋,
|
156 |
-
.gradio-container ๋ด๋ถ์ ์ฒซ ๋ธ๋ก์ด๋ H1/P ํ๊ทธ๋ฅผ ํ๊ฒํ ์ ์์ต๋๋ค.
|
157 |
-
ํ
๋ง์ ํจ๊ป ์ฌ์ฉํ๋ฉด ๋๋ถ๋ถ ์ ์ฒ๋ฆฌ๋ฉ๋๋ค. ์ฌ๊ธฐ์๋ ์ถ๊ฐ์ ์ธ ํจ๋ฉ ๋ฑ๋ง ๊ณ ๋ ค */
|
158 |
-
.gradio-container > .gradio-block:first-child {
|
159 |
-
padding: 20px 20px 10px 20px; /* ์๋จ ํจ๋ฉ ์กฐ์ */
|
160 |
-
}
|
161 |
-
|
162 |
-
/* ์ฑํ
๋ฐ์ค ์์ญ ์คํ์ผ */
|
163 |
-
.gradio-chatbox {
|
164 |
-
/* ํ
๋ง์ ์ํด ์คํ์ผ๋ง๋์ง๋ง, ์ถ๊ฐ์ ์ธ ๋ด๋ถ ํจ๋ฉ ๋ฑ ์กฐ์ ๊ฐ๋ฅ */
|
165 |
-
padding: 15px;
|
166 |
-
background-color: #fefefe; /* ์ฑํ
์์ญ ๋ฐฐ๊ฒฝ์ */
|
167 |
-
border-radius: 8px; /* ์ฑํ
์์ญ ๋ด๋ถ ๋ชจ์๋ฆฌ */
|
168 |
-
border: 1px solid #e0e0e0; /* ๊ฒฝ๊ณ์ */
|
169 |
-
}
|
170 |
-
|
171 |
-
/* ์ฑํ
๋ฉ์์ง ์คํ์ผ */
|
172 |
-
.gradio-chatmessage {
|
173 |
-
margin-bottom: 12px;
|
174 |
-
padding: 10px 15px;
|
175 |
-
border-radius: 20px; /* ๋ฅ๊ทผ ๋ฉ์์ง ๋ชจ์๋ฆฌ */
|
176 |
-
max-width: 75%; /* ๋ฉ์์ง ๋๋น ์ ํ */
|
177 |
-
word-wrap: break-word; /* ๊ธด ๋จ์ด ์ค๋ฐ๊ฟ */
|
178 |
-
white-space: pre-wrap; /* ๊ณต๋ฐฑ ๋ฐ ์ค๋ฐ๊ฟ ์ ์ง */
|
179 |
-
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05); /* ๋ฉ์์ง์ ์ฝ๊ฐ์ ๊ทธ๋ฆผ์ */
|
180 |
-
}
|
181 |
-
|
182 |
-
/* ์ฌ์ฉ์ ๋ฉ์์ง ์คํ์ผ */
|
183 |
-
.gradio-chatmessage.user {
|
184 |
-
background-color: #007bff; /* ํ๋์ ๊ณ์ด */
|
185 |
-
color: white;
|
186 |
-
margin-left: auto; /* ์ค๋ฅธ์ชฝ ์ ๋ ฌ */
|
187 |
-
border-bottom-right-radius: 2px; /* ์ค๋ฅธ์ชฝ ์๋ ๋ชจ์๋ฆฌ ๊ฐ์ง๊ฒ */
|
188 |
-
}
|
189 |
-
|
190 |
-
/* ๋ด ๋ฉ์์ง ์คํ์ผ */
|
191 |
-
.gradio-chatmessage.bot {
|
192 |
-
background-color: #e9ecef; /* ๋ฐ์ ํ์ */
|
193 |
-
color: #333; /* ์ด๋์ด ํ
์คํธ */
|
194 |
-
margin-right: auto; /* ์ผ์ชฝ ์ ๋ ฌ */
|
195 |
-
border-bottom-left-radius: 2px; /* ์ผ์ชฝ ์๋ ๋ชจ์๋ฆฌ ๊ฐ์ง๊ฒ */
|
196 |
-
}
|
197 |
-
|
198 |
-
/* ์
๋ ฅ์ฐฝ ๋ฐ ๋ฒํผ ์์ญ ์คํ์ผ */
|
199 |
-
.gradio-input-box {
|
200 |
-
padding: 15px;
|
201 |
-
border-top: 1px solid #eee; /* ์์ชฝ ๊ฒฝ๊ณ์ */
|
202 |
-
background-color: #f8f9fa; /* ์
๋ ฅ ์์ญ ๋ฐฐ๊ฒฝ์ */
|
203 |
-
}
|
204 |
-
/* ์
๋ ฅ ํ
์คํธ ์์ด๋ฆฌ์ด ์คํ์ผ */
|
205 |
-
.gradio-input-box textarea {
|
206 |
-
border-radius: 8px;
|
207 |
-
padding: 10px;
|
208 |
-
border: 1px solid #ccc;
|
209 |
-
resize: none !important; /* ์
๋ ฅ์ฐฝ ํฌ๊ธฐ ์กฐ์ ๋นํ์ฑํ (์ ํ ์ฌํญ) */
|
210 |
-
min-height: 50px; /* ์ต์ ๋์ด */
|
211 |
-
max-height: 150px; /* ์ต๋ ๋์ด */
|
212 |
-
overflow-y: auto; /* ๋ด์ฉ ๋์น ๊ฒฝ์ฐ ์คํฌ๋กค */
|
213 |
-
}
|
214 |
-
/* ์คํฌ๋กค๋ฐ ์คํ์ผ (์ ํ ์ฌํญ) */
|
215 |
-
.gradio-input-box textarea::-webkit-scrollbar {
|
216 |
-
width: 8px;
|
217 |
-
}
|
218 |
-
.gradio-input-box textarea::-webkit-scrollbar-thumb {
|
219 |
-
background-color: #ccc;
|
220 |
-
border-radius: 4px;
|
221 |
-
}
|
222 |
-
.gradio-input-box textarea::-webkit-scrollbar-track {
|
223 |
-
background-color: #f1f1f1;
|
224 |
-
}
|
225 |
-
|
226 |
-
|
227 |
-
/* ๋ฒํผ ์คํ์ผ */
|
228 |
-
.gradio-button {
|
229 |
-
border-radius: 8px;
|
230 |
-
padding: 10px 20px;
|
231 |
-
font-weight: bold;
|
232 |
-
transition: background-color 0.2s ease, opacity 0.2s ease; /* ํธ๋ฒ ์ ๋๋ฉ์ด์
*/
|
233 |
-
border: none; /* ๊ธฐ๋ณธ ํ
๋๋ฆฌ ์ ๊ฑฐ */
|
234 |
-
cursor: pointer;
|
235 |
-
}
|
236 |
-
|
237 |
-
.gradio-button:not(.clear-button) { /* Send ๋ฒํผ */
|
238 |
-
background-color: #28a745; /* ์ด๋ก์ */
|
239 |
-
color: white;
|
240 |
-
}
|
241 |
-
.gradio-button:not(.clear-button):hover {
|
242 |
-
background-color: #218838;
|
243 |
-
}
|
244 |
-
.gradio-button:disabled { /* ๋นํ์ฑํ๋ ๋ฒํผ */
|
245 |
-
opacity: 0.6;
|
246 |
-
cursor: not-allowed;
|
247 |
-
}
|
248 |
-
|
249 |
-
|
250 |
-
.gradio-button.clear-button { /* Clear ๋ฒํผ */
|
251 |
-
background-color: #dc3545; /* ๋นจ๊ฐ์ */
|
252 |
-
color: white;
|
253 |
-
}
|
254 |
-
.gradio-button.clear-button:hover {
|
255 |
-
background-color: #c82333;
|
256 |
-
}
|
257 |
-
|
258 |
-
/* Additional inputs (์ถ๊ฐ ์ค์ ) ์์ญ ์คํ์ผ */
|
259 |
-
/* ์ด ์์ญ์ ๋ณดํต ์์ฝ๋์ธ ํํ๋ก ๋์ด ์์ผ๋ฉฐ, .gradio-accordion ํด๋์ค๋ฅผ ๊ฐ์ง๋๋ค. */
|
260 |
-
.gradio-accordion {
|
261 |
-
border-radius: 12px; /* ์ธ๋ถ ์ปจํ
์ด๋์ ๋์ผํ ๋ชจ์๋ฆฌ */
|
262 |
-
margin-top: 15px; /* ์ฑํ
์์ญ๊ณผ์ ๊ฐ๊ฒฉ */
|
263 |
-
border: 1px solid #ddd; /* ๊ฒฝ๊ณ์ */
|
264 |
-
box-shadow: none; /* ๋ด๋ถ ๊ทธ๋ฆผ์ ์ ๊ฑฐ */
|
265 |
-
}
|
266 |
-
/* ์์ฝ๋์ธ ํค๋ (๋ผ๋ฒจ) ์คํ์ผ */
|
267 |
-
.gradio-accordion .label {
|
268 |
-
font-weight: bold;
|
269 |
-
color: #007bff; /* ํ๋์ ๊ณ์ด */
|
270 |
-
padding: 15px; /* ํค๋ ํจ๋ฉ */
|
271 |
-
background-color: #e9ecef; /* ํค๋ ๋ฐฐ๊ฒฝ์ */
|
272 |
-
border-bottom: 1px solid #ddd; /* ํค๋ ์๋ ๊ฒฝ๊ณ์ */
|
273 |
-
border-top-left-radius: 11px; /* ์๋จ ๋ชจ์๋ฆฌ */
|
274 |
-
border-top-right-radius: 11px;
|
275 |
-
}
|
276 |
-
/* ์์ฝ๋์ธ ๋ด์ฉ ์์ญ ์คํ์ผ */
|
277 |
-
.gradio-accordion .wrap {
|
278 |
-
padding: 15px; /* ๋ด์ฉ ํจ๋ฉ */
|
279 |
-
background-color: #fefefe; /* ๋ด์ฉ ๋ฐฐ๊ฒฝ์ */
|
280 |
-
border-bottom-left-radius: 11px; /* ํ๋จ ๋ชจ์๋ฆฌ */
|
281 |
-
border-bottom-right-radius: 11px;
|
282 |
-
}
|
283 |
-
/* ์ถ๊ฐ ์ค์ ๋ด ๊ฐ๋ณ ์
๋ ฅ ์ปดํฌ๋ํธ ์ค๏ฟฝ๏ฟฝ๏ฟฝ์ผ (์ฌ๋ผ์ด๋, ํ
์คํธ๋ฐ์ค ๋ฑ) */
|
284 |
-
.gradio-slider, .gradio-textbox, .gradio-number {
|
285 |
-
margin-bottom: 10px; /* ๊ฐ ์
๋ ฅ ์์ ์๋ ๊ฐ๊ฒฉ */
|
286 |
-
padding: 8px; /* ๋ด๋ถ ํจ๋ฉ */
|
287 |
-
border: 1px solid #e0e0e0; /* ๊ฒฝ๊ณ์ */
|
288 |
-
border-radius: 8px; /* ๋ฅ๊ทผ ๋ชจ์๋ฆฌ */
|
289 |
-
background-color: #fff; /* ๋ฐฐ๊ฒฝ์ */
|
290 |
-
}
|
291 |
-
/* ์
๋ ฅ ํ๋ ๋ผ๋ฒจ ์คํ์ผ */
|
292 |
-
.gradio-label {
|
293 |
-
font-weight: normal; /* ๋ผ๋ฒจ ํฐํธ ๊ตต๊ธฐ */
|
294 |
-
margin-bottom: 5px; /* ๋ผ๋ฒจ๊ณผ ์
๋ ฅ ํ๋ ๊ฐ ๊ฐ๊ฒฉ */
|
295 |
-
color: #555; /* ๋ผ๋ฒจ ์์ */
|
296 |
-
display: block; /* ๋ผ๋ฒจ์ ๋ธ๋ก ์์๋ก ๋ง๋ค์ด ์๋ก ์ฌ๋ฆผ */
|
297 |
-
}
|
298 |
-
/* ์ฌ๋ผ์ด๋ ํธ๋ ๋ฐ ํธ๋ค ์คํ์ผ (๋ ์ธ๋ฐํ ์กฐ์ ๊ฐ๋ฅ) */
|
299 |
-
/* ์: .gradio-slider input[type="range"]::-webkit-slider-thumb {} */
|
300 |
-
|
301 |
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
308 |
|
|
|
|
|
|
|
|
|
309 |
|
310 |
-
# Gradio ์ธํฐํ์ด์ค ์ค์
|
311 |
demo = gr.ChatInterface(
|
312 |
fn=respond,
|
313 |
-
|
314 |
-
|
315 |
-
description="<p style='text-align: center; color: #555;'>This chat application is powered by Microsoft's SOTA Bitnet-b1.58-2B-4T and designed for natural and fast conversations.</p>",
|
316 |
examples=[
|
317 |
[
|
318 |
-
"Hello!
|
319 |
-
"You are a helpful AI
|
320 |
512,
|
321 |
0.7,
|
322 |
0.95,
|
323 |
],
|
324 |
[
|
325 |
-
"Can you code a snake game
|
326 |
-
"You are a helpful AI
|
327 |
2048,
|
328 |
0.7,
|
329 |
0.95,
|
@@ -332,8 +101,7 @@ demo = gr.ChatInterface(
|
|
332 |
additional_inputs=[
|
333 |
gr.Textbox(
|
334 |
value="You are a helpful AI assistant.",
|
335 |
-
label="System message"
|
336 |
-
lines=3 # ์์คํ
๋ฉ์์ง ์
๋ ฅ์ฐฝ ๋์ด ์กฐ์
|
337 |
),
|
338 |
gr.Slider(
|
339 |
minimum=1,
|
@@ -357,14 +125,7 @@ demo = gr.ChatInterface(
|
|
357 |
label="Top-p (nucleus sampling)"
|
358 |
),
|
359 |
],
|
360 |
-
# ํ
๋ง ์ ์ฉ (์ฌ๋ฌ ํ
๋ง ์ค ์ ํ ๊ฐ๋ฅ: gr.themes.Soft(), gr.themes.Glass(), gr.themes.Default(), etc.)
|
361 |
-
theme=gr.themes.Soft(),
|
362 |
-
# ์ปค์คํ
CSS ์ ์ฉ
|
363 |
-
css=css_styles,
|
364 |
)
|
365 |
|
366 |
-
# ์ ํ๋ฆฌ์ผ์ด์
์คํ
|
367 |
if __name__ == "__main__":
|
368 |
-
|
369 |
-
demo.launch()
|
370 |
-
# demo.launch(debug=True) # ๋๋ฒ๊น
๋ชจ๋ ํ์ฑํ
|
|
|
1 |
import os
|
2 |
+
|
3 |
+
os.system("pip install git+https://github.com/shumingma/transformers.git")
|
4 |
+
|
5 |
import threading
|
6 |
import torch
|
7 |
import torch._dynamo
|
|
|
15 |
import gradio as gr
|
16 |
import spaces
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
model_id = "microsoft/bitnet-b1.58-2B-4T"
|
19 |
|
20 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
21 |
+
model = AutoModelForCausalLM.from_pretrained(
|
22 |
+
model_id,
|
23 |
+
torch_dtype=torch.bfloat16,
|
24 |
+
device_map="auto"
|
25 |
+
)
|
26 |
+
print(model.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
+
@spaces.GPU
|
29 |
def respond(
|
30 |
message: str,
|
31 |
history: list[tuple[str, str]],
|
|
|
46 |
Yields:
|
47 |
The growing response text as new tokens are generated.
|
48 |
"""
|
|
|
|
|
|
|
|
|
|
|
49 |
messages = [{"role": "system", "content": system_message}]
|
50 |
for user_msg, bot_msg in history:
|
51 |
if user_msg:
|
|
|
54 |
messages.append({"role": "assistant", "content": bot_msg})
|
55 |
messages.append({"role": "user", "content": message})
|
56 |
|
57 |
+
prompt = tokenizer.apply_chat_template(
|
58 |
+
messages, tokenize=False, add_generation_prompt=True
|
59 |
+
)
|
60 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
+
streamer = TextIteratorStreamer(
|
63 |
+
tokenizer, skip_prompt=True, skip_special_tokens=True
|
64 |
+
)
|
65 |
+
generate_kwargs = dict(
|
66 |
+
**inputs,
|
67 |
+
streamer=streamer,
|
68 |
+
max_new_tokens=max_tokens,
|
69 |
+
temperature=temperature,
|
70 |
+
top_p=top_p,
|
71 |
+
do_sample=True,
|
72 |
+
)
|
73 |
+
thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
|
74 |
+
thread.start()
|
75 |
|
76 |
+
response = ""
|
77 |
+
for new_text in streamer:
|
78 |
+
response += new_text
|
79 |
+
yield response
|
80 |
|
|
|
81 |
demo = gr.ChatInterface(
|
82 |
fn=respond,
|
83 |
+
title="Bitnet-b1.58-2B-4T",
|
84 |
+
description="Bitnet-b1.58-2B-4T",
|
|
|
85 |
examples=[
|
86 |
[
|
87 |
+
"Hello!",
|
88 |
+
"You are a helpful AI.",
|
89 |
512,
|
90 |
0.7,
|
91 |
0.95,
|
92 |
],
|
93 |
[
|
94 |
+
"Can you code a snake game?",
|
95 |
+
"You are a helpful AI.",
|
96 |
2048,
|
97 |
0.7,
|
98 |
0.95,
|
|
|
101 |
additional_inputs=[
|
102 |
gr.Textbox(
|
103 |
value="You are a helpful AI assistant.",
|
104 |
+
label="System message"
|
|
|
105 |
),
|
106 |
gr.Slider(
|
107 |
minimum=1,
|
|
|
125 |
label="Top-p (nucleus sampling)"
|
126 |
),
|
127 |
],
|
|
|
|
|
|
|
|
|
128 |
)
|
129 |
|
|
|
130 |
if __name__ == "__main__":
|
131 |
+
demo.launch()
|
|
|
|