openfree commited on
Commit
42f4126
ยท
verified ยท
1 Parent(s): ee63f7e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +222 -0
app.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import threading
3
+
4
+ import gradio as gr
5
+ import spaces
6
+ import transformers
7
+ from transformers import pipeline
8
+
9
+ # ๋ชจ๋ธ๊ณผ ํ† ํฌ๋‚˜์ด์ € ๋กœ๋”ฉ
10
+ model_name = "Qwen/Qwen2-1.5B-Instruct"
11
+ if gr.NO_RELOAD:
12
+ pipe = pipeline(
13
+ "text-generation",
14
+ model=model_name,
15
+ device_map="auto",
16
+ torch_dtype="auto",
17
+ )
18
+
19
+ # ์ตœ์ข… ๋‹ต๋ณ€์„ ๊ฐ์ง€ํ•˜๊ธฐ ์œ„ํ•œ ๋งˆ์ปค
20
+ ANSWER_MARKER = "**๋‹ต๋ณ€**"
21
+
22
+ # ๋‹จ๊ณ„๋ณ„ ์ถ”๋ก ์„ ์‹œ์ž‘ํ•˜๋Š” ๋ฌธ์žฅ๋“ค
23
+ rethink_prepends = [
24
+ "์ž, ์ด์ œ ๋‹ค์Œ์„ ํŒŒ์•…ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค ",
25
+ "์ œ ์ƒ๊ฐ์—๋Š” ",
26
+ "์ž ์‹œ๋งŒ์š”, ์ œ ์ƒ๊ฐ์—๋Š” ",
27
+ "๋‹ค์Œ ์‚ฌํ•ญ์ด ๋งž๋Š”์ง€ ํ™•์ธํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค ",
28
+ "๋˜ํ•œ ๊ธฐ์–ตํ•ด์•ผ ํ•  ๊ฒƒ์€ ",
29
+ "๋˜ ๋‹ค๋ฅธ ์ฃผ๋ชฉํ•  ์ ์€ ",
30
+ "๊ทธ๋ฆฌ๊ณ  ์ €๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์€ ์‚ฌ์‹ค๋„ ๊ธฐ์–ตํ•ฉ๋‹ˆ๋‹ค ",
31
+ "์ด์ œ ์ถฉ๋ถ„ํžˆ ์ดํ•ดํ–ˆ๋‹ค๊ณ  ์ƒ๊ฐํ•ฉ๋‹ˆ๋‹ค ",
32
+ "์ง€๊ธˆ๊นŒ์ง€์˜ ์ •๋ณด๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ, ์›๋ž˜ ์งˆ๋ฌธ์— ์‚ฌ์šฉ๋œ ์–ธ์–ด๋กœ ๋‹ต๋ณ€ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค:"
33
+ "\n{question}\n"
34
+ f"\n{ANSWER_MARKER}\n",
35
+ ]
36
+
37
+
38
+ # ์ˆ˜์‹ ํ‘œ์‹œ ๋ฌธ์ œ ํ•ด๊ฒฐ์„ ์œ„ํ•œ ์„ค์ •
39
+ latex_delimiters = [
40
+ {"left": "$$", "right": "$$", "display": True},
41
+ {"left": "$", "right": "$", "display": False},
42
+ ]
43
+
44
+
45
+ def reformat_math(text):
46
+ """Gradio ๊ตฌ๋ฌธ(Katex)์„ ์‚ฌ์šฉํ•˜๋„๋ก MathJax ๊ตฌ๋ถ„ ๊ธฐํ˜ธ ์ˆ˜์ •.
47
+ ์ด๊ฒƒ์€ Gradio์—์„œ ์ˆ˜ํ•™ ๊ณต์‹์„ ํ‘œ์‹œํ•˜๊ธฐ ์œ„ํ•œ ์ž„์‹œ ํ•ด๊ฒฐ์ฑ…์ž…๋‹ˆ๋‹ค. ํ˜„์žฌ๋กœ์„œ๋Š”
48
+ ๋‹ค๋ฅธ latex_delimiters๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์˜ˆ์ƒ๋Œ€๋กœ ์ž‘๋™ํ•˜๊ฒŒ ํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ์ฐพ์ง€ ๋ชปํ–ˆ์Šต๋‹ˆ๋‹ค...
49
+ """
50
+ text = re.sub(r"\\\[\s*(.*?)\s*\\\]", r"$$\1$$", text, flags=re.DOTALL)
51
+ text = re.sub(r"\\\(\s*(.*?)\s*\\\)", r"$\1$", text, flags=re.DOTALL)
52
+ return text
53
+
54
+
55
+ def user_input(message, history: list):
56
+ """์‚ฌ์šฉ์ž ์ž…๋ ฅ์„ ํžˆ์Šคํ† ๋ฆฌ์— ์ถ”๊ฐ€ํ•˜๊ณ  ์ž…๋ ฅ ํ…์ŠคํŠธ ์ƒ์ž ๋น„์šฐ๊ธฐ"""
57
+ return "", history + [
58
+ gr.ChatMessage(role="user", content=message.replace(ANSWER_MARKER, ""))
59
+ ]
60
+
61
+
62
+ def rebuild_messages(history: list):
63
+ """์ค‘๊ฐ„ ์ƒ๊ฐ ๊ณผ์ • ์—†์ด ๋ชจ๋ธ์ด ์‚ฌ์šฉํ•  ํžˆ์Šคํ† ๋ฆฌ์—์„œ ๋ฉ”์‹œ์ง€ ์žฌ๊ตฌ์„ฑ"""
64
+ messages = []
65
+ for h in history:
66
+ if isinstance(h, dict) and not h.get("metadata", {}).get("title", False):
67
+ messages.append(h)
68
+ elif (
69
+ isinstance(h, gr.ChatMessage)
70
+ and h.metadata.get("title")
71
+ and isinstance(h.content, str)
72
+ ):
73
+ messages.append({"role": h.role, "content": h.content})
74
+ return messages
75
+
76
+
77
+ @spaces.GPU
78
+ def bot(
79
+ history: list,
80
+ max_num_tokens: int,
81
+ final_num_tokens: int,
82
+ do_sample: bool,
83
+ temperature: float,
84
+ ):
85
+ """๋ชจ๋ธ์ด ์งˆ๋ฌธ์— ๋‹ต๋ณ€ํ•˜๋„๋ก ํ•˜๊ธฐ"""
86
+
87
+ # ๋‚˜์ค‘์— ์Šค๋ ˆ๋“œ์—์„œ ํ† ํฐ์„ ์ŠคํŠธ๋ฆผ์œผ๋กœ ๊ฐ€์ ธ์˜ค๊ธฐ ์œ„ํ•จ
88
+ streamer = transformers.TextIteratorStreamer(
89
+ pipe.tokenizer, # pyright: ignore
90
+ skip_special_tokens=True,
91
+ skip_prompt=True,
92
+ )
93
+
94
+ # ํ•„์š”ํ•œ ๊ฒฝ์šฐ ์ถ”๋ก ์— ์งˆ๋ฌธ์„ ๋‹ค์‹œ ์‚ฝ์ž…ํ•˜๊ธฐ ์œ„ํ•จ
95
+ question = history[-1]["content"]
96
+
97
+ # ๋ณด์กฐ์ž ๋ฉ”์‹œ์ง€ ์ค€๋น„
98
+ history.append(
99
+ gr.ChatMessage(
100
+ role="assistant",
101
+ content=str(""),
102
+ metadata={"title": "๐Ÿง  ์ƒ๊ฐ ์ค‘...", "status": "pending"},
103
+ )
104
+ )
105
+
106
+ # ํ˜„์žฌ ์ฑ„ํŒ…์— ํ‘œ์‹œ๋  ์ถ”๋ก  ๊ณผ์ •
107
+ messages = rebuild_messages(history)
108
+ for i, prepend in enumerate(rethink_prepends):
109
+ if i > 0:
110
+ messages[-1]["content"] += "\n\n"
111
+ messages[-1]["content"] += prepend.format(question=question)
112
+
113
+ num_tokens = int(
114
+ max_num_tokens if ANSWER_MARKER not in prepend else final_num_tokens
115
+ )
116
+ t = threading.Thread(
117
+ target=pipe,
118
+ args=(messages,),
119
+ kwargs=dict(
120
+ max_new_tokens=num_tokens,
121
+ streamer=streamer,
122
+ do_sample=do_sample,
123
+ temperature=temperature,
124
+ ),
125
+ )
126
+ t.start()
127
+
128
+ # ์ƒˆ ๋‚ด์šฉ์œผ๋กœ ํžˆ์Šคํ† ๋ฆฌ ์žฌ๊ตฌ์„ฑ
129
+ history[-1].content += prepend.format(question=question)
130
+ if ANSWER_MARKER in prepend:
131
+ history[-1].metadata = {"title": "๐Ÿ’ญ ์‚ฌ๊ณ  ๊ณผ์ •", "status": "done"}
132
+ # ์ƒ๊ฐ ์ข…๋ฃŒ, ์ด์ œ ๋‹ต๋ณ€์ž…๋‹ˆ๋‹ค (์ค‘๊ฐ„ ๋‹จ๊ณ„์— ๋Œ€ํ•œ ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ ์—†์Œ)
133
+ history.append(gr.ChatMessage(role="assistant", content=""))
134
+ for token in streamer:
135
+ history[-1].content += token
136
+ history[-1].content = reformat_math(history[-1].content)
137
+ yield history
138
+ t.join()
139
+
140
+ yield history
141
+
142
+
143
+ with gr.Blocks(fill_height=True, title="๋ชจ๋“  LLM ๋ชจ๋ธ์— ์ถ”๋ก  ๋Šฅ๋ ฅ ๋ถ€์—ฌํ•˜๊ธฐ") as demo:
144
+ with gr.Row(scale=1):
145
+ with gr.Column(scale=5):
146
+ gr.Markdown(f"""
147
+ # ๋ชจ๋“  LLM์— ์ถ”๋ก  ๋Šฅ๋ ฅ ๊ฐ•์ œํ•˜๊ธฐ
148
+
149
+ ์ด๊ฒƒ์€ ๋ชจ๋“  LLM(๋Œ€๊ทœ๋ชจ ์–ธ์–ด ๋ชจ๋ธ)์ด ์‘๋‹ต ์ „์— ์ถ”๋ก ํ•  ์ˆ˜ ์žˆ๋„๋ก ํ•˜๋Š” ๊ฐ„๋‹จํ•œ ๊ฐœ๋… ์ฆ๋ช…์ž…๋‹ˆ๋‹ค.
150
+ ์ด ์ธํ„ฐํŽ˜์ด์Šค๋Š” *{model_name}* ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜๋Š”๋ฐ, **์ด๋Š” ์ถ”๋ก  ๋ชจ๋ธ์ด ์•„๋‹™๋‹ˆ๋‹ค**. ์‚ฌ์šฉ๋œ ๋ฐฉ๋ฒ•์€
151
+ ๋‹จ์ง€ ์ ‘๋‘์‚ฌ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ชจ๋ธ์ด ๋‹ต๋ณ€์„ ํ–ฅ์ƒ์‹œํ‚ค๋Š” ๋ฐ ๋„์›€์ด ๋˜๋Š” "์ถ”๋ก " ๋‹จ๊ณ„๋ฅผ ๊ฐ•์ œํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.
152
+ ๊ด€๋ จ ๊ธฐ์‚ฌ๋Š” ๋‹ค์Œ์—์„œ ํ™•์ธํ•˜์„ธ์š”: [๋ชจ๋“  ๋ชจ๋ธ์— ์ถ”๋ก  ๋Šฅ๋ ฅ ๋ถ€์—ฌํ•˜๊ธฐ](https://huggingface.co/blog/Metal3d/making-any-model-reasoning)
153
+ """)
154
+ chatbot = gr.Chatbot(
155
+ scale=1,
156
+ type="messages",
157
+ latex_delimiters=latex_delimiters,
158
+ )
159
+ msg = gr.Textbox(
160
+ submit_btn=True,
161
+ label="",
162
+ show_label=False,
163
+ placeholder="์—ฌ๊ธฐ์— ์งˆ๋ฌธ์„ ์ž…๋ ฅํ•˜์„ธ์š”.",
164
+ autofocus=True,
165
+ )
166
+ with gr.Column(scale=1):
167
+ gr.Markdown("""## ๋งค๊ฐœ๋ณ€์ˆ˜ ์กฐ์ •""")
168
+ num_tokens = gr.Slider(
169
+ 50,
170
+ 1024,
171
+ 100,
172
+ step=1,
173
+ label="์ถ”๋ก  ๋‹จ๊ณ„๋‹น ์ตœ๋Œ€ ํ† ํฐ ์ˆ˜",
174
+ interactive=True,
175
+ )
176
+ final_num_tokens = gr.Slider(
177
+ 50,
178
+ 1024,
179
+ 512,
180
+ step=1,
181
+ label="์ตœ์ข… ๋‹ต๋ณ€์˜ ์ตœ๋Œ€ ํ† ํฐ ์ˆ˜",
182
+ interactive=True,
183
+ )
184
+ do_sample = gr.Checkbox(True, label="์ƒ˜ํ”Œ๋ง ์‚ฌ์šฉ")
185
+ temperature = gr.Slider(0.1, 1.0, 0.7, step=0.1, label="์˜จ๋„")
186
+ gr.Markdown("""
187
+ ์ถ”๋ก  ๋‹จ๊ณ„์—์„œ ๋” ์ ์€ ์ˆ˜์˜ ํ† ํฐ์„ ์‚ฌ์šฉํ•˜๋ฉด ๋ชจ๋ธ์ด
188
+ ๋” ๋นจ๋ฆฌ ๋‹ต๋ณ€ํ•  ์ˆ˜ ์žˆ์ง€๋งŒ, ์ถฉ๋ถ„ํžˆ ๊นŠ๊ฒŒ ์ถ”๋ก ํ•˜์ง€ ๋ชปํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
189
+ ์ ์ ˆํ•œ ๊ฐ’์€ 100์—์„œ 512์ž…๋‹ˆ๋‹ค.
190
+ ์ตœ์ข… ๋‹ต๋ณ€์— ๋” ์ ์€ ์ˆ˜์˜ ํ† ํฐ์„ ์‚ฌ์šฉํ•˜๋ฉด ๋ชจ๋ธ์˜
191
+ ์‘๋‹ต์ด ๋œ ์žฅํ™ฉํ•ด์ง€์ง€๋งŒ, ์™„์ „ํ•œ ๋‹ต๋ณ€์„ ์ œ๊ณตํ•˜์ง€ ๋ชปํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
192
+ ์ ์ ˆํ•œ ๊ฐ’์€ 512์—์„œ 1024์ž…๋‹ˆ๋‹ค.
193
+ **์ƒ˜ํ”Œ๋ง ์‚ฌ์šฉ**์€ ๋‹ต๋ณ€์„ ์™„์„ฑํ•˜๊ธฐ ์œ„ํ•ด ๋‹ค์Œ ํ† ํฐ์„ ์„ ํƒํ•˜๋Š” ๋‹ค๋ฅธ ์ „๋žต์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
194
+ ์ผ๋ฐ˜์ ์œผ๋กœ ์ด ์˜ต์…˜์„ ์ฒดํฌํ•ด ๋‘๋Š” ๊ฒƒ์ด ์ข‹์Šต๋‹ˆ๋‹ค.
195
+ **์˜จ๋„**๋Š” ๋ชจ๋ธ์ด ์–ผ๋งˆ๋‚˜ "์ฐฝ์˜์ "์ผ ์ˆ˜ ์žˆ๋Š”์ง€๋ฅผ ๋‚˜ํƒ€๋ƒ…๋‹ˆ๋‹ค. 0.7์ด ์ผ๋ฐ˜์ ์ธ ๊ฐ’์ž…๋‹ˆ๋‹ค.
196
+ ๋„ˆ๋ฌด ๋†’์€ ๊ฐ’(์˜ˆ: 1.0)์„ ์„ค์ •ํ•˜๋ฉด ๋ชจ๋ธ์ด ์ผ๊ด€์„ฑ์ด ์—†์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋‚ฎ์€ ๊ฐ’(์˜ˆ: 0.3)์œผ๋กœ
197
+ ์„ค์ •ํ•˜๋ฉด ๋ชจ๋ธ์€ ๋งค์šฐ ์˜ˆ์ธก ๊ฐ€๋Šฅํ•œ ๋‹ต๋ณ€์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
198
+ """)
199
+ gr.Markdown("""
200
+ ์ด ์ธํ„ฐํŽ˜์ด์Šค๋Š” 6GB VRAM์„ ๊ฐ€์ง„ ๊ฐœ์ธ ์ปดํ“จํ„ฐ์—์„œ ์ž‘๋™ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค(์˜ˆ: ๋…ธํŠธ๋ถ์˜ NVidia 3050/3060).
201
+ ์ž์œ ๋กญ๊ฒŒ ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜์„ ํฌํฌํ•˜์—ฌ ๋‹ค๋ฅธ instruct ๋ชจ๋ธ์„ ์‹œ๋„ํ•ด ๋ณด์„ธ์š”.
202
+ """)
203
+
204
+ # ์‚ฌ์šฉ์ž๊ฐ€ ๋ฉ”์‹œ์ง€๋ฅผ ์ œ์ถœํ•˜๋ฉด ๋ด‡์ด ์‘๋‹ตํ•ฉ๋‹ˆ๋‹ค
205
+ msg.submit(
206
+ user_input,
207
+ [msg, chatbot], # ์ž…๋ ฅ
208
+ [msg, chatbot], # ์ถœ๋ ฅ
209
+ ).then(
210
+ bot,
211
+ [
212
+ chatbot,
213
+ num_tokens,
214
+ final_num_tokens,
215
+ do_sample,
216
+ temperature,
217
+ ], # ์‹ค์ œ๋กœ๋Š” "history" ์ž…๋ ฅ
218
+ chatbot, # ์ถœ๋ ฅ์—์„œ ์ƒˆ ํžˆ์Šคํ† ๋ฆฌ ์ €์žฅ
219
+ )
220
+
221
+ if __name__ == "__main__":
222
+ demo.queue().launch()