zhangfeng144
update batch import
c4e7690
raw
history blame contribute delete
7.96 kB
import gradio as gr
import logging
from app.services import sticker_service
from app.gradio_formatter import gradio_formatter
import zipfile
import threading
from queue import Queue
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s %(levelname)s %(message)s'
)
logger = logging.getLogger(__name__)
class StickerUI:
"""Main class for the Sticker Search UI application."""
def __init__(self):
self.demo = None
self._initialize_components()
def _initialize_components(self):
"""Initialize all UI components."""
# Search Tab
self.search_input = gr.Textbox(label="Search keywords")
self.limit_input = gr.Slider(1, 10, value=4, step=1, label="Max results")
self.search_button = gr.Button("Search")
self.deepseek_button = gr.Button("AI Search")
self.search_results = gr.Dataframe(
headers=["Preview", "Relevance", "Description", "Filename"],
datatype=["markdown", "number", "str", "str"],
row_count=("dynamic")
)
self.ai_search_results = gr.Dataframe(
headers=["Preview", "AI-Score", "AI-Reason", "Description", "Filename"],
datatype=["markdown", "number", "str", "str", "str"],
row_count=("dynamic")
)
# Upload Tab
self.image_input = gr.Image(type="filepath")
self.title_input = gr.Textbox(label="Title")
self.tags_input = gr.Textbox(label="Tags (comma separated)")
self.desc_input = gr.Textbox(label="Description")
self.upload_button = gr.Button("Upload")
self.upload_output = gr.Textbox(label="Status")
# All Stickers Tab
self.refresh_button = gr.Button("Refresh")
self.stickers_table = gr.Dataframe(
headers=["ID", "Preview", "Title", "Description", "Tags", "Filename", "Hash"],
datatype=["str", "markdown", "str", "str", "str", "str", "str"]
)
# Import Dataset Tab
self.dataset_input = gr.File(label="Upload Dataset (ZIP)", file_types=[".zip"])
self.import_button = gr.Button("Import Dataset")
self.import_progress = gr.Slider(
minimum=0,
maximum=100,
value=0,
label="Import Progress",
interactive=False
)
self.import_output = gr.Textbox(label="Import Status", lines=10)
def _setup_search_tab(self):
"""Configure the search tab layout and functionality."""
with gr.Tab("Search"):
with gr.Row():
self.search_input.render()
self.limit_input.render()
with gr.Row():
self.search_button.render()
self.deepseek_button.render()
self.search_results.render()
self.ai_search_results.render()
# Event handlers
self.search_button.click(
fn=lambda q, l: gradio_formatter.format_search_results(
sticker_service.search_stickers(q, l)
),
inputs=[self.search_input, self.limit_input],
outputs=self.search_results
)
self.deepseek_button.click(
fn=lambda q, l: gradio_formatter.format_ai_search_results(
sticker_service.search_stickers(q, l, reranking=True)
),
inputs=[self.search_input, self.limit_input],
outputs=self.ai_search_results
)
def _setup_upload_tab(self):
"""Configure the upload tab layout and functionality."""
with gr.Tab("Upload"):
with gr.Row():
self.image_input.render()
with gr.Row():
self.title_input.render()
self.tags_input.render()
with gr.Row():
self.desc_input.render()
self.upload_button.render()
self.upload_output.render()
self.upload_button.click(
fn=sticker_service.upload_sticker,
inputs=[
self.image_input,
self.title_input,
self.desc_input,
self.tags_input
],
outputs=self.upload_output
)
def _setup_all_stickers_tab(self):
"""Configure the all stickers tab layout and functionality."""
with gr.Tab("All Stickers"):
self.refresh_button.render()
self.stickers_table.render()
self.refresh_button.click(
fn=sticker_service.get_all_stickers,
outputs=self.stickers_table
)
def _setup_import_tab(self):
"""Configure the import dataset tab layout and functionality."""
with gr.Tab("Import Dataset"):
with gr.Row():
self.dataset_input.render()
with gr.Row():
self.upload_checkbox = gr.Checkbox(label="Upload to HuggingFace", value=False)
self.save_to_milvus_checkbox = gr.Checkbox(label="Save to Milvus", value=False)
with gr.Row():
self.import_button.render()
self.import_output = gr.Textbox(label="Import Status", lines=10)
self.import_button.click(
fn=self._import_stickers_with_progress,
inputs=[
self.dataset_input,
self.upload_checkbox,
self.save_to_milvus_checkbox
],
outputs=self.import_output
)
def _import_stickers_with_progress(self, dataset_path, upload, save_to_milvus, progress=gr.Progress()):
"""Import stickers with progress tracking."""
try:
# Count total files first
total_files = 0
with zipfile.ZipFile(dataset_path, 'r') as zip_ref:
for file in zip_ref.namelist():
if file.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.webp')):
total_files += 1
if total_files == 0:
return "No image files found in the dataset"
# Track progress
processed_files = 0
results = []
def update_progress(file, status):
nonlocal processed_files
processed_files += 1
progress_value = (processed_files / total_files)
progress(progress_value, desc=f"Processing {file} - {status}")
return f"Processing {file} - {status}"
# Import stickers
results = sticker_service.import_stickers(
dataset_path,
upload=upload,
save_to_milvus=save_to_milvus,
progress_callback=update_progress,
)
return "\n".join(results)
except Exception as e:
logger.error(f"Import failed: {str(e)}")
return f"Import failed: {str(e)}"
def _update_progress(self, progress_queue):
"""Update progress from queue."""
try:
while not progress_queue.empty():
message, progress = progress_queue.get_nowait()
yield message, progress
except:
pass
def create_ui(self):
"""Create and configure the complete UI."""
with gr.Blocks(title="Neko Sticker Search πŸ”") as self.demo:
self._setup_search_tab()
self._setup_upload_tab()
self._setup_all_stickers_tab()
self._setup_import_tab()
return self.demo