kimhyunwoo commited on
Commit
dc10a73
Β·
verified Β·
1 Parent(s): 0ef9d3d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -20
app.py CHANGED
@@ -29,11 +29,11 @@ try:
29
  MODEL_ID,
30
  torch_dtype=torch.float32,
31
  device_map="cpu",
32
- force_download=True # κ°•μ œ μž¬λ‹€μš΄λ‘œλ“œ ν™œμ„±ν™”
33
  )
34
  tokenizer = AutoTokenizer.from_pretrained(
35
  MODEL_ID,
36
- force_download=True # κ°•μ œ μž¬λ‹€μš΄λ‘œλ“œ ν™œμ„±ν™”
37
  )
38
  model.eval()
39
  print("--- Model and Tokenizer Loaded Successfully on CPU ---")
@@ -41,14 +41,24 @@ try:
41
  stop_token_strings = ["<|endofturn|>", "<|stop|>"]
42
  stop_token_ids_list = [tokenizer.convert_tokens_to_ids(token) for token in stop_token_strings]
43
 
44
- if tokenizer.eos_token_id not in stop_token_ids_list:
45
- stop_token_ids_list.append(tokenizer.eos_token_id)
 
 
 
 
 
 
 
 
46
 
47
  stop_token_ids_list = [tid for tid in stop_token_ids_list if tid is not None]
48
 
49
  if not stop_token_ids_list:
50
- print("Warning: Could not find any stop token IDs. Using default EOS only.")
51
- stop_token_ids_list = [tokenizer.eos_token_id]
 
 
52
 
53
  print(f"Using Stop Token IDs: {stop_token_ids_list}")
54
 
@@ -77,6 +87,7 @@ def predict(message, history):
77
  {"role": "tool_list", "content": ""},
78
  {"role": "system", "content": system_prompt}
79
  ]
 
80
  for user_msg, ai_msg in history:
81
  chat_history_formatted.append({"role": "user", "content": user_msg})
82
  chat_history_formatted.append({"role": "assistant", "content": ai_msg})
@@ -87,12 +98,13 @@ def predict(message, history):
87
  output_ids = None
88
 
89
  try:
 
90
  inputs = tokenizer.apply_chat_template(
91
  chat_history_formatted,
92
  add_generation_prompt=True,
93
  return_dict=True,
94
  return_tensors="pt"
95
- ).to(model.device)
96
  input_length = inputs['input_ids'].shape[1]
97
  print(f"\nInput tokens: {input_length}")
98
 
@@ -103,11 +115,12 @@ def predict(message, history):
103
  try:
104
  print("Generating response...")
105
  with torch.no_grad():
 
106
  output_ids = model.generate(
107
  **inputs,
108
  max_new_tokens=MAX_NEW_TOKENS,
109
- eos_token_id=stop_token_ids_list,
110
- pad_token_id=tokenizer.eos_token_id,
111
  do_sample=True,
112
  temperature=0.7,
113
  top_p=0.9,
@@ -121,26 +134,35 @@ def predict(message, history):
121
  gc.collect()
122
  return f"였λ₯˜: 응닡을 μƒμ„±ν•˜λŠ” 쀑 λ¬Έμ œκ°€ λ°œμƒν–ˆμŠ΅λ‹ˆλ‹€. ({e})"
123
 
124
- new_tokens = output_ids[0, input_length:]
125
- response = tokenizer.decode(new_tokens, skip_special_tokens=True)
 
 
 
 
 
 
 
126
 
127
- print(f"Output tokens: {len(new_tokens)}")
128
 
129
- del inputs
130
- del output_ids
131
- del new_tokens
132
  gc.collect()
133
  print("Memory cleaned.")
134
 
135
  return response
136
 
 
137
  print("--- Setting up Gradio Interface ---")
138
 
 
139
  chatbot_component = gr.Chatbot(
140
  label="HyperCLOVA X SEED (0.5B) λŒ€ν™”",
141
  bubble_full_width=False,
142
- height=600
143
- )
 
144
 
145
  examples = [
146
  ["넀이버 ν΄λ‘œλ°”XλŠ” λ¬΄μ—‡μΈκ°€μš”?"],
@@ -149,6 +171,7 @@ examples = [
149
  ["μ œμ£Όλ„ μ—¬ν–‰ κ³„νšμ„ μ„Έμš°κ³  μžˆλŠ”λ°, 3λ°• 4일 μΆ”μ²œ μ½”μŠ€ μ’€ μ§œμ€„λž˜?"],
150
  ]
151
 
 
152
  demo = gr.ChatInterface(
153
  fn=predict,
154
  chatbot=chatbot_component,
@@ -162,11 +185,9 @@ demo = gr.ChatInterface(
162
  examples=examples,
163
  cache_examples=False,
164
  theme="soft",
165
- retry_btn="λ‹€μ‹œ μ‹œλ„",
166
- undo_btn="이전 ν„΄ μ‚­μ œ",
167
- clear_btn="λŒ€ν™” μ΄ˆκΈ°ν™”",
168
  )
169
 
 
170
  if __name__ == "__main__":
171
  print("--- Launching Gradio App ---")
172
  demo.queue().launch()
 
29
  MODEL_ID,
30
  torch_dtype=torch.float32,
31
  device_map="cpu",
32
+ force_download=True # 이전 였λ₯˜ 해결을 μœ„ν•΄ μœ μ§€ (ν•„μš” μ—†μœΌλ©΄ False λ˜λŠ” 제거)
33
  )
34
  tokenizer = AutoTokenizer.from_pretrained(
35
  MODEL_ID,
36
+ force_download=True # 이전 였λ₯˜ 해결을 μœ„ν•΄ μœ μ§€ (ν•„μš” μ—†μœΌλ©΄ False λ˜λŠ” 제거)
37
  )
38
  model.eval()
39
  print("--- Model and Tokenizer Loaded Successfully on CPU ---")
 
41
  stop_token_strings = ["<|endofturn|>", "<|stop|>"]
42
  stop_token_ids_list = [tokenizer.convert_tokens_to_ids(token) for token in stop_token_strings]
43
 
44
+ # λͺ¨λΈ ν† ν¬λ‚˜μ΄μ €μ— eos_token_idκ°€ μ œλŒ€λ‘œ μ„€μ •λ˜μ–΄ μžˆλŠ”μ§€ 확인
45
+ if tokenizer.eos_token is not None and tokenizer.eos_token_id not in stop_token_ids_list:
46
+ # eos_token_idκ°€ None이 μ•„λ‹ˆκ³  λ¦¬μŠ€νŠΈμ— 없을 λ•Œλ§Œ μΆ”κ°€
47
+ if tokenizer.eos_token_id is not None:
48
+ stop_token_ids_list.append(tokenizer.eos_token_id)
49
+ else:
50
+ print("Warning: tokenizer.eos_token_id is None. Cannot add to stop tokens.")
51
+ elif tokenizer.eos_token is None:
52
+ print("Warning: tokenizer.eos_token is not defined.")
53
+
54
 
55
  stop_token_ids_list = [tid for tid in stop_token_ids_list if tid is not None]
56
 
57
  if not stop_token_ids_list:
58
+ print("Warning: Could not find any stop token IDs. Generation might not stop correctly.")
59
+ # Fallback: λ§Œμ•½ eos 토큰 ID도 μ—†λ‹€λ©΄, generationμ—μ„œ λ¬Έμ œκ°€ 생길 수 있음
60
+ # ν•„μš”ν•˜λ‹€λ©΄ κΈ°λ³Έ eos 토큰 IDλ₯Ό ν•˜λ“œμ½”λ”©ν•˜κ±°λ‚˜ λ‹€λ₯Έ λ°©μ‹μœΌλ‘œ μ²˜λ¦¬ν•΄μ•Ό ν•  수 있음
61
+ # 예: stop_token_ids_list = [some_default_eos_id]
62
 
63
  print(f"Using Stop Token IDs: {stop_token_ids_list}")
64
 
 
87
  {"role": "tool_list", "content": ""},
88
  {"role": "system", "content": system_prompt}
89
  ]
90
+ # historyκ°€ (user, ai) νŠœν”Œ 리슀트라고 κ°€μ •
91
  for user_msg, ai_msg in history:
92
  chat_history_formatted.append({"role": "user", "content": user_msg})
93
  chat_history_formatted.append({"role": "assistant", "content": ai_msg})
 
98
  output_ids = None
99
 
100
  try:
101
+ # device_map="cpu"둜 λͺ¨λΈμ„ λ‘œλ“œν–ˆμœΌλ―€λ‘œ, inputs도 cpu둜 λ³΄λƒ…λ‹ˆλ‹€.
102
  inputs = tokenizer.apply_chat_template(
103
  chat_history_formatted,
104
  add_generation_prompt=True,
105
  return_dict=True,
106
  return_tensors="pt"
107
+ ).to("cpu") # λͺ…μ‹œμ μœΌλ‘œ CPU μ§€μ •
108
  input_length = inputs['input_ids'].shape[1]
109
  print(f"\nInput tokens: {input_length}")
110
 
 
115
  try:
116
  print("Generating response...")
117
  with torch.no_grad():
118
+ # eos_token_id에 리슀트λ₯Ό μ „λ‹¬ν•˜λŠ” 것이 μΌλ°˜μ μž…λ‹ˆλ‹€.
119
  output_ids = model.generate(
120
  **inputs,
121
  max_new_tokens=MAX_NEW_TOKENS,
122
+ eos_token_id=stop_token_ids_list, # μˆ˜μ •λœ stop_token_ids_list μ‚¬μš©
123
+ pad_token_id=tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.pad_token_id, # pad_token_id 확인
124
  do_sample=True,
125
  temperature=0.7,
126
  top_p=0.9,
 
134
  gc.collect()
135
  return f"였λ₯˜: 응닡을 μƒμ„±ν•˜λŠ” 쀑 λ¬Έμ œκ°€ λ°œμƒν–ˆμŠ΅λ‹ˆλ‹€. ({e})"
136
 
137
+ # output_idsκ°€ None이 아닐 κ²½μš°μ—λ§Œ λ””μ½”λ”© μ‹œλ„
138
+ if output_ids is not None:
139
+ new_tokens = output_ids[0, input_length:]
140
+ response = tokenizer.decode(new_tokens, skip_special_tokens=True)
141
+ print(f"Output tokens: {len(new_tokens)}")
142
+ del new_tokens # λ©”λͺ¨λ¦¬ 정리
143
+ else:
144
+ response = "였λ₯˜: 응닡 생성에 μ‹€νŒ¨ν–ˆμŠ΅λ‹ˆλ‹€ (output_ids is None)."
145
+ print("Generation failed, output_ids is None.")
146
 
 
147
 
148
+ # λ©”λͺ¨λ¦¬ 정리
149
+ if inputs is not None: del inputs
150
+ if output_ids is not None: del output_ids
151
  gc.collect()
152
  print("Memory cleaned.")
153
 
154
  return response
155
 
156
+ # --- Gradio Interface ---
157
  print("--- Setting up Gradio Interface ---")
158
 
159
+ # UserWarning ν•΄κ²° 및 μ΅œμ‹  ν˜•μ‹ μ‚¬μš©
160
  chatbot_component = gr.Chatbot(
161
  label="HyperCLOVA X SEED (0.5B) λŒ€ν™”",
162
  bubble_full_width=False,
163
+ height=600,
164
+ type='messages' # message ν˜•μ‹ μ‚¬μš© λͺ…μ‹œ
165
+ )
166
 
167
  examples = [
168
  ["넀이버 ν΄λ‘œλ°”XλŠ” λ¬΄μ—‡μΈκ°€μš”?"],
 
171
  ["μ œμ£Όλ„ μ—¬ν–‰ κ³„νšμ„ μ„Έμš°κ³  μžˆλŠ”λ°, 3λ°• 4일 μΆ”μ²œ μ½”μŠ€ μ’€ μ§œμ€„λž˜?"],
172
  ]
173
 
174
+ # λ¬Έμ œκ°€ 된 인자(retry_btn, undo_btn, clear_btn) 제거
175
  demo = gr.ChatInterface(
176
  fn=predict,
177
  chatbot=chatbot_component,
 
185
  examples=examples,
186
  cache_examples=False,
187
  theme="soft",
 
 
 
188
  )
189
 
190
+ # --- Launch the App ---
191
  if __name__ == "__main__":
192
  print("--- Launching Gradio App ---")
193
  demo.queue().launch()