openfree commited on
Commit
e6c14df
·
verified ·
1 Parent(s): 3e45a0e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +133 -153
app.py CHANGED
@@ -1,15 +1,31 @@
1
  import os
2
- # Dynamo 완전 비활성화
 
3
  os.environ["TORCH_DYNAMO_DISABLE"] = "1"
4
 
 
 
 
 
 
 
 
 
5
  import torch
6
- # 성능 최적화를 위한 설정 (TensorFloat32 연산 활성화)
7
  torch.set_float32_matmul_precision('high')
 
 
8
  import torch._inductor
9
  torch._inductor.config.triton.cudagraphs = False
 
 
10
  import torch._dynamo
 
 
11
  import gradio as gr
12
  import spaces
 
13
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
14
 
15
  from threading import Thread
@@ -24,16 +40,14 @@ from datetime import datetime
24
  import pyarrow.parquet as pq
25
  import pypdf
26
  import io
27
- import pyarrow.parquet as pq
28
- from tabulate import tabulate
29
  import platform
30
  import subprocess
31
  import pytesseract
32
  from pdf2image import convert_from_path
33
- import queue # 추가: queue.Empty 예외 처리를 위해
34
- import time # 추가: 스트리밍 타이밍을 위해
35
 
36
- # -------------------- 추가: PDF to Markdown 변환 관련 import --------------------
37
  try:
38
  import re
39
  import requests
@@ -50,9 +64,6 @@ except ModuleNotFoundError as e:
50
  )
51
  # ---------------------------------------------------------------------------
52
 
53
- # 1) Dynamo suppress_errors 옵션 사용 (오류 시 eager로 fallback)
54
- torch._dynamo.config.suppress_errors = True
55
-
56
  # 전역 변수
57
  current_file_context = None
58
 
@@ -62,21 +73,21 @@ MODEL_ID = "CohereForAI/c4ai-command-r7b-12-2024"
62
  MODELS = os.environ.get("MODELS")
63
  MODEL_NAME = MODEL_ID.split("/")[-1]
64
 
65
- model = None # 전역 변수로 선언
66
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
67
 
68
- # 위키피디아 데이터셋 로드
69
  wiki_dataset = load_dataset("lcw99/wikipedia-korean-20240501-1million-qna")
70
  print("Wikipedia dataset loaded:", wiki_dataset)
71
 
72
- # TF-IDF 벡터라이저 초기화 및 학습
73
  print("TF-IDF 벡터화 시작...")
74
  questions = wiki_dataset['train']['question'][:10000] # 처음 10000개만 사용
75
  vectorizer = TfidfVectorizer(max_features=1000)
76
  question_vectors = vectorizer.fit_transform(questions)
77
  print("TF-IDF 벡터화 완료")
78
 
79
-
80
  class ChatHistory:
81
  def __init__(self):
82
  self.history = []
@@ -132,19 +143,18 @@ class ChatHistory:
132
  print(f"히스토리 로드 실패: {e}")
133
  self.history = []
134
 
135
-
136
- # 전역 ChatHistory 인스턴스 생성
137
  chat_history = ChatHistory()
138
 
139
-
140
  def find_relevant_context(query, top_k=3):
141
  # 쿼리 벡터화
142
  query_vector = vectorizer.transform([query])
143
- # 코사인 유사도 계산
144
  similarities = (query_vector * question_vectors.T).toarray()[0]
145
- # 가장 유사한 질문들의 인덱스
146
  top_indices = np.argsort(similarities)[-top_k:][::-1]
147
- # 관련 컨텍스트 추출
148
  relevant_contexts = []
149
  for idx in top_indices:
150
  if similarities[idx] > 0:
@@ -155,16 +165,14 @@ def find_relevant_context(query, top_k=3):
155
  })
156
  return relevant_contexts
157
 
158
-
159
  def init_msg():
160
  return "파일을 분석하고 있습니다..."
161
 
162
-
163
  # -------------------- PDF 파일을 Markdown으로 변환하는 유틸 함수들 --------------------
164
  def extract_text_from_pdf(reader: PdfReader) -> str:
165
  """
166
  PyPDF를 사용해 모든 페이지 텍스트를 추출.
167
- 만약 텍스트가 없으면 빈 문자열 반환.
168
  """
169
  full_text = ""
170
  for idx, page in enumerate(reader.pages):
@@ -173,20 +181,17 @@ def extract_text_from_pdf(reader: PdfReader) -> str:
173
  full_text += f"---- Page {idx+1} ----\n" + text + "\n\n"
174
  return full_text.strip()
175
 
176
-
177
  def convert_pdf_to_markdown(pdf_file: str):
178
  """
179
- PDF 파일을 읽고 텍스트를 추출한 뒤,
180
- 이미지가 많고 텍스트가 적은 경우에는 OCR 시도한다.
181
- 최종적으로 Markdown 형식으로 변환 가능한 텍스트를 반환한다.
182
- 메타데이터도 함께 반환.
183
  """
184
  try:
185
  reader = PdfReader(pdf_file)
186
  except Exception as e:
187
  return f"PDF 파일을 읽는 중 오류 발생: {e}", None, None
188
 
189
- # Extract metadata
190
  raw_meta = reader.metadata
191
  metadata = {
192
  "author": raw_meta.author if raw_meta else None,
@@ -196,19 +201,16 @@ def convert_pdf_to_markdown(pdf_file: str):
196
  "title": raw_meta.title if raw_meta else None,
197
  }
198
 
199
- # Extract text
200
  full_text = extract_text_from_pdf(reader)
201
 
202
- # 이미지가 많고 텍스트가 너무 짧으면 OCR 시도
203
- image_count = 0
204
- for page in reader.pages:
205
- image_count += len(page.images)
206
-
207
  if image_count > 0 and len(full_text) < 1000:
208
  try:
209
  out_pdf_file = pdf_file.replace(".pdf", "_ocr.pdf")
210
  ocrmypdf.ocr(pdf_file, out_pdf_file, force_ocr=True)
211
- # Re-extract text from OCR-processed PDF
212
  reader_ocr = PdfReader(out_pdf_file)
213
  full_text = extract_text_from_pdf(reader_ocr)
214
  except Exception as e:
@@ -216,11 +218,9 @@ def convert_pdf_to_markdown(pdf_file: str):
216
 
217
  return full_text, metadata, pdf_file
218
 
219
-
220
- # ---------------------------------------------------------------------------
221
-
222
  def analyze_file_content(content, file_type):
223
- """파일 내용을 간단히 분석한 후 구조 요약을 반환."""
224
  if file_type in ['parquet', 'csv']:
225
  try:
226
  lines = content.split('\n')
@@ -245,15 +245,13 @@ def analyze_file_content(content, file_type):
245
  words = len(content.split())
246
  return f"📝 Document Structure: {total_lines} lines, {paragraphs} paragraphs, approximately {words} words"
247
 
248
-
249
  def read_uploaded_file(file):
250
  """
251
- 업로드된 파일을 처리하여
252
- 1) 파일 타입별로 내용을 읽고
253
- 2) 분석 결과와 함께 반환
254
  """
255
  if file is None:
256
  return "", ""
 
257
  try:
258
  file_ext = os.path.splitext(file.name)[1].lower()
259
 
@@ -267,7 +265,8 @@ def read_uploaded_file(file):
267
  content += f"1. Basic Information:\n"
268
  content += f"- Total Rows: {len(df):,}\n"
269
  content += f"- Total Columns: {len(df.columns)}\n"
270
- content += f"- Memory Usage: {df.memory_usage(deep=True).sum() / 1024 / 1024:.2f} MB\n\n"
 
271
 
272
  content += f"2. Column Information:\n"
273
  for col in df.columns:
@@ -279,7 +278,8 @@ def read_uploaded_file(file):
279
  content += f"\n\n4. Missing Values:\n"
280
  null_counts = df.isnull().sum()
281
  for col, count in null_counts[null_counts > 0].items():
282
- content += f"- {col}: {count:,} ({count/len(df)*100:.1f}%)\n"
 
283
 
284
  numeric_cols = df.select_dtypes(include=['int64', 'float64']).columns
285
  if len(numeric_cols) > 0:
@@ -291,7 +291,7 @@ def read_uploaded_file(file):
291
  except Exception as e:
292
  return f"Error reading Parquet file: {str(e)}", "error"
293
 
294
- # PDF (Markdown 변환)
295
  if file_ext == '.pdf':
296
  try:
297
  markdown_text, metadata, processed_pdf_path = convert_pdf_to_markdown(file.name)
@@ -305,7 +305,6 @@ def read_uploaded_file(file):
305
 
306
  content += "## Extracted Text\n\n"
307
  content += markdown_text
308
-
309
  return content, "pdf"
310
  except Exception as e:
311
  return f"Error reading PDF file: {str(e)}", "error"
@@ -320,7 +319,8 @@ def read_uploaded_file(file):
320
  content += f"1. Basic Information:\n"
321
  content += f"- Total Rows: {len(df):,}\n"
322
  content += f"- Total Columns: {len(df.columns)}\n"
323
- content += f"- Memory Usage: {df.memory_usage(deep=True).sum() / 1024 / 1024:.2f} MB\n\n"
 
324
 
325
  content += f"2. Column Information:\n"
326
  for col in df.columns:
@@ -332,14 +332,17 @@ def read_uploaded_file(file):
332
  content += f"\n\n4. Missing Values:\n"
333
  null_counts = df.isnull().sum()
334
  for col, count in null_counts[null_counts > 0].items():
335
- content += f"- {col}: {count:,} ({count/len(df)*100:.1f}%)\n"
 
336
 
337
  return content, "csv"
338
  except UnicodeDecodeError:
339
  continue
340
- raise UnicodeDecodeError(f"Unable to read file with supported encodings ({', '.join(encodings)})")
 
 
341
 
342
- # 일반 텍스트 파일
343
  else:
344
  encodings = ['utf-8', 'cp949', 'euc-kr', 'latin1']
345
  for encoding in encodings:
@@ -350,15 +353,19 @@ def read_uploaded_file(file):
350
  lines = content.split('\n')
351
  total_lines = len(lines)
352
  non_empty_lines = len([line for line in lines if line.strip()])
353
-
354
- is_code = any(keyword in content.lower() for keyword in ['def ', 'class ', 'import ', 'function'])
 
 
355
 
356
  analysis = f"\n📝 File Analysis:\n"
357
  if is_code:
358
- functions = len([line for line in lines if 'def ' in line])
359
- classes = len([line for line in lines if 'class ' in line])
360
- imports = len([line for line in lines if 'import ' in line or 'from ' in line])
361
-
 
 
362
  analysis += f"- File Type: Code\n"
363
  analysis += f"- Total Lines: {total_lines:,}\n"
364
  analysis += f"- Functions: {functions}\n"
@@ -375,14 +382,18 @@ def read_uploaded_file(file):
375
  analysis += f"- Character Count: {chars:,}\n"
376
 
377
  return content + analysis, "text"
 
378
  except UnicodeDecodeError:
379
  continue
380
- raise UnicodeDecodeError(f"Unable to read file with supported encodings ({', '.join(encodings)})")
 
 
 
381
 
382
  except Exception as e:
383
  return f"Error reading file: {str(e)}", "error"
384
 
385
-
386
  CSS = """
387
  /* 3D 스타일 CSS */
388
  :root {
@@ -539,22 +550,20 @@ body {
539
  """
540
 
541
  def clear_cuda_memory():
 
542
  if hasattr(torch.cuda, 'empty_cache'):
543
  with torch.cuda.device('cuda'):
544
  torch.cuda.empty_cache()
545
 
546
-
547
  @spaces.GPU
548
  def load_model():
549
  try:
550
- # 메모리 정리 먼저 수행
551
  clear_cuda_memory()
552
-
553
  loaded_model = AutoModelForCausalLM.from_pretrained(
554
  MODEL_ID,
555
  torch_dtype=torch.bfloat16,
556
  device_map="auto",
557
- # 낮은 메모리 사용을 위한 설정 추가
558
  low_cpu_mem_usage=True,
559
  )
560
  return loaded_model
@@ -562,22 +571,8 @@ def load_model():
562
  print(f"모델 로드 오류: {str(e)}")
563
  raise
564
 
565
- def _truncate_tokens_for_context(input_ids_str: str, desired_input_length: int) -> str:
566
- """
567
- 입력 문자열이 desired_input_length 토큰을 넘으면, 앞부분(오래된 컨텍스트)을 잘라내는 함수.
568
- """
569
- tokens = input_ids_str.split()
570
- if len(tokens) > desired_input_length:
571
- tokens = tokens[-desired_input_length:]
572
- return " ".join(tokens)
573
-
574
-
575
- # build_prompt 함수: 대화 내역을 문자열로 변환
576
  def build_prompt(conversation: list) -> str:
577
- """
578
- conversation은 각 항목이 {"role": "user" 또는 "assistant", "content": ...} 형태의 딕셔너리 목록입니다.
579
- 이를 단순 텍스트 프롬프트로 변환합니다.
580
- """
581
  prompt = ""
582
  for msg in conversation:
583
  if msg["role"] == "user":
@@ -587,7 +582,7 @@ def build_prompt(conversation: list) -> str:
587
  prompt += "Assistant: "
588
  return prompt
589
 
590
-
591
  @spaces.GPU
592
  def stream_chat(
593
  message: str,
@@ -602,13 +597,14 @@ def stream_chat(
602
  global model, current_file_context
603
 
604
  try:
 
605
  if model is None:
606
  model = load_model()
607
 
608
- print(f'message is - {message}')
609
- print(f'history is - {history}')
610
 
611
- # 파일 업로드 처리
612
  file_context = ""
613
  if uploaded_file and message == "파일을 분석하고 있습니다...":
614
  current_file_context = None
@@ -623,23 +619,16 @@ def stream_chat(
623
  current_file_context = file_context
624
  message = "업로드된 파일을 분석해주세요."
625
  except Exception as e:
626
- print(f"파일 분석 오류: {str(e)}")
627
  file_context = f"\n\n❌ 파일 분석 중 오류가 발생했습니다: {str(e)}"
628
  elif current_file_context:
629
  file_context = current_file_context
630
 
631
- if torch.cuda.is_available():
632
- print(f"CUDA 메모리 사용량: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
633
-
634
- max_history_length = 10
635
- if len(history) > max_history_length:
636
- history = history[-max_history_length:]
637
-
638
- # 위키피디아 컨텍스트 검색
639
  wiki_context = ""
640
  try:
641
  relevant_contexts = find_relevant_context(message)
642
- if relevant_contexts: # 결과가 있을 경우만 추가
643
  wiki_context = "\n\n관련 위키피디아 정보:\n"
644
  for ctx in relevant_contexts:
645
  wiki_context += (
@@ -648,9 +637,13 @@ def stream_chat(
648
  f"유사도: {ctx['similarity']:.3f}\n\n"
649
  )
650
  except Exception as e:
651
- print(f"컨텍스트 검색 오류: {str(e)}")
 
 
 
 
 
652
 
653
- # 대화 내역 구성
654
  conversation = []
655
  for prompt, answer in history:
656
  conversation.extend([
@@ -658,7 +651,7 @@ def stream_chat(
658
  {"role": "assistant", "content": answer}
659
  ])
660
 
661
- # 최종 메시지 구성
662
  final_message = message
663
  if file_context:
664
  final_message = file_context + "\n현재 질문: " + message
@@ -666,53 +659,42 @@ def stream_chat(
666
  final_message = wiki_context + "\n현재 질문: " + message
667
  if file_context and wiki_context:
668
  final_message = file_context + wiki_context + "\n현재 질문: " + message
669
-
670
  conversation.append({"role": "user", "content": final_message})
671
 
672
- # 프롬프트 구성토큰화
673
  input_ids_str = build_prompt(conversation)
674
-
675
- # 먼저 컨텍스트 길이 확인 및 제한
676
  max_context = 8192
677
  tokenized_input = tokenizer(input_ids_str, return_tensors="pt")
678
  input_length = tokenized_input["input_ids"].shape[1]
679
-
680
- # 컨텍스트가 너무 길면 자르기
681
  if input_length > max_context - max_new_tokens:
682
- print(f"입력이 너무 깁니다: {input_length} 토큰. 자르는 중...")
683
- # 최소 생성 토큰 수 확보
684
  min_generation = min(256, max_new_tokens)
685
  new_desired_input_length = max_context - min_generation
686
-
687
- # 입력 텍스트를 토큰 단위로 자르기
688
  tokens = tokenizer.encode(input_ids_str)
689
  if len(tokens) > new_desired_input_length:
690
  tokens = tokens[-new_desired_input_length:]
691
  input_ids_str = tokenizer.decode(tokens)
692
-
693
- # 다시 토큰화
694
  tokenized_input = tokenizer(input_ids_str, return_tensors="pt")
695
  input_length = tokenized_input["input_ids"].shape[1]
696
-
697
- print(f"최종 입력 길이: {input_length} 토큰")
698
-
699
- # CUDA로 입력 이동
700
  inputs = tokenized_input.to("cuda")
701
-
702
- # 남은 토큰 계산 및 max_new_tokens 조정
703
  remaining = max_context - input_length
704
  if remaining < max_new_tokens:
705
- print(f"max_new_tokens 조정: {max_new_tokens} -> {remaining}")
706
  max_new_tokens = remaining
707
 
708
- print(f"입력 텐서 생성 후 CUDA 메모리: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
709
-
710
  # 스트리머 설정
711
  streamer = TextIteratorStreamer(
712
  tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=True
713
  )
714
-
715
- # 생성 매개변수 설정
716
  generate_kwargs = dict(
717
  **inputs,
718
  streamer=streamer,
@@ -727,63 +709,56 @@ def stream_chat(
727
  use_cache=True
728
  )
729
 
730
- # 메모리 정리
731
  clear_cuda_memory()
732
 
733
- # 별도 스레드에서 생성 실행
734
  thread = Thread(target=model.generate, kwargs=generate_kwargs)
735
  thread.start()
736
 
737
- # 응답 스트리밍
738
  buffer = ""
739
  partial_message = ""
740
  last_yield_time = time.time()
741
-
742
  try:
743
  for new_text in streamer:
744
- try:
745
- buffer += new_text
746
- partial_message += new_text
747
-
748
- # 일정 시간마다 또는 텍스트가 쌓일 때마다 결과 업데이트
749
- current_time = time.time()
750
- if current_time - last_yield_time > 0.1 or len(partial_message) > 20:
751
- yield "", history + [[message, buffer]]
752
- partial_message = ""
753
- last_yield_time = current_time
754
- except Exception as inner_e:
755
- print(f"개별 토큰 처리 중 오류: {str(inner_e)}")
756
- continue
757
-
758
- # 마지막 응답 확인
759
  if buffer:
760
  yield "", history + [[message, buffer]]
761
-
762
- # 대화 기록에 저장
763
  chat_history.add_conversation(message, buffer)
764
-
765
  except Exception as e:
766
- print(f"스트리밍 중 오류 발생: {str(e)}")
767
- if not buffer: # 버퍼가 비어있으면 오류 메시지 표시
768
- buffer = f"응답 생성 중 오류가 발생했습니다: {str(e)}"
769
  yield "", history + [[message, buffer]]
770
-
771
- # 스레드가 여전히 실행 중이면 종료 대기
772
  if thread.is_alive():
773
  thread.join(timeout=5.0)
774
-
775
- # 메모리 정리
776
  clear_cuda_memory()
777
 
778
  except Exception as e:
779
  import traceback
780
  error_details = traceback.format_exc()
781
  error_message = f"오류가 발생했습니다: {str(e)}\n{error_details}"
782
- print(f"Stream chat 오류: {error_message}")
783
  clear_cuda_memory()
784
  yield "", history + [[message, error_message]]
785
 
786
-
787
  def create_demo():
788
  with gr.Blocks(css=CSS) as demo:
789
  with gr.Column(elem_classes="markdown-style"):
@@ -834,6 +809,7 @@ def create_demo():
834
  scale=1
835
  )
836
 
 
837
  with gr.Accordion("🎮 Advanced Settings", open=False):
838
  with gr.Row():
839
  with gr.Column(scale=1):
@@ -859,6 +835,7 @@ def create_demo():
859
  label="Repetition Penalty 🔄"
860
  )
861
 
 
862
  gr.Examples(
863
  examples=[
864
  ["Please analyze this code and suggest improvements:\ndef fibonacci(n):\n if n <= 1: return n\n return fibonacci(n-1) + fibonacci(n-2)"],
@@ -869,23 +846,25 @@ def create_demo():
869
  inputs=msg
870
  )
871
 
 
872
  def clear_conversation():
873
  global current_file_context
874
  current_file_context = None
875
  return [], None, "Start a new conversation..."
876
 
 
877
  msg.submit(
878
  stream_chat,
879
  inputs=[msg, chatbot, file_upload, temperature, max_new_tokens, top_p, top_k, penalty],
880
  outputs=[msg, chatbot]
881
  )
882
-
883
  send.click(
884
  stream_chat,
885
  inputs=[msg, chatbot, file_upload, temperature, max_new_tokens, top_p, top_k, penalty],
886
  outputs=[msg, chatbot]
887
  )
888
 
 
889
  file_upload.change(
890
  fn=lambda: ("처리 중...", [["시스템", "파일을 분석 중입니다. 잠시만 기다려주세요..."]]),
891
  outputs=[msg, chatbot],
@@ -901,6 +880,7 @@ def create_demo():
901
  queue=True
902
  )
903
 
 
904
  clear.click(
905
  fn=clear_conversation,
906
  outputs=[chatbot, file_upload, msg],
@@ -909,7 +889,7 @@ def create_demo():
909
 
910
  return demo
911
 
912
-
913
  if __name__ == "__main__":
914
  demo = create_demo()
915
- demo.launch()
 
1
  import os
2
+
3
+ # 1) Dynamo 완전 비활성화
4
  os.environ["TORCH_DYNAMO_DISABLE"] = "1"
5
 
6
+ # 2) Triton의 cudagraphs 최적화 비활성화
7
+ os.environ["TRITON_DISABLE_CUDAGRAPHS"] = "1"
8
+
9
+ # 3) 경고 무시 설정 (skipping cudagraphs 관련)
10
+ import warnings
11
+ warnings.filterwarnings("ignore", message="skipping cudagraphs due to mutated inputs")
12
+ warnings.filterwarnings("ignore", message="Not enough SMs to use max_autotune_gemm mode")
13
+
14
  import torch
15
+ # TensorFloat32 연산 활성화 (성능 최적화)
16
  torch.set_float32_matmul_precision('high')
17
+
18
+ # TorchInductor cudagraphs 비활성화
19
  import torch._inductor
20
  torch._inductor.config.triton.cudagraphs = False
21
+
22
+ # Dynamo suppress_errors 옵션 (오류 시 eager로 fallback)
23
  import torch._dynamo
24
+ torch._dynamo.config.suppress_errors = True
25
+
26
  import gradio as gr
27
  import spaces
28
+
29
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
30
 
31
  from threading import Thread
 
40
  import pyarrow.parquet as pq
41
  import pypdf
42
  import io
 
 
43
  import platform
44
  import subprocess
45
  import pytesseract
46
  from pdf2image import convert_from_path
47
+ import queue # queue.Empty 예외 처리를 위해
48
+ import time # 스트리밍 타이밍을 위해
49
 
50
+ # -------------------- PDF to Markdown 변환 관련 import --------------------
51
  try:
52
  import re
53
  import requests
 
64
  )
65
  # ---------------------------------------------------------------------------
66
 
 
 
 
67
  # 전역 변수
68
  current_file_context = None
69
 
 
73
  MODELS = os.environ.get("MODELS")
74
  MODEL_NAME = MODEL_ID.split("/")[-1]
75
 
76
+ model = None # 전역에서 관리
77
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
78
 
79
+ # (1) 위키피디아 데이터셋 로드
80
  wiki_dataset = load_dataset("lcw99/wikipedia-korean-20240501-1million-qna")
81
  print("Wikipedia dataset loaded:", wiki_dataset)
82
 
83
+ # (2) TF-IDF 벡터라이저 초기화 및 학습
84
  print("TF-IDF 벡터화 시작...")
85
  questions = wiki_dataset['train']['question'][:10000] # 처음 10000개만 사용
86
  vectorizer = TfidfVectorizer(max_features=1000)
87
  question_vectors = vectorizer.fit_transform(questions)
88
  print("TF-IDF 벡터화 완료")
89
 
90
+ # ------------------------- ChatHistory 클래스 -------------------------
91
  class ChatHistory:
92
  def __init__(self):
93
  self.history = []
 
143
  print(f"히스토리 로드 실패: {e}")
144
  self.history = []
145
 
146
+ # 전역 ChatHistory 인스턴스
 
147
  chat_history = ChatHistory()
148
 
149
+ # ------------------------- 위키 문서 검색 (TF-IDF) -------------------------
150
  def find_relevant_context(query, top_k=3):
151
  # 쿼리 벡터화
152
  query_vector = vectorizer.transform([query])
153
+ # 코사인 유사도
154
  similarities = (query_vector * question_vectors.T).toarray()[0]
155
+ # 유사도 높은 질문 인덱스
156
  top_indices = np.argsort(similarities)[-top_k:][::-1]
157
+
158
  relevant_contexts = []
159
  for idx in top_indices:
160
  if similarities[idx] > 0:
 
165
  })
166
  return relevant_contexts
167
 
168
+ # 파일 업로드 시 표시할 초기 메시지
169
  def init_msg():
170
  return "파일을 분석하고 있습니다..."
171
 
 
172
  # -------------------- PDF 파일을 Markdown으로 변환하는 유틸 함수들 --------------------
173
  def extract_text_from_pdf(reader: PdfReader) -> str:
174
  """
175
  PyPDF를 사용해 모든 페이지 텍스트를 추출.
 
176
  """
177
  full_text = ""
178
  for idx, page in enumerate(reader.pages):
 
181
  full_text += f"---- Page {idx+1} ----\n" + text + "\n\n"
182
  return full_text.strip()
183
 
 
184
  def convert_pdf_to_markdown(pdf_file: str):
185
  """
186
+ PDF 파일에서 텍스트를 추출하고,
187
+ 이미지가 많고 텍스트가 적으면 OCR 시도
 
 
188
  """
189
  try:
190
  reader = PdfReader(pdf_file)
191
  except Exception as e:
192
  return f"PDF 파일을 읽는 중 오류 발생: {e}", None, None
193
 
194
+ # 메타데이터 추출
195
  raw_meta = reader.metadata
196
  metadata = {
197
  "author": raw_meta.author if raw_meta else None,
 
201
  "title": raw_meta.title if raw_meta else None,
202
  }
203
 
204
+ # 텍스트 추출
205
  full_text = extract_text_from_pdf(reader)
206
 
207
+ # 이미지-텍스트 비율 판단 OCR 시도
208
+ image_count = sum(len(page.images) for page in reader.pages)
 
 
 
209
  if image_count > 0 and len(full_text) < 1000:
210
  try:
211
  out_pdf_file = pdf_file.replace(".pdf", "_ocr.pdf")
212
  ocrmypdf.ocr(pdf_file, out_pdf_file, force_ocr=True)
213
+ # OCR된 PDF 다시 읽기
214
  reader_ocr = PdfReader(out_pdf_file)
215
  full_text = extract_text_from_pdf(reader_ocr)
216
  except Exception as e:
 
218
 
219
  return full_text, metadata, pdf_file
220
 
221
+ # ------------------------- 파일 분석 함수 -------------------------
 
 
222
  def analyze_file_content(content, file_type):
223
+ """간단한 구조 분석/요약."""
224
  if file_type in ['parquet', 'csv']:
225
  try:
226
  lines = content.split('\n')
 
245
  words = len(content.split())
246
  return f"📝 Document Structure: {total_lines} lines, {paragraphs} paragraphs, approximately {words} words"
247
 
 
248
  def read_uploaded_file(file):
249
  """
250
+ 업로드된 파일 처리 -> 내용/타입
 
 
251
  """
252
  if file is None:
253
  return "", ""
254
+
255
  try:
256
  file_ext = os.path.splitext(file.name)[1].lower()
257
 
 
265
  content += f"1. Basic Information:\n"
266
  content += f"- Total Rows: {len(df):,}\n"
267
  content += f"- Total Columns: {len(df.columns)}\n"
268
+ mem_usage = df.memory_usage(deep=True).sum() / 1024 / 1024
269
+ content += f"- Memory Usage: {mem_usage:.2f} MB\n\n"
270
 
271
  content += f"2. Column Information:\n"
272
  for col in df.columns:
 
278
  content += f"\n\n4. Missing Values:\n"
279
  null_counts = df.isnull().sum()
280
  for col, count in null_counts[null_counts > 0].items():
281
+ rate = count / len(df) * 100
282
+ content += f"- {col}: {count:,} ({rate:.1f}%)\n"
283
 
284
  numeric_cols = df.select_dtypes(include=['int64', 'float64']).columns
285
  if len(numeric_cols) > 0:
 
291
  except Exception as e:
292
  return f"Error reading Parquet file: {str(e)}", "error"
293
 
294
+ # PDF
295
  if file_ext == '.pdf':
296
  try:
297
  markdown_text, metadata, processed_pdf_path = convert_pdf_to_markdown(file.name)
 
305
 
306
  content += "## Extracted Text\n\n"
307
  content += markdown_text
 
308
  return content, "pdf"
309
  except Exception as e:
310
  return f"Error reading PDF file: {str(e)}", "error"
 
319
  content += f"1. Basic Information:\n"
320
  content += f"- Total Rows: {len(df):,}\n"
321
  content += f"- Total Columns: {len(df.columns)}\n"
322
+ mem_usage = df.memory_usage(deep=True).sum() / 1024 / 1024
323
+ content += f"- Memory Usage: {mem_usage:.2f} MB\n\n"
324
 
325
  content += f"2. Column Information:\n"
326
  for col in df.columns:
 
332
  content += f"\n\n4. Missing Values:\n"
333
  null_counts = df.isnull().sum()
334
  for col, count in null_counts[null_counts > 0].items():
335
+ rate = count / len(df) * 100
336
+ content += f"- {col}: {count:,} ({rate:.1f}%)\n"
337
 
338
  return content, "csv"
339
  except UnicodeDecodeError:
340
  continue
341
+ raise UnicodeDecodeError(
342
+ f"Unable to read file with supported encodings ({', '.join(encodings)})"
343
+ )
344
 
345
+ # 텍스트 파일
346
  else:
347
  encodings = ['utf-8', 'cp949', 'euc-kr', 'latin1']
348
  for encoding in encodings:
 
353
  lines = content.split('\n')
354
  total_lines = len(lines)
355
  non_empty_lines = len([line for line in lines if line.strip()])
356
+ is_code = any(
357
+ keyword in content.lower()
358
+ for keyword in ['def ', 'class ', 'import ', 'function']
359
+ )
360
 
361
  analysis = f"\n📝 File Analysis:\n"
362
  if is_code:
363
+ functions = sum('def ' in line for line in lines)
364
+ classes = sum('class ' in line for line in lines)
365
+ imports = sum(
366
+ ('import ' in line) or ('from ' in line)
367
+ for line in lines
368
+ )
369
  analysis += f"- File Type: Code\n"
370
  analysis += f"- Total Lines: {total_lines:,}\n"
371
  analysis += f"- Functions: {functions}\n"
 
382
  analysis += f"- Character Count: {chars:,}\n"
383
 
384
  return content + analysis, "text"
385
+
386
  except UnicodeDecodeError:
387
  continue
388
+
389
+ raise UnicodeDecodeError(
390
+ f"Unable to read file with supported encodings ({', '.join(encodings)})"
391
+ )
392
 
393
  except Exception as e:
394
  return f"Error reading file: {str(e)}", "error"
395
 
396
+ # ------------------------- CSS -------------------------
397
  CSS = """
398
  /* 3D 스타일 CSS */
399
  :root {
 
550
  """
551
 
552
  def clear_cuda_memory():
553
+ """CUDA 캐시 정리."""
554
  if hasattr(torch.cuda, 'empty_cache'):
555
  with torch.cuda.device('cuda'):
556
  torch.cuda.empty_cache()
557
 
558
+ # ------------------------- 모델 로딩 함수 -------------------------
559
  @spaces.GPU
560
  def load_model():
561
  try:
 
562
  clear_cuda_memory()
 
563
  loaded_model = AutoModelForCausalLM.from_pretrained(
564
  MODEL_ID,
565
  torch_dtype=torch.bfloat16,
566
  device_map="auto",
 
567
  low_cpu_mem_usage=True,
568
  )
569
  return loaded_model
 
571
  print(f"모델 로드 오류: {str(e)}")
572
  raise
573
 
 
 
 
 
 
 
 
 
 
 
 
574
  def build_prompt(conversation: list) -> str:
575
+ """대화 내역을 단순 텍스트 프롬프트로 변환."""
 
 
 
576
  prompt = ""
577
  for msg in conversation:
578
  if msg["role"] == "user":
 
582
  prompt += "Assistant: "
583
  return prompt
584
 
585
+ # ------------------------- 메시지 스트리밍 함수 -------------------------
586
  @spaces.GPU
587
  def stream_chat(
588
  message: str,
 
597
  global model, current_file_context
598
 
599
  try:
600
+ # 모델 미로드시 로딩
601
  if model is None:
602
  model = load_model()
603
 
604
+ print(f'[User input] message: {message}')
605
+ print(f'[History] {history}')
606
 
607
+ # (1) 파일 업로드 처리
608
  file_context = ""
609
  if uploaded_file and message == "파일을 분석하고 있습니다...":
610
  current_file_context = None
 
619
  current_file_context = file_context
620
  message = "업로드된 파일을 분석해주세요."
621
  except Exception as e:
622
+ print(f"[파일 분석 오류] {str(e)}")
623
  file_context = f"\n\n❌ 파일 분석 중 오류가 발생했습니다: {str(e)}"
624
  elif current_file_context:
625
  file_context = current_file_context
626
 
627
+ # (2) TF-IDF 기반 관련 문서 탐색
 
 
 
 
 
 
 
628
  wiki_context = ""
629
  try:
630
  relevant_contexts = find_relevant_context(message)
631
+ if relevant_contexts:
632
  wiki_context = "\n\n관련 위키피디아 정보:\n"
633
  for ctx in relevant_contexts:
634
  wiki_context += (
 
637
  f"유사도: {ctx['similarity']:.3f}\n\n"
638
  )
639
  except Exception as e:
640
+ print(f"[컨텍스트 검색 오류] {str(e)}")
641
+
642
+ # (3) 대화 이력 구성
643
+ max_history_length = 10
644
+ if len(history) > max_history_length:
645
+ history = history[-max_history_length:]
646
 
 
647
  conversation = []
648
  for prompt, answer in history:
649
  conversation.extend([
 
651
  {"role": "assistant", "content": answer}
652
  ])
653
 
654
+ # (4) 최종 메시지 결정
655
  final_message = message
656
  if file_context:
657
  final_message = file_context + "\n현재 질문: " + message
 
659
  final_message = wiki_context + "\n현재 질문: " + message
660
  if file_context and wiki_context:
661
  final_message = file_context + wiki_context + "\n현재 질문: " + message
662
+
663
  conversation.append({"role": "user", "content": final_message})
664
 
665
+ # (5) 토큰화프롬프트 구축
666
  input_ids_str = build_prompt(conversation)
 
 
667
  max_context = 8192
668
  tokenized_input = tokenizer(input_ids_str, return_tensors="pt")
669
  input_length = tokenized_input["input_ids"].shape[1]
670
+
671
+ # (6) 컨텍스트가 너무 길면 앞부분 토큰 자르기
672
  if input_length > max_context - max_new_tokens:
673
+ print(f"[경고] 입력이 너무 깁니다: {input_length} 토큰 -> 잘라냄.")
 
674
  min_generation = min(256, max_new_tokens)
675
  new_desired_input_length = max_context - min_generation
 
 
676
  tokens = tokenizer.encode(input_ids_str)
677
  if len(tokens) > new_desired_input_length:
678
  tokens = tokens[-new_desired_input_length:]
679
  input_ids_str = tokenizer.decode(tokens)
 
 
680
  tokenized_input = tokenizer(input_ids_str, return_tensors="pt")
681
  input_length = tokenized_input["input_ids"].shape[1]
682
+
683
+ print(f"[토큰 길이] {input_length}")
 
 
684
  inputs = tokenized_input.to("cuda")
685
+
686
+ # 남은 토큰 수로 max_new_tokens 조정
687
  remaining = max_context - input_length
688
  if remaining < max_new_tokens:
689
+ print(f"[max_new_tokens 조정] {max_new_tokens} -> {remaining}")
690
  max_new_tokens = remaining
691
 
 
 
692
  # 스트리머 설정
693
  streamer = TextIteratorStreamer(
694
  tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=True
695
  )
696
+
697
+ # (7) 생성 파라미터
698
  generate_kwargs = dict(
699
  **inputs,
700
  streamer=streamer,
 
709
  use_cache=True
710
  )
711
 
 
712
  clear_cuda_memory()
713
 
714
+ # (8) 별도 스레드에서 생성
715
  thread = Thread(target=model.generate, kwargs=generate_kwargs)
716
  thread.start()
717
 
718
+ # (9) 스트리밍 응답
719
  buffer = ""
720
  partial_message = ""
721
  last_yield_time = time.time()
722
+
723
  try:
724
  for new_text in streamer:
725
+ buffer += new_text
726
+ partial_message += new_text
727
+
728
+ # 일정 시간 또는 버퍼 길이 기준으로 yield
729
+ current_time = time.time()
730
+ if (current_time - last_yield_time > 0.1) or (len(partial_message) > 20):
731
+ yield "", history + [[message, buffer]]
732
+ partial_message = ""
733
+ last_yield_time = current_time
734
+
735
+ # 마지막 완성된 응답
 
 
 
 
736
  if buffer:
737
  yield "", history + [[message, buffer]]
738
+
739
+ # 대화 내용 저장
740
  chat_history.add_conversation(message, buffer)
741
+
742
  except Exception as e:
743
+ print(f"[스트리밍 중 오류] {str(e)}")
744
+ if not buffer: # buffer가 비어있다면 오류메시지 대화창 표시
745
+ buffer = f"응답 생성 중 오류 발생: {str(e)}"
746
  yield "", history + [[message, buffer]]
747
+
 
748
  if thread.is_alive():
749
  thread.join(timeout=5.0)
750
+
 
751
  clear_cuda_memory()
752
 
753
  except Exception as e:
754
  import traceback
755
  error_details = traceback.format_exc()
756
  error_message = f"오류가 발생했습니다: {str(e)}\n{error_details}"
757
+ print(f"[Stream chat 오류] {error_message}")
758
  clear_cuda_memory()
759
  yield "", history + [[message, error_message]]
760
 
761
+ # ------------------------- Gradio UI 구성 -------------------------
762
  def create_demo():
763
  with gr.Blocks(css=CSS) as demo:
764
  with gr.Column(elem_classes="markdown-style"):
 
809
  scale=1
810
  )
811
 
812
+ # 고급 설정
813
  with gr.Accordion("🎮 Advanced Settings", open=False):
814
  with gr.Row():
815
  with gr.Column(scale=1):
 
835
  label="Repetition Penalty 🔄"
836
  )
837
 
838
+ # 예시
839
  gr.Examples(
840
  examples=[
841
  ["Please analyze this code and suggest improvements:\ndef fibonacci(n):\n if n <= 1: return n\n return fibonacci(n-1) + fibonacci(n-2)"],
 
846
  inputs=msg
847
  )
848
 
849
+ # 대화 내용 초기화
850
  def clear_conversation():
851
  global current_file_context
852
  current_file_context = None
853
  return [], None, "Start a new conversation..."
854
 
855
+ # 메시지 전송
856
  msg.submit(
857
  stream_chat,
858
  inputs=[msg, chatbot, file_upload, temperature, max_new_tokens, top_p, top_k, penalty],
859
  outputs=[msg, chatbot]
860
  )
 
861
  send.click(
862
  stream_chat,
863
  inputs=[msg, chatbot, file_upload, temperature, max_new_tokens, top_p, top_k, penalty],
864
  outputs=[msg, chatbot]
865
  )
866
 
867
+ # 파일 업로드 이벤트
868
  file_upload.change(
869
  fn=lambda: ("처리 중...", [["시스템", "파일을 분석 중입니다. 잠시만 기다려주세요..."]]),
870
  outputs=[msg, chatbot],
 
880
  queue=True
881
  )
882
 
883
+ # Clear 버튼
884
  clear.click(
885
  fn=clear_conversation,
886
  outputs=[chatbot, file_upload, msg],
 
889
 
890
  return demo
891
 
892
+ # 메인 실행
893
  if __name__ == "__main__":
894
  demo = create_demo()
895
+ demo.launch()