Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,15 +1,31 @@
|
|
1 |
import os
|
2 |
-
|
|
|
3 |
os.environ["TORCH_DYNAMO_DISABLE"] = "1"
|
4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
import torch
|
6 |
-
#
|
7 |
torch.set_float32_matmul_precision('high')
|
|
|
|
|
8 |
import torch._inductor
|
9 |
torch._inductor.config.triton.cudagraphs = False
|
|
|
|
|
10 |
import torch._dynamo
|
|
|
|
|
11 |
import gradio as gr
|
12 |
import spaces
|
|
|
13 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
14 |
|
15 |
from threading import Thread
|
@@ -24,16 +40,14 @@ from datetime import datetime
|
|
24 |
import pyarrow.parquet as pq
|
25 |
import pypdf
|
26 |
import io
|
27 |
-
import pyarrow.parquet as pq
|
28 |
-
from tabulate import tabulate
|
29 |
import platform
|
30 |
import subprocess
|
31 |
import pytesseract
|
32 |
from pdf2image import convert_from_path
|
33 |
-
import queue #
|
34 |
-
import time
|
35 |
|
36 |
-
# --------------------
|
37 |
try:
|
38 |
import re
|
39 |
import requests
|
@@ -50,9 +64,6 @@ except ModuleNotFoundError as e:
|
|
50 |
)
|
51 |
# ---------------------------------------------------------------------------
|
52 |
|
53 |
-
# 1) Dynamo suppress_errors 옵션 사용 (오류 시 eager로 fallback)
|
54 |
-
torch._dynamo.config.suppress_errors = True
|
55 |
-
|
56 |
# 전역 변수
|
57 |
current_file_context = None
|
58 |
|
@@ -62,21 +73,21 @@ MODEL_ID = "CohereForAI/c4ai-command-r7b-12-2024"
|
|
62 |
MODELS = os.environ.get("MODELS")
|
63 |
MODEL_NAME = MODEL_ID.split("/")[-1]
|
64 |
|
65 |
-
model = None #
|
66 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
67 |
|
68 |
-
# 위키피디아 데이터셋 로드
|
69 |
wiki_dataset = load_dataset("lcw99/wikipedia-korean-20240501-1million-qna")
|
70 |
print("Wikipedia dataset loaded:", wiki_dataset)
|
71 |
|
72 |
-
# TF-IDF 벡터라이저 초기화 및 학습
|
73 |
print("TF-IDF 벡터화 시작...")
|
74 |
questions = wiki_dataset['train']['question'][:10000] # 처음 10000개만 사용
|
75 |
vectorizer = TfidfVectorizer(max_features=1000)
|
76 |
question_vectors = vectorizer.fit_transform(questions)
|
77 |
print("TF-IDF 벡터화 완료")
|
78 |
|
79 |
-
|
80 |
class ChatHistory:
|
81 |
def __init__(self):
|
82 |
self.history = []
|
@@ -132,19 +143,18 @@ class ChatHistory:
|
|
132 |
print(f"히스토리 로드 실패: {e}")
|
133 |
self.history = []
|
134 |
|
135 |
-
|
136 |
-
# 전역 ChatHistory 인스턴스 생성
|
137 |
chat_history = ChatHistory()
|
138 |
|
139 |
-
|
140 |
def find_relevant_context(query, top_k=3):
|
141 |
# 쿼리 벡터화
|
142 |
query_vector = vectorizer.transform([query])
|
143 |
-
# 코사인 유사도
|
144 |
similarities = (query_vector * question_vectors.T).toarray()[0]
|
145 |
-
#
|
146 |
top_indices = np.argsort(similarities)[-top_k:][::-1]
|
147 |
-
|
148 |
relevant_contexts = []
|
149 |
for idx in top_indices:
|
150 |
if similarities[idx] > 0:
|
@@ -155,16 +165,14 @@ def find_relevant_context(query, top_k=3):
|
|
155 |
})
|
156 |
return relevant_contexts
|
157 |
|
158 |
-
|
159 |
def init_msg():
|
160 |
return "파일을 분석하고 있습니다..."
|
161 |
|
162 |
-
|
163 |
# -------------------- PDF 파일을 Markdown으로 변환하는 유틸 함수들 --------------------
|
164 |
def extract_text_from_pdf(reader: PdfReader) -> str:
|
165 |
"""
|
166 |
PyPDF를 사용해 모든 페이지 텍스트를 추출.
|
167 |
-
만약 텍스트가 없으면 빈 문자열 반환.
|
168 |
"""
|
169 |
full_text = ""
|
170 |
for idx, page in enumerate(reader.pages):
|
@@ -173,20 +181,17 @@ def extract_text_from_pdf(reader: PdfReader) -> str:
|
|
173 |
full_text += f"---- Page {idx+1} ----\n" + text + "\n\n"
|
174 |
return full_text.strip()
|
175 |
|
176 |
-
|
177 |
def convert_pdf_to_markdown(pdf_file: str):
|
178 |
"""
|
179 |
-
PDF
|
180 |
-
이미지가 많고 텍스트가
|
181 |
-
최종적으로 Markdown 형식으로 변환 가능한 텍스트를 반환한다.
|
182 |
-
메타데이터도 함께 반환.
|
183 |
"""
|
184 |
try:
|
185 |
reader = PdfReader(pdf_file)
|
186 |
except Exception as e:
|
187 |
return f"PDF 파일을 읽는 중 오류 발생: {e}", None, None
|
188 |
|
189 |
-
#
|
190 |
raw_meta = reader.metadata
|
191 |
metadata = {
|
192 |
"author": raw_meta.author if raw_meta else None,
|
@@ -196,19 +201,16 @@ def convert_pdf_to_markdown(pdf_file: str):
|
|
196 |
"title": raw_meta.title if raw_meta else None,
|
197 |
}
|
198 |
|
199 |
-
#
|
200 |
full_text = extract_text_from_pdf(reader)
|
201 |
|
202 |
-
#
|
203 |
-
image_count =
|
204 |
-
for page in reader.pages:
|
205 |
-
image_count += len(page.images)
|
206 |
-
|
207 |
if image_count > 0 and len(full_text) < 1000:
|
208 |
try:
|
209 |
out_pdf_file = pdf_file.replace(".pdf", "_ocr.pdf")
|
210 |
ocrmypdf.ocr(pdf_file, out_pdf_file, force_ocr=True)
|
211 |
-
#
|
212 |
reader_ocr = PdfReader(out_pdf_file)
|
213 |
full_text = extract_text_from_pdf(reader_ocr)
|
214 |
except Exception as e:
|
@@ -216,11 +218,9 @@ def convert_pdf_to_markdown(pdf_file: str):
|
|
216 |
|
217 |
return full_text, metadata, pdf_file
|
218 |
|
219 |
-
|
220 |
-
# ---------------------------------------------------------------------------
|
221 |
-
|
222 |
def analyze_file_content(content, file_type):
|
223 |
-
"""
|
224 |
if file_type in ['parquet', 'csv']:
|
225 |
try:
|
226 |
lines = content.split('\n')
|
@@ -245,15 +245,13 @@ def analyze_file_content(content, file_type):
|
|
245 |
words = len(content.split())
|
246 |
return f"📝 Document Structure: {total_lines} lines, {paragraphs} paragraphs, approximately {words} words"
|
247 |
|
248 |
-
|
249 |
def read_uploaded_file(file):
|
250 |
"""
|
251 |
-
업로드된
|
252 |
-
1) 파일 타입별로 내용을 읽고
|
253 |
-
2) 분석 결과와 함께 반환
|
254 |
"""
|
255 |
if file is None:
|
256 |
return "", ""
|
|
|
257 |
try:
|
258 |
file_ext = os.path.splitext(file.name)[1].lower()
|
259 |
|
@@ -267,7 +265,8 @@ def read_uploaded_file(file):
|
|
267 |
content += f"1. Basic Information:\n"
|
268 |
content += f"- Total Rows: {len(df):,}\n"
|
269 |
content += f"- Total Columns: {len(df.columns)}\n"
|
270 |
-
|
|
|
271 |
|
272 |
content += f"2. Column Information:\n"
|
273 |
for col in df.columns:
|
@@ -279,7 +278,8 @@ def read_uploaded_file(file):
|
|
279 |
content += f"\n\n4. Missing Values:\n"
|
280 |
null_counts = df.isnull().sum()
|
281 |
for col, count in null_counts[null_counts > 0].items():
|
282 |
-
|
|
|
283 |
|
284 |
numeric_cols = df.select_dtypes(include=['int64', 'float64']).columns
|
285 |
if len(numeric_cols) > 0:
|
@@ -291,7 +291,7 @@ def read_uploaded_file(file):
|
|
291 |
except Exception as e:
|
292 |
return f"Error reading Parquet file: {str(e)}", "error"
|
293 |
|
294 |
-
# PDF
|
295 |
if file_ext == '.pdf':
|
296 |
try:
|
297 |
markdown_text, metadata, processed_pdf_path = convert_pdf_to_markdown(file.name)
|
@@ -305,7 +305,6 @@ def read_uploaded_file(file):
|
|
305 |
|
306 |
content += "## Extracted Text\n\n"
|
307 |
content += markdown_text
|
308 |
-
|
309 |
return content, "pdf"
|
310 |
except Exception as e:
|
311 |
return f"Error reading PDF file: {str(e)}", "error"
|
@@ -320,7 +319,8 @@ def read_uploaded_file(file):
|
|
320 |
content += f"1. Basic Information:\n"
|
321 |
content += f"- Total Rows: {len(df):,}\n"
|
322 |
content += f"- Total Columns: {len(df.columns)}\n"
|
323 |
-
|
|
|
324 |
|
325 |
content += f"2. Column Information:\n"
|
326 |
for col in df.columns:
|
@@ -332,14 +332,17 @@ def read_uploaded_file(file):
|
|
332 |
content += f"\n\n4. Missing Values:\n"
|
333 |
null_counts = df.isnull().sum()
|
334 |
for col, count in null_counts[null_counts > 0].items():
|
335 |
-
|
|
|
336 |
|
337 |
return content, "csv"
|
338 |
except UnicodeDecodeError:
|
339 |
continue
|
340 |
-
raise UnicodeDecodeError(
|
|
|
|
|
341 |
|
342 |
-
#
|
343 |
else:
|
344 |
encodings = ['utf-8', 'cp949', 'euc-kr', 'latin1']
|
345 |
for encoding in encodings:
|
@@ -350,15 +353,19 @@ def read_uploaded_file(file):
|
|
350 |
lines = content.split('\n')
|
351 |
total_lines = len(lines)
|
352 |
non_empty_lines = len([line for line in lines if line.strip()])
|
353 |
-
|
354 |
-
|
|
|
|
|
355 |
|
356 |
analysis = f"\n📝 File Analysis:\n"
|
357 |
if is_code:
|
358 |
-
functions =
|
359 |
-
classes =
|
360 |
-
imports =
|
361 |
-
|
|
|
|
|
362 |
analysis += f"- File Type: Code\n"
|
363 |
analysis += f"- Total Lines: {total_lines:,}\n"
|
364 |
analysis += f"- Functions: {functions}\n"
|
@@ -375,14 +382,18 @@ def read_uploaded_file(file):
|
|
375 |
analysis += f"- Character Count: {chars:,}\n"
|
376 |
|
377 |
return content + analysis, "text"
|
|
|
378 |
except UnicodeDecodeError:
|
379 |
continue
|
380 |
-
|
|
|
|
|
|
|
381 |
|
382 |
except Exception as e:
|
383 |
return f"Error reading file: {str(e)}", "error"
|
384 |
|
385 |
-
|
386 |
CSS = """
|
387 |
/* 3D 스타일 CSS */
|
388 |
:root {
|
@@ -539,22 +550,20 @@ body {
|
|
539 |
"""
|
540 |
|
541 |
def clear_cuda_memory():
|
|
|
542 |
if hasattr(torch.cuda, 'empty_cache'):
|
543 |
with torch.cuda.device('cuda'):
|
544 |
torch.cuda.empty_cache()
|
545 |
|
546 |
-
|
547 |
@spaces.GPU
|
548 |
def load_model():
|
549 |
try:
|
550 |
-
# 메모리 정리 먼저 수행
|
551 |
clear_cuda_memory()
|
552 |
-
|
553 |
loaded_model = AutoModelForCausalLM.from_pretrained(
|
554 |
MODEL_ID,
|
555 |
torch_dtype=torch.bfloat16,
|
556 |
device_map="auto",
|
557 |
-
# 낮은 메모리 사용을 위한 설정 추가
|
558 |
low_cpu_mem_usage=True,
|
559 |
)
|
560 |
return loaded_model
|
@@ -562,22 +571,8 @@ def load_model():
|
|
562 |
print(f"모델 로드 오류: {str(e)}")
|
563 |
raise
|
564 |
|
565 |
-
def _truncate_tokens_for_context(input_ids_str: str, desired_input_length: int) -> str:
|
566 |
-
"""
|
567 |
-
입력 문자열이 desired_input_length 토큰을 넘으면, 앞부분(오래된 컨텍스트)을 잘라내는 함수.
|
568 |
-
"""
|
569 |
-
tokens = input_ids_str.split()
|
570 |
-
if len(tokens) > desired_input_length:
|
571 |
-
tokens = tokens[-desired_input_length:]
|
572 |
-
return " ".join(tokens)
|
573 |
-
|
574 |
-
|
575 |
-
# build_prompt 함수: 대화 내역을 문자열로 변환
|
576 |
def build_prompt(conversation: list) -> str:
|
577 |
-
"""
|
578 |
-
conversation은 각 항목이 {"role": "user" 또는 "assistant", "content": ...} 형태의 딕셔너리 목록입니다.
|
579 |
-
이를 단순 텍스트 프롬프트로 변환합니다.
|
580 |
-
"""
|
581 |
prompt = ""
|
582 |
for msg in conversation:
|
583 |
if msg["role"] == "user":
|
@@ -587,7 +582,7 @@ def build_prompt(conversation: list) -> str:
|
|
587 |
prompt += "Assistant: "
|
588 |
return prompt
|
589 |
|
590 |
-
|
591 |
@spaces.GPU
|
592 |
def stream_chat(
|
593 |
message: str,
|
@@ -602,13 +597,14 @@ def stream_chat(
|
|
602 |
global model, current_file_context
|
603 |
|
604 |
try:
|
|
|
605 |
if model is None:
|
606 |
model = load_model()
|
607 |
|
608 |
-
print(f'
|
609 |
-
print(f'
|
610 |
|
611 |
-
# 파일 업로드 처리
|
612 |
file_context = ""
|
613 |
if uploaded_file and message == "파일을 분석하고 있습니다...":
|
614 |
current_file_context = None
|
@@ -623,23 +619,16 @@ def stream_chat(
|
|
623 |
current_file_context = file_context
|
624 |
message = "업로드된 파일을 분석해주세요."
|
625 |
except Exception as e:
|
626 |
-
print(f"파일 분석
|
627 |
file_context = f"\n\n❌ 파일 분석 중 오류가 발생했습니다: {str(e)}"
|
628 |
elif current_file_context:
|
629 |
file_context = current_file_context
|
630 |
|
631 |
-
|
632 |
-
print(f"CUDA 메모리 사용량: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
|
633 |
-
|
634 |
-
max_history_length = 10
|
635 |
-
if len(history) > max_history_length:
|
636 |
-
history = history[-max_history_length:]
|
637 |
-
|
638 |
-
# 위키피디아 컨텍스트 검색
|
639 |
wiki_context = ""
|
640 |
try:
|
641 |
relevant_contexts = find_relevant_context(message)
|
642 |
-
if relevant_contexts:
|
643 |
wiki_context = "\n\n관련 위키피디아 정보:\n"
|
644 |
for ctx in relevant_contexts:
|
645 |
wiki_context += (
|
@@ -648,9 +637,13 @@ def stream_chat(
|
|
648 |
f"유사도: {ctx['similarity']:.3f}\n\n"
|
649 |
)
|
650 |
except Exception as e:
|
651 |
-
print(f"컨텍스트 검색
|
|
|
|
|
|
|
|
|
|
|
652 |
|
653 |
-
# 대화 내역 구성
|
654 |
conversation = []
|
655 |
for prompt, answer in history:
|
656 |
conversation.extend([
|
@@ -658,7 +651,7 @@ def stream_chat(
|
|
658 |
{"role": "assistant", "content": answer}
|
659 |
])
|
660 |
|
661 |
-
# 최종 메시지
|
662 |
final_message = message
|
663 |
if file_context:
|
664 |
final_message = file_context + "\n현재 질문: " + message
|
@@ -666,53 +659,42 @@ def stream_chat(
|
|
666 |
final_message = wiki_context + "\n현재 질문: " + message
|
667 |
if file_context and wiki_context:
|
668 |
final_message = file_context + wiki_context + "\n현재 질문: " + message
|
669 |
-
|
670 |
conversation.append({"role": "user", "content": final_message})
|
671 |
|
672 |
-
#
|
673 |
input_ids_str = build_prompt(conversation)
|
674 |
-
|
675 |
-
# 먼저 컨텍스트 길이 확인 및 제한
|
676 |
max_context = 8192
|
677 |
tokenized_input = tokenizer(input_ids_str, return_tensors="pt")
|
678 |
input_length = tokenized_input["input_ids"].shape[1]
|
679 |
-
|
680 |
-
# 컨텍스트가 너무 길면 자르기
|
681 |
if input_length > max_context - max_new_tokens:
|
682 |
-
print(f"입력이 너무 깁니다: {input_length}
|
683 |
-
# 최소 생성 토큰 수 확보
|
684 |
min_generation = min(256, max_new_tokens)
|
685 |
new_desired_input_length = max_context - min_generation
|
686 |
-
|
687 |
-
# 입력 텍스트를 토큰 단위로 자르기
|
688 |
tokens = tokenizer.encode(input_ids_str)
|
689 |
if len(tokens) > new_desired_input_length:
|
690 |
tokens = tokens[-new_desired_input_length:]
|
691 |
input_ids_str = tokenizer.decode(tokens)
|
692 |
-
|
693 |
-
# 다시 토큰화
|
694 |
tokenized_input = tokenizer(input_ids_str, return_tensors="pt")
|
695 |
input_length = tokenized_input["input_ids"].shape[1]
|
696 |
-
|
697 |
-
print(f"
|
698 |
-
|
699 |
-
# CUDA로 입력 이동
|
700 |
inputs = tokenized_input.to("cuda")
|
701 |
-
|
702 |
-
# 남은 토큰
|
703 |
remaining = max_context - input_length
|
704 |
if remaining < max_new_tokens:
|
705 |
-
print(f"max_new_tokens
|
706 |
max_new_tokens = remaining
|
707 |
|
708 |
-
print(f"입력 텐서 생성 후 CUDA 메모리: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
|
709 |
-
|
710 |
# 스트리머 설정
|
711 |
streamer = TextIteratorStreamer(
|
712 |
tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=True
|
713 |
)
|
714 |
-
|
715 |
-
# 생성
|
716 |
generate_kwargs = dict(
|
717 |
**inputs,
|
718 |
streamer=streamer,
|
@@ -727,63 +709,56 @@ def stream_chat(
|
|
727 |
use_cache=True
|
728 |
)
|
729 |
|
730 |
-
# 메모리 정리
|
731 |
clear_cuda_memory()
|
732 |
|
733 |
-
# 별도 스레드에서 생성
|
734 |
thread = Thread(target=model.generate, kwargs=generate_kwargs)
|
735 |
thread.start()
|
736 |
|
737 |
-
#
|
738 |
buffer = ""
|
739 |
partial_message = ""
|
740 |
last_yield_time = time.time()
|
741 |
-
|
742 |
try:
|
743 |
for new_text in streamer:
|
744 |
-
|
745 |
-
|
746 |
-
|
747 |
-
|
748 |
-
|
749 |
-
|
750 |
-
|
751 |
-
|
752 |
-
|
753 |
-
|
754 |
-
|
755 |
-
print(f"개별 토큰 처리 중 오류: {str(inner_e)}")
|
756 |
-
continue
|
757 |
-
|
758 |
-
# 마지막 응답 확인
|
759 |
if buffer:
|
760 |
yield "", history + [[message, buffer]]
|
761 |
-
|
762 |
-
# 대화
|
763 |
chat_history.add_conversation(message, buffer)
|
764 |
-
|
765 |
except Exception as e:
|
766 |
-
print(f"스트리밍 중 오류
|
767 |
-
if not buffer: #
|
768 |
-
buffer = f"응답 생성 중
|
769 |
yield "", history + [[message, buffer]]
|
770 |
-
|
771 |
-
# 스레드가 여전히 실행 중이면 종료 대기
|
772 |
if thread.is_alive():
|
773 |
thread.join(timeout=5.0)
|
774 |
-
|
775 |
-
# 메모리 정리
|
776 |
clear_cuda_memory()
|
777 |
|
778 |
except Exception as e:
|
779 |
import traceback
|
780 |
error_details = traceback.format_exc()
|
781 |
error_message = f"오류가 발생했습니다: {str(e)}\n{error_details}"
|
782 |
-
print(f"Stream chat
|
783 |
clear_cuda_memory()
|
784 |
yield "", history + [[message, error_message]]
|
785 |
|
786 |
-
|
787 |
def create_demo():
|
788 |
with gr.Blocks(css=CSS) as demo:
|
789 |
with gr.Column(elem_classes="markdown-style"):
|
@@ -834,6 +809,7 @@ def create_demo():
|
|
834 |
scale=1
|
835 |
)
|
836 |
|
|
|
837 |
with gr.Accordion("🎮 Advanced Settings", open=False):
|
838 |
with gr.Row():
|
839 |
with gr.Column(scale=1):
|
@@ -859,6 +835,7 @@ def create_demo():
|
|
859 |
label="Repetition Penalty 🔄"
|
860 |
)
|
861 |
|
|
|
862 |
gr.Examples(
|
863 |
examples=[
|
864 |
["Please analyze this code and suggest improvements:\ndef fibonacci(n):\n if n <= 1: return n\n return fibonacci(n-1) + fibonacci(n-2)"],
|
@@ -869,23 +846,25 @@ def create_demo():
|
|
869 |
inputs=msg
|
870 |
)
|
871 |
|
|
|
872 |
def clear_conversation():
|
873 |
global current_file_context
|
874 |
current_file_context = None
|
875 |
return [], None, "Start a new conversation..."
|
876 |
|
|
|
877 |
msg.submit(
|
878 |
stream_chat,
|
879 |
inputs=[msg, chatbot, file_upload, temperature, max_new_tokens, top_p, top_k, penalty],
|
880 |
outputs=[msg, chatbot]
|
881 |
)
|
882 |
-
|
883 |
send.click(
|
884 |
stream_chat,
|
885 |
inputs=[msg, chatbot, file_upload, temperature, max_new_tokens, top_p, top_k, penalty],
|
886 |
outputs=[msg, chatbot]
|
887 |
)
|
888 |
|
|
|
889 |
file_upload.change(
|
890 |
fn=lambda: ("처리 중...", [["시스템", "파일을 분석 중입니다. 잠시만 기다려주세요..."]]),
|
891 |
outputs=[msg, chatbot],
|
@@ -901,6 +880,7 @@ def create_demo():
|
|
901 |
queue=True
|
902 |
)
|
903 |
|
|
|
904 |
clear.click(
|
905 |
fn=clear_conversation,
|
906 |
outputs=[chatbot, file_upload, msg],
|
@@ -909,7 +889,7 @@ def create_demo():
|
|
909 |
|
910 |
return demo
|
911 |
|
912 |
-
|
913 |
if __name__ == "__main__":
|
914 |
demo = create_demo()
|
915 |
-
demo.launch()
|
|
|
1 |
import os
|
2 |
+
|
3 |
+
# 1) Dynamo 완전 비활성화
|
4 |
os.environ["TORCH_DYNAMO_DISABLE"] = "1"
|
5 |
|
6 |
+
# 2) Triton의 cudagraphs 최적화 비활성화
|
7 |
+
os.environ["TRITON_DISABLE_CUDAGRAPHS"] = "1"
|
8 |
+
|
9 |
+
# 3) 경고 무시 설정 (skipping cudagraphs 관련)
|
10 |
+
import warnings
|
11 |
+
warnings.filterwarnings("ignore", message="skipping cudagraphs due to mutated inputs")
|
12 |
+
warnings.filterwarnings("ignore", message="Not enough SMs to use max_autotune_gemm mode")
|
13 |
+
|
14 |
import torch
|
15 |
+
# TensorFloat32 연산 활성화 (성능 최적화)
|
16 |
torch.set_float32_matmul_precision('high')
|
17 |
+
|
18 |
+
# TorchInductor cudagraphs 비활성화
|
19 |
import torch._inductor
|
20 |
torch._inductor.config.triton.cudagraphs = False
|
21 |
+
|
22 |
+
# Dynamo suppress_errors 옵션 (오류 시 eager로 fallback)
|
23 |
import torch._dynamo
|
24 |
+
torch._dynamo.config.suppress_errors = True
|
25 |
+
|
26 |
import gradio as gr
|
27 |
import spaces
|
28 |
+
|
29 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
30 |
|
31 |
from threading import Thread
|
|
|
40 |
import pyarrow.parquet as pq
|
41 |
import pypdf
|
42 |
import io
|
|
|
|
|
43 |
import platform
|
44 |
import subprocess
|
45 |
import pytesseract
|
46 |
from pdf2image import convert_from_path
|
47 |
+
import queue # queue.Empty 예외 처리를 위해
|
48 |
+
import time # 스트리밍 타이밍을 위해
|
49 |
|
50 |
+
# -------------------- PDF to Markdown 변환 관련 import --------------------
|
51 |
try:
|
52 |
import re
|
53 |
import requests
|
|
|
64 |
)
|
65 |
# ---------------------------------------------------------------------------
|
66 |
|
|
|
|
|
|
|
67 |
# 전역 변수
|
68 |
current_file_context = None
|
69 |
|
|
|
73 |
MODELS = os.environ.get("MODELS")
|
74 |
MODEL_NAME = MODEL_ID.split("/")[-1]
|
75 |
|
76 |
+
model = None # 전역에서 관리
|
77 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
78 |
|
79 |
+
# (1) 위키피디아 데이터셋 로드
|
80 |
wiki_dataset = load_dataset("lcw99/wikipedia-korean-20240501-1million-qna")
|
81 |
print("Wikipedia dataset loaded:", wiki_dataset)
|
82 |
|
83 |
+
# (2) TF-IDF 벡터라이저 초기화 및 학습
|
84 |
print("TF-IDF 벡터화 시작...")
|
85 |
questions = wiki_dataset['train']['question'][:10000] # 처음 10000개만 사용
|
86 |
vectorizer = TfidfVectorizer(max_features=1000)
|
87 |
question_vectors = vectorizer.fit_transform(questions)
|
88 |
print("TF-IDF 벡터화 완료")
|
89 |
|
90 |
+
# ------------------------- ChatHistory 클래스 -------------------------
|
91 |
class ChatHistory:
|
92 |
def __init__(self):
|
93 |
self.history = []
|
|
|
143 |
print(f"히스토리 로드 실패: {e}")
|
144 |
self.history = []
|
145 |
|
146 |
+
# 전역 ChatHistory 인스턴스
|
|
|
147 |
chat_history = ChatHistory()
|
148 |
|
149 |
+
# ------------------------- 위키 문서 검색 (TF-IDF) -------------------------
|
150 |
def find_relevant_context(query, top_k=3):
|
151 |
# 쿼리 벡터화
|
152 |
query_vector = vectorizer.transform([query])
|
153 |
+
# 코사인 유사도
|
154 |
similarities = (query_vector * question_vectors.T).toarray()[0]
|
155 |
+
# 유사도 높은 질문 인덱스
|
156 |
top_indices = np.argsort(similarities)[-top_k:][::-1]
|
157 |
+
|
158 |
relevant_contexts = []
|
159 |
for idx in top_indices:
|
160 |
if similarities[idx] > 0:
|
|
|
165 |
})
|
166 |
return relevant_contexts
|
167 |
|
168 |
+
# 파일 업로드 시 표시할 초기 메시지
|
169 |
def init_msg():
|
170 |
return "파일을 분석하고 있습니다..."
|
171 |
|
|
|
172 |
# -------------------- PDF 파일을 Markdown으로 변환하는 유틸 함수들 --------------------
|
173 |
def extract_text_from_pdf(reader: PdfReader) -> str:
|
174 |
"""
|
175 |
PyPDF를 사용해 모든 페이지 텍스트를 추출.
|
|
|
176 |
"""
|
177 |
full_text = ""
|
178 |
for idx, page in enumerate(reader.pages):
|
|
|
181 |
full_text += f"---- Page {idx+1} ----\n" + text + "\n\n"
|
182 |
return full_text.strip()
|
183 |
|
|
|
184 |
def convert_pdf_to_markdown(pdf_file: str):
|
185 |
"""
|
186 |
+
PDF 파일에서 텍스트를 추출하고,
|
187 |
+
이미지가 많고 텍스트가 적으면 OCR 시도
|
|
|
|
|
188 |
"""
|
189 |
try:
|
190 |
reader = PdfReader(pdf_file)
|
191 |
except Exception as e:
|
192 |
return f"PDF 파일을 읽는 중 오류 발생: {e}", None, None
|
193 |
|
194 |
+
# 메타데이터 추출
|
195 |
raw_meta = reader.metadata
|
196 |
metadata = {
|
197 |
"author": raw_meta.author if raw_meta else None,
|
|
|
201 |
"title": raw_meta.title if raw_meta else None,
|
202 |
}
|
203 |
|
204 |
+
# 텍스트 추출
|
205 |
full_text = extract_text_from_pdf(reader)
|
206 |
|
207 |
+
# 이미지-텍스트 비율 판단 후 OCR 시도
|
208 |
+
image_count = sum(len(page.images) for page in reader.pages)
|
|
|
|
|
|
|
209 |
if image_count > 0 and len(full_text) < 1000:
|
210 |
try:
|
211 |
out_pdf_file = pdf_file.replace(".pdf", "_ocr.pdf")
|
212 |
ocrmypdf.ocr(pdf_file, out_pdf_file, force_ocr=True)
|
213 |
+
# OCR된 PDF 다시 읽기
|
214 |
reader_ocr = PdfReader(out_pdf_file)
|
215 |
full_text = extract_text_from_pdf(reader_ocr)
|
216 |
except Exception as e:
|
|
|
218 |
|
219 |
return full_text, metadata, pdf_file
|
220 |
|
221 |
+
# ------------------------- 파일 분석 함수 -------------------------
|
|
|
|
|
222 |
def analyze_file_content(content, file_type):
|
223 |
+
"""간단한 구조 분석/요약."""
|
224 |
if file_type in ['parquet', 'csv']:
|
225 |
try:
|
226 |
lines = content.split('\n')
|
|
|
245 |
words = len(content.split())
|
246 |
return f"📝 Document Structure: {total_lines} lines, {paragraphs} paragraphs, approximately {words} words"
|
247 |
|
|
|
248 |
def read_uploaded_file(file):
|
249 |
"""
|
250 |
+
업로드된 파일 처리 -> 내용/타입
|
|
|
|
|
251 |
"""
|
252 |
if file is None:
|
253 |
return "", ""
|
254 |
+
|
255 |
try:
|
256 |
file_ext = os.path.splitext(file.name)[1].lower()
|
257 |
|
|
|
265 |
content += f"1. Basic Information:\n"
|
266 |
content += f"- Total Rows: {len(df):,}\n"
|
267 |
content += f"- Total Columns: {len(df.columns)}\n"
|
268 |
+
mem_usage = df.memory_usage(deep=True).sum() / 1024 / 1024
|
269 |
+
content += f"- Memory Usage: {mem_usage:.2f} MB\n\n"
|
270 |
|
271 |
content += f"2. Column Information:\n"
|
272 |
for col in df.columns:
|
|
|
278 |
content += f"\n\n4. Missing Values:\n"
|
279 |
null_counts = df.isnull().sum()
|
280 |
for col, count in null_counts[null_counts > 0].items():
|
281 |
+
rate = count / len(df) * 100
|
282 |
+
content += f"- {col}: {count:,} ({rate:.1f}%)\n"
|
283 |
|
284 |
numeric_cols = df.select_dtypes(include=['int64', 'float64']).columns
|
285 |
if len(numeric_cols) > 0:
|
|
|
291 |
except Exception as e:
|
292 |
return f"Error reading Parquet file: {str(e)}", "error"
|
293 |
|
294 |
+
# PDF
|
295 |
if file_ext == '.pdf':
|
296 |
try:
|
297 |
markdown_text, metadata, processed_pdf_path = convert_pdf_to_markdown(file.name)
|
|
|
305 |
|
306 |
content += "## Extracted Text\n\n"
|
307 |
content += markdown_text
|
|
|
308 |
return content, "pdf"
|
309 |
except Exception as e:
|
310 |
return f"Error reading PDF file: {str(e)}", "error"
|
|
|
319 |
content += f"1. Basic Information:\n"
|
320 |
content += f"- Total Rows: {len(df):,}\n"
|
321 |
content += f"- Total Columns: {len(df.columns)}\n"
|
322 |
+
mem_usage = df.memory_usage(deep=True).sum() / 1024 / 1024
|
323 |
+
content += f"- Memory Usage: {mem_usage:.2f} MB\n\n"
|
324 |
|
325 |
content += f"2. Column Information:\n"
|
326 |
for col in df.columns:
|
|
|
332 |
content += f"\n\n4. Missing Values:\n"
|
333 |
null_counts = df.isnull().sum()
|
334 |
for col, count in null_counts[null_counts > 0].items():
|
335 |
+
rate = count / len(df) * 100
|
336 |
+
content += f"- {col}: {count:,} ({rate:.1f}%)\n"
|
337 |
|
338 |
return content, "csv"
|
339 |
except UnicodeDecodeError:
|
340 |
continue
|
341 |
+
raise UnicodeDecodeError(
|
342 |
+
f"Unable to read file with supported encodings ({', '.join(encodings)})"
|
343 |
+
)
|
344 |
|
345 |
+
# 텍스트 파일
|
346 |
else:
|
347 |
encodings = ['utf-8', 'cp949', 'euc-kr', 'latin1']
|
348 |
for encoding in encodings:
|
|
|
353 |
lines = content.split('\n')
|
354 |
total_lines = len(lines)
|
355 |
non_empty_lines = len([line for line in lines if line.strip()])
|
356 |
+
is_code = any(
|
357 |
+
keyword in content.lower()
|
358 |
+
for keyword in ['def ', 'class ', 'import ', 'function']
|
359 |
+
)
|
360 |
|
361 |
analysis = f"\n📝 File Analysis:\n"
|
362 |
if is_code:
|
363 |
+
functions = sum('def ' in line for line in lines)
|
364 |
+
classes = sum('class ' in line for line in lines)
|
365 |
+
imports = sum(
|
366 |
+
('import ' in line) or ('from ' in line)
|
367 |
+
for line in lines
|
368 |
+
)
|
369 |
analysis += f"- File Type: Code\n"
|
370 |
analysis += f"- Total Lines: {total_lines:,}\n"
|
371 |
analysis += f"- Functions: {functions}\n"
|
|
|
382 |
analysis += f"- Character Count: {chars:,}\n"
|
383 |
|
384 |
return content + analysis, "text"
|
385 |
+
|
386 |
except UnicodeDecodeError:
|
387 |
continue
|
388 |
+
|
389 |
+
raise UnicodeDecodeError(
|
390 |
+
f"Unable to read file with supported encodings ({', '.join(encodings)})"
|
391 |
+
)
|
392 |
|
393 |
except Exception as e:
|
394 |
return f"Error reading file: {str(e)}", "error"
|
395 |
|
396 |
+
# ------------------------- CSS -------------------------
|
397 |
CSS = """
|
398 |
/* 3D 스타일 CSS */
|
399 |
:root {
|
|
|
550 |
"""
|
551 |
|
552 |
def clear_cuda_memory():
|
553 |
+
"""CUDA 캐시 정리."""
|
554 |
if hasattr(torch.cuda, 'empty_cache'):
|
555 |
with torch.cuda.device('cuda'):
|
556 |
torch.cuda.empty_cache()
|
557 |
|
558 |
+
# ------------------------- 모델 로딩 함수 -------------------------
|
559 |
@spaces.GPU
|
560 |
def load_model():
|
561 |
try:
|
|
|
562 |
clear_cuda_memory()
|
|
|
563 |
loaded_model = AutoModelForCausalLM.from_pretrained(
|
564 |
MODEL_ID,
|
565 |
torch_dtype=torch.bfloat16,
|
566 |
device_map="auto",
|
|
|
567 |
low_cpu_mem_usage=True,
|
568 |
)
|
569 |
return loaded_model
|
|
|
571 |
print(f"모델 로드 오류: {str(e)}")
|
572 |
raise
|
573 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
574 |
def build_prompt(conversation: list) -> str:
|
575 |
+
"""대화 내역을 단순 텍스트 프롬프트로 변환."""
|
|
|
|
|
|
|
576 |
prompt = ""
|
577 |
for msg in conversation:
|
578 |
if msg["role"] == "user":
|
|
|
582 |
prompt += "Assistant: "
|
583 |
return prompt
|
584 |
|
585 |
+
# ------------------------- 메시지 스트리밍 함수 -------------------------
|
586 |
@spaces.GPU
|
587 |
def stream_chat(
|
588 |
message: str,
|
|
|
597 |
global model, current_file_context
|
598 |
|
599 |
try:
|
600 |
+
# 모델 미로드시 로딩
|
601 |
if model is None:
|
602 |
model = load_model()
|
603 |
|
604 |
+
print(f'[User input] message: {message}')
|
605 |
+
print(f'[History] {history}')
|
606 |
|
607 |
+
# (1) 파일 업로드 처리
|
608 |
file_context = ""
|
609 |
if uploaded_file and message == "파일을 분석하고 있습니다...":
|
610 |
current_file_context = None
|
|
|
619 |
current_file_context = file_context
|
620 |
message = "업로드된 파일을 분석해주세요."
|
621 |
except Exception as e:
|
622 |
+
print(f"[파일 분석 오류] {str(e)}")
|
623 |
file_context = f"\n\n❌ 파일 분석 중 오류가 발생했습니다: {str(e)}"
|
624 |
elif current_file_context:
|
625 |
file_context = current_file_context
|
626 |
|
627 |
+
# (2) TF-IDF 기반 관련 문서 탐색
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
628 |
wiki_context = ""
|
629 |
try:
|
630 |
relevant_contexts = find_relevant_context(message)
|
631 |
+
if relevant_contexts:
|
632 |
wiki_context = "\n\n관련 위키피디아 정보:\n"
|
633 |
for ctx in relevant_contexts:
|
634 |
wiki_context += (
|
|
|
637 |
f"유사도: {ctx['similarity']:.3f}\n\n"
|
638 |
)
|
639 |
except Exception as e:
|
640 |
+
print(f"[컨텍스트 검색 오류] {str(e)}")
|
641 |
+
|
642 |
+
# (3) 대화 이력 구성
|
643 |
+
max_history_length = 10
|
644 |
+
if len(history) > max_history_length:
|
645 |
+
history = history[-max_history_length:]
|
646 |
|
|
|
647 |
conversation = []
|
648 |
for prompt, answer in history:
|
649 |
conversation.extend([
|
|
|
651 |
{"role": "assistant", "content": answer}
|
652 |
])
|
653 |
|
654 |
+
# (4) 최종 메시지 결정
|
655 |
final_message = message
|
656 |
if file_context:
|
657 |
final_message = file_context + "\n현재 질문: " + message
|
|
|
659 |
final_message = wiki_context + "\n현재 질문: " + message
|
660 |
if file_context and wiki_context:
|
661 |
final_message = file_context + wiki_context + "\n현재 질문: " + message
|
662 |
+
|
663 |
conversation.append({"role": "user", "content": final_message})
|
664 |
|
665 |
+
# (5) 토큰화 및 프롬프트 구축
|
666 |
input_ids_str = build_prompt(conversation)
|
|
|
|
|
667 |
max_context = 8192
|
668 |
tokenized_input = tokenizer(input_ids_str, return_tensors="pt")
|
669 |
input_length = tokenized_input["input_ids"].shape[1]
|
670 |
+
|
671 |
+
# (6) 컨텍스트가 너무 길면 앞부분 토큰 자르기
|
672 |
if input_length > max_context - max_new_tokens:
|
673 |
+
print(f"[경고] 입력이 너무 깁니다: {input_length} 토큰 -> 잘라냄.")
|
|
|
674 |
min_generation = min(256, max_new_tokens)
|
675 |
new_desired_input_length = max_context - min_generation
|
|
|
|
|
676 |
tokens = tokenizer.encode(input_ids_str)
|
677 |
if len(tokens) > new_desired_input_length:
|
678 |
tokens = tokens[-new_desired_input_length:]
|
679 |
input_ids_str = tokenizer.decode(tokens)
|
|
|
|
|
680 |
tokenized_input = tokenizer(input_ids_str, return_tensors="pt")
|
681 |
input_length = tokenized_input["input_ids"].shape[1]
|
682 |
+
|
683 |
+
print(f"[토큰 길이] {input_length}")
|
|
|
|
|
684 |
inputs = tokenized_input.to("cuda")
|
685 |
+
|
686 |
+
# 남은 토큰 수로 max_new_tokens 조정
|
687 |
remaining = max_context - input_length
|
688 |
if remaining < max_new_tokens:
|
689 |
+
print(f"[max_new_tokens 조정] {max_new_tokens} -> {remaining}")
|
690 |
max_new_tokens = remaining
|
691 |
|
|
|
|
|
692 |
# 스트리머 설정
|
693 |
streamer = TextIteratorStreamer(
|
694 |
tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=True
|
695 |
)
|
696 |
+
|
697 |
+
# (7) 생성 파라미터
|
698 |
generate_kwargs = dict(
|
699 |
**inputs,
|
700 |
streamer=streamer,
|
|
|
709 |
use_cache=True
|
710 |
)
|
711 |
|
|
|
712 |
clear_cuda_memory()
|
713 |
|
714 |
+
# (8) 별도 스레드에서 생성
|
715 |
thread = Thread(target=model.generate, kwargs=generate_kwargs)
|
716 |
thread.start()
|
717 |
|
718 |
+
# (9) 스트리밍 응답
|
719 |
buffer = ""
|
720 |
partial_message = ""
|
721 |
last_yield_time = time.time()
|
722 |
+
|
723 |
try:
|
724 |
for new_text in streamer:
|
725 |
+
buffer += new_text
|
726 |
+
partial_message += new_text
|
727 |
+
|
728 |
+
# 일정 시간 또는 버퍼 길이 기준으로 yield
|
729 |
+
current_time = time.time()
|
730 |
+
if (current_time - last_yield_time > 0.1) or (len(partial_message) > 20):
|
731 |
+
yield "", history + [[message, buffer]]
|
732 |
+
partial_message = ""
|
733 |
+
last_yield_time = current_time
|
734 |
+
|
735 |
+
# 마지막 완성된 응답
|
|
|
|
|
|
|
|
|
736 |
if buffer:
|
737 |
yield "", history + [[message, buffer]]
|
738 |
+
|
739 |
+
# 대화 내용 저장
|
740 |
chat_history.add_conversation(message, buffer)
|
741 |
+
|
742 |
except Exception as e:
|
743 |
+
print(f"[스트리밍 중 오류] {str(e)}")
|
744 |
+
if not buffer: # buffer가 비어있다면 오류메시지 대화창 표시
|
745 |
+
buffer = f"응답 생성 중 오류 발생: {str(e)}"
|
746 |
yield "", history + [[message, buffer]]
|
747 |
+
|
|
|
748 |
if thread.is_alive():
|
749 |
thread.join(timeout=5.0)
|
750 |
+
|
|
|
751 |
clear_cuda_memory()
|
752 |
|
753 |
except Exception as e:
|
754 |
import traceback
|
755 |
error_details = traceback.format_exc()
|
756 |
error_message = f"오류가 발생했습니다: {str(e)}\n{error_details}"
|
757 |
+
print(f"[Stream chat 오류] {error_message}")
|
758 |
clear_cuda_memory()
|
759 |
yield "", history + [[message, error_message]]
|
760 |
|
761 |
+
# ------------------------- Gradio UI 구성 -------------------------
|
762 |
def create_demo():
|
763 |
with gr.Blocks(css=CSS) as demo:
|
764 |
with gr.Column(elem_classes="markdown-style"):
|
|
|
809 |
scale=1
|
810 |
)
|
811 |
|
812 |
+
# 고급 설정
|
813 |
with gr.Accordion("🎮 Advanced Settings", open=False):
|
814 |
with gr.Row():
|
815 |
with gr.Column(scale=1):
|
|
|
835 |
label="Repetition Penalty 🔄"
|
836 |
)
|
837 |
|
838 |
+
# 예시
|
839 |
gr.Examples(
|
840 |
examples=[
|
841 |
["Please analyze this code and suggest improvements:\ndef fibonacci(n):\n if n <= 1: return n\n return fibonacci(n-1) + fibonacci(n-2)"],
|
|
|
846 |
inputs=msg
|
847 |
)
|
848 |
|
849 |
+
# 대화 내용 초기화
|
850 |
def clear_conversation():
|
851 |
global current_file_context
|
852 |
current_file_context = None
|
853 |
return [], None, "Start a new conversation..."
|
854 |
|
855 |
+
# 메시지 전송
|
856 |
msg.submit(
|
857 |
stream_chat,
|
858 |
inputs=[msg, chatbot, file_upload, temperature, max_new_tokens, top_p, top_k, penalty],
|
859 |
outputs=[msg, chatbot]
|
860 |
)
|
|
|
861 |
send.click(
|
862 |
stream_chat,
|
863 |
inputs=[msg, chatbot, file_upload, temperature, max_new_tokens, top_p, top_k, penalty],
|
864 |
outputs=[msg, chatbot]
|
865 |
)
|
866 |
|
867 |
+
# 파일 업로드 이벤트
|
868 |
file_upload.change(
|
869 |
fn=lambda: ("처리 중...", [["시스템", "파일을 분석 중입니다. 잠시만 기다려주세요..."]]),
|
870 |
outputs=[msg, chatbot],
|
|
|
880 |
queue=True
|
881 |
)
|
882 |
|
883 |
+
# Clear 버튼
|
884 |
clear.click(
|
885 |
fn=clear_conversation,
|
886 |
outputs=[chatbot, file_upload, msg],
|
|
|
889 |
|
890 |
return demo
|
891 |
|
892 |
+
# 메인 실행
|
893 |
if __name__ == "__main__":
|
894 |
demo = create_demo()
|
895 |
+
demo.launch()
|