openfree commited on
Commit
55caecd
ยท
verified ยท
1 Parent(s): e6c14df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -214
app.py CHANGED
@@ -6,7 +6,7 @@ os.environ["TORCH_DYNAMO_DISABLE"] = "1"
6
  # 2) Triton์˜ cudagraphs ์ตœ์ ํ™” ๋น„ํ™œ์„ฑํ™”
7
  os.environ["TRITON_DISABLE_CUDAGRAPHS"] = "1"
8
 
9
- # 3) ๊ฒฝ๊ณ  ๋ฌด์‹œ ์„ค์ • (skipping cudagraphs ๊ด€๋ จ)
10
  import warnings
11
  warnings.filterwarnings("ignore", message="skipping cudagraphs due to mutated inputs")
12
  warnings.filterwarnings("ignore", message="Not enough SMs to use max_autotune_gemm mode")
@@ -15,26 +15,22 @@ import torch
15
  # TensorFloat32 ์—ฐ์‚ฐ ํ™œ์„ฑํ™” (์„ฑ๋Šฅ ์ตœ์ ํ™”)
16
  torch.set_float32_matmul_precision('high')
17
 
18
- # TorchInductor cudagraphs ๋น„ํ™œ์„ฑํ™”
19
  import torch._inductor
20
  torch._inductor.config.triton.cudagraphs = False
21
 
22
- # Dynamo suppress_errors ์˜ต์…˜ (์˜ค๋ฅ˜ ์‹œ eager๋กœ fallback)
23
  import torch._dynamo
 
24
  torch._dynamo.config.suppress_errors = True
25
 
26
  import gradio as gr
27
  import spaces
28
-
29
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
30
 
31
  from threading import Thread
32
- import random
33
  from datasets import load_dataset
34
  import numpy as np
35
  from sklearn.feature_extraction.text import TfidfVectorizer
36
  import pandas as pd
37
- from typing import List, Tuple
38
  import json
39
  from datetime import datetime
40
  import pyarrow.parquet as pq
@@ -44,8 +40,8 @@ import platform
44
  import subprocess
45
  import pytesseract
46
  from pdf2image import convert_from_path
47
- import queue # queue.Empty ์˜ˆ์™ธ ์ฒ˜๋ฆฌ๋ฅผ ์œ„ํ•ด
48
- import time # ์ŠคํŠธ๋ฆฌ๋ฐ ํƒ€์ด๋ฐ์„ ์œ„ํ•ด
49
 
50
  # -------------------- PDF to Markdown ๋ณ€ํ™˜ ๊ด€๋ จ import --------------------
51
  try:
@@ -70,7 +66,6 @@ current_file_context = None
70
  # ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ์„ค์ •
71
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
72
  MODEL_ID = "CohereForAI/c4ai-command-r7b-12-2024"
73
- MODELS = os.environ.get("MODELS")
74
  MODEL_NAME = MODEL_ID.split("/")[-1]
75
 
76
  model = None # ์ „์—ญ์—์„œ ๊ด€๋ฆฌ
@@ -80,9 +75,9 @@ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
80
  wiki_dataset = load_dataset("lcw99/wikipedia-korean-20240501-1million-qna")
81
  print("Wikipedia dataset loaded:", wiki_dataset)
82
 
83
- # (2) TF-IDF ๋ฒกํ„ฐ๋ผ์ด์ € ์ดˆ๊ธฐํ™” ๋ฐ ํ•™์Šต
84
  print("TF-IDF ๋ฒกํ„ฐํ™” ์‹œ์ž‘...")
85
- questions = wiki_dataset['train']['question'][:10000] # ์ฒ˜์Œ 10000๊ฐœ๋งŒ ์‚ฌ์šฉ
86
  vectorizer = TfidfVectorizer(max_features=1000)
87
  question_vectors = vectorizer.fit_transform(questions)
88
  print("TF-IDF ๋ฒกํ„ฐํ™” ์™„๋ฃŒ")
@@ -143,16 +138,12 @@ class ChatHistory:
143
  print(f"ํžˆ์Šคํ† ๋ฆฌ ๋กœ๋“œ ์‹คํŒจ: {e}")
144
  self.history = []
145
 
146
- # ์ „์—ญ ChatHistory ์ธ์Šคํ„ด์Šค
147
  chat_history = ChatHistory()
148
 
149
  # ------------------------- ์œ„ํ‚ค ๋ฌธ์„œ ๊ฒ€์ƒ‰ (TF-IDF) -------------------------
150
  def find_relevant_context(query, top_k=3):
151
- # ์ฟผ๋ฆฌ ๋ฒกํ„ฐํ™”
152
  query_vector = vectorizer.transform([query])
153
- # ์ฝ”์‚ฌ์ธ ์œ ์‚ฌ๋„
154
  similarities = (query_vector * question_vectors.T).toarray()[0]
155
- # ์œ ์‚ฌ๋„ ๋†’์€ ์งˆ๋ฌธ ์ธ๋ฑ์Šค
156
  top_indices = np.argsort(similarities)[-top_k:][::-1]
157
 
158
  relevant_contexts = []
@@ -165,15 +156,11 @@ def find_relevant_context(query, top_k=3):
165
  })
166
  return relevant_contexts
167
 
168
- # ํŒŒ์ผ ์—…๋กœ๋“œ ์‹œ ํ‘œ์‹œํ•  ์ดˆ๊ธฐ ๋ฉ”์‹œ์ง€
169
  def init_msg():
170
  return "ํŒŒ์ผ์„ ๋ถ„์„ํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค..."
171
 
172
  # -------------------- PDF ํŒŒ์ผ์„ Markdown์œผ๋กœ ๋ณ€ํ™˜ํ•˜๋Š” ์œ ํ‹ธ ํ•จ์ˆ˜๋“ค --------------------
173
  def extract_text_from_pdf(reader: PdfReader) -> str:
174
- """
175
- PyPDF๋ฅผ ์‚ฌ์šฉํ•ด ๋ชจ๋“  ํŽ˜์ด์ง€ ํ…์ŠคํŠธ๋ฅผ ์ถ”์ถœ.
176
- """
177
  full_text = ""
178
  for idx, page in enumerate(reader.pages):
179
  text = page.extract_text() or ""
@@ -182,16 +169,11 @@ def extract_text_from_pdf(reader: PdfReader) -> str:
182
  return full_text.strip()
183
 
184
  def convert_pdf_to_markdown(pdf_file: str):
185
- """
186
- PDF ํŒŒ์ผ์—์„œ ํ…์ŠคํŠธ๋ฅผ ์ถ”์ถœํ•˜๊ณ ,
187
- ์ด๋ฏธ์ง€๊ฐ€ ๋งŽ๊ณ  ํ…์ŠคํŠธ๊ฐ€ ์ ์œผ๋ฉด OCR ์‹œ๋„
188
- """
189
  try:
190
  reader = PdfReader(pdf_file)
191
  except Exception as e:
192
  return f"PDF ํŒŒ์ผ์„ ์ฝ๋Š” ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}", None, None
193
 
194
- # ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ ์ถ”์ถœ
195
  raw_meta = reader.metadata
196
  metadata = {
197
  "author": raw_meta.author if raw_meta else None,
@@ -201,16 +183,13 @@ def convert_pdf_to_markdown(pdf_file: str):
201
  "title": raw_meta.title if raw_meta else None,
202
  }
203
 
204
- # ํ…์ŠคํŠธ ์ถ”์ถœ
205
  full_text = extract_text_from_pdf(reader)
206
 
207
- # ์ด๋ฏธ์ง€-ํ…์ŠคํŠธ ๋น„์œจ ํŒ๋‹จ ํ›„ OCR ์‹œ๋„
208
  image_count = sum(len(page.images) for page in reader.pages)
209
  if image_count > 0 and len(full_text) < 1000:
210
  try:
211
  out_pdf_file = pdf_file.replace(".pdf", "_ocr.pdf")
212
  ocrmypdf.ocr(pdf_file, out_pdf_file, force_ocr=True)
213
- # OCR๋œ PDF ๋‹ค์‹œ ์ฝ๊ธฐ
214
  reader_ocr = PdfReader(out_pdf_file)
215
  full_text = extract_text_from_pdf(reader_ocr)
216
  except Exception as e:
@@ -220,7 +199,6 @@ def convert_pdf_to_markdown(pdf_file: str):
220
 
221
  # ------------------------- ํŒŒ์ผ ๋ถ„์„ ํ•จ์ˆ˜ -------------------------
222
  def analyze_file_content(content, file_type):
223
- """๊ฐ„๋‹จํ•œ ๊ตฌ์กฐ ๋ถ„์„/์š”์•ฝ."""
224
  if file_type in ['parquet', 'csv']:
225
  try:
226
  lines = content.split('\n')
@@ -246,16 +224,16 @@ def analyze_file_content(content, file_type):
246
  return f"๐Ÿ“ Document Structure: {total_lines} lines, {paragraphs} paragraphs, approximately {words} words"
247
 
248
  def read_uploaded_file(file):
249
- """
250
- ์—…๋กœ๋“œ๋œ ํŒŒ์ผ ์ฒ˜๋ฆฌ -> ๋‚ด์šฉ/ํƒ€์ž…
251
- """
252
  if file is None:
253
  return "", ""
254
 
 
 
 
 
255
  try:
256
  file_ext = os.path.splitext(file.name)[1].lower()
257
 
258
- # Parquet
259
  if file_ext == '.parquet':
260
  try:
261
  table = pq.read_table(file.name)
@@ -291,8 +269,7 @@ def read_uploaded_file(file):
291
  except Exception as e:
292
  return f"Error reading Parquet file: {str(e)}", "error"
293
 
294
- # PDF
295
- if file_ext == '.pdf':
296
  try:
297
  markdown_text, metadata, processed_pdf_path = convert_pdf_to_markdown(file.name)
298
  if metadata is None:
@@ -302,14 +279,13 @@ def read_uploaded_file(file):
302
  content += "## Metadata\n"
303
  for k, v in metadata.items():
304
  content += f"**{k.capitalize()}**: {v}\n\n"
305
-
306
  content += "## Extracted Text\n\n"
307
  content += markdown_text
 
308
  return content, "pdf"
309
  except Exception as e:
310
  return f"Error reading PDF file: {str(e)}", "error"
311
 
312
- # CSV
313
  elif file_ext == '.csv':
314
  encodings = ['utf-8', 'cp949', 'euc-kr', 'latin1']
315
  for encoding in encodings:
@@ -342,7 +318,6 @@ def read_uploaded_file(file):
342
  f"Unable to read file with supported encodings ({', '.join(encodings)})"
343
  )
344
 
345
- # ํ…์ŠคํŠธ ํŒŒ์ผ
346
  else:
347
  encodings = ['utf-8', 'cp949', 'euc-kr', 'latin1']
348
  for encoding in encodings:
@@ -358,7 +333,7 @@ def read_uploaded_file(file):
358
  for keyword in ['def ', 'class ', 'import ', 'function']
359
  )
360
 
361
- analysis = f"\n๐Ÿ“ File Analysis:\n"
362
  if is_code:
363
  functions = sum('def ' in line for line in lines)
364
  classes = sum('class ' in line for line in lines)
@@ -374,7 +349,6 @@ def read_uploaded_file(file):
374
  else:
375
  words = len(content.split())
376
  chars = len(content)
377
-
378
  analysis += f"- File Type: Text\n"
379
  analysis += f"- Total Lines: {total_lines:,}\n"
380
  analysis += f"- Non-empty Lines: {non_empty_lines:,}\n"
@@ -395,162 +369,10 @@ def read_uploaded_file(file):
395
 
396
  # ------------------------- CSS -------------------------
397
  CSS = """
398
- /* 3D ์Šคํƒ€์ผ CSS */
399
- :root {
400
- --primary-color: #2196f3;
401
- --secondary-color: #1976d2;
402
- --background-color: #f0f2f5;
403
- --card-background: #ffffff;
404
- --text-color: #333333;
405
- --shadow-color: rgba(0, 0, 0, 0.1);
406
- }
407
- body {
408
- background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
409
- min-height: 100vh;
410
- font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
411
- }
412
- .container {
413
- transform-style: preserve-3d;
414
- perspective: 1000px;
415
- }
416
- .chatbot {
417
- background: var(--card-background);
418
- border-radius: 20px;
419
- box-shadow:
420
- 0 10px 20px var(--shadow-color),
421
- 0 6px 6px var(--shadow-color);
422
- transform: translateZ(0);
423
- transition: transform 0.3s ease;
424
- backdrop-filter: blur(10px);
425
- }
426
- .chatbot:hover {
427
- transform: translateZ(10px);
428
- }
429
- /* ๋ฉ”์‹œ์ง€ ์ž…๋ ฅ ์˜์—ญ */
430
- .input-area {
431
- background: var(--card-background);
432
- border-radius: 15px;
433
- padding: 15px;
434
- margin-top: 20px;
435
- box-shadow:
436
- 0 5px 15px var(--shadow-color),
437
- 0 3px 3px var(--shadow-color);
438
- transform: translateZ(0);
439
- transition: all 0.3s ease;
440
- display: flex;
441
- align-items: center;
442
- gap: 10px;
443
- }
444
- .input-area:hover {
445
- transform: translateZ(5px);
446
- }
447
- /* ๋ฒ„ํŠผ ์Šคํƒ€์ผ */
448
- .custom-button {
449
- background: linear-gradient(145deg, var(--primary-color), var(--secondary-color));
450
- color: white;
451
- border: none;
452
- border-radius: 10px;
453
- padding: 10px 20px;
454
- font-weight: 600;
455
- cursor: pointer;
456
- transform: translateZ(0);
457
- transition: all 0.3s ease;
458
- box-shadow:
459
- 0 4px 6px var(--shadow-color),
460
- 0 1px 3px var(--shadow-color);
461
- }
462
- .custom-button:hover {
463
- transform: translateZ(5px) translateY(-2px);
464
- box-shadow:
465
- 0 7px 14px var(--shadow-color),
466
- 0 3px 6px var(--shadow-color);
467
- }
468
- /* ํŒŒ์ผ ์—…๋กœ๋“œ ๋ฒ„ํŠผ */
469
- .file-upload-icon {
470
- background: linear-gradient(145deg, #64b5f6, #42a5f5);
471
- color: white;
472
- border-radius: 8px;
473
- font-size: 2em;
474
- cursor: pointer;
475
- display: flex;
476
- align-items: center;
477
- justify-content: center;
478
- height: 70px;
479
- width: 70px;
480
- transition: all 0.3s ease;
481
- box-shadow: 0 2px 5px rgba(0,0,0,0.1);
482
- }
483
- .file-upload-icon:hover {
484
- transform: translateY(-2px);
485
- box-shadow: 0 4px 8px rgba(0,0,0,0.2);
486
- }
487
- /* ํŒŒ์ผ ์—…๋กœ๋“œ ๋ฒ„ํŠผ ๋‚ด๋ถ€ ์š”์†Œ ์Šคํƒ€์ผ๋ง */
488
- .file-upload-icon > .wrap {
489
- display: flex !important;
490
- align-items: center;
491
- justify-content: center;
492
- width: 100%;
493
- height: 100%;
494
- }
495
- .file-upload-icon > .wrap > p {
496
- display: none !important;
497
- }
498
- .file-upload-icon > .wrap::before {
499
- content: "๐Ÿ“";
500
- font-size: 2em;
501
- display: block;
502
- }
503
- /* ๋ฉ”์‹œ์ง€ ์Šคํƒ€์ผ */
504
- .message {
505
- background: var(--card-background);
506
- border-radius: 15px;
507
- padding: 15px;
508
- margin: 10px 0;
509
- box-shadow:
510
- 0 4px 6px var(--shadow-color),
511
- 0 1px 3px var(--shadow-color);
512
- transform: translateZ(0);
513
- transition: all 0.3s ease;
514
- }
515
- .message:hover {
516
- transform: translateZ(5px);
517
- }
518
- .chat-container {
519
- height: 600px !important;
520
- margin-bottom: 10px;
521
- }
522
- .input-container {
523
- height: 70px !important;
524
- display: flex;
525
- align-items: center;
526
- gap: 10px;
527
- margin-top: 5px;
528
- }
529
- .input-textbox {
530
- height: 70px !important;
531
- border-radius: 8px !important;
532
- font-size: 1.1em !important;
533
- padding: 10px 15px !important;
534
- display: flex !important;
535
- align-items: flex-start !important;
536
- }
537
- .input-textbox textarea {
538
- padding-top: 5px !important;
539
- }
540
- .send-button {
541
- height: 70px !important;
542
- min-width: 70px !important;
543
- font-size: 1.1em !important;
544
- }
545
- /* ์„ค์ • ํŒจ๋„ ๊ธฐ๋ณธ ์Šคํƒ€์ผ */
546
- .settings-panel {
547
- padding: 20px;
548
- margin-top: 20px;
549
- }
550
  """
551
 
552
  def clear_cuda_memory():
553
- """CUDA ์บ์‹œ ์ •๋ฆฌ."""
554
  if hasattr(torch.cuda, 'empty_cache'):
555
  with torch.cuda.device('cuda'):
556
  torch.cuda.empty_cache()
@@ -566,13 +388,14 @@ def load_model():
566
  device_map="auto",
567
  low_cpu_mem_usage=True,
568
  )
 
 
569
  return loaded_model
570
  except Exception as e:
571
  print(f"๋ชจ๋ธ ๋กœ๋“œ ์˜ค๋ฅ˜: {str(e)}")
572
  raise
573
 
574
  def build_prompt(conversation: list) -> str:
575
- """๋Œ€ํ™” ๋‚ด์—ญ์„ ๋‹จ์ˆœ ํ…์ŠคํŠธ ํ”„๋กฌํ”„ํŠธ๋กœ ๋ณ€ํ™˜."""
576
  prompt = ""
577
  for msg in conversation:
578
  if msg["role"] == "user":
@@ -597,14 +420,13 @@ def stream_chat(
597
  global model, current_file_context
598
 
599
  try:
600
- # ๋ชจ๋ธ ๋ฏธ๋กœ๋“œ์‹œ ๋กœ๋”ฉ
601
  if model is None:
602
  model = load_model()
603
 
604
  print(f'[User input] message: {message}')
605
  print(f'[History] {history}')
606
 
607
- # (1) ํŒŒ์ผ ์—…๋กœ๋“œ ์ฒ˜๋ฆฌ
608
  file_context = ""
609
  if uploaded_file and message == "ํŒŒ์ผ์„ ๋ถ„์„ํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค...":
610
  current_file_context = None
@@ -624,7 +446,7 @@ def stream_chat(
624
  elif current_file_context:
625
  file_context = current_file_context
626
 
627
- # (2) TF-IDF ๊ธฐ๋ฐ˜ ๊ด€๋ จ ๋ฌธ์„œ ํƒ์ƒ‰
628
  wiki_context = ""
629
  try:
630
  relevant_contexts = find_relevant_context(message)
@@ -639,7 +461,7 @@ def stream_chat(
639
  except Exception as e:
640
  print(f"[์ปจํ…์ŠคํŠธ ๊ฒ€์ƒ‰ ์˜ค๋ฅ˜] {str(e)}")
641
 
642
- # (3) ๋Œ€ํ™” ์ด๋ ฅ ๊ตฌ์„ฑ
643
  max_history_length = 10
644
  if len(history) > max_history_length:
645
  history = history[-max_history_length:]
@@ -651,7 +473,7 @@ def stream_chat(
651
  {"role": "assistant", "content": answer}
652
  ])
653
 
654
- # (4) ์ตœ์ข… ๋ฉ”์‹œ์ง€ ๊ฒฐ์ •
655
  final_message = message
656
  if file_context:
657
  final_message = file_context + "\nํ˜„์žฌ ์งˆ๋ฌธ: " + message
@@ -662,13 +484,13 @@ def stream_chat(
662
 
663
  conversation.append({"role": "user", "content": final_message})
664
 
665
- # (5) ํ† ํฐํ™” ๋ฐ ํ”„๋กฌํ”„ํŠธ ๊ตฌ์ถ•
666
  input_ids_str = build_prompt(conversation)
667
  max_context = 8192
668
  tokenized_input = tokenizer(input_ids_str, return_tensors="pt")
669
  input_length = tokenized_input["input_ids"].shape[1]
670
 
671
- # (6) ์ปจํ…์ŠคํŠธ๊ฐ€ ๋„ˆ๋ฌด ๊ธธ๋ฉด ์•ž๋ถ€๋ถ„ ํ† ํฐ ์ž๋ฅด๊ธฐ
672
  if input_length > max_context - max_new_tokens:
673
  print(f"[๊ฒฝ๊ณ ] ์ž…๋ ฅ์ด ๋„ˆ๋ฌด ๊น๋‹ˆ๋‹ค: {input_length} ํ† ํฐ -> ์ž˜๋ผ๋ƒ„.")
674
  min_generation = min(256, max_new_tokens)
@@ -683,18 +505,18 @@ def stream_chat(
683
  print(f"[ํ† ํฐ ๊ธธ์ด] {input_length}")
684
  inputs = tokenized_input.to("cuda")
685
 
686
- # ๋‚จ์€ ํ† ํฐ ์ˆ˜๋กœ max_new_tokens ์กฐ์ •
687
  remaining = max_context - input_length
688
  if remaining < max_new_tokens:
689
  print(f"[max_new_tokens ์กฐ์ •] {max_new_tokens} -> {remaining}")
690
  max_new_tokens = remaining
691
 
692
- # ์ŠคํŠธ๋ฆฌ๋จธ ์„ค์ •
693
  streamer = TextIteratorStreamer(
694
  tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=True
695
  )
696
 
697
- # (7) ์ƒ์„ฑ ํŒŒ๋ผ๋ฏธํ„ฐ
698
  generate_kwargs = dict(
699
  **inputs,
700
  streamer=streamer,
@@ -704,18 +526,18 @@ def stream_chat(
704
  max_new_tokens=max_new_tokens,
705
  do_sample=True,
706
  temperature=temperature,
707
- pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id else tokenizer.eos_token_id,
708
  eos_token_id=tokenizer.eos_token_id,
709
- use_cache=True
710
  )
711
 
712
  clear_cuda_memory()
713
 
714
- # (8) ๋ณ„๋„ ์Šค๋ ˆ๋“œ์—์„œ ์ƒ์„ฑ
715
  thread = Thread(target=model.generate, kwargs=generate_kwargs)
716
  thread.start()
717
 
718
- # (9) ์ŠคํŠธ๋ฆฌ๋ฐ ์‘๋‹ต
719
  buffer = ""
720
  partial_message = ""
721
  last_yield_time = time.time()
@@ -725,23 +547,23 @@ def stream_chat(
725
  buffer += new_text
726
  partial_message += new_text
727
 
728
- # ์ผ์ • ์‹œ๊ฐ„ ๋˜๋Š” ๋ฒ„ํผ ๊ธธ์ด ๊ธฐ์ค€์œผ๋กœ yield
729
  current_time = time.time()
730
  if (current_time - last_yield_time > 0.1) or (len(partial_message) > 20):
731
  yield "", history + [[message, buffer]]
732
  partial_message = ""
733
  last_yield_time = current_time
734
 
735
- # ๋งˆ์ง€๋ง‰ ์™„์„ฑ๋œ ์‘๋‹ต
736
  if buffer:
737
  yield "", history + [[message, buffer]]
738
 
739
- # ๋Œ€ํ™” ๋‚ด์šฉ ์ €์žฅ
740
  chat_history.add_conversation(message, buffer)
741
 
742
  except Exception as e:
743
  print(f"[์ŠคํŠธ๋ฆฌ๋ฐ ์ค‘ ์˜ค๋ฅ˜] {str(e)}")
744
- if not buffer: # buffer๊ฐ€ ๋น„์–ด์žˆ๋‹ค๋ฉด ์˜ค๋ฅ˜๋ฉ”์‹œ์ง€ ๋Œ€ํ™”์ฐฝ ํ‘œ์‹œ
745
  buffer = f"์‘๋‹ต ์ƒ์„ฑ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}"
746
  yield "", history + [[message, buffer]]
747
 
@@ -835,7 +657,7 @@ def create_demo():
835
  label="Repetition Penalty ๐Ÿ”„"
836
  )
837
 
838
- # ์˜ˆ์‹œ
839
  gr.Examples(
840
  examples=[
841
  ["Please analyze this code and suggest improvements:\ndef fibonacci(n):\n if n <= 1: return n\n return fibonacci(n-1) + fibonacci(n-2)"],
@@ -852,7 +674,7 @@ def create_demo():
852
  current_file_context = None
853
  return [], None, "Start a new conversation..."
854
 
855
- # ๋ฉ”์‹œ์ง€ ์ „์†ก
856
  msg.submit(
857
  stream_chat,
858
  inputs=[msg, chatbot, file_upload, temperature, max_new_tokens, top_p, top_k, penalty],
@@ -889,7 +711,6 @@ def create_demo():
889
 
890
  return demo
891
 
892
- # ๋ฉ”์ธ ์‹คํ–‰
893
  if __name__ == "__main__":
894
  demo = create_demo()
895
  demo.launch()
 
6
  # 2) Triton์˜ cudagraphs ์ตœ์ ํ™” ๋น„ํ™œ์„ฑํ™”
7
  os.environ["TRITON_DISABLE_CUDAGRAPHS"] = "1"
8
 
9
+ # (์˜ต์…˜) ๊ฒฝ๊ณ  ๋ฌด์‹œ ์„ค์ •
10
  import warnings
11
  warnings.filterwarnings("ignore", message="skipping cudagraphs due to mutated inputs")
12
  warnings.filterwarnings("ignore", message="Not enough SMs to use max_autotune_gemm mode")
 
15
  # TensorFloat32 ์—ฐ์‚ฐ ํ™œ์„ฑํ™” (์„ฑ๋Šฅ ์ตœ์ ํ™”)
16
  torch.set_float32_matmul_precision('high')
17
 
 
18
  import torch._inductor
19
  torch._inductor.config.triton.cudagraphs = False
20
 
 
21
  import torch._dynamo
22
+ # suppress_errors (์˜ค๋ฅ˜ ์‹œ eager๋กœ fallback)
23
  torch._dynamo.config.suppress_errors = True
24
 
25
  import gradio as gr
26
  import spaces
 
27
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
28
 
29
  from threading import Thread
 
30
  from datasets import load_dataset
31
  import numpy as np
32
  from sklearn.feature_extraction.text import TfidfVectorizer
33
  import pandas as pd
 
34
  import json
35
  from datetime import datetime
36
  import pyarrow.parquet as pq
 
40
  import subprocess
41
  import pytesseract
42
  from pdf2image import convert_from_path
43
+ import queue
44
+ import time
45
 
46
  # -------------------- PDF to Markdown ๋ณ€ํ™˜ ๊ด€๋ จ import --------------------
47
  try:
 
66
  # ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ์„ค์ •
67
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
68
  MODEL_ID = "CohereForAI/c4ai-command-r7b-12-2024"
 
69
  MODEL_NAME = MODEL_ID.split("/")[-1]
70
 
71
  model = None # ์ „์—ญ์—์„œ ๊ด€๋ฆฌ
 
75
  wiki_dataset = load_dataset("lcw99/wikipedia-korean-20240501-1million-qna")
76
  print("Wikipedia dataset loaded:", wiki_dataset)
77
 
78
+ # (2) TF-IDF ๋ฒกํ„ฐ๋ผ์ด์ € ์ดˆ๊ธฐํ™” ๋ฐ ํ•™์Šต (์ผ๋ถ€๋งŒ ์‚ฌ์šฉ)
79
  print("TF-IDF ๋ฒกํ„ฐํ™” ์‹œ์ž‘...")
80
+ questions = wiki_dataset['train']['question'][:10000]
81
  vectorizer = TfidfVectorizer(max_features=1000)
82
  question_vectors = vectorizer.fit_transform(questions)
83
  print("TF-IDF ๋ฒกํ„ฐํ™” ์™„๋ฃŒ")
 
138
  print(f"ํžˆ์Šคํ† ๋ฆฌ ๋กœ๋“œ ์‹คํŒจ: {e}")
139
  self.history = []
140
 
 
141
  chat_history = ChatHistory()
142
 
143
  # ------------------------- ์œ„ํ‚ค ๋ฌธ์„œ ๊ฒ€์ƒ‰ (TF-IDF) -------------------------
144
  def find_relevant_context(query, top_k=3):
 
145
  query_vector = vectorizer.transform([query])
 
146
  similarities = (query_vector * question_vectors.T).toarray()[0]
 
147
  top_indices = np.argsort(similarities)[-top_k:][::-1]
148
 
149
  relevant_contexts = []
 
156
  })
157
  return relevant_contexts
158
 
 
159
  def init_msg():
160
  return "ํŒŒ์ผ์„ ๋ถ„์„ํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค..."
161
 
162
  # -------------------- PDF ํŒŒ์ผ์„ Markdown์œผ๋กœ ๋ณ€ํ™˜ํ•˜๋Š” ์œ ํ‹ธ ํ•จ์ˆ˜๋“ค --------------------
163
  def extract_text_from_pdf(reader: PdfReader) -> str:
 
 
 
164
  full_text = ""
165
  for idx, page in enumerate(reader.pages):
166
  text = page.extract_text() or ""
 
169
  return full_text.strip()
170
 
171
  def convert_pdf_to_markdown(pdf_file: str):
 
 
 
 
172
  try:
173
  reader = PdfReader(pdf_file)
174
  except Exception as e:
175
  return f"PDF ํŒŒ์ผ์„ ์ฝ๋Š” ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}", None, None
176
 
 
177
  raw_meta = reader.metadata
178
  metadata = {
179
  "author": raw_meta.author if raw_meta else None,
 
183
  "title": raw_meta.title if raw_meta else None,
184
  }
185
 
 
186
  full_text = extract_text_from_pdf(reader)
187
 
 
188
  image_count = sum(len(page.images) for page in reader.pages)
189
  if image_count > 0 and len(full_text) < 1000:
190
  try:
191
  out_pdf_file = pdf_file.replace(".pdf", "_ocr.pdf")
192
  ocrmypdf.ocr(pdf_file, out_pdf_file, force_ocr=True)
 
193
  reader_ocr = PdfReader(out_pdf_file)
194
  full_text = extract_text_from_pdf(reader_ocr)
195
  except Exception as e:
 
199
 
200
  # ------------------------- ํŒŒ์ผ ๋ถ„์„ ํ•จ์ˆ˜ -------------------------
201
  def analyze_file_content(content, file_type):
 
202
  if file_type in ['parquet', 'csv']:
203
  try:
204
  lines = content.split('\n')
 
224
  return f"๐Ÿ“ Document Structure: {total_lines} lines, {paragraphs} paragraphs, approximately {words} words"
225
 
226
  def read_uploaded_file(file):
 
 
 
227
  if file is None:
228
  return "", ""
229
 
230
+ import pyarrow.parquet as pq
231
+ import pandas as pd
232
+ from tabulate import tabulate
233
+
234
  try:
235
  file_ext = os.path.splitext(file.name)[1].lower()
236
 
 
237
  if file_ext == '.parquet':
238
  try:
239
  table = pq.read_table(file.name)
 
269
  except Exception as e:
270
  return f"Error reading Parquet file: {str(e)}", "error"
271
 
272
+ elif file_ext == '.pdf':
 
273
  try:
274
  markdown_text, metadata, processed_pdf_path = convert_pdf_to_markdown(file.name)
275
  if metadata is None:
 
279
  content += "## Metadata\n"
280
  for k, v in metadata.items():
281
  content += f"**{k.capitalize()}**: {v}\n\n"
 
282
  content += "## Extracted Text\n\n"
283
  content += markdown_text
284
+
285
  return content, "pdf"
286
  except Exception as e:
287
  return f"Error reading PDF file: {str(e)}", "error"
288
 
 
289
  elif file_ext == '.csv':
290
  encodings = ['utf-8', 'cp949', 'euc-kr', 'latin1']
291
  for encoding in encodings:
 
318
  f"Unable to read file with supported encodings ({', '.join(encodings)})"
319
  )
320
 
 
321
  else:
322
  encodings = ['utf-8', 'cp949', 'euc-kr', 'latin1']
323
  for encoding in encodings:
 
333
  for keyword in ['def ', 'class ', 'import ', 'function']
334
  )
335
 
336
+ analysis = "\n๐Ÿ“ File Analysis:\n"
337
  if is_code:
338
  functions = sum('def ' in line for line in lines)
339
  classes = sum('class ' in line for line in lines)
 
349
  else:
350
  words = len(content.split())
351
  chars = len(content)
 
352
  analysis += f"- File Type: Text\n"
353
  analysis += f"- Total Lines: {total_lines:,}\n"
354
  analysis += f"- Non-empty Lines: {non_empty_lines:,}\n"
 
369
 
370
  # ------------------------- CSS -------------------------
371
  CSS = """
372
+ /* (์ƒ๋žต: ๋™์ผ) */
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373
  """
374
 
375
  def clear_cuda_memory():
 
376
  if hasattr(torch.cuda, 'empty_cache'):
377
  with torch.cuda.device('cuda'):
378
  torch.cuda.empty_cache()
 
388
  device_map="auto",
389
  low_cpu_mem_usage=True,
390
  )
391
+ # (์ค‘์š”) ๋ชจ๋ธ ๊ธฐ๋ณธ config์—์„œ๋„ ์บ์‹œ ์‚ฌ์šฉ ๊บผ๋‘˜ ์ˆ˜ ์žˆ์Œ
392
+ loaded_model.config.use_cache = False
393
  return loaded_model
394
  except Exception as e:
395
  print(f"๋ชจ๋ธ ๋กœ๋“œ ์˜ค๋ฅ˜: {str(e)}")
396
  raise
397
 
398
  def build_prompt(conversation: list) -> str:
 
399
  prompt = ""
400
  for msg in conversation:
401
  if msg["role"] == "user":
 
420
  global model, current_file_context
421
 
422
  try:
 
423
  if model is None:
424
  model = load_model()
425
 
426
  print(f'[User input] message: {message}')
427
  print(f'[History] {history}')
428
 
429
+ # 1) ํŒŒ์ผ ์—…๋กœ๋“œ ์ฒ˜๋ฆฌ
430
  file_context = ""
431
  if uploaded_file and message == "ํŒŒ์ผ์„ ๋ถ„์„ํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค...":
432
  current_file_context = None
 
446
  elif current_file_context:
447
  file_context = current_file_context
448
 
449
+ # 2) ์œ„ํ‚ค ์ปจํ…์ŠคํŠธ
450
  wiki_context = ""
451
  try:
452
  relevant_contexts = find_relevant_context(message)
 
461
  except Exception as e:
462
  print(f"[์ปจํ…์ŠคํŠธ ๊ฒ€์ƒ‰ ์˜ค๋ฅ˜] {str(e)}")
463
 
464
+ # 3) ๋Œ€ํ™” ์ด๋ ฅ ์ถ•์†Œ
465
  max_history_length = 10
466
  if len(history) > max_history_length:
467
  history = history[-max_history_length:]
 
473
  {"role": "assistant", "content": answer}
474
  ])
475
 
476
+ # 4) ์ตœ์ข… ๋ฉ”์‹œ์ง€
477
  final_message = message
478
  if file_context:
479
  final_message = file_context + "\nํ˜„์žฌ ์งˆ๋ฌธ: " + message
 
484
 
485
  conversation.append({"role": "user", "content": final_message})
486
 
487
+ # 5) ํ† ํฐํ™”
488
  input_ids_str = build_prompt(conversation)
489
  max_context = 8192
490
  tokenized_input = tokenizer(input_ids_str, return_tensors="pt")
491
  input_length = tokenized_input["input_ids"].shape[1]
492
 
493
+ # 6) ์ปจํ…์ŠคํŠธ ์ดˆ๊ณผ ์‹œ ์ž๋ฅด๊ธฐ
494
  if input_length > max_context - max_new_tokens:
495
  print(f"[๊ฒฝ๊ณ ] ์ž…๋ ฅ์ด ๋„ˆ๋ฌด ๊น๋‹ˆ๋‹ค: {input_length} ํ† ํฐ -> ์ž˜๋ผ๋ƒ„.")
496
  min_generation = min(256, max_new_tokens)
 
505
  print(f"[ํ† ํฐ ๊ธธ์ด] {input_length}")
506
  inputs = tokenized_input.to("cuda")
507
 
508
+ # 7) ๋‚จ์€ ํ† ํฐ ์ˆ˜๋กœ max_new_tokens ๋ณด์ •
509
  remaining = max_context - input_length
510
  if remaining < max_new_tokens:
511
  print(f"[max_new_tokens ์กฐ์ •] {max_new_tokens} -> {remaining}")
512
  max_new_tokens = remaining
513
 
514
+ # 8) TextIteratorStreamer ์„ค์ •
515
  streamer = TextIteratorStreamer(
516
  tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=True
517
  )
518
 
519
+ # โ˜… use_cache=False ์„ค์ • (์ค‘์š”) โ˜…
520
  generate_kwargs = dict(
521
  **inputs,
522
  streamer=streamer,
 
526
  max_new_tokens=max_new_tokens,
527
  do_sample=True,
528
  temperature=temperature,
529
+ pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
530
  eos_token_id=tokenizer.eos_token_id,
531
+ use_cache=False, # โ† ์—ฌ๊ธฐ๊ฐ€ ํ•ต์‹ฌ!
532
  )
533
 
534
  clear_cuda_memory()
535
 
536
+ # 9) ๋ณ„๋„ ์Šค๋ ˆ๋“œ๋กœ ๋ชจ๋ธ ํ˜ธ์ถœ
537
  thread = Thread(target=model.generate, kwargs=generate_kwargs)
538
  thread.start()
539
 
540
+ # 10) ์ŠคํŠธ๋ฆฌ๋ฐ
541
  buffer = ""
542
  partial_message = ""
543
  last_yield_time = time.time()
 
547
  buffer += new_text
548
  partial_message += new_text
549
 
550
+ # ํƒ€์ด๋ฐ or ์ผ์ • ๊ธธ์ด๋งˆ๋‹ค UI ์—…๋ฐ์ดํŠธ
551
  current_time = time.time()
552
  if (current_time - last_yield_time > 0.1) or (len(partial_message) > 20):
553
  yield "", history + [[message, buffer]]
554
  partial_message = ""
555
  last_yield_time = current_time
556
 
557
+ # ๋งˆ์ง€๋ง‰ ์ถœ๋ ฅ
558
  if buffer:
559
  yield "", history + [[message, buffer]]
560
 
561
+ # ๋Œ€ํ™” ํžˆ์Šคํ† ๋ฆฌ ์ €์žฅ
562
  chat_history.add_conversation(message, buffer)
563
 
564
  except Exception as e:
565
  print(f"[์ŠคํŠธ๋ฆฌ๋ฐ ์ค‘ ์˜ค๋ฅ˜] {str(e)}")
566
+ if not buffer:
567
  buffer = f"์‘๋‹ต ์ƒ์„ฑ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}"
568
  yield "", history + [[message, buffer]]
569
 
 
657
  label="Repetition Penalty ๐Ÿ”„"
658
  )
659
 
660
+ # ์˜ˆ์‹œ ์ž…๋ ฅ
661
  gr.Examples(
662
  examples=[
663
  ["Please analyze this code and suggest improvements:\ndef fibonacci(n):\n if n <= 1: return n\n return fibonacci(n-1) + fibonacci(n-2)"],
 
674
  current_file_context = None
675
  return [], None, "Start a new conversation..."
676
 
677
+ # ๋ฉ”์‹œ์ง€ ์ „์†ก(Submit)
678
  msg.submit(
679
  stream_chat,
680
  inputs=[msg, chatbot, file_upload, temperature, max_new_tokens, top_p, top_k, penalty],
 
711
 
712
  return demo
713
 
 
714
  if __name__ == "__main__":
715
  demo = create_demo()
716
  demo.launch()