kimhyunwoo commited on
Commit
4bf6d97
ยท
verified ยท
1 Parent(s): 529d051

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -54
app.py CHANGED
@@ -1,7 +1,25 @@
1
- import gradio as gr
2
- import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  # --- ๋ชจ๋ธ ๋กœ๋“œ ---
7
  # ๋ชจ๋ธ ๊ฒฝ๋กœ ์„ค์ • (Hugging Face ๋ชจ๋ธ ID)
@@ -11,28 +29,18 @@ model_id = "microsoft/bitnet-b1.58-2B-4T"
11
  os.environ["TRANSFORMERS_VERBOSITY"] = "error"
12
 
13
  # AutoModelForCausalLM๊ณผ AutoTokenizer๋ฅผ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.
14
- # BitNet ๋ชจ๋ธ์€ trust_remote_code=True๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.
15
- # GitHub ํŠน์ • ๋ธŒ๋žœ์น˜์—์„œ ์„ค์น˜ํ•œ transformers๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
16
  try:
17
  print(f"๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘: {model_id}...")
18
- # GPU๊ฐ€ ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•˜๋ฉด bf16 ์‚ฌ์šฉ
19
- if torch.cuda.is_available():
20
- # torch_dtype์„ ๋ช…์‹œ์ ์œผ๋กœ ์„ค์ •ํ•˜์—ฌ ๋กœ๋“œ ์˜ค๋ฅ˜ ๋ฐฉ์ง€ ์‹œ๋„
21
- model = AutoModelForCausalLM.from_pretrained(
22
- model_id,
23
- torch_dtype=torch.bfloat16,
24
- trust_remote_code=True
25
- ).to("cuda") # GPU๋กœ ๋ชจ๋ธ ์ด๋™
26
- tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
27
- print("GPU๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ชจ๋ธ ๋กœ๋“œ ์™„๋ฃŒ.")
28
- else:
29
- # CPU ์‚ฌ์šฉ ์‹œ torch_dtype ์ƒ๋žต ๋˜๋Š” float32
30
- model = AutoModelForCausalLM.from_pretrained(
31
- model_id,
32
- trust_remote_code=True
33
- )
34
- tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
35
- print("CPU๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ชจ๋ธ ๋กœ๋“œ ์™„๋ฃŒ. ์„ฑ๋Šฅ์ด ๋А๋ฆด ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.")
36
 
37
  except Exception as e:
38
  print(f"๋ชจ๋ธ ๋กœ๋“œ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
@@ -41,54 +49,119 @@ except Exception as e:
41
  print("๋ชจ๋ธ ๋กœ๋“œ์— ์‹คํŒจํ–ˆ์Šต๋‹ˆ๋‹ค. ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜์ด ์ œ๋Œ€๋กœ ๋™์ž‘ํ•˜์ง€ ์•Š์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.")
42
 
43
 
44
- # --- ํ…์ŠคํŠธ ์ƒ์„ฑ ํ•จ์ˆ˜ ---
45
- def generate_text(prompt, max_length=100, temperature=0.7):
 
 
 
 
 
 
 
 
46
  if model is None or tokenizer is None:
47
- return "๋ชจ๋ธ ๋กœ๋“œ์— ์‹คํŒจํ•˜์—ฌ ํ…์ŠคํŠธ ์ƒ์„ฑ์„ ํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."
 
48
 
49
  try:
50
- # ํ”„๋กฌํ”„ํŠธ ํ† ํฐํ™”
51
- inputs = tokenizer(prompt, return_tensors="pt")
52
- # GPU ์‚ฌ์šฉ ๊ฐ€๋Šฅ ์‹œ GPU๋กœ ์ž…๋ ฅ ์ด๋™
53
- if torch.cuda.is_available():
54
- inputs = {k: v.to("cuda") for k, v in inputs.items()}
55
-
56
- # ํ…์ŠคํŠธ ์ƒ์„ฑ
57
- # LLaMA 3 ํ† ํฌ๋‚˜์ด์ €๋ฅผ ์‚ฌ์šฉํ•˜๋ฏ€๋กœ chat template ์ ์šฉ ๊ฐ€๋Šฅ (์„ ํƒ ์‚ฌํ•ญ)
58
- # ๋ฉ”์‹œ์ง€ ํ˜•์‹์„ ์‚ฌ์šฉํ•˜์ง€ ์•Š๊ณ  ์ง์ ‘ ํ”„๋กฌํ”„ํŠธ ์ž…๋ ฅ ์‹œ ์•„๋ž˜ ์ฝ”๋“œ ์‚ฌ์šฉ
59
- outputs = model.generate(
 
 
 
 
 
 
 
 
 
60
  **inputs,
61
- max_new_tokens=max_length,
 
62
  temperature=temperature,
63
- do_sample=True, # ์ƒ˜ํ”Œ๋ง ํ™œ์„ฑํ™”
64
- pad_token_id=tokenizer.eos_token_id # ํŒจ๋”ฉ ํ† ํฐ ID ์„ค์ • (ํ•„์š”์‹œ)
 
65
  )
66
 
67
- # ์ƒ์„ฑ๋œ ํ…์ŠคํŠธ ๋””์ฝ”๋”ฉ
68
- # ์ž…๋ ฅ ํ”„๋กฌํ”„ํŠธ ๋ถ€๋ถ„์„ ์ œ์™ธํ•˜๊ณ  ์ƒ์„ฑ๋œ ๋ถ€๋ถ„๋งŒ ๋””์ฝ”๋”ฉ
69
- generated_text = tokenizer.decode(outputs[0][inputs['input_ids'].shape[-1]:], skip_special_tokens=True)
70
 
71
- return generated_text
 
 
 
 
72
 
73
  except Exception as e:
74
- return f"ํ…์ŠคํŠธ ์ƒ์„ฑ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}"
 
 
75
 
76
  # --- Gradio ์ธํ„ฐํŽ˜์ด์Šค ์„ค์ • ---
77
  if model is not None and tokenizer is not None:
78
- interface = gr.Interface(
79
- fn=generate_text,
80
- inputs=[
81
- gr.Textbox(lines=2, placeholder="ํ…์ŠคํŠธ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”...", label="์ž…๋ ฅ ํ”„๋กฌํ”„ํŠธ"),
82
- gr.Slider(minimum=10, maximum=500, value=100, label="์ตœ๋Œ€ ์ƒ์„ฑ ๊ธธ์ด"),
83
- gr.Slider(minimum=0.1, maximum=1.0, value=0.7, label="Temperature (์ฐฝ์˜์„ฑ)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  ],
85
- outputs=gr.Textbox(label="์ƒ์„ฑ๋œ ํ…์ŠคํŠธ"),
86
- title="BitNet b1.58-2B-4T ํ…์ŠคํŠธ ์ƒ์„ฑ ๋ฐ๋ชจ",
87
- description="BitNet b1.58-2B-4T ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜์—ฌ ํ…์ŠคํŠธ๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค."
88
  )
89
 
90
  # Gradio ์•ฑ ์‹คํ–‰
91
  # Hugging Face Spaces์—์„œ๋Š” share=True๊ฐ€ ์ž๋™์œผ๋กœ ์„ค์ •๋ฉ๋‹ˆ๋‹ค.
92
- interface.launch()
 
93
  else:
94
  print("๋ชจ๋ธ ๋กœ๋“œ ์‹คํŒจ๋กœ ์ธํ•ด Gradio ์ธํ„ฐํŽ˜์ด์Šค๋ฅผ ์‹คํ–‰ํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
 
1
+ # ํ•„์š”ํ•œ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ์„ค์น˜ํ•˜๋Š” ๋ช…๋ น์–ด์ž…๋‹ˆ๋‹ค.
2
+ # ์ด ๋ถ€๋ถ„์€ ์Šคํฌ๋ฆฝํŠธ ์‹คํ–‰ ์ดˆ๋ฐ˜์— ํ•œ ๋ฒˆ ์‹คํ–‰๋ฉ๋‹ˆ๋‹ค.
 
3
  import os
4
+ print("Installing required transformers branch...")
5
+ os.system("pip install git+https://github.com/shumingma/transformers.git")
6
+ print("Installation complete.")
7
+
8
+ # ํ•„์š”ํ•œ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋“ค์„ import ํ•ฉ๋‹ˆ๋‹ค.
9
+ import threading
10
+ import torch
11
+ import torch._dynamo
12
+ import gradio as gr
13
+ import spaces # Hugging Face Spaces ๊ด€๋ จ ์œ ํ‹ธ๋ฆฌํ‹ฐ
14
+
15
+ # torch._dynamo ์„ค์ • (์„ ํƒ ์‚ฌํ•ญ, ์„ฑ๋Šฅ ํ–ฅ์ƒ ์‹œ๋„)
16
+ torch._dynamo.config.suppress_errors = True
17
+
18
+ from transformers import (
19
+ AutoModelForCausalLM,
20
+ AutoTokenizer,
21
+ TextIteratorStreamer,
22
+ )
23
 
24
  # --- ๋ชจ๋ธ ๋กœ๋“œ ---
25
  # ๋ชจ๋ธ ๊ฒฝ๋กœ ์„ค์ • (Hugging Face ๋ชจ๋ธ ID)
 
29
  os.environ["TRANSFORMERS_VERBOSITY"] = "error"
30
 
31
  # AutoModelForCausalLM๊ณผ AutoTokenizer๋ฅผ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.
32
+ # trust_remote_code=True๊ฐ€ ํ•„์š”ํ•˜๋ฉฐ, device_map="auto"๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ž๋™์œผ๋กœ ๋””๋ฐ”์ด์Šค ์„ค์ •
 
33
  try:
34
  print(f"๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘: {model_id}...")
35
+ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
36
+ model = AutoModelForCausalLM.from_pretrained(
37
+ model_id,
38
+ torch_dtype=torch.bfloat16, # bf16 ์‚ฌ์šฉ (GPU ๊ถŒ์žฅ)
39
+ device_map="auto", # ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ๋””๋ฐ”์ด์Šค์— ์ž๋™์œผ๋กœ ๋ชจ๋ธ ๋ฐฐ์น˜
40
+ trust_remote_code=True
41
+ )
42
+ print(f"๋ชจ๋ธ ๋””๋ฐ”์ด์Šค: {model.device}")
43
+ print("๋ชจ๋ธ ๋กœ๋“œ ์™„๋ฃŒ.")
 
 
 
 
 
 
 
 
 
44
 
45
  except Exception as e:
46
  print(f"๋ชจ๋ธ ๋กœ๋“œ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
 
49
  print("๋ชจ๋ธ ๋กœ๋“œ์— ์‹คํŒจํ–ˆ์Šต๋‹ˆ๋‹ค. ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜์ด ์ œ๋Œ€๋กœ ๋™์ž‘ํ•˜์ง€ ์•Š์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.")
50
 
51
 
52
+ # --- ํ…์ŠคํŠธ ์ƒ์„ฑ ํ•จ์ˆ˜ (Gradio ChatInterface์šฉ) ---
53
+ @spaces.GPU # ์ด ํ•จ์ˆ˜๊ฐ€ GPU ์ž์›์„ ์‚ฌ์šฉํ•˜๋„๋ก ๋ช…์‹œ (Hugging Face Spaces)
54
+ def respond(
55
+ message: str,
56
+ history: list[tuple[str, str]],
57
+ system_message: str,
58
+ max_tokens: int,
59
+ temperature: float,
60
+ top_p: float,
61
+ ):
62
  if model is None or tokenizer is None:
63
+ yield "๋ชจ๋ธ ๋กœ๋“œ์— ์‹คํŒจํ•˜์—ฌ ํ…์ŠคํŠธ ์ƒ์„ฑ์„ ํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."
64
+ return # ์ƒ์„ฑ๊ธฐ ํ•จ์ˆ˜์ด๋ฏ€๋กœ return ๋Œ€์‹  ๋นˆ yield ๋˜๋Š” ๊ทธ๋ƒฅ return
65
 
66
  try:
67
+ # ๋ฉ”์‹œ์ง€ ํ˜•์‹์„ ๋ชจ๋ธ์˜ chat template์— ๋งž๊ฒŒ ๊ตฌ์„ฑ
68
+ messages = [{"role": "system", "content": system_message}]
69
+ for user_msg, bot_msg in history:
70
+ if user_msg:
71
+ messages.append({"role": "user", "content": user_msg})
72
+ if bot_msg:
73
+ messages.append({"role": "assistant", "content": bot_msg})
74
+ messages.append({"role": "user", "content": message})
75
+
76
+ prompt = tokenizer.apply_chat_template(
77
+ messages, tokenize=False, add_generation_prompt=True
78
+ )
79
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
80
+
81
+ # ํ…์ŠคํŠธ ์ŠคํŠธ๋ฆฌ๋ฐ์„ ์œ„ํ•œ streamer ์„ค์ •
82
+ streamer = TextIteratorStreamer(
83
+ tokenizer, skip_prompt=True, skip_special_tokens=True
84
+ )
85
+ generate_kwargs = dict(
86
  **inputs,
87
+ streamer=streamer,
88
+ max_new_tokens=max_tokens,
89
  temperature=temperature,
90
+ top_p=top_p,
91
+ do_sample=True,
92
+ pad_token_id=tokenizer.eos_token_id # ํŒจ๋”ฉ ํ† ํฐ ID ์„ค์ •
93
  )
94
 
95
+ # ๋ชจ๋ธ ์ƒ์„ฑ์„ ๋ณ„๋„์˜ ์Šค๋ ˆ๋“œ์—์„œ ์‹คํ–‰
96
+ thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
97
+ thread.start()
98
 
99
+ # ์ŠคํŠธ๋ฆฌ๋จธ์—์„œ ์ƒ์„ฑ๋œ ํ…์ŠคํŠธ๋ฅผ ์ฝ์–ด์™€ yield
100
+ response = ""
101
+ for new_text in streamer:
102
+ response += new_text
103
+ yield response # ์‹ค์‹œ๊ฐ„์œผ๋กœ ์‘๋‹ต์„ Gradio ์ธํ„ฐํŽ˜์ด์Šค๋กœ ์ „๋‹ฌ
104
 
105
  except Exception as e:
106
+ yield f"ํ…์ŠคํŠธ ์ƒ์„ฑ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}"
107
+ # ์˜ค๋ฅ˜ ๋ฐœ์ƒ ์‹œ ์Šค๋ ˆ๋“œ ์ฒ˜๋ฆฌ ๋กœ์ง ์ถ”๊ฐ€ ๊ณ ๋ ค ํ•„์š” (์„ ํƒ ์‚ฌํ•ญ)
108
+
109
 
110
  # --- Gradio ์ธํ„ฐํŽ˜์ด์Šค ์„ค์ • ---
111
  if model is not None and tokenizer is not None:
112
+ demo = gr.ChatInterface(
113
+ fn=respond,
114
+ title="Bitnet-b1.58-2B-4T Chatbot",
115
+ description="Microsoft Bitnet-b1.58-2B-4T ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•œ ์ฑ„ํŒ… ๋ฐ๋ชจ์ž…๋‹ˆ๋‹ค.",
116
+ examples=[
117
+ [
118
+ "์•ˆ๋…•ํ•˜์„ธ์š”! ์ž๊ธฐ์†Œ๊ฐœ ํ•ด์ฃผ์„ธ์š”.",
119
+ "๋‹น์‹ ์€ ์œ ๋Šฅํ•œ AI ๋น„์„œ์ž…๋‹ˆ๋‹ค.", # System message ์˜ˆ์‹œ
120
+ 512, # Max new tokens ์˜ˆ์‹œ
121
+ 0.7, # Temperature ์˜ˆ์‹œ
122
+ 0.95, # Top-p ์˜ˆ์‹œ
123
+ ],
124
+ [
125
+ "ํŒŒ์ด์ฌ์œผ๋กœ ๊ฐ„๋‹จํ•œ ์›น ์„œ๋ฒ„ ๋งŒ๋“œ๋Š” ์ฝ”๋“œ ์•Œ๋ ค์ค˜",
126
+ "๋‹น์‹ ์€ ์œ ๋Šฅํ•œ AI ๊ฐœ๋ฐœ์ž์ž…๋‹ˆ๋‹ค.", # System message ์˜ˆ์‹œ
127
+ 1024, # Max new tokens ์˜ˆ์‹œ
128
+ 0.8, # Temperature ์˜ˆ์‹œ
129
+ 0.9, # Top-p ์˜ˆ์‹œ
130
+ ],
131
+ ],
132
+ additional_inputs=[
133
+ gr.Textbox(
134
+ value="๋‹น์‹ ์€ ์œ ๋Šฅํ•œ AI ๋น„์„œ์ž…๋‹ˆ๋‹ค.", # ๊ธฐ๋ณธ ์‹œ์Šคํ…œ ๋ฉ”์‹œ์ง€
135
+ label="System message",
136
+ lines=1
137
+ ),
138
+ gr.Slider(
139
+ minimum=1,
140
+ maximum=4096, # ๋ชจ๋ธ ์ตœ๋Œ€ ์ปจํ…์ŠคํŠธ ๊ธธ์ด ๊ณ ๋ ค (๋˜๋Š” ๋” ๊ธธ๊ฒŒ ์„ค์ •)
141
+ value=512,
142
+ step=1,
143
+ label="Max new tokens"
144
+ ),
145
+ gr.Slider(
146
+ minimum=0.1,
147
+ maximum=2.0, # Temperature ๋ฒ”์œ„ ์กฐ์ • (ํ•„์š”์‹œ)
148
+ value=0.7,
149
+ step=0.1,
150
+ label="Temperature"
151
+ ),
152
+ gr.Slider(
153
+ minimum=0.0, # Top-p ๋ฒ”์œ„ ์กฐ์ • (ํ•„์š”์‹œ)
154
+ maximum=1.0,
155
+ value=0.95,
156
+ step=0.05,
157
+ label="Top-p (nucleus sampling)"
158
+ ),
159
  ],
 
 
 
160
  )
161
 
162
  # Gradio ์•ฑ ์‹คํ–‰
163
  # Hugging Face Spaces์—์„œ๋Š” share=True๊ฐ€ ์ž๋™์œผ๋กœ ์„ค์ •๋ฉ๋‹ˆ๋‹ค.
164
+ # debug=True๋กœ ์„ค์ •ํ•˜๋ฉด ์ƒ์„ธ ๋กœ๊ทธ๋ฅผ ๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
165
+ demo.launch(debug=True)
166
  else:
167
  print("๋ชจ๋ธ ๋กœ๋“œ ์‹คํŒจ๋กœ ์ธํ•ด Gradio ์ธํ„ฐํŽ˜์ด์Šค๋ฅผ ์‹คํ–‰ํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")