zhangfeng144 commited on
Commit
c4e7690
·
1 Parent(s): 66e44c7

update batch import

Browse files
Files changed (3) hide show
  1. app/services.py +17 -5
  2. app/ui.py +5 -3
  3. main.py +2 -2
app/services.py CHANGED
@@ -81,6 +81,7 @@ class StickerService:
81
  def import_stickers(
82
  sticker_dataset: str,
83
  upload: bool = False,
 
84
  progress_callback: callable = None,
85
  ) -> List[str]:
86
  """导入表情包数据集
@@ -98,7 +99,7 @@ class StickerService:
98
  cache_folder = os.path.join(TEMP_DIR, 'cache/')
99
  img_folder = os.path.join(TEMP_DIR, 'images/')
100
  data_json_path = os.path.join(cache_folder, 'data.json')
101
-
102
  logger.info(f"start import dataset")
103
  # 解压数据集
104
  with zipfile.ZipFile(sticker_dataset, 'r') as zip_ref:
@@ -142,11 +143,16 @@ class StickerService:
142
  image_filename = f"image_{random.randint(100000, 999999)}.png"
143
  file_path = f"images/{image_filename}"
144
  generate_temp_image(img_folder, image, image_filename)
145
-
146
- db.store_sticker("", description, "", file_path, image_hash)
147
- results.append(f"成功导入: {image_filename}")
 
 
 
 
148
 
149
  if progress_callback:
 
150
  progress_callback(file, "Imported")
151
 
152
  except Exception as e:
@@ -154,9 +160,15 @@ class StickerService:
154
  results.append(f"处理失败 {file}: {str(e)}")
155
  if progress_callback:
156
  progress_callback(file, f"Failed: {str(e)}")
 
 
 
 
 
157
 
158
  # 上传到 HuggingFace
159
- if upload:
 
160
  upload_folder_to_huggingface(img_folder)
161
 
162
  return results
 
81
  def import_stickers(
82
  sticker_dataset: str,
83
  upload: bool = False,
84
+ save_to_milvus: bool = False,
85
  progress_callback: callable = None,
86
  ) -> List[str]:
87
  """导入表情包数据集
 
99
  cache_folder = os.path.join(TEMP_DIR, 'cache/')
100
  img_folder = os.path.join(TEMP_DIR, 'images/')
101
  data_json_path = os.path.join(cache_folder, 'data.json')
102
+ stickers = []
103
  logger.info(f"start import dataset")
104
  # 解压数据集
105
  with zipfile.ZipFile(sticker_dataset, 'r') as zip_ref:
 
143
  image_filename = f"image_{random.randint(100000, 999999)}.png"
144
  file_path = f"images/{image_filename}"
145
  generate_temp_image(img_folder, image, image_filename)
146
+ stickers.append({
147
+ "title": "",
148
+ "description": description,
149
+ "tags": "",
150
+ "file_path": file_path,
151
+ "image_hash": image_hash
152
+ })
153
 
154
  if progress_callback:
155
+ results.append(f"成功导入: {image_filename}")
156
  progress_callback(file, "Imported")
157
 
158
  except Exception as e:
 
160
  results.append(f"处理失败 {file}: {str(e)}")
161
  if progress_callback:
162
  progress_callback(file, f"Failed: {str(e)}")
163
+
164
+
165
+ if save_to_milvus and len(stickers) > 0:
166
+ logger.info(f"save to milvus, {len(stickers)} stickers")
167
+ db.batch_store_stickers(stickers)
168
 
169
  # 上传到 HuggingFace
170
+ if upload and len(stickers) > 0:
171
+ logger.info(f"upload to huggingface, {len(stickers)} stickers")
172
  upload_folder_to_huggingface(img_folder)
173
 
174
  return results
app/ui.py CHANGED
@@ -144,6 +144,7 @@ class StickerUI:
144
 
145
  with gr.Row():
146
  self.upload_checkbox = gr.Checkbox(label="Upload to HuggingFace", value=False)
 
147
 
148
  with gr.Row():
149
  self.import_button.render()
@@ -154,12 +155,13 @@ class StickerUI:
154
  fn=self._import_stickers_with_progress,
155
  inputs=[
156
  self.dataset_input,
157
- self.upload_checkbox
 
158
  ],
159
  outputs=self.import_output
160
  )
161
 
162
- def _import_stickers_with_progress(self, dataset_path, upload, progress=gr.Progress()):
163
  """Import stickers with progress tracking."""
164
  try:
165
  # Count total files first
@@ -187,8 +189,8 @@ class StickerUI:
187
  results = sticker_service.import_stickers(
188
  dataset_path,
189
  upload=upload,
 
190
  progress_callback=update_progress,
191
- total_files=total_files
192
  )
193
 
194
  return "\n".join(results)
 
144
 
145
  with gr.Row():
146
  self.upload_checkbox = gr.Checkbox(label="Upload to HuggingFace", value=False)
147
+ self.save_to_milvus_checkbox = gr.Checkbox(label="Save to Milvus", value=False)
148
 
149
  with gr.Row():
150
  self.import_button.render()
 
155
  fn=self._import_stickers_with_progress,
156
  inputs=[
157
  self.dataset_input,
158
+ self.upload_checkbox,
159
+ self.save_to_milvus_checkbox
160
  ],
161
  outputs=self.import_output
162
  )
163
 
164
+ def _import_stickers_with_progress(self, dataset_path, upload, save_to_milvus, progress=gr.Progress()):
165
  """Import stickers with progress tracking."""
166
  try:
167
  # Count total files first
 
189
  results = sticker_service.import_stickers(
190
  dataset_path,
191
  upload=upload,
192
+ save_to_milvus=save_to_milvus,
193
  progress_callback=update_progress,
 
194
  )
195
 
196
  return "\n".join(results)
main.py CHANGED
@@ -72,7 +72,7 @@ async def api_delete_stickers(request: dict):
72
 
73
 
74
  @app.post("/api/import_dataset")
75
- async def api_import_dataset(file: UploadFile = File(...), upload: bool = False):
76
  """Import sticker dataset from ZIP file"""
77
  try:
78
  # Create a temporary file to store the uploaded ZIP
@@ -82,7 +82,7 @@ async def api_import_dataset(file: UploadFile = File(...), upload: bool = False)
82
  temp_file_path = temp_file.name
83
 
84
  # Import the dataset
85
- results = sticker_service.import_stickers(temp_file_path, upload)
86
 
87
  # Clean up the temporary file
88
  os.unlink(temp_file_path)
 
72
 
73
 
74
  @app.post("/api/import_dataset")
75
+ async def api_import_dataset(file: UploadFile = File(...), upload: bool = False, save_to_milvus: bool = False):
76
  """Import sticker dataset from ZIP file"""
77
  try:
78
  # Create a temporary file to store the uploaded ZIP
 
82
  temp_file_path = temp_file.name
83
 
84
  # Import the dataset
85
+ results = sticker_service.import_stickers(temp_file_path, upload, save_to_milvus)
86
 
87
  # Clean up the temporary file
88
  os.unlink(temp_file_path)