Spaces:
Sleeping
Sleeping
zhangfeng144
commited on
Commit
·
c4e7690
1
Parent(s):
66e44c7
update batch import
Browse files- app/services.py +17 -5
- app/ui.py +5 -3
- main.py +2 -2
app/services.py
CHANGED
@@ -81,6 +81,7 @@ class StickerService:
|
|
81 |
def import_stickers(
|
82 |
sticker_dataset: str,
|
83 |
upload: bool = False,
|
|
|
84 |
progress_callback: callable = None,
|
85 |
) -> List[str]:
|
86 |
"""导入表情包数据集
|
@@ -98,7 +99,7 @@ class StickerService:
|
|
98 |
cache_folder = os.path.join(TEMP_DIR, 'cache/')
|
99 |
img_folder = os.path.join(TEMP_DIR, 'images/')
|
100 |
data_json_path = os.path.join(cache_folder, 'data.json')
|
101 |
-
|
102 |
logger.info(f"start import dataset")
|
103 |
# 解压数据集
|
104 |
with zipfile.ZipFile(sticker_dataset, 'r') as zip_ref:
|
@@ -142,11 +143,16 @@ class StickerService:
|
|
142 |
image_filename = f"image_{random.randint(100000, 999999)}.png"
|
143 |
file_path = f"images/{image_filename}"
|
144 |
generate_temp_image(img_folder, image, image_filename)
|
145 |
-
|
146 |
-
|
147 |
-
|
|
|
|
|
|
|
|
|
148 |
|
149 |
if progress_callback:
|
|
|
150 |
progress_callback(file, "Imported")
|
151 |
|
152 |
except Exception as e:
|
@@ -154,9 +160,15 @@ class StickerService:
|
|
154 |
results.append(f"处理失败 {file}: {str(e)}")
|
155 |
if progress_callback:
|
156 |
progress_callback(file, f"Failed: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
157 |
|
158 |
# 上传到 HuggingFace
|
159 |
-
if upload:
|
|
|
160 |
upload_folder_to_huggingface(img_folder)
|
161 |
|
162 |
return results
|
|
|
81 |
def import_stickers(
|
82 |
sticker_dataset: str,
|
83 |
upload: bool = False,
|
84 |
+
save_to_milvus: bool = False,
|
85 |
progress_callback: callable = None,
|
86 |
) -> List[str]:
|
87 |
"""导入表情包数据集
|
|
|
99 |
cache_folder = os.path.join(TEMP_DIR, 'cache/')
|
100 |
img_folder = os.path.join(TEMP_DIR, 'images/')
|
101 |
data_json_path = os.path.join(cache_folder, 'data.json')
|
102 |
+
stickers = []
|
103 |
logger.info(f"start import dataset")
|
104 |
# 解压数据集
|
105 |
with zipfile.ZipFile(sticker_dataset, 'r') as zip_ref:
|
|
|
143 |
image_filename = f"image_{random.randint(100000, 999999)}.png"
|
144 |
file_path = f"images/{image_filename}"
|
145 |
generate_temp_image(img_folder, image, image_filename)
|
146 |
+
stickers.append({
|
147 |
+
"title": "",
|
148 |
+
"description": description,
|
149 |
+
"tags": "",
|
150 |
+
"file_path": file_path,
|
151 |
+
"image_hash": image_hash
|
152 |
+
})
|
153 |
|
154 |
if progress_callback:
|
155 |
+
results.append(f"成功导入: {image_filename}")
|
156 |
progress_callback(file, "Imported")
|
157 |
|
158 |
except Exception as e:
|
|
|
160 |
results.append(f"处理失败 {file}: {str(e)}")
|
161 |
if progress_callback:
|
162 |
progress_callback(file, f"Failed: {str(e)}")
|
163 |
+
|
164 |
+
|
165 |
+
if save_to_milvus and len(stickers) > 0:
|
166 |
+
logger.info(f"save to milvus, {len(stickers)} stickers")
|
167 |
+
db.batch_store_stickers(stickers)
|
168 |
|
169 |
# 上传到 HuggingFace
|
170 |
+
if upload and len(stickers) > 0:
|
171 |
+
logger.info(f"upload to huggingface, {len(stickers)} stickers")
|
172 |
upload_folder_to_huggingface(img_folder)
|
173 |
|
174 |
return results
|
app/ui.py
CHANGED
@@ -144,6 +144,7 @@ class StickerUI:
|
|
144 |
|
145 |
with gr.Row():
|
146 |
self.upload_checkbox = gr.Checkbox(label="Upload to HuggingFace", value=False)
|
|
|
147 |
|
148 |
with gr.Row():
|
149 |
self.import_button.render()
|
@@ -154,12 +155,13 @@ class StickerUI:
|
|
154 |
fn=self._import_stickers_with_progress,
|
155 |
inputs=[
|
156 |
self.dataset_input,
|
157 |
-
self.upload_checkbox
|
|
|
158 |
],
|
159 |
outputs=self.import_output
|
160 |
)
|
161 |
|
162 |
-
def _import_stickers_with_progress(self, dataset_path, upload, progress=gr.Progress()):
|
163 |
"""Import stickers with progress tracking."""
|
164 |
try:
|
165 |
# Count total files first
|
@@ -187,8 +189,8 @@ class StickerUI:
|
|
187 |
results = sticker_service.import_stickers(
|
188 |
dataset_path,
|
189 |
upload=upload,
|
|
|
190 |
progress_callback=update_progress,
|
191 |
-
total_files=total_files
|
192 |
)
|
193 |
|
194 |
return "\n".join(results)
|
|
|
144 |
|
145 |
with gr.Row():
|
146 |
self.upload_checkbox = gr.Checkbox(label="Upload to HuggingFace", value=False)
|
147 |
+
self.save_to_milvus_checkbox = gr.Checkbox(label="Save to Milvus", value=False)
|
148 |
|
149 |
with gr.Row():
|
150 |
self.import_button.render()
|
|
|
155 |
fn=self._import_stickers_with_progress,
|
156 |
inputs=[
|
157 |
self.dataset_input,
|
158 |
+
self.upload_checkbox,
|
159 |
+
self.save_to_milvus_checkbox
|
160 |
],
|
161 |
outputs=self.import_output
|
162 |
)
|
163 |
|
164 |
+
def _import_stickers_with_progress(self, dataset_path, upload, save_to_milvus, progress=gr.Progress()):
|
165 |
"""Import stickers with progress tracking."""
|
166 |
try:
|
167 |
# Count total files first
|
|
|
189 |
results = sticker_service.import_stickers(
|
190 |
dataset_path,
|
191 |
upload=upload,
|
192 |
+
save_to_milvus=save_to_milvus,
|
193 |
progress_callback=update_progress,
|
|
|
194 |
)
|
195 |
|
196 |
return "\n".join(results)
|
main.py
CHANGED
@@ -72,7 +72,7 @@ async def api_delete_stickers(request: dict):
|
|
72 |
|
73 |
|
74 |
@app.post("/api/import_dataset")
|
75 |
-
async def api_import_dataset(file: UploadFile = File(...), upload: bool = False):
|
76 |
"""Import sticker dataset from ZIP file"""
|
77 |
try:
|
78 |
# Create a temporary file to store the uploaded ZIP
|
@@ -82,7 +82,7 @@ async def api_import_dataset(file: UploadFile = File(...), upload: bool = False)
|
|
82 |
temp_file_path = temp_file.name
|
83 |
|
84 |
# Import the dataset
|
85 |
-
results = sticker_service.import_stickers(temp_file_path, upload)
|
86 |
|
87 |
# Clean up the temporary file
|
88 |
os.unlink(temp_file_path)
|
|
|
72 |
|
73 |
|
74 |
@app.post("/api/import_dataset")
|
75 |
+
async def api_import_dataset(file: UploadFile = File(...), upload: bool = False, save_to_milvus: bool = False):
|
76 |
"""Import sticker dataset from ZIP file"""
|
77 |
try:
|
78 |
# Create a temporary file to store the uploaded ZIP
|
|
|
82 |
temp_file_path = temp_file.name
|
83 |
|
84 |
# Import the dataset
|
85 |
+
results = sticker_service.import_stickers(temp_file_path, upload, save_to_milvus)
|
86 |
|
87 |
# Clean up the temporary file
|
88 |
os.unlink(temp_file_path)
|