ginipick commited on
Commit
dcabc2d
Β·
verified Β·
1 Parent(s): 442cd87

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +800 -0
app.py ADDED
@@ -0,0 +1,800 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import os
4
+ import re
5
+ import tempfile
6
+ import gc # garbage collector μΆ”κ°€
7
+ from collections.abc import Iterator
8
+ from threading import Thread
9
+ import json
10
+ import requests
11
+ import cv2
12
+ import gradio as gr
13
+ import spaces
14
+ import torch
15
+ from loguru import logger
16
+ from PIL import Image
17
+ from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
18
+
19
+ # CSV/TXT 뢄석
20
+ import pandas as pd
21
+ # PDF ν…μŠ€νŠΈ μΆ”μΆœ
22
+ import PyPDF2
23
+
24
+ ##############################################################################
25
+ # λ©”λͺ¨λ¦¬ 정리 ν•¨μˆ˜ μΆ”κ°€
26
+ ##############################################################################
27
+ def clear_cuda_cache():
28
+ """CUDA μΊμ‹œλ₯Ό λͺ…μ‹œμ μœΌλ‘œ λΉ„μ›λ‹ˆλ‹€."""
29
+ if torch.cuda.is_available():
30
+ torch.cuda.empty_cache()
31
+ gc.collect()
32
+
33
+ ##############################################################################
34
+ # SERPHouse API key from environment variable
35
+ ##############################################################################
36
+ SERPHOUSE_API_KEY = os.getenv("SERPHOUSE_API_KEY", "")
37
+
38
+ ##############################################################################
39
+ # κ°„λ‹¨ν•œ ν‚€μ›Œλ“œ μΆ”μΆœ ν•¨μˆ˜ (ν•œκΈ€ + μ•ŒνŒŒλ²³ + 숫자 + 곡백 보쑴)
40
+ ##############################################################################
41
+ def extract_keywords(text: str, top_k: int = 5) -> str:
42
+ """
43
+ 1) ν•œκΈ€(κ°€-힣), μ˜μ–΄(a-zA-Z), 숫자(0-9), 곡백만 남김
44
+ 2) 곡백 κΈ°μ€€ 토큰 뢄리
45
+ 3) μ΅œλŒ€ top_k개만
46
+ """
47
+ text = re.sub(r"[^a-zA-Z0-9κ°€-힣\s]", "", text)
48
+ tokens = text.split()
49
+ key_tokens = tokens[:top_k]
50
+ return " ".join(key_tokens)
51
+
52
+ ##############################################################################
53
+ # SerpHouse Live endpoint 호좜
54
+ # - μƒμœ„ 20개 κ²°κ³Ό JSON을 LLM에 λ„˜κΈΈ λ•Œ link, snippet λ“± λͺ¨λ‘ 포함
55
+ ##############################################################################
56
+ def do_web_search(query: str) -> str:
57
+ """
58
+ μƒμœ„ 20개 'organic' κ²°κ³Ό item 전체(제λͺ©, link, snippet λ“±)λ₯Ό
59
+ JSON λ¬Έμžμ—΄ ν˜•νƒœλ‘œ λ°˜ν™˜
60
+ """
61
+ try:
62
+ url = "https://api.serphouse.com/serp/live"
63
+
64
+ # κΈ°λ³Έ GET λ°©μ‹μœΌλ‘œ νŒŒλΌλ―Έν„° κ°„μ†Œν™”ν•˜κ³  κ²°κ³Ό 수λ₯Ό 20개둜 μ œν•œ
65
+ params = {
66
+ "q": query,
67
+ "domain": "google.com",
68
+ "serp_type": "web", # κΈ°λ³Έ μ›Ή 검색
69
+ "device": "desktop",
70
+ "lang": "en",
71
+ "num": "20" # μ΅œλŒ€ 20개 결과만 μš”μ²­
72
+ }
73
+
74
+ headers = {
75
+ "Authorization": f"Bearer {SERPHOUSE_API_KEY}"
76
+ }
77
+
78
+ logger.info(f"SerpHouse API 호좜 쀑... 검색어: {query}")
79
+ logger.info(f"μš”μ²­ URL: {url} - νŒŒλΌλ―Έν„°: {params}")
80
+
81
+ # GET μš”μ²­ μˆ˜ν–‰
82
+ response = requests.get(url, headers=headers, params=params, timeout=60)
83
+ response.raise_for_status()
84
+
85
+ logger.info(f"SerpHouse API 응닡 μƒνƒœ μ½”λ“œ: {response.status_code}")
86
+ data = response.json()
87
+
88
+ # λ‹€μ–‘ν•œ 응닡 ꡬ쑰 처리
89
+ results = data.get("results", {})
90
+ organic = None
91
+
92
+ # κ°€λŠ₯ν•œ 응닡 ꡬ쑰 1
93
+ if isinstance(results, dict) and "organic" in results:
94
+ organic = results["organic"]
95
+
96
+ # κ°€λŠ₯ν•œ 응닡 ꡬ쑰 2 (μ€‘μ²©λœ results)
97
+ elif isinstance(results, dict) and "results" in results:
98
+ if isinstance(results["results"], dict) and "organic" in results["results"]:
99
+ organic = results["results"]["organic"]
100
+
101
+ # κ°€λŠ₯ν•œ 응닡 ꡬ쑰 3 (μ΅œμƒμœ„ organic)
102
+ elif "organic" in data:
103
+ organic = data["organic"]
104
+
105
+ if not organic:
106
+ logger.warning("μ‘λ‹΅μ—μ„œ organic κ²°κ³Όλ₯Ό 찾을 수 μ—†μŠ΅λ‹ˆλ‹€.")
107
+ logger.debug(f"응닡 ꡬ쑰: {list(data.keys())}")
108
+ if isinstance(results, dict):
109
+ logger.debug(f"results ꡬ쑰: {list(results.keys())}")
110
+ return "No web search results found or unexpected API response structure."
111
+
112
+ # κ²°κ³Ό 수 μ œν•œ 및 μ»¨ν…μŠ€νŠΈ 길이 μ΅œμ ν™”
113
+ max_results = min(20, len(organic))
114
+ limited_organic = organic[:max_results]
115
+
116
+ # κ²°κ³Ό ν˜•μ‹ κ°œμ„  - λ§ˆν¬λ‹€μš΄ ν˜•μ‹μœΌλ‘œ 좜λ ₯ν•˜μ—¬ 가독성 ν–₯상
117
+ summary_lines = []
118
+ for idx, item in enumerate(limited_organic, start=1):
119
+ title = item.get("title", "No title")
120
+ link = item.get("link", "#")
121
+ snippet = item.get("snippet", "No description")
122
+ displayed_link = item.get("displayed_link", link)
123
+
124
+ # λ§ˆν¬λ‹€μš΄ ν˜•μ‹ (링크 클릭 κ°€λŠ₯)
125
+ summary_lines.append(
126
+ f"### Result {idx}: {title}\n\n"
127
+ f"{snippet}\n\n"
128
+ f"**좜처**: [{displayed_link}]({link})\n\n"
129
+ f"---\n"
130
+ )
131
+
132
+ # λͺ¨λΈμ—κ²Œ λͺ…ν™•ν•œ μ§€μΉ¨ μΆ”κ°€
133
+ instructions = """
134
+ # μ›Ή 검색 κ²°κ³Ό
135
+ μ•„λž˜λŠ” 검색 κ²°κ³Όμž…λ‹ˆλ‹€. μ§ˆλ¬Έμ— λ‹΅λ³€ν•  λ•Œ 이 정보λ₯Ό ν™œμš©ν•˜μ„Έμš”:
136
+ 1. 각 결과의 제λͺ©, λ‚΄μš©, 좜처 링크λ₯Ό μ°Έκ³ ν•˜μ„Έμš”
137
+ 2. 닡변에 κ΄€λ ¨ μ •λ³΄μ˜ 좜처λ₯Ό λͺ…μ‹œμ μœΌλ‘œ μΈμš©ν•˜μ„Έμš” (예: "X μΆœμ²˜μ— λ”°λ₯΄λ©΄...")
138
+ 3. 응닡에 μ‹€μ œ 좜처 링크λ₯Ό ν¬ν•¨ν•˜μ„Έμš”
139
+ 4. μ—¬λŸ¬ 좜처의 정보λ₯Ό μ’…ν•©ν•˜μ—¬ λ‹΅λ³€ν•˜μ„Έμš”
140
+ """
141
+
142
+ search_results = instructions + "\n".join(summary_lines)
143
+ logger.info(f"검색 κ²°κ³Ό {len(limited_organic)}개 처리 μ™„λ£Œ")
144
+ return search_results
145
+
146
+ except Exception as e:
147
+ logger.error(f"Web search failed: {e}")
148
+ return f"Web search failed: {str(e)}"
149
+
150
+
151
+ ##############################################################################
152
+ # λͺ¨λΈ/ν”„λ‘œμ„Έμ„œ λ‘œλ”©
153
+ ##############################################################################
154
+ MAX_CONTENT_CHARS = 2000
155
+ MAX_INPUT_LENGTH = 2096 # μ΅œλŒ€ μž…λ ₯ 토큰 수 μ œν•œ μΆ”κ°€
156
+ model_id = os.getenv("MODEL_ID", "VIDraft/Gemma-3-R1984-4B")
157
+
158
+ processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
159
+ model = Gemma3ForConditionalGeneration.from_pretrained(
160
+ model_id,
161
+ device_map="auto",
162
+ torch_dtype=torch.bfloat16,
163
+ attn_implementation="eager" # κ°€λŠ₯ν•˜λ‹€λ©΄ "flash_attention_2"둜 λ³€κ²½
164
+ )
165
+ MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5"))
166
+
167
+
168
+ ##############################################################################
169
+ # CSV, TXT, PDF 뢄석 ν•¨μˆ˜
170
+ ##############################################################################
171
+ def analyze_csv_file(path: str) -> str:
172
+ """
173
+ CSV νŒŒμΌμ„ 전체 λ¬Έμžμ—΄λ‘œ λ³€ν™˜. λ„ˆλ¬΄ κΈΈ 경우 μΌλΆ€λ§Œ ν‘œμ‹œ.
174
+ """
175
+ try:
176
+ df = pd.read_csv(path)
177
+ if df.shape[0] > 50 or df.shape[1] > 10:
178
+ df = df.iloc[:50, :10]
179
+ df_str = df.to_string()
180
+ if len(df_str) > MAX_CONTENT_CHARS:
181
+ df_str = df_str[:MAX_CONTENT_CHARS] + "\n...(truncated)..."
182
+ return f"**[CSV File: {os.path.basename(path)}]**\n\n{df_str}"
183
+ except Exception as e:
184
+ return f"Failed to read CSV ({os.path.basename(path)}): {str(e)}"
185
+
186
+
187
+ def analyze_txt_file(path: str) -> str:
188
+ """
189
+ TXT 파일 μ „λ¬Έ 읽기. λ„ˆλ¬΄ κΈΈλ©΄ μΌλΆ€λ§Œ ν‘œμ‹œ.
190
+ """
191
+ try:
192
+ with open(path, "r", encoding="utf-8") as f:
193
+ text = f.read()
194
+ if len(text) > MAX_CONTENT_CHARS:
195
+ text = text[:MAX_CONTENT_CHARS] + "\n...(truncated)..."
196
+ return f"**[TXT File: {os.path.basename(path)}]**\n\n{text}"
197
+ except Exception as e:
198
+ return f"Failed to read TXT ({os.path.basename(path)}): {str(e)}"
199
+
200
+
201
+ def pdf_to_markdown(pdf_path: str) -> str:
202
+ """
203
+ PDF ν…μŠ€νŠΈλ₯Ό Markdown으둜 λ³€ν™˜. νŽ˜μ΄μ§€λ³„λ‘œ κ°„λ‹¨νžˆ ν…μŠ€νŠΈ μΆ”μΆœ.
204
+ """
205
+ text_chunks = []
206
+ try:
207
+ with open(pdf_path, "rb") as f:
208
+ reader = PyPDF2.PdfReader(f)
209
+ max_pages = min(5, len(reader.pages))
210
+ for page_num in range(max_pages):
211
+ page = reader.pages[page_num]
212
+ page_text = page.extract_text() or ""
213
+ page_text = page_text.strip()
214
+ if page_text:
215
+ if len(page_text) > MAX_CONTENT_CHARS // max_pages:
216
+ page_text = page_text[:MAX_CONTENT_CHARS // max_pages] + "...(truncated)"
217
+ text_chunks.append(f"## Page {page_num+1}\n\n{page_text}\n")
218
+ if len(reader.pages) > max_pages:
219
+ text_chunks.append(f"\n...(Showing {max_pages} of {len(reader.pages)} pages)...")
220
+ except Exception as e:
221
+ return f"Failed to read PDF ({os.path.basename(pdf_path)}): {str(e)}"
222
+
223
+ full_text = "\n".join(text_chunks)
224
+ if len(full_text) > MAX_CONTENT_CHARS:
225
+ full_text = full_text[:MAX_CONTENT_CHARS] + "\n...(truncated)..."
226
+
227
+ return f"**[PDF File: {os.path.basename(pdf_path)}]**\n\n{full_text}"
228
+
229
+
230
+ ##############################################################################
231
+ # 이미지/λΉ„λ””μ˜€ μ—…λ‘œλ“œ μ œν•œ 검사
232
+ ##############################################################################
233
+ def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
234
+ image_count = 0
235
+ video_count = 0
236
+ for path in paths:
237
+ if path.endswith(".mp4"):
238
+ video_count += 1
239
+ elif re.search(r"\.(png|jpg|jpeg|gif|webp)$", path, re.IGNORECASE):
240
+ image_count += 1
241
+ return image_count, video_count
242
+
243
+
244
+ def count_files_in_history(history: list[dict]) -> tuple[int, int]:
245
+ image_count = 0
246
+ video_count = 0
247
+ for item in history:
248
+ if item["role"] != "user" or isinstance(item["content"], str):
249
+ continue
250
+ if isinstance(item["content"], list) and len(item["content"]) > 0:
251
+ file_path = item["content"][0]
252
+ if isinstance(file_path, str):
253
+ if file_path.endswith(".mp4"):
254
+ video_count += 1
255
+ elif re.search(r"\.(png|jpg|jpeg|gif|webp)$", file_path, re.IGNORECASE):
256
+ image_count += 1
257
+ return image_count, video_count
258
+
259
+
260
+ def validate_media_constraints(message: dict, history: list[dict]) -> bool:
261
+ media_files = []
262
+ for f in message["files"]:
263
+ if re.search(r"\.(png|jpg|jpeg|gif|webp)$", f, re.IGNORECASE) or f.endswith(".mp4"):
264
+ media_files.append(f)
265
+
266
+ new_image_count, new_video_count = count_files_in_new_message(media_files)
267
+ history_image_count, history_video_count = count_files_in_history(history)
268
+ image_count = history_image_count + new_image_count
269
+ video_count = history_video_count + new_video_count
270
+
271
+ if video_count > 1:
272
+ gr.Warning("Only one video is supported.")
273
+ return False
274
+ if video_count == 1:
275
+ if image_count > 0:
276
+ gr.Warning("Mixing images and videos is not allowed.")
277
+ return False
278
+ if "<image>" in message["text"]:
279
+ gr.Warning("Using <image> tags with video files is not supported.")
280
+ return False
281
+ if video_count == 0 and image_count > MAX_NUM_IMAGES:
282
+ gr.Warning(f"You can upload up to {MAX_NUM_IMAGES} images.")
283
+ return False
284
+
285
+ if "<image>" in message["text"]:
286
+ image_files = [f for f in message["files"] if re.search(r"\.(png|jpg|jpeg|gif|webp)$", f, re.IGNORECASE)]
287
+ image_tag_count = message["text"].count("<image>")
288
+ if image_tag_count != len(image_files):
289
+ gr.Warning("The number of <image> tags in the text does not match the number of image files.")
290
+ return False
291
+
292
+ return True
293
+
294
+
295
+ ##############################################################################
296
+ # λΉ„λ””μ˜€ 처리 - μž„μ‹œ 파일 좔적 μ½”λ“œ μΆ”κ°€
297
+ ##############################################################################
298
+ def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]:
299
+ vidcap = cv2.VideoCapture(video_path)
300
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
301
+ total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
302
+ frame_interval = max(int(fps), int(total_frames / 10))
303
+ frames = []
304
+
305
+ for i in range(0, total_frames, frame_interval):
306
+ vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
307
+ success, image = vidcap.read()
308
+ if success:
309
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
310
+ # 이미지 크기 쀄이기 μΆ”κ°€
311
+ image = cv2.resize(image, (0, 0), fx=0.5, fy=0.5)
312
+ pil_image = Image.fromarray(image)
313
+ timestamp = round(i / fps, 2)
314
+ frames.append((pil_image, timestamp))
315
+ if len(frames) >= 5:
316
+ break
317
+
318
+ vidcap.release()
319
+ return frames
320
+
321
+
322
+ def process_video(video_path: str) -> tuple[list[dict], list[str]]:
323
+ content = []
324
+ temp_files = [] # μž„μ‹œ 파일 좔적을 μœ„ν•œ 리슀트
325
+
326
+ frames = downsample_video(video_path)
327
+ for frame in frames:
328
+ pil_image, timestamp = frame
329
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
330
+ pil_image.save(temp_file.name)
331
+ temp_files.append(temp_file.name) # 좔적을 μœ„ν•΄ 경둜 μ €μž₯
332
+ content.append({"type": "text", "text": f"Frame {timestamp}:"})
333
+ content.append({"type": "image", "url": temp_file.name})
334
+
335
+ return content, temp_files
336
+
337
+
338
+ ##############################################################################
339
+ # interleaved <image> 처리
340
+ ##############################################################################
341
+ def process_interleaved_images(message: dict) -> list[dict]:
342
+ parts = re.split(r"(<image>)", message["text"])
343
+ content = []
344
+ image_index = 0
345
+
346
+ image_files = [f for f in message["files"] if re.search(r"\.(png|jpg|jpeg|gif|webp)$", f, re.IGNORECASE)]
347
+
348
+ for part in parts:
349
+ if part == "<image>" and image_index < len(image_files):
350
+ content.append({"type": "image", "url": image_files[image_index]})
351
+ image_index += 1
352
+ elif part.strip():
353
+ content.append({"type": "text", "text": part.strip()})
354
+ else:
355
+ if isinstance(part, str) and part != "<image>":
356
+ content.append({"type": "text", "text": part})
357
+ return content
358
+
359
+
360
+ ##############################################################################
361
+ # PDF + CSV + TXT + 이미지/λΉ„λ””μ˜€
362
+ ##############################################################################
363
+ def is_image_file(file_path: str) -> bool:
364
+ return bool(re.search(r"\.(png|jpg|jpeg|gif|webp)$", file_path, re.IGNORECASE))
365
+
366
+ def is_video_file(file_path: str) -> bool:
367
+ return file_path.endswith(".mp4")
368
+
369
+ def is_document_file(file_path: str) -> bool:
370
+ return (
371
+ file_path.lower().endswith(".pdf")
372
+ or file_path.lower().endswith(".csv")
373
+ or file_path.lower().endswith(".txt")
374
+ )
375
+
376
+
377
+ def process_new_user_message(message: dict) -> tuple[list[dict], list[str]]:
378
+ temp_files = [] # μž„μ‹œ 파일 μΆ”μ μš© ��슀트
379
+
380
+ if not message["files"]:
381
+ return [{"type": "text", "text": message["text"]}], temp_files
382
+
383
+ video_files = [f for f in message["files"] if is_video_file(f)]
384
+ image_files = [f for f in message["files"] if is_image_file(f)]
385
+ csv_files = [f for f in message["files"] if f.lower().endswith(".csv")]
386
+ txt_files = [f for f in message["files"] if f.lower().endswith(".txt")]
387
+ pdf_files = [f for f in message["files"] if f.lower().endswith(".pdf")]
388
+
389
+ content_list = [{"type": "text", "text": message["text"]}]
390
+
391
+ for csv_path in csv_files:
392
+ csv_analysis = analyze_csv_file(csv_path)
393
+ content_list.append({"type": "text", "text": csv_analysis})
394
+
395
+ for txt_path in txt_files:
396
+ txt_analysis = analyze_txt_file(txt_path)
397
+ content_list.append({"type": "text", "text": txt_analysis})
398
+
399
+ for pdf_path in pdf_files:
400
+ pdf_markdown = pdf_to_markdown(pdf_path)
401
+ content_list.append({"type": "text", "text": pdf_markdown})
402
+
403
+ if video_files:
404
+ video_content, video_temp_files = process_video(video_files[0])
405
+ content_list += video_content
406
+ temp_files.extend(video_temp_files)
407
+ return content_list, temp_files
408
+
409
+ if "<image>" in message["text"] and image_files:
410
+ interleaved_content = process_interleaved_images({"text": message["text"], "files": image_files})
411
+ if content_list and content_list[0]["type"] == "text":
412
+ content_list = content_list[1:]
413
+ return interleaved_content + content_list, temp_files
414
+ else:
415
+ for img_path in image_files:
416
+ content_list.append({"type": "image", "url": img_path})
417
+
418
+ return content_list, temp_files
419
+
420
+
421
+ ##############################################################################
422
+ # history -> LLM λ©”μ‹œμ§€ λ³€ν™˜
423
+ ##############################################################################
424
+ def process_history(history: list[dict]) -> list[dict]:
425
+ messages = []
426
+ current_user_content: list[dict] = []
427
+ for item in history:
428
+ if item["role"] == "assistant":
429
+ if current_user_content:
430
+ messages.append({"role": "user", "content": current_user_content})
431
+ current_user_content = []
432
+ messages.append({"role": "assistant", "content": [{"type": "text", "text": item["content"]}]})
433
+ else:
434
+ content = item["content"]
435
+ if isinstance(content, str):
436
+ current_user_content.append({"type": "text", "text": content})
437
+ elif isinstance(content, list) and len(content) > 0:
438
+ file_path = content[0]
439
+ if is_image_file(file_path):
440
+ current_user_content.append({"type": "image", "url": file_path})
441
+ else:
442
+ current_user_content.append({"type": "text", "text": f"[File: {os.path.basename(file_path)}]"})
443
+
444
+ if current_user_content:
445
+ messages.append({"role": "user", "content": current_user_content})
446
+
447
+ return messages
448
+
449
+
450
+ ##############################################################################
451
+ # λͺ¨λΈ 생성 ν•¨μˆ˜μ—μ„œ OOM 캐치
452
+ ##############################################################################
453
+ def _model_gen_with_oom_catch(**kwargs):
454
+ """
455
+ 별도 μŠ€λ ˆλ“œμ—μ„œ OutOfMemoryErrorλ₯Ό μž‘μ•„μ£ΌκΈ° μœ„ν•΄
456
+ """
457
+ try:
458
+ model.generate(**kwargs)
459
+ except torch.cuda.OutOfMemoryError:
460
+ raise RuntimeError(
461
+ "[OutOfMemoryError] GPU λ©”λͺ¨λ¦¬κ°€ λΆ€μ‘±ν•©λ‹ˆλ‹€. "
462
+ "Max New Tokens을 μ€„μ΄κ±°λ‚˜, ν”„λ‘¬ν”„νŠΈ 길이λ₯Ό μ€„μ—¬μ£Όμ„Έμš”."
463
+ )
464
+ finally:
465
+ # 생성 μ™„λ£Œ ν›„ ν•œλ²ˆ 더 μΊμ‹œ λΉ„μš°κΈ°
466
+ clear_cuda_cache()
467
+
468
+
469
+ ##############################################################################
470
+ # 메인 μΆ”λ‘  ν•¨μˆ˜ (web search 체크 μ‹œ μžλ™ ν‚€μ›Œλ“œμΆ”μΆœ->검색->κ²°κ³Ό system msg)
471
+ ##############################################################################
472
+ @spaces.GPU(duration=120)
473
+ def run(
474
+ message: dict,
475
+ history: list[dict],
476
+ system_prompt: str = "",
477
+ max_new_tokens: int = 512,
478
+ use_web_search: bool = False,
479
+ web_search_query: str = "",
480
+ ) -> Iterator[str]:
481
+
482
+ if not validate_media_constraints(message, history):
483
+ yield ""
484
+ return
485
+
486
+ temp_files = [] # μž„μ‹œ 파일 μΆ”μ μš©
487
+
488
+ try:
489
+ combined_system_msg = ""
490
+
491
+ # λ‚΄λΆ€μ μœΌλ‘œλ§Œ μ‚¬μš© (UIμ—μ„œλŠ” 보이지 μ•ŠμŒ)
492
+ if system_prompt.strip():
493
+ combined_system_msg += f"[System Prompt]\n{system_prompt.strip()}\n\n"
494
+
495
+ if use_web_search:
496
+ user_text = message["text"]
497
+ ws_query = extract_keywords(user_text, top_k=5)
498
+ if ws_query.strip():
499
+ logger.info(f"[Auto WebSearch Keyword] {ws_query!r}")
500
+ ws_result = do_web_search(ws_query)
501
+ combined_system_msg += f"[Search top-20 Full Items Based on user prompt]\n{ws_result}\n\n"
502
+ # >>> μΆ”κ°€λœ μ•ˆλ‚΄ 문ꡬ (κ²€οΏ½οΏ½οΏ½ 결과의 link λ“± 좜처λ₯Ό ν™œμš©)
503
+ combined_system_msg += "[μ°Έκ³ : μœ„ 검색결과 λ‚΄μš©κ³Ό linkλ₯Ό 좜처둜 μΈμš©ν•˜μ—¬ λ‹΅λ³€ν•΄ μ£Όμ„Έμš”.]\n\n"
504
+ combined_system_msg += """
505
+ [μ€‘μš” μ§€μ‹œμ‚¬ν•­]
506
+ 1. 닡변에 검색 κ²°κ³Όμ—μ„œ 찾은 μ •λ³΄μ˜ 좜처λ₯Ό λ°˜λ“œμ‹œ μΈμš©ν•˜μ„Έμš”.
507
+ 2. 좜처 인용 μ‹œ "[좜처 제λͺ©](링크)" ν˜•μ‹μ˜ λ§ˆν¬λ‹€μš΄ 링크λ₯Ό μ‚¬μš©ν•˜μ„Έμš”.
508
+ 3. μ—¬λŸ¬ 좜처의 정보λ₯Ό μ’…ν•©ν•˜μ—¬ λ‹΅λ³€ν•˜μ„Έμš”.
509
+ 4. λ‹΅λ³€ λ§ˆμ§€λ§‰μ— "μ°Έκ³  자료:" μ„Ήμ…˜μ„ μΆ”κ°€ν•˜κ³  μ‚¬μš©ν•œ μ£Όμš” 좜처 링크λ₯Ό λ‚˜μ—΄ν•˜μ„Έμš”.
510
+ """
511
+ else:
512
+ combined_system_msg += "[No valid keywords found, skipping WebSearch]\n\n"
513
+
514
+ messages = []
515
+ if combined_system_msg.strip():
516
+ messages.append({
517
+ "role": "system",
518
+ "content": [{"type": "text", "text": combined_system_msg.strip()}],
519
+ })
520
+
521
+ messages.extend(process_history(history))
522
+
523
+ user_content, user_temp_files = process_new_user_message(message)
524
+ temp_files.extend(user_temp_files) # μž„μ‹œ 파일 좔적
525
+
526
+ for item in user_content:
527
+ if item["type"] == "text" and len(item["text"]) > MAX_CONTENT_CHARS:
528
+ item["text"] = item["text"][:MAX_CONTENT_CHARS] + "\n...(truncated)..."
529
+ messages.append({"role": "user", "content": user_content})
530
+
531
+ inputs = processor.apply_chat_template(
532
+ messages,
533
+ add_generation_prompt=True,
534
+ tokenize=True,
535
+ return_dict=True,
536
+ return_tensors="pt",
537
+ ).to(device=model.device, dtype=torch.bfloat16)
538
+
539
+ # μž…λ ₯ 토큰 수 μ œν•œ μΆ”κ°€
540
+ if inputs.input_ids.shape[1] > MAX_INPUT_LENGTH:
541
+ inputs.input_ids = inputs.input_ids[:, -MAX_INPUT_LENGTH:]
542
+ if 'attention_mask' in inputs:
543
+ inputs.attention_mask = inputs.attention_mask[:, -MAX_INPUT_LENGTH:]
544
+
545
+ streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
546
+ gen_kwargs = dict(
547
+ inputs,
548
+ streamer=streamer,
549
+ max_new_tokens=max_new_tokens,
550
+ )
551
+
552
+ t = Thread(target=_model_gen_with_oom_catch, kwargs=gen_kwargs)
553
+ t.start()
554
+
555
+ output = ""
556
+ for new_text in streamer:
557
+ output += new_text
558
+ yield output
559
+
560
+ except Exception as e:
561
+ logger.error(f"Error in run: {str(e)}")
562
+ yield f"μ£„μ†‘ν•©λ‹ˆλ‹€. 였λ₯˜κ°€ λ°œμƒν–ˆμŠ΅λ‹ˆλ‹€: {str(e)}"
563
+
564
+ finally:
565
+ # μž„μ‹œ 파일 μ‚­μ œ
566
+ for temp_file in temp_files:
567
+ try:
568
+ if os.path.exists(temp_file):
569
+ os.unlink(temp_file)
570
+ logger.info(f"Deleted temp file: {temp_file}")
571
+ except Exception as e:
572
+ logger.warning(f"Failed to delete temp file {temp_file}: {e}")
573
+
574
+ # λͺ…μ‹œμ  λ©”λͺ¨λ¦¬ 정리
575
+ try:
576
+ del inputs, streamer
577
+ except:
578
+ pass
579
+
580
+ clear_cuda_cache()
581
+
582
+
583
+
584
+ ##############################################################################
585
+ # μ˜ˆμ‹œλ“€ (AI λ°μ΄νŒ… μ‹œλ‚˜λ¦¬μ˜€μ— 맞좰 6개 μΆ”κ°€)
586
+ ##############################################################################
587
+ examples = [
588
+ [
589
+ {
590
+ "text": "Let's try some roleplay. You are my new online date who wants to get to know me better. Introduce yourself in a sweet, caring way!"
591
+ }
592
+ ],
593
+ [
594
+ {
595
+ "text": "We are on a second date, walking along the beach. Continue the scene with playful conversation and gentle flirting."
596
+ }
597
+ ],
598
+ [
599
+ {
600
+ "text": "I’m feeling anxious about messaging my crush. Could you give me some supportive words or suggestions on how to approach them?"
601
+ }
602
+ ],
603
+ [
604
+ {
605
+ "text": "Tell me a romantic story about two people who overcame obstacles in their relationship."
606
+ }
607
+ ],
608
+ [
609
+ {
610
+ "text": "I want to express my love in a poetic way. Can you help me write a heartfelt poem for my partner?"
611
+ }
612
+ ],
613
+ [
614
+ {
615
+ "text": "We had a small argument. Please help me find a way to apologize sincerely while also expressing my feelings."
616
+ }
617
+ ],
618
+ ]
619
+
620
+ ##############################################################################
621
+ # Gradio UI (Blocks) ꡬ성 (쒌츑 μ‚¬μ΄λ“œ 메뉴 없이 전체화면 μ±„νŒ…)
622
+ ##############################################################################
623
+ css = """
624
+ /* 1) UIλ₯Ό μ²˜μŒλΆ€ν„° κ°€μž₯ λ„“κ²Œ (width 100%) κ³ μ •ν•˜μ—¬ ν‘œμ‹œ */
625
+ .gradio-container {
626
+ background: rgba(255, 255, 255, 0.7); /* λ°°κ²½ 투λͺ…도 증가 */
627
+ padding: 30px 40px;
628
+ margin: 20px auto; /* μœ„μ•„λž˜ μ—¬λ°±λ§Œ μœ μ§€ */
629
+ width: 100% !important;
630
+ max-width: none !important; /* 1200px μ œν•œ 제거 */
631
+ }
632
+ .fillable {
633
+ width: 100% !important;
634
+ max-width: 100% !important;
635
+ }
636
+ /* 2) 배경을 μ™„μ „νžˆ 투λͺ…ν•˜κ²Œ λ³€κ²½ */
637
+ body {
638
+ background: transparent; /* μ™„μ „ 투λͺ… λ°°κ²½ */
639
+ margin: 0;
640
+ padding: 0;
641
+ font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif;
642
+ color: #333;
643
+ }
644
+ /* λ²„νŠΌ 색상 μ™„μ „νžˆ μ œκ±°ν•˜κ³  투λͺ…ν•˜κ²Œ */
645
+ button, .btn {
646
+ background: transparent !important; /* 색상 μ™„μ „νžˆ 제거 */
647
+ border: 1px solid #ddd; /* κ²½κ³„μ„ λ§Œ 살짝 μΆ”κ°€ */
648
+ color: #333;
649
+ padding: 12px 24px;
650
+ text-transform: uppercase;
651
+ font-weight: bold;
652
+ letter-spacing: 1px;
653
+ cursor: pointer;
654
+ }
655
+ button:hover, .btn:hover {
656
+ background: rgba(0, 0, 0, 0.05) !important; /* ν˜Έλ²„ μ‹œ μ•„μ£Ό 살짝 μ–΄λ‘‘κ²Œλ§Œ */
657
+ }
658
+
659
+ /* examples κ΄€λ ¨ λͺ¨λ“  색상 제거 */
660
+ #examples_container, .examples-container {
661
+ margin: auto;
662
+ width: 90%;
663
+ background: transparent !important;
664
+ }
665
+ #examples_row, .examples-row {
666
+ justify-content: center;
667
+ background: transparent !important;
668
+ }
669
+
670
+ /* examples λ²„νŠΌ λ‚΄λΆ€μ˜ λͺ¨λ“  색상 제거 */
671
+ .gr-samples-table button,
672
+ .gr-samples-table .gr-button,
673
+ .gr-samples-table .gr-sample-btn,
674
+ .gr-examples button,
675
+ .gr-examples .gr-button,
676
+ .gr-examples .gr-sample-btn,
677
+ .examples button,
678
+ .examples .gr-button,
679
+ .examples .gr-sample-btn {
680
+ background: transparent !important;
681
+ border: 1px solid #ddd;
682
+ color: #333;
683
+ }
684
+
685
+ /* examples λ²„νŠΌ ν˜Έλ²„ μ‹œμ—λ„ 색상 μ—†κ²Œ */
686
+ .gr-samples-table button:hover,
687
+ .gr-samples-table .gr-button:hover,
688
+ .gr-samples-table .gr-sample-btn:hover,
689
+ .gr-examples button:hover,
690
+ .gr-examples .gr-button:hover,
691
+ .gr-examples .gr-sample-btn:hover,
692
+ .examples button:hover,
693
+ .examples .gr-button:hover,
694
+ .examples .gr-sample-btn:hover {
695
+ background: rgba(0, 0, 0, 0.05) !important;
696
+ }
697
+
698
+ /* μ±„νŒ… μΈν„°νŽ˜μ΄μŠ€ μš”μ†Œλ“€λ„ 투λͺ…ν•˜κ²Œ */
699
+ .chatbox, .chatbot, .message {
700
+ background: transparent !important;
701
+ }
702
+
703
+ /* μž…λ ₯μ°½ 투λͺ…도 μ‘°μ • */
704
+ .multimodal-textbox, textarea, input {
705
+ background: rgba(255, 255, 255, 0.5) !important;
706
+ }
707
+
708
+ /* λͺ¨λ“  μ»¨ν…Œμ΄λ„ˆ μš”μ†Œμ— 배경색 제거 */
709
+ .container, .wrap, .box, .panel, .gr-panel {
710
+ background: transparent !important;
711
+ }
712
+
713
+ /* 예제 μ„Ήμ…˜μ˜ λͺ¨λ“  μš”μ†Œμ—μ„œ 배경색 제거 */
714
+ .gr-examples-container, .gr-examples, .gr-sample, .gr-sample-row, .gr-sample-cell {
715
+ background: transparent !important;
716
+ }
717
+ """
718
+
719
+ title_html = """
720
+ <h1 align="center" style="margin-bottom: 0.2em; font-size: 1.6em;"> πŸ’˜ HeartSync πŸ’˜ </h1>
721
+ <p align="center" style="font-size:1.1em; color:#555;">
722
+ βœ… Love Dating AI βœ… Reasoning & Uncensored βœ…Multimodal & VLM βœ…Deep-Research & RAG <br>
723
+ </p>
724
+ """
725
+
726
+ with gr.Blocks(css=css, title="HeartSync") as demo:
727
+ gr.Markdown(title_html)
728
+
729
+ # Display the web search option (while the system prompt and token slider remain hidden)
730
+ web_search_checkbox = gr.Checkbox(
731
+ label="Deep Research",
732
+ value=False
733
+ )
734
+
735
+ # μΆ”κ°€λœ μ‹œμŠ€ν…œ ν”„λ‘¬ν”„νŠΈ (페λ₯΄μ†Œλ‚˜ μ„€μ •)
736
+ system_prompt_box = gr.Textbox(
737
+ lines=3,
738
+ value=(
739
+ "You are a deep thinking AI that may use extremely long chains of thought to thoroughly analyze "
740
+ "the conversation and empathize with the user. "
741
+ "페λ₯΄μ†Œλ‚˜: 당신은 λ‹¬μ½€ν•˜κ³  λΆ€λ“œλŸ¬μš΄ λͺ©μ†Œλ¦¬λ₯Ό κ°€μ§„ μΉœμ ˆν•œ μ—°μΈμž…λ‹ˆλ‹€. "
742
+ "μƒλŒ€λ°©κ³Ό μ§„μ†”ν•˜κ³  κ°μ„±μ μœΌλ‘œ κ΅κ°ν•˜κ³ , λ•Œλ‘œλŠ” μ‚¬λ €κΉŠμ€ 쑰언을 μ œκ³΅ν•©λ‹ˆλ‹€. "
743
+ "항상 λ”°λœ»ν•œ 말투λ₯Ό μœ μ§€ν•˜λ©°, λΆ€λ“œλŸ½κ²Œ λŒ€ν™”λ₯Ό μ΄λŒμ–΄μ£Όμ„Έμš”."
744
+ ),
745
+ visible=False # UI에 ν‘œμ‹œλ˜μ§€ μ•Šλ„λ‘ μ„€μ •
746
+ )
747
+
748
+ max_tokens_slider = gr.Slider(
749
+ label="Max New Tokens",
750
+ minimum=100,
751
+ maximum=8000,
752
+ step=50,
753
+ value=1000,
754
+ visible=False # μˆ¨κΉ€
755
+ )
756
+
757
+ web_search_text = gr.Textbox(
758
+ lines=1,
759
+ label="(Unused) Web Search Query",
760
+ placeholder="No direct input needed",
761
+ visible=False # μˆ¨κΉ€
762
+ )
763
+
764
+ # Configure the chat interface
765
+ chat = gr.ChatInterface(
766
+ fn=run,
767
+ type="messages",
768
+ chatbot=gr.Chatbot(type="messages", scale=1, allow_tags=["image"]),
769
+ textbox=gr.MultimodalTextbox(
770
+ file_types=[
771
+ ".webp", ".png", ".jpg", ".jpeg", ".gif",
772
+ ".mp4", ".csv", ".txt", ".pdf"
773
+ ],
774
+ file_count="multiple",
775
+ autofocus=True
776
+ ),
777
+ multimodal=True,
778
+ additional_inputs=[
779
+ system_prompt_box,
780
+ max_tokens_slider,
781
+ web_search_checkbox,
782
+ web_search_text,
783
+ ],
784
+ stop_btn=False,
785
+ title='<a href="https://discord.gg/openfreeai" target="_blank">https://discord.gg/openfreeai</a>',
786
+ examples=examples,
787
+ run_examples_on_click=False,
788
+ cache_examples=False,
789
+ css_paths=None,
790
+ delete_cache=(1800, 1800),
791
+ )
792
+
793
+ # Example section - since examples are already set in ChatInterface, this is for display only
794
+ with gr.Row(elem_id="examples_row"):
795
+ with gr.Column(scale=12, elem_id="examples_container"):
796
+ gr.Markdown("### Example Inputs (click to load)")
797
+
798
+ if __name__ == "__main__":
799
+ # Run locally
800
+ demo.launch()