Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -6,7 +6,7 @@ os.environ["TORCH_DYNAMO_DISABLE"] = "1"
|
|
6 |
# 2) Triton์ cudagraphs ์ต์ ํ ๋นํ์ฑํ
|
7 |
os.environ["TRITON_DISABLE_CUDAGRAPHS"] = "1"
|
8 |
|
9 |
-
#
|
10 |
import warnings
|
11 |
warnings.filterwarnings("ignore", message="skipping cudagraphs due to mutated inputs")
|
12 |
warnings.filterwarnings("ignore", message="Not enough SMs to use max_autotune_gemm mode")
|
@@ -15,26 +15,22 @@ import torch
|
|
15 |
# TensorFloat32 ์ฐ์ฐ ํ์ฑํ (์ฑ๋ฅ ์ต์ ํ)
|
16 |
torch.set_float32_matmul_precision('high')
|
17 |
|
18 |
-
# TorchInductor cudagraphs ๋นํ์ฑํ
|
19 |
import torch._inductor
|
20 |
torch._inductor.config.triton.cudagraphs = False
|
21 |
|
22 |
-
# Dynamo suppress_errors ์ต์
(์ค๋ฅ ์ eager๋ก fallback)
|
23 |
import torch._dynamo
|
|
|
24 |
torch._dynamo.config.suppress_errors = True
|
25 |
|
26 |
import gradio as gr
|
27 |
import spaces
|
28 |
-
|
29 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
30 |
|
31 |
from threading import Thread
|
32 |
-
import random
|
33 |
from datasets import load_dataset
|
34 |
import numpy as np
|
35 |
from sklearn.feature_extraction.text import TfidfVectorizer
|
36 |
import pandas as pd
|
37 |
-
from typing import List, Tuple
|
38 |
import json
|
39 |
from datetime import datetime
|
40 |
import pyarrow.parquet as pq
|
@@ -44,8 +40,8 @@ import platform
|
|
44 |
import subprocess
|
45 |
import pytesseract
|
46 |
from pdf2image import convert_from_path
|
47 |
-
import queue
|
48 |
-
import time
|
49 |
|
50 |
# -------------------- PDF to Markdown ๋ณํ ๊ด๋ จ import --------------------
|
51 |
try:
|
@@ -70,7 +66,6 @@ current_file_context = None
|
|
70 |
# ํ๊ฒฝ ๋ณ์ ์ค์
|
71 |
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
72 |
MODEL_ID = "CohereForAI/c4ai-command-r7b-12-2024"
|
73 |
-
MODELS = os.environ.get("MODELS")
|
74 |
MODEL_NAME = MODEL_ID.split("/")[-1]
|
75 |
|
76 |
model = None # ์ ์ญ์์ ๊ด๋ฆฌ
|
@@ -80,9 +75,9 @@ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
|
80 |
wiki_dataset = load_dataset("lcw99/wikipedia-korean-20240501-1million-qna")
|
81 |
print("Wikipedia dataset loaded:", wiki_dataset)
|
82 |
|
83 |
-
# (2) TF-IDF ๋ฒกํฐ๋ผ์ด์ ์ด๊ธฐํ ๋ฐ ํ์ต
|
84 |
print("TF-IDF ๋ฒกํฐํ ์์...")
|
85 |
-
questions = wiki_dataset['train']['question'][:10000]
|
86 |
vectorizer = TfidfVectorizer(max_features=1000)
|
87 |
question_vectors = vectorizer.fit_transform(questions)
|
88 |
print("TF-IDF ๋ฒกํฐํ ์๋ฃ")
|
@@ -143,16 +138,12 @@ class ChatHistory:
|
|
143 |
print(f"ํ์คํ ๋ฆฌ ๋ก๋ ์คํจ: {e}")
|
144 |
self.history = []
|
145 |
|
146 |
-
# ์ ์ญ ChatHistory ์ธ์คํด์ค
|
147 |
chat_history = ChatHistory()
|
148 |
|
149 |
# ------------------------- ์ํค ๋ฌธ์ ๊ฒ์ (TF-IDF) -------------------------
|
150 |
def find_relevant_context(query, top_k=3):
|
151 |
-
# ์ฟผ๋ฆฌ ๋ฒกํฐํ
|
152 |
query_vector = vectorizer.transform([query])
|
153 |
-
# ์ฝ์ฌ์ธ ์ ์ฌ๋
|
154 |
similarities = (query_vector * question_vectors.T).toarray()[0]
|
155 |
-
# ์ ์ฌ๋ ๋์ ์ง๋ฌธ ์ธ๋ฑ์ค
|
156 |
top_indices = np.argsort(similarities)[-top_k:][::-1]
|
157 |
|
158 |
relevant_contexts = []
|
@@ -165,15 +156,11 @@ def find_relevant_context(query, top_k=3):
|
|
165 |
})
|
166 |
return relevant_contexts
|
167 |
|
168 |
-
# ํ์ผ ์
๋ก๋ ์ ํ์ํ ์ด๊ธฐ ๋ฉ์์ง
|
169 |
def init_msg():
|
170 |
return "ํ์ผ์ ๋ถ์ํ๊ณ ์์ต๋๋ค..."
|
171 |
|
172 |
# -------------------- PDF ํ์ผ์ Markdown์ผ๋ก ๋ณํํ๋ ์ ํธ ํจ์๋ค --------------------
|
173 |
def extract_text_from_pdf(reader: PdfReader) -> str:
|
174 |
-
"""
|
175 |
-
PyPDF๋ฅผ ์ฌ์ฉํด ๋ชจ๋ ํ์ด์ง ํ
์คํธ๋ฅผ ์ถ์ถ.
|
176 |
-
"""
|
177 |
full_text = ""
|
178 |
for idx, page in enumerate(reader.pages):
|
179 |
text = page.extract_text() or ""
|
@@ -182,16 +169,11 @@ def extract_text_from_pdf(reader: PdfReader) -> str:
|
|
182 |
return full_text.strip()
|
183 |
|
184 |
def convert_pdf_to_markdown(pdf_file: str):
|
185 |
-
"""
|
186 |
-
PDF ํ์ผ์์ ํ
์คํธ๋ฅผ ์ถ์ถํ๊ณ ,
|
187 |
-
์ด๋ฏธ์ง๊ฐ ๋ง๊ณ ํ
์คํธ๊ฐ ์ ์ผ๋ฉด OCR ์๋
|
188 |
-
"""
|
189 |
try:
|
190 |
reader = PdfReader(pdf_file)
|
191 |
except Exception as e:
|
192 |
return f"PDF ํ์ผ์ ์ฝ๋ ์ค ์ค๋ฅ ๋ฐ์: {e}", None, None
|
193 |
|
194 |
-
# ๋ฉํ๋ฐ์ดํฐ ์ถ์ถ
|
195 |
raw_meta = reader.metadata
|
196 |
metadata = {
|
197 |
"author": raw_meta.author if raw_meta else None,
|
@@ -201,16 +183,13 @@ def convert_pdf_to_markdown(pdf_file: str):
|
|
201 |
"title": raw_meta.title if raw_meta else None,
|
202 |
}
|
203 |
|
204 |
-
# ํ
์คํธ ์ถ์ถ
|
205 |
full_text = extract_text_from_pdf(reader)
|
206 |
|
207 |
-
# ์ด๋ฏธ์ง-ํ
์คํธ ๋น์จ ํ๋จ ํ OCR ์๋
|
208 |
image_count = sum(len(page.images) for page in reader.pages)
|
209 |
if image_count > 0 and len(full_text) < 1000:
|
210 |
try:
|
211 |
out_pdf_file = pdf_file.replace(".pdf", "_ocr.pdf")
|
212 |
ocrmypdf.ocr(pdf_file, out_pdf_file, force_ocr=True)
|
213 |
-
# OCR๋ PDF ๋ค์ ์ฝ๊ธฐ
|
214 |
reader_ocr = PdfReader(out_pdf_file)
|
215 |
full_text = extract_text_from_pdf(reader_ocr)
|
216 |
except Exception as e:
|
@@ -220,7 +199,6 @@ def convert_pdf_to_markdown(pdf_file: str):
|
|
220 |
|
221 |
# ------------------------- ํ์ผ ๋ถ์ ํจ์ -------------------------
|
222 |
def analyze_file_content(content, file_type):
|
223 |
-
"""๊ฐ๋จํ ๊ตฌ์กฐ ๋ถ์/์์ฝ."""
|
224 |
if file_type in ['parquet', 'csv']:
|
225 |
try:
|
226 |
lines = content.split('\n')
|
@@ -246,16 +224,16 @@ def analyze_file_content(content, file_type):
|
|
246 |
return f"๐ Document Structure: {total_lines} lines, {paragraphs} paragraphs, approximately {words} words"
|
247 |
|
248 |
def read_uploaded_file(file):
|
249 |
-
"""
|
250 |
-
์
๋ก๋๋ ํ์ผ ์ฒ๋ฆฌ -> ๋ด์ฉ/ํ์
|
251 |
-
"""
|
252 |
if file is None:
|
253 |
return "", ""
|
254 |
|
|
|
|
|
|
|
|
|
255 |
try:
|
256 |
file_ext = os.path.splitext(file.name)[1].lower()
|
257 |
|
258 |
-
# Parquet
|
259 |
if file_ext == '.parquet':
|
260 |
try:
|
261 |
table = pq.read_table(file.name)
|
@@ -291,8 +269,7 @@ def read_uploaded_file(file):
|
|
291 |
except Exception as e:
|
292 |
return f"Error reading Parquet file: {str(e)}", "error"
|
293 |
|
294 |
-
|
295 |
-
if file_ext == '.pdf':
|
296 |
try:
|
297 |
markdown_text, metadata, processed_pdf_path = convert_pdf_to_markdown(file.name)
|
298 |
if metadata is None:
|
@@ -302,14 +279,13 @@ def read_uploaded_file(file):
|
|
302 |
content += "## Metadata\n"
|
303 |
for k, v in metadata.items():
|
304 |
content += f"**{k.capitalize()}**: {v}\n\n"
|
305 |
-
|
306 |
content += "## Extracted Text\n\n"
|
307 |
content += markdown_text
|
|
|
308 |
return content, "pdf"
|
309 |
except Exception as e:
|
310 |
return f"Error reading PDF file: {str(e)}", "error"
|
311 |
|
312 |
-
# CSV
|
313 |
elif file_ext == '.csv':
|
314 |
encodings = ['utf-8', 'cp949', 'euc-kr', 'latin1']
|
315 |
for encoding in encodings:
|
@@ -342,7 +318,6 @@ def read_uploaded_file(file):
|
|
342 |
f"Unable to read file with supported encodings ({', '.join(encodings)})"
|
343 |
)
|
344 |
|
345 |
-
# ํ
์คํธ ํ์ผ
|
346 |
else:
|
347 |
encodings = ['utf-8', 'cp949', 'euc-kr', 'latin1']
|
348 |
for encoding in encodings:
|
@@ -358,7 +333,7 @@ def read_uploaded_file(file):
|
|
358 |
for keyword in ['def ', 'class ', 'import ', 'function']
|
359 |
)
|
360 |
|
361 |
-
analysis =
|
362 |
if is_code:
|
363 |
functions = sum('def ' in line for line in lines)
|
364 |
classes = sum('class ' in line for line in lines)
|
@@ -374,7 +349,6 @@ def read_uploaded_file(file):
|
|
374 |
else:
|
375 |
words = len(content.split())
|
376 |
chars = len(content)
|
377 |
-
|
378 |
analysis += f"- File Type: Text\n"
|
379 |
analysis += f"- Total Lines: {total_lines:,}\n"
|
380 |
analysis += f"- Non-empty Lines: {non_empty_lines:,}\n"
|
@@ -395,162 +369,10 @@ def read_uploaded_file(file):
|
|
395 |
|
396 |
# ------------------------- CSS -------------------------
|
397 |
CSS = """
|
398 |
-
/*
|
399 |
-
:root {
|
400 |
-
--primary-color: #2196f3;
|
401 |
-
--secondary-color: #1976d2;
|
402 |
-
--background-color: #f0f2f5;
|
403 |
-
--card-background: #ffffff;
|
404 |
-
--text-color: #333333;
|
405 |
-
--shadow-color: rgba(0, 0, 0, 0.1);
|
406 |
-
}
|
407 |
-
body {
|
408 |
-
background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
|
409 |
-
min-height: 100vh;
|
410 |
-
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
|
411 |
-
}
|
412 |
-
.container {
|
413 |
-
transform-style: preserve-3d;
|
414 |
-
perspective: 1000px;
|
415 |
-
}
|
416 |
-
.chatbot {
|
417 |
-
background: var(--card-background);
|
418 |
-
border-radius: 20px;
|
419 |
-
box-shadow:
|
420 |
-
0 10px 20px var(--shadow-color),
|
421 |
-
0 6px 6px var(--shadow-color);
|
422 |
-
transform: translateZ(0);
|
423 |
-
transition: transform 0.3s ease;
|
424 |
-
backdrop-filter: blur(10px);
|
425 |
-
}
|
426 |
-
.chatbot:hover {
|
427 |
-
transform: translateZ(10px);
|
428 |
-
}
|
429 |
-
/* ๋ฉ์์ง ์
๋ ฅ ์์ญ */
|
430 |
-
.input-area {
|
431 |
-
background: var(--card-background);
|
432 |
-
border-radius: 15px;
|
433 |
-
padding: 15px;
|
434 |
-
margin-top: 20px;
|
435 |
-
box-shadow:
|
436 |
-
0 5px 15px var(--shadow-color),
|
437 |
-
0 3px 3px var(--shadow-color);
|
438 |
-
transform: translateZ(0);
|
439 |
-
transition: all 0.3s ease;
|
440 |
-
display: flex;
|
441 |
-
align-items: center;
|
442 |
-
gap: 10px;
|
443 |
-
}
|
444 |
-
.input-area:hover {
|
445 |
-
transform: translateZ(5px);
|
446 |
-
}
|
447 |
-
/* ๋ฒํผ ์คํ์ผ */
|
448 |
-
.custom-button {
|
449 |
-
background: linear-gradient(145deg, var(--primary-color), var(--secondary-color));
|
450 |
-
color: white;
|
451 |
-
border: none;
|
452 |
-
border-radius: 10px;
|
453 |
-
padding: 10px 20px;
|
454 |
-
font-weight: 600;
|
455 |
-
cursor: pointer;
|
456 |
-
transform: translateZ(0);
|
457 |
-
transition: all 0.3s ease;
|
458 |
-
box-shadow:
|
459 |
-
0 4px 6px var(--shadow-color),
|
460 |
-
0 1px 3px var(--shadow-color);
|
461 |
-
}
|
462 |
-
.custom-button:hover {
|
463 |
-
transform: translateZ(5px) translateY(-2px);
|
464 |
-
box-shadow:
|
465 |
-
0 7px 14px var(--shadow-color),
|
466 |
-
0 3px 6px var(--shadow-color);
|
467 |
-
}
|
468 |
-
/* ํ์ผ ์
๋ก๋ ๋ฒํผ */
|
469 |
-
.file-upload-icon {
|
470 |
-
background: linear-gradient(145deg, #64b5f6, #42a5f5);
|
471 |
-
color: white;
|
472 |
-
border-radius: 8px;
|
473 |
-
font-size: 2em;
|
474 |
-
cursor: pointer;
|
475 |
-
display: flex;
|
476 |
-
align-items: center;
|
477 |
-
justify-content: center;
|
478 |
-
height: 70px;
|
479 |
-
width: 70px;
|
480 |
-
transition: all 0.3s ease;
|
481 |
-
box-shadow: 0 2px 5px rgba(0,0,0,0.1);
|
482 |
-
}
|
483 |
-
.file-upload-icon:hover {
|
484 |
-
transform: translateY(-2px);
|
485 |
-
box-shadow: 0 4px 8px rgba(0,0,0,0.2);
|
486 |
-
}
|
487 |
-
/* ํ์ผ ์
๋ก๋ ๋ฒํผ ๋ด๋ถ ์์ ์คํ์ผ๋ง */
|
488 |
-
.file-upload-icon > .wrap {
|
489 |
-
display: flex !important;
|
490 |
-
align-items: center;
|
491 |
-
justify-content: center;
|
492 |
-
width: 100%;
|
493 |
-
height: 100%;
|
494 |
-
}
|
495 |
-
.file-upload-icon > .wrap > p {
|
496 |
-
display: none !important;
|
497 |
-
}
|
498 |
-
.file-upload-icon > .wrap::before {
|
499 |
-
content: "๐";
|
500 |
-
font-size: 2em;
|
501 |
-
display: block;
|
502 |
-
}
|
503 |
-
/* ๋ฉ์์ง ์คํ์ผ */
|
504 |
-
.message {
|
505 |
-
background: var(--card-background);
|
506 |
-
border-radius: 15px;
|
507 |
-
padding: 15px;
|
508 |
-
margin: 10px 0;
|
509 |
-
box-shadow:
|
510 |
-
0 4px 6px var(--shadow-color),
|
511 |
-
0 1px 3px var(--shadow-color);
|
512 |
-
transform: translateZ(0);
|
513 |
-
transition: all 0.3s ease;
|
514 |
-
}
|
515 |
-
.message:hover {
|
516 |
-
transform: translateZ(5px);
|
517 |
-
}
|
518 |
-
.chat-container {
|
519 |
-
height: 600px !important;
|
520 |
-
margin-bottom: 10px;
|
521 |
-
}
|
522 |
-
.input-container {
|
523 |
-
height: 70px !important;
|
524 |
-
display: flex;
|
525 |
-
align-items: center;
|
526 |
-
gap: 10px;
|
527 |
-
margin-top: 5px;
|
528 |
-
}
|
529 |
-
.input-textbox {
|
530 |
-
height: 70px !important;
|
531 |
-
border-radius: 8px !important;
|
532 |
-
font-size: 1.1em !important;
|
533 |
-
padding: 10px 15px !important;
|
534 |
-
display: flex !important;
|
535 |
-
align-items: flex-start !important;
|
536 |
-
}
|
537 |
-
.input-textbox textarea {
|
538 |
-
padding-top: 5px !important;
|
539 |
-
}
|
540 |
-
.send-button {
|
541 |
-
height: 70px !important;
|
542 |
-
min-width: 70px !important;
|
543 |
-
font-size: 1.1em !important;
|
544 |
-
}
|
545 |
-
/* ์ค์ ํจ๋ ๊ธฐ๋ณธ ์คํ์ผ */
|
546 |
-
.settings-panel {
|
547 |
-
padding: 20px;
|
548 |
-
margin-top: 20px;
|
549 |
-
}
|
550 |
"""
|
551 |
|
552 |
def clear_cuda_memory():
|
553 |
-
"""CUDA ์บ์ ์ ๋ฆฌ."""
|
554 |
if hasattr(torch.cuda, 'empty_cache'):
|
555 |
with torch.cuda.device('cuda'):
|
556 |
torch.cuda.empty_cache()
|
@@ -566,13 +388,14 @@ def load_model():
|
|
566 |
device_map="auto",
|
567 |
low_cpu_mem_usage=True,
|
568 |
)
|
|
|
|
|
569 |
return loaded_model
|
570 |
except Exception as e:
|
571 |
print(f"๋ชจ๋ธ ๋ก๋ ์ค๋ฅ: {str(e)}")
|
572 |
raise
|
573 |
|
574 |
def build_prompt(conversation: list) -> str:
|
575 |
-
"""๋ํ ๋ด์ญ์ ๋จ์ ํ
์คํธ ํ๋กฌํํธ๋ก ๋ณํ."""
|
576 |
prompt = ""
|
577 |
for msg in conversation:
|
578 |
if msg["role"] == "user":
|
@@ -597,14 +420,13 @@ def stream_chat(
|
|
597 |
global model, current_file_context
|
598 |
|
599 |
try:
|
600 |
-
# ๋ชจ๋ธ ๋ฏธ๋ก๋์ ๋ก๋ฉ
|
601 |
if model is None:
|
602 |
model = load_model()
|
603 |
|
604 |
print(f'[User input] message: {message}')
|
605 |
print(f'[History] {history}')
|
606 |
|
607 |
-
#
|
608 |
file_context = ""
|
609 |
if uploaded_file and message == "ํ์ผ์ ๋ถ์ํ๊ณ ์์ต๋๋ค...":
|
610 |
current_file_context = None
|
@@ -624,7 +446,7 @@ def stream_chat(
|
|
624 |
elif current_file_context:
|
625 |
file_context = current_file_context
|
626 |
|
627 |
-
#
|
628 |
wiki_context = ""
|
629 |
try:
|
630 |
relevant_contexts = find_relevant_context(message)
|
@@ -639,7 +461,7 @@ def stream_chat(
|
|
639 |
except Exception as e:
|
640 |
print(f"[์ปจํ
์คํธ ๊ฒ์ ์ค๋ฅ] {str(e)}")
|
641 |
|
642 |
-
#
|
643 |
max_history_length = 10
|
644 |
if len(history) > max_history_length:
|
645 |
history = history[-max_history_length:]
|
@@ -651,7 +473,7 @@ def stream_chat(
|
|
651 |
{"role": "assistant", "content": answer}
|
652 |
])
|
653 |
|
654 |
-
#
|
655 |
final_message = message
|
656 |
if file_context:
|
657 |
final_message = file_context + "\nํ์ฌ ์ง๋ฌธ: " + message
|
@@ -662,13 +484,13 @@ def stream_chat(
|
|
662 |
|
663 |
conversation.append({"role": "user", "content": final_message})
|
664 |
|
665 |
-
#
|
666 |
input_ids_str = build_prompt(conversation)
|
667 |
max_context = 8192
|
668 |
tokenized_input = tokenizer(input_ids_str, return_tensors="pt")
|
669 |
input_length = tokenized_input["input_ids"].shape[1]
|
670 |
|
671 |
-
#
|
672 |
if input_length > max_context - max_new_tokens:
|
673 |
print(f"[๊ฒฝ๊ณ ] ์
๋ ฅ์ด ๋๋ฌด ๊น๋๋ค: {input_length} ํ ํฐ -> ์๋ผ๋.")
|
674 |
min_generation = min(256, max_new_tokens)
|
@@ -683,18 +505,18 @@ def stream_chat(
|
|
683 |
print(f"[ํ ํฐ ๊ธธ์ด] {input_length}")
|
684 |
inputs = tokenized_input.to("cuda")
|
685 |
|
686 |
-
# ๋จ์ ํ ํฐ ์๋ก max_new_tokens
|
687 |
remaining = max_context - input_length
|
688 |
if remaining < max_new_tokens:
|
689 |
print(f"[max_new_tokens ์กฐ์ ] {max_new_tokens} -> {remaining}")
|
690 |
max_new_tokens = remaining
|
691 |
|
692 |
-
#
|
693 |
streamer = TextIteratorStreamer(
|
694 |
tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=True
|
695 |
)
|
696 |
|
697 |
-
# (
|
698 |
generate_kwargs = dict(
|
699 |
**inputs,
|
700 |
streamer=streamer,
|
@@ -704,18 +526,18 @@ def stream_chat(
|
|
704 |
max_new_tokens=max_new_tokens,
|
705 |
do_sample=True,
|
706 |
temperature=temperature,
|
707 |
-
pad_token_id=tokenizer.pad_token_id
|
708 |
eos_token_id=tokenizer.eos_token_id,
|
709 |
-
use_cache=
|
710 |
)
|
711 |
|
712 |
clear_cuda_memory()
|
713 |
|
714 |
-
#
|
715 |
thread = Thread(target=model.generate, kwargs=generate_kwargs)
|
716 |
thread.start()
|
717 |
|
718 |
-
#
|
719 |
buffer = ""
|
720 |
partial_message = ""
|
721 |
last_yield_time = time.time()
|
@@ -725,23 +547,23 @@ def stream_chat(
|
|
725 |
buffer += new_text
|
726 |
partial_message += new_text
|
727 |
|
728 |
-
#
|
729 |
current_time = time.time()
|
730 |
if (current_time - last_yield_time > 0.1) or (len(partial_message) > 20):
|
731 |
yield "", history + [[message, buffer]]
|
732 |
partial_message = ""
|
733 |
last_yield_time = current_time
|
734 |
|
735 |
-
# ๋ง์ง๋ง
|
736 |
if buffer:
|
737 |
yield "", history + [[message, buffer]]
|
738 |
|
739 |
-
# ๋ํ
|
740 |
chat_history.add_conversation(message, buffer)
|
741 |
|
742 |
except Exception as e:
|
743 |
print(f"[์คํธ๋ฆฌ๋ฐ ์ค ์ค๋ฅ] {str(e)}")
|
744 |
-
if not buffer:
|
745 |
buffer = f"์๋ต ์์ฑ ์ค ์ค๋ฅ ๋ฐ์: {str(e)}"
|
746 |
yield "", history + [[message, buffer]]
|
747 |
|
@@ -835,7 +657,7 @@ def create_demo():
|
|
835 |
label="Repetition Penalty ๐"
|
836 |
)
|
837 |
|
838 |
-
# ์์
|
839 |
gr.Examples(
|
840 |
examples=[
|
841 |
["Please analyze this code and suggest improvements:\ndef fibonacci(n):\n if n <= 1: return n\n return fibonacci(n-1) + fibonacci(n-2)"],
|
@@ -852,7 +674,7 @@ def create_demo():
|
|
852 |
current_file_context = None
|
853 |
return [], None, "Start a new conversation..."
|
854 |
|
855 |
-
# ๋ฉ์์ง ์ ์ก
|
856 |
msg.submit(
|
857 |
stream_chat,
|
858 |
inputs=[msg, chatbot, file_upload, temperature, max_new_tokens, top_p, top_k, penalty],
|
@@ -889,7 +711,6 @@ def create_demo():
|
|
889 |
|
890 |
return demo
|
891 |
|
892 |
-
# ๋ฉ์ธ ์คํ
|
893 |
if __name__ == "__main__":
|
894 |
demo = create_demo()
|
895 |
demo.launch()
|
|
|
6 |
# 2) Triton์ cudagraphs ์ต์ ํ ๋นํ์ฑํ
|
7 |
os.environ["TRITON_DISABLE_CUDAGRAPHS"] = "1"
|
8 |
|
9 |
+
# (์ต์
) ๊ฒฝ๊ณ ๋ฌด์ ์ค์
|
10 |
import warnings
|
11 |
warnings.filterwarnings("ignore", message="skipping cudagraphs due to mutated inputs")
|
12 |
warnings.filterwarnings("ignore", message="Not enough SMs to use max_autotune_gemm mode")
|
|
|
15 |
# TensorFloat32 ์ฐ์ฐ ํ์ฑํ (์ฑ๋ฅ ์ต์ ํ)
|
16 |
torch.set_float32_matmul_precision('high')
|
17 |
|
|
|
18 |
import torch._inductor
|
19 |
torch._inductor.config.triton.cudagraphs = False
|
20 |
|
|
|
21 |
import torch._dynamo
|
22 |
+
# suppress_errors (์ค๋ฅ ์ eager๋ก fallback)
|
23 |
torch._dynamo.config.suppress_errors = True
|
24 |
|
25 |
import gradio as gr
|
26 |
import spaces
|
|
|
27 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
28 |
|
29 |
from threading import Thread
|
|
|
30 |
from datasets import load_dataset
|
31 |
import numpy as np
|
32 |
from sklearn.feature_extraction.text import TfidfVectorizer
|
33 |
import pandas as pd
|
|
|
34 |
import json
|
35 |
from datetime import datetime
|
36 |
import pyarrow.parquet as pq
|
|
|
40 |
import subprocess
|
41 |
import pytesseract
|
42 |
from pdf2image import convert_from_path
|
43 |
+
import queue
|
44 |
+
import time
|
45 |
|
46 |
# -------------------- PDF to Markdown ๋ณํ ๊ด๋ จ import --------------------
|
47 |
try:
|
|
|
66 |
# ํ๊ฒฝ ๋ณ์ ์ค์
|
67 |
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
68 |
MODEL_ID = "CohereForAI/c4ai-command-r7b-12-2024"
|
|
|
69 |
MODEL_NAME = MODEL_ID.split("/")[-1]
|
70 |
|
71 |
model = None # ์ ์ญ์์ ๊ด๋ฆฌ
|
|
|
75 |
wiki_dataset = load_dataset("lcw99/wikipedia-korean-20240501-1million-qna")
|
76 |
print("Wikipedia dataset loaded:", wiki_dataset)
|
77 |
|
78 |
+
# (2) TF-IDF ๋ฒกํฐ๋ผ์ด์ ์ด๊ธฐํ ๋ฐ ํ์ต (์ผ๋ถ๋ง ์ฌ์ฉ)
|
79 |
print("TF-IDF ๋ฒกํฐํ ์์...")
|
80 |
+
questions = wiki_dataset['train']['question'][:10000]
|
81 |
vectorizer = TfidfVectorizer(max_features=1000)
|
82 |
question_vectors = vectorizer.fit_transform(questions)
|
83 |
print("TF-IDF ๋ฒกํฐํ ์๋ฃ")
|
|
|
138 |
print(f"ํ์คํ ๋ฆฌ ๋ก๋ ์คํจ: {e}")
|
139 |
self.history = []
|
140 |
|
|
|
141 |
chat_history = ChatHistory()
|
142 |
|
143 |
# ------------------------- ์ํค ๋ฌธ์ ๊ฒ์ (TF-IDF) -------------------------
|
144 |
def find_relevant_context(query, top_k=3):
|
|
|
145 |
query_vector = vectorizer.transform([query])
|
|
|
146 |
similarities = (query_vector * question_vectors.T).toarray()[0]
|
|
|
147 |
top_indices = np.argsort(similarities)[-top_k:][::-1]
|
148 |
|
149 |
relevant_contexts = []
|
|
|
156 |
})
|
157 |
return relevant_contexts
|
158 |
|
|
|
159 |
def init_msg():
|
160 |
return "ํ์ผ์ ๋ถ์ํ๊ณ ์์ต๋๋ค..."
|
161 |
|
162 |
# -------------------- PDF ํ์ผ์ Markdown์ผ๋ก ๋ณํํ๋ ์ ํธ ํจ์๋ค --------------------
|
163 |
def extract_text_from_pdf(reader: PdfReader) -> str:
|
|
|
|
|
|
|
164 |
full_text = ""
|
165 |
for idx, page in enumerate(reader.pages):
|
166 |
text = page.extract_text() or ""
|
|
|
169 |
return full_text.strip()
|
170 |
|
171 |
def convert_pdf_to_markdown(pdf_file: str):
|
|
|
|
|
|
|
|
|
172 |
try:
|
173 |
reader = PdfReader(pdf_file)
|
174 |
except Exception as e:
|
175 |
return f"PDF ํ์ผ์ ์ฝ๋ ์ค ์ค๋ฅ ๋ฐ์: {e}", None, None
|
176 |
|
|
|
177 |
raw_meta = reader.metadata
|
178 |
metadata = {
|
179 |
"author": raw_meta.author if raw_meta else None,
|
|
|
183 |
"title": raw_meta.title if raw_meta else None,
|
184 |
}
|
185 |
|
|
|
186 |
full_text = extract_text_from_pdf(reader)
|
187 |
|
|
|
188 |
image_count = sum(len(page.images) for page in reader.pages)
|
189 |
if image_count > 0 and len(full_text) < 1000:
|
190 |
try:
|
191 |
out_pdf_file = pdf_file.replace(".pdf", "_ocr.pdf")
|
192 |
ocrmypdf.ocr(pdf_file, out_pdf_file, force_ocr=True)
|
|
|
193 |
reader_ocr = PdfReader(out_pdf_file)
|
194 |
full_text = extract_text_from_pdf(reader_ocr)
|
195 |
except Exception as e:
|
|
|
199 |
|
200 |
# ------------------------- ํ์ผ ๋ถ์ ํจ์ -------------------------
|
201 |
def analyze_file_content(content, file_type):
|
|
|
202 |
if file_type in ['parquet', 'csv']:
|
203 |
try:
|
204 |
lines = content.split('\n')
|
|
|
224 |
return f"๐ Document Structure: {total_lines} lines, {paragraphs} paragraphs, approximately {words} words"
|
225 |
|
226 |
def read_uploaded_file(file):
|
|
|
|
|
|
|
227 |
if file is None:
|
228 |
return "", ""
|
229 |
|
230 |
+
import pyarrow.parquet as pq
|
231 |
+
import pandas as pd
|
232 |
+
from tabulate import tabulate
|
233 |
+
|
234 |
try:
|
235 |
file_ext = os.path.splitext(file.name)[1].lower()
|
236 |
|
|
|
237 |
if file_ext == '.parquet':
|
238 |
try:
|
239 |
table = pq.read_table(file.name)
|
|
|
269 |
except Exception as e:
|
270 |
return f"Error reading Parquet file: {str(e)}", "error"
|
271 |
|
272 |
+
elif file_ext == '.pdf':
|
|
|
273 |
try:
|
274 |
markdown_text, metadata, processed_pdf_path = convert_pdf_to_markdown(file.name)
|
275 |
if metadata is None:
|
|
|
279 |
content += "## Metadata\n"
|
280 |
for k, v in metadata.items():
|
281 |
content += f"**{k.capitalize()}**: {v}\n\n"
|
|
|
282 |
content += "## Extracted Text\n\n"
|
283 |
content += markdown_text
|
284 |
+
|
285 |
return content, "pdf"
|
286 |
except Exception as e:
|
287 |
return f"Error reading PDF file: {str(e)}", "error"
|
288 |
|
|
|
289 |
elif file_ext == '.csv':
|
290 |
encodings = ['utf-8', 'cp949', 'euc-kr', 'latin1']
|
291 |
for encoding in encodings:
|
|
|
318 |
f"Unable to read file with supported encodings ({', '.join(encodings)})"
|
319 |
)
|
320 |
|
|
|
321 |
else:
|
322 |
encodings = ['utf-8', 'cp949', 'euc-kr', 'latin1']
|
323 |
for encoding in encodings:
|
|
|
333 |
for keyword in ['def ', 'class ', 'import ', 'function']
|
334 |
)
|
335 |
|
336 |
+
analysis = "\n๐ File Analysis:\n"
|
337 |
if is_code:
|
338 |
functions = sum('def ' in line for line in lines)
|
339 |
classes = sum('class ' in line for line in lines)
|
|
|
349 |
else:
|
350 |
words = len(content.split())
|
351 |
chars = len(content)
|
|
|
352 |
analysis += f"- File Type: Text\n"
|
353 |
analysis += f"- Total Lines: {total_lines:,}\n"
|
354 |
analysis += f"- Non-empty Lines: {non_empty_lines:,}\n"
|
|
|
369 |
|
370 |
# ------------------------- CSS -------------------------
|
371 |
CSS = """
|
372 |
+
/* (์๋ต: ๋์ผ) */
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
373 |
"""
|
374 |
|
375 |
def clear_cuda_memory():
|
|
|
376 |
if hasattr(torch.cuda, 'empty_cache'):
|
377 |
with torch.cuda.device('cuda'):
|
378 |
torch.cuda.empty_cache()
|
|
|
388 |
device_map="auto",
|
389 |
low_cpu_mem_usage=True,
|
390 |
)
|
391 |
+
# (์ค์) ๋ชจ๋ธ ๊ธฐ๋ณธ config์์๋ ์บ์ ์ฌ์ฉ ๊บผ๋ ์ ์์
|
392 |
+
loaded_model.config.use_cache = False
|
393 |
return loaded_model
|
394 |
except Exception as e:
|
395 |
print(f"๋ชจ๋ธ ๋ก๋ ์ค๋ฅ: {str(e)}")
|
396 |
raise
|
397 |
|
398 |
def build_prompt(conversation: list) -> str:
|
|
|
399 |
prompt = ""
|
400 |
for msg in conversation:
|
401 |
if msg["role"] == "user":
|
|
|
420 |
global model, current_file_context
|
421 |
|
422 |
try:
|
|
|
423 |
if model is None:
|
424 |
model = load_model()
|
425 |
|
426 |
print(f'[User input] message: {message}')
|
427 |
print(f'[History] {history}')
|
428 |
|
429 |
+
# 1) ํ์ผ ์
๋ก๋ ์ฒ๋ฆฌ
|
430 |
file_context = ""
|
431 |
if uploaded_file and message == "ํ์ผ์ ๋ถ์ํ๊ณ ์์ต๋๋ค...":
|
432 |
current_file_context = None
|
|
|
446 |
elif current_file_context:
|
447 |
file_context = current_file_context
|
448 |
|
449 |
+
# 2) ์ํค ์ปจํ
์คํธ
|
450 |
wiki_context = ""
|
451 |
try:
|
452 |
relevant_contexts = find_relevant_context(message)
|
|
|
461 |
except Exception as e:
|
462 |
print(f"[์ปจํ
์คํธ ๊ฒ์ ์ค๋ฅ] {str(e)}")
|
463 |
|
464 |
+
# 3) ๋ํ ์ด๋ ฅ ์ถ์
|
465 |
max_history_length = 10
|
466 |
if len(history) > max_history_length:
|
467 |
history = history[-max_history_length:]
|
|
|
473 |
{"role": "assistant", "content": answer}
|
474 |
])
|
475 |
|
476 |
+
# 4) ์ต์ข
๋ฉ์์ง
|
477 |
final_message = message
|
478 |
if file_context:
|
479 |
final_message = file_context + "\nํ์ฌ ์ง๋ฌธ: " + message
|
|
|
484 |
|
485 |
conversation.append({"role": "user", "content": final_message})
|
486 |
|
487 |
+
# 5) ํ ํฐํ
|
488 |
input_ids_str = build_prompt(conversation)
|
489 |
max_context = 8192
|
490 |
tokenized_input = tokenizer(input_ids_str, return_tensors="pt")
|
491 |
input_length = tokenized_input["input_ids"].shape[1]
|
492 |
|
493 |
+
# 6) ์ปจํ
์คํธ ์ด๊ณผ ์ ์๋ฅด๊ธฐ
|
494 |
if input_length > max_context - max_new_tokens:
|
495 |
print(f"[๊ฒฝ๊ณ ] ์
๋ ฅ์ด ๋๋ฌด ๊น๋๋ค: {input_length} ํ ํฐ -> ์๋ผ๋.")
|
496 |
min_generation = min(256, max_new_tokens)
|
|
|
505 |
print(f"[ํ ํฐ ๊ธธ์ด] {input_length}")
|
506 |
inputs = tokenized_input.to("cuda")
|
507 |
|
508 |
+
# 7) ๋จ์ ํ ํฐ ์๋ก max_new_tokens ๋ณด์
|
509 |
remaining = max_context - input_length
|
510 |
if remaining < max_new_tokens:
|
511 |
print(f"[max_new_tokens ์กฐ์ ] {max_new_tokens} -> {remaining}")
|
512 |
max_new_tokens = remaining
|
513 |
|
514 |
+
# 8) TextIteratorStreamer ์ค์
|
515 |
streamer = TextIteratorStreamer(
|
516 |
tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=True
|
517 |
)
|
518 |
|
519 |
+
# โ
use_cache=False ์ค์ (์ค์) โ
|
520 |
generate_kwargs = dict(
|
521 |
**inputs,
|
522 |
streamer=streamer,
|
|
|
526 |
max_new_tokens=max_new_tokens,
|
527 |
do_sample=True,
|
528 |
temperature=temperature,
|
529 |
+
pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
|
530 |
eos_token_id=tokenizer.eos_token_id,
|
531 |
+
use_cache=False, # โ ์ฌ๊ธฐ๊ฐ ํต์ฌ!
|
532 |
)
|
533 |
|
534 |
clear_cuda_memory()
|
535 |
|
536 |
+
# 9) ๋ณ๋ ์ค๋ ๋๋ก ๋ชจ๋ธ ํธ์ถ
|
537 |
thread = Thread(target=model.generate, kwargs=generate_kwargs)
|
538 |
thread.start()
|
539 |
|
540 |
+
# 10) ์คํธ๋ฆฌ๋ฐ
|
541 |
buffer = ""
|
542 |
partial_message = ""
|
543 |
last_yield_time = time.time()
|
|
|
547 |
buffer += new_text
|
548 |
partial_message += new_text
|
549 |
|
550 |
+
# ํ์ด๋ฐ or ์ผ์ ๊ธธ์ด๋ง๋ค UI ์
๋ฐ์ดํธ
|
551 |
current_time = time.time()
|
552 |
if (current_time - last_yield_time > 0.1) or (len(partial_message) > 20):
|
553 |
yield "", history + [[message, buffer]]
|
554 |
partial_message = ""
|
555 |
last_yield_time = current_time
|
556 |
|
557 |
+
# ๋ง์ง๋ง ์ถ๋ ฅ
|
558 |
if buffer:
|
559 |
yield "", history + [[message, buffer]]
|
560 |
|
561 |
+
# ๋ํ ํ์คํ ๋ฆฌ ์ ์ฅ
|
562 |
chat_history.add_conversation(message, buffer)
|
563 |
|
564 |
except Exception as e:
|
565 |
print(f"[์คํธ๋ฆฌ๋ฐ ์ค ์ค๋ฅ] {str(e)}")
|
566 |
+
if not buffer:
|
567 |
buffer = f"์๋ต ์์ฑ ์ค ์ค๋ฅ ๋ฐ์: {str(e)}"
|
568 |
yield "", history + [[message, buffer]]
|
569 |
|
|
|
657 |
label="Repetition Penalty ๐"
|
658 |
)
|
659 |
|
660 |
+
# ์์ ์
๋ ฅ
|
661 |
gr.Examples(
|
662 |
examples=[
|
663 |
["Please analyze this code and suggest improvements:\ndef fibonacci(n):\n if n <= 1: return n\n return fibonacci(n-1) + fibonacci(n-2)"],
|
|
|
674 |
current_file_context = None
|
675 |
return [], None, "Start a new conversation..."
|
676 |
|
677 |
+
# ๋ฉ์์ง ์ ์ก(Submit)
|
678 |
msg.submit(
|
679 |
stream_chat,
|
680 |
inputs=[msg, chatbot, file_upload, temperature, max_new_tokens, top_p, top_k, penalty],
|
|
|
711 |
|
712 |
return demo
|
713 |
|
|
|
714 |
if __name__ == "__main__":
|
715 |
demo = create_demo()
|
716 |
demo.launch()
|