Spaces:
Running
Running
import gradio as gr | |
import logging | |
from app.services import sticker_service | |
from app.gradio_formatter import gradio_formatter | |
import zipfile | |
import threading | |
from queue import Queue | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s %(levelname)s %(message)s' | |
) | |
logger = logging.getLogger(__name__) | |
class StickerUI: | |
"""Main class for the Sticker Search UI application.""" | |
def __init__(self): | |
self.demo = None | |
self._initialize_components() | |
def _initialize_components(self): | |
"""Initialize all UI components.""" | |
# Search Tab | |
self.search_input = gr.Textbox(label="Search keywords") | |
self.limit_input = gr.Slider(1, 10, value=4, step=1, label="Max results") | |
self.search_button = gr.Button("Search") | |
self.deepseek_button = gr.Button("AI Search") | |
self.search_results = gr.Dataframe( | |
headers=["Preview", "Relevance", "Description", "Filename"], | |
datatype=["markdown", "number", "str", "str"], | |
row_count=("dynamic") | |
) | |
self.ai_search_results = gr.Dataframe( | |
headers=["Preview", "AI-Score", "AI-Reason", "Description", "Filename"], | |
datatype=["markdown", "number", "str", "str", "str"], | |
row_count=("dynamic") | |
) | |
# Upload Tab | |
self.image_input = gr.Image(type="filepath") | |
self.title_input = gr.Textbox(label="Title") | |
self.tags_input = gr.Textbox(label="Tags (comma separated)") | |
self.desc_input = gr.Textbox(label="Description") | |
self.upload_button = gr.Button("Upload") | |
self.upload_output = gr.Textbox(label="Status") | |
# All Stickers Tab | |
self.refresh_button = gr.Button("Refresh") | |
self.stickers_table = gr.Dataframe( | |
headers=["ID", "Preview", "Title", "Description", "Tags", "Filename", "Hash"], | |
datatype=["str", "markdown", "str", "str", "str", "str", "str"] | |
) | |
# Import Dataset Tab | |
self.dataset_input = gr.File(label="Upload Dataset (ZIP)", file_types=[".zip"]) | |
self.import_button = gr.Button("Import Dataset") | |
self.import_progress = gr.Slider( | |
minimum=0, | |
maximum=100, | |
value=0, | |
label="Import Progress", | |
interactive=False | |
) | |
self.import_output = gr.Textbox(label="Import Status", lines=10) | |
def _setup_search_tab(self): | |
"""Configure the search tab layout and functionality.""" | |
with gr.Tab("Search"): | |
with gr.Row(): | |
self.search_input.render() | |
self.limit_input.render() | |
with gr.Row(): | |
self.search_button.render() | |
self.deepseek_button.render() | |
self.search_results.render() | |
self.ai_search_results.render() | |
# Event handlers | |
self.search_button.click( | |
fn=lambda q, l: gradio_formatter.format_search_results( | |
sticker_service.search_stickers(q, l) | |
), | |
inputs=[self.search_input, self.limit_input], | |
outputs=self.search_results | |
) | |
self.deepseek_button.click( | |
fn=lambda q, l: gradio_formatter.format_ai_search_results( | |
sticker_service.search_stickers(q, l, reranking=True) | |
), | |
inputs=[self.search_input, self.limit_input], | |
outputs=self.ai_search_results | |
) | |
def _setup_upload_tab(self): | |
"""Configure the upload tab layout and functionality.""" | |
with gr.Tab("Upload"): | |
with gr.Row(): | |
self.image_input.render() | |
with gr.Row(): | |
self.title_input.render() | |
self.tags_input.render() | |
with gr.Row(): | |
self.desc_input.render() | |
self.upload_button.render() | |
self.upload_output.render() | |
self.upload_button.click( | |
fn=sticker_service.upload_sticker, | |
inputs=[ | |
self.image_input, | |
self.title_input, | |
self.desc_input, | |
self.tags_input | |
], | |
outputs=self.upload_output | |
) | |
def _setup_all_stickers_tab(self): | |
"""Configure the all stickers tab layout and functionality.""" | |
with gr.Tab("All Stickers"): | |
self.refresh_button.render() | |
self.stickers_table.render() | |
self.refresh_button.click( | |
fn=sticker_service.get_all_stickers, | |
outputs=self.stickers_table | |
) | |
def _setup_import_tab(self): | |
"""Configure the import dataset tab layout and functionality.""" | |
with gr.Tab("Import Dataset"): | |
with gr.Row(): | |
self.dataset_input.render() | |
with gr.Row(): | |
self.upload_checkbox = gr.Checkbox(label="Upload to HuggingFace", value=False) | |
self.save_to_milvus_checkbox = gr.Checkbox(label="Save to Milvus", value=False) | |
with gr.Row(): | |
self.import_button.render() | |
self.import_output = gr.Textbox(label="Import Status", lines=10) | |
self.import_button.click( | |
fn=self._import_stickers_with_progress, | |
inputs=[ | |
self.dataset_input, | |
self.upload_checkbox, | |
self.save_to_milvus_checkbox | |
], | |
outputs=self.import_output | |
) | |
def _import_stickers_with_progress(self, dataset_path, upload, save_to_milvus, progress=gr.Progress()): | |
"""Import stickers with progress tracking.""" | |
try: | |
# Count total files first | |
total_files = 0 | |
with zipfile.ZipFile(dataset_path, 'r') as zip_ref: | |
for file in zip_ref.namelist(): | |
if file.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.webp')): | |
total_files += 1 | |
if total_files == 0: | |
return "No image files found in the dataset" | |
# Track progress | |
processed_files = 0 | |
results = [] | |
def update_progress(file, status): | |
nonlocal processed_files | |
processed_files += 1 | |
progress_value = (processed_files / total_files) | |
progress(progress_value, desc=f"Processing {file} - {status}") | |
return f"Processing {file} - {status}" | |
# Import stickers | |
results = sticker_service.import_stickers( | |
dataset_path, | |
upload=upload, | |
save_to_milvus=save_to_milvus, | |
progress_callback=update_progress, | |
) | |
return "\n".join(results) | |
except Exception as e: | |
logger.error(f"Import failed: {str(e)}") | |
return f"Import failed: {str(e)}" | |
def _update_progress(self, progress_queue): | |
"""Update progress from queue.""" | |
try: | |
while not progress_queue.empty(): | |
message, progress = progress_queue.get_nowait() | |
yield message, progress | |
except: | |
pass | |
def create_ui(self): | |
"""Create and configure the complete UI.""" | |
with gr.Blocks(title="Neko Sticker Search π") as self.demo: | |
self._setup_search_tab() | |
self._setup_upload_tab() | |
self._setup_all_stickers_tab() | |
self._setup_import_tab() | |
return self.demo | |