jinfengxie commited on
Commit
ca86773
·
verified ·
1 Parent(s): e221ef4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -2
app.py CHANGED
@@ -342,14 +342,84 @@ def plt_to_image():
342
  img = Image.open(buf)
343
  return img
344
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
 
 
346
  def process_image(image_path):
347
  #image = Image.open(image_path)
348
  one_mask,color_mask, counts_dict = predict_and_visualize(image_path)
349
  colors_per_class=ext_colors(image_path,one_mask,n_clusters=4)
350
  colors_per_label = {id2material[key]: value for key, value in colors_per_class.items()}
351
  # 定义一个列表,包含需要从字典中删除的键
352
- labels_to_remove = ['sky', 'background','Glass','tree','water','Plastic, clear']
353
  # 使用字典推导式删除列表中的键
354
  colors_per_label = {key: value for key, value in colors_per_label.items() if key not in labels_to_remove}
355
  palette_image = plot_material_color_palette_grid(colors_per_label)
@@ -369,7 +439,14 @@ def process_image(image_path):
369
  # 重新命名 DataFrame 为 percentage_df 以清楚表达其内容
370
  percentage_df = counts_df.rename(columns={'计数': '数量', '百分比': '占比 (%)'})
371
 
372
- return color_mask_img, palette_image, percentage_df
 
 
 
 
 
 
 
373
 
374
  iface = gr.Interface(
375
  fn=process_image,
@@ -377,6 +454,7 @@ iface = gr.Interface(
377
  outputs=[
378
  gr.Image(type="pil", label="Color Mask"),
379
  gr.Image(type="pil", label="Color Palette"),
 
380
  gr.DataFrame()
381
  ],
382
  title="Image Processing for Mask Visualization and Color Extraction",
 
342
  img = Image.open(buf)
343
  return img
344
 
345
+ def calculate_slice_statistics(one_mask, slice_size=128):
346
+ """计算每个切片的材质占比"""
347
+ num_rows, num_cols = one_mask.shape[0] // slice_size, one_mask.shape[1] // slice_size
348
+ slice_stats = {}
349
+
350
+ for i in range(num_rows):
351
+ for j in range(num_cols):
352
+ slice_mask = one_mask[i*slice_size:(i+1)*slice_size, j*slice_size:(j+1)*slice_size]
353
+ unique, counts = np.unique(slice_mask, return_counts=True)
354
+ total_pixels = counts.sum()
355
+ slice_stats[(i, j)] = {k: v / total_pixels for k, v in zip(unique, counts)}
356
+
357
+ return slice_stats
358
+
359
+ def find_top_slices(slice_stats, exclusion_list, min_percent=0.7, min_slices=1, top_k=3):
360
+ """找出每个类材质占比最高的前三个切片,加入新的筛选条件"""
361
+ from collections import defaultdict
362
+ import heapq
363
+
364
+ top_slices = defaultdict(list)
365
+
366
+ for slice_pos, stats in slice_stats.items():
367
+ for material_id, percent in stats.items():
368
+ # 第一个判断:材质是否在排除列表中
369
+ if material_id in exclusion_list:
370
+ continue
371
+ # 第二个判断:材质占比是否至少为70%
372
+ if percent < min_percent:
373
+ continue
374
+
375
+ # 将符合条件的切片添加到堆中
376
+ if len(top_slices[material_id]) < top_k:
377
+ heapq.heappush(top_slices[material_id], (percent, slice_pos))
378
+ else:
379
+ heapq.heappushpop(top_slices[material_id], (percent, slice_pos))
380
+
381
+ # 过滤出符合第三个条件的材质
382
+ valid_top_slices = {}
383
+ for material_id, slices in top_slices.items():
384
+ if len(slices) > min_slices: # 至少有超过一个切片
385
+ valid_top_slices[material_id] = sorted(slices, reverse=True)
386
+
387
+ return valid_top_slices
388
+
389
+ def extract_and_visualize_top_slices(image, top_slices, slice_size=128):
390
+ """从原始图像中提取并可视化顶部切片"""
391
+ import matplotlib.pyplot as plt
392
+
393
+ fig, axs = plt.subplots(nrows=len(top_slices), ncols=3, figsize=(15, 5 * len(top_slices)))
394
+
395
+ if len(top_slices) == 1:
396
+ axs = [axs]
397
+
398
+ for idx, (material_id, slices) in enumerate(top_slices.items()):
399
+ for col, (_, pos) in enumerate(slices):
400
+ i, j = pos
401
+ img_slice = image.crop((j * slice_size, i * slice_size, (j + 1) * slice_size, (i + 1) * slice_size))
402
+ axs[idx][col].imshow(img_slice)
403
+ axs[idx][col].set_title(f'Material {id2material[material_id]} - Slice {pos}')
404
+ axs[idx][col].axis('off')
405
+
406
+ plt.tight_layout()
407
+ # 保存到内存,而不是显示图像
408
+ buf = io.BytesIO()
409
+ plt.savefig(buf, format='png')
410
+ plt.close()
411
+ buf.seek(0)
412
+ img = Image.open(buf)
413
+ return img
414
 
415
+ # main program
416
  def process_image(image_path):
417
  #image = Image.open(image_path)
418
  one_mask,color_mask, counts_dict = predict_and_visualize(image_path)
419
  colors_per_class=ext_colors(image_path,one_mask,n_clusters=4)
420
  colors_per_label = {id2material[key]: value for key, value in colors_per_class.items()}
421
  # 定义一个列表,包含需要从字典中删除的键
422
+ labels_to_remove = ['Sky', 'background','Glass','tree','water','Plastic, clear']
423
  # 使用字典推导式删除列表中的键
424
  colors_per_label = {key: value for key, value in colors_per_label.items() if key not in labels_to_remove}
425
  palette_image = plot_material_color_palette_grid(colors_per_label)
 
439
  # 重新命名 DataFrame 为 percentage_df 以清楚表达其内容
440
  percentage_df = counts_df.rename(columns={'计数': '数量', '百分比': '占比 (%)'})
441
 
442
+ slice_size = 64
443
+ exclusion_list = [38]
444
+ slice_stats = calculate_slice_statistics(one_mask, slice_size=slice_size)
445
+ top_slices = find_top_slices(slice_stats, exclusion_list=exclusion_list, min_percent=0.5, min_slices=1)
446
+ slice_image=extract_and_visualize_top_slices(image_path, top_slices, slice_size=slice_size)
447
+
448
+ return color_mask_img, palette_image, percentage_df,slice_image
449
+
450
 
451
  iface = gr.Interface(
452
  fn=process_image,
 
454
  outputs=[
455
  gr.Image(type="pil", label="Color Mask"),
456
  gr.Image(type="pil", label="Color Palette"),
457
+ gr.Image(type='pil', label='Texture Slices')
458
  gr.DataFrame()
459
  ],
460
  title="Image Processing for Mask Visualization and Color Extraction",