Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -342,14 +342,84 @@ def plt_to_image():
|
|
342 |
img = Image.open(buf)
|
343 |
return img
|
344 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
345 |
|
|
|
346 |
def process_image(image_path):
|
347 |
#image = Image.open(image_path)
|
348 |
one_mask,color_mask, counts_dict = predict_and_visualize(image_path)
|
349 |
colors_per_class=ext_colors(image_path,one_mask,n_clusters=4)
|
350 |
colors_per_label = {id2material[key]: value for key, value in colors_per_class.items()}
|
351 |
# 定义一个列表,包含需要从字典中删除的键
|
352 |
-
labels_to_remove = ['
|
353 |
# 使用字典推导式删除列表中的键
|
354 |
colors_per_label = {key: value for key, value in colors_per_label.items() if key not in labels_to_remove}
|
355 |
palette_image = plot_material_color_palette_grid(colors_per_label)
|
@@ -369,7 +439,14 @@ def process_image(image_path):
|
|
369 |
# 重新命名 DataFrame 为 percentage_df 以清楚表达其内容
|
370 |
percentage_df = counts_df.rename(columns={'计数': '数量', '百分比': '占比 (%)'})
|
371 |
|
372 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
373 |
|
374 |
iface = gr.Interface(
|
375 |
fn=process_image,
|
@@ -377,6 +454,7 @@ iface = gr.Interface(
|
|
377 |
outputs=[
|
378 |
gr.Image(type="pil", label="Color Mask"),
|
379 |
gr.Image(type="pil", label="Color Palette"),
|
|
|
380 |
gr.DataFrame()
|
381 |
],
|
382 |
title="Image Processing for Mask Visualization and Color Extraction",
|
|
|
342 |
img = Image.open(buf)
|
343 |
return img
|
344 |
|
345 |
+
def calculate_slice_statistics(one_mask, slice_size=128):
|
346 |
+
"""计算每个切片的材质占比"""
|
347 |
+
num_rows, num_cols = one_mask.shape[0] // slice_size, one_mask.shape[1] // slice_size
|
348 |
+
slice_stats = {}
|
349 |
+
|
350 |
+
for i in range(num_rows):
|
351 |
+
for j in range(num_cols):
|
352 |
+
slice_mask = one_mask[i*slice_size:(i+1)*slice_size, j*slice_size:(j+1)*slice_size]
|
353 |
+
unique, counts = np.unique(slice_mask, return_counts=True)
|
354 |
+
total_pixels = counts.sum()
|
355 |
+
slice_stats[(i, j)] = {k: v / total_pixels for k, v in zip(unique, counts)}
|
356 |
+
|
357 |
+
return slice_stats
|
358 |
+
|
359 |
+
def find_top_slices(slice_stats, exclusion_list, min_percent=0.7, min_slices=1, top_k=3):
|
360 |
+
"""找出每个类材质占比最高的前三个切片,加入新的筛选条件"""
|
361 |
+
from collections import defaultdict
|
362 |
+
import heapq
|
363 |
+
|
364 |
+
top_slices = defaultdict(list)
|
365 |
+
|
366 |
+
for slice_pos, stats in slice_stats.items():
|
367 |
+
for material_id, percent in stats.items():
|
368 |
+
# 第一个判断:材质是否在排除列表中
|
369 |
+
if material_id in exclusion_list:
|
370 |
+
continue
|
371 |
+
# 第二个判断:材质占比是否至少为70%
|
372 |
+
if percent < min_percent:
|
373 |
+
continue
|
374 |
+
|
375 |
+
# 将符合条件的切片添加到堆中
|
376 |
+
if len(top_slices[material_id]) < top_k:
|
377 |
+
heapq.heappush(top_slices[material_id], (percent, slice_pos))
|
378 |
+
else:
|
379 |
+
heapq.heappushpop(top_slices[material_id], (percent, slice_pos))
|
380 |
+
|
381 |
+
# 过滤出符合第三个条件的材质
|
382 |
+
valid_top_slices = {}
|
383 |
+
for material_id, slices in top_slices.items():
|
384 |
+
if len(slices) > min_slices: # 至少有超过一个切片
|
385 |
+
valid_top_slices[material_id] = sorted(slices, reverse=True)
|
386 |
+
|
387 |
+
return valid_top_slices
|
388 |
+
|
389 |
+
def extract_and_visualize_top_slices(image, top_slices, slice_size=128):
|
390 |
+
"""从原始图像中提取并可视化顶部切片"""
|
391 |
+
import matplotlib.pyplot as plt
|
392 |
+
|
393 |
+
fig, axs = plt.subplots(nrows=len(top_slices), ncols=3, figsize=(15, 5 * len(top_slices)))
|
394 |
+
|
395 |
+
if len(top_slices) == 1:
|
396 |
+
axs = [axs]
|
397 |
+
|
398 |
+
for idx, (material_id, slices) in enumerate(top_slices.items()):
|
399 |
+
for col, (_, pos) in enumerate(slices):
|
400 |
+
i, j = pos
|
401 |
+
img_slice = image.crop((j * slice_size, i * slice_size, (j + 1) * slice_size, (i + 1) * slice_size))
|
402 |
+
axs[idx][col].imshow(img_slice)
|
403 |
+
axs[idx][col].set_title(f'Material {id2material[material_id]} - Slice {pos}')
|
404 |
+
axs[idx][col].axis('off')
|
405 |
+
|
406 |
+
plt.tight_layout()
|
407 |
+
# 保存到内存,而不是显示图像
|
408 |
+
buf = io.BytesIO()
|
409 |
+
plt.savefig(buf, format='png')
|
410 |
+
plt.close()
|
411 |
+
buf.seek(0)
|
412 |
+
img = Image.open(buf)
|
413 |
+
return img
|
414 |
|
415 |
+
# main program
|
416 |
def process_image(image_path):
|
417 |
#image = Image.open(image_path)
|
418 |
one_mask,color_mask, counts_dict = predict_and_visualize(image_path)
|
419 |
colors_per_class=ext_colors(image_path,one_mask,n_clusters=4)
|
420 |
colors_per_label = {id2material[key]: value for key, value in colors_per_class.items()}
|
421 |
# 定义一个列表,包含需要从字典中删除的键
|
422 |
+
labels_to_remove = ['Sky', 'background','Glass','tree','water','Plastic, clear']
|
423 |
# 使用字典推导式删除列表中的键
|
424 |
colors_per_label = {key: value for key, value in colors_per_label.items() if key not in labels_to_remove}
|
425 |
palette_image = plot_material_color_palette_grid(colors_per_label)
|
|
|
439 |
# 重新命名 DataFrame 为 percentage_df 以清楚表达其内容
|
440 |
percentage_df = counts_df.rename(columns={'计数': '数量', '百分比': '占比 (%)'})
|
441 |
|
442 |
+
slice_size = 64
|
443 |
+
exclusion_list = [38]
|
444 |
+
slice_stats = calculate_slice_statistics(one_mask, slice_size=slice_size)
|
445 |
+
top_slices = find_top_slices(slice_stats, exclusion_list=exclusion_list, min_percent=0.5, min_slices=1)
|
446 |
+
slice_image=extract_and_visualize_top_slices(image_path, top_slices, slice_size=slice_size)
|
447 |
+
|
448 |
+
return color_mask_img, palette_image, percentage_df,slice_image
|
449 |
+
|
450 |
|
451 |
iface = gr.Interface(
|
452 |
fn=process_image,
|
|
|
454 |
outputs=[
|
455 |
gr.Image(type="pil", label="Color Mask"),
|
456 |
gr.Image(type="pil", label="Color Palette"),
|
457 |
+
gr.Image(type='pil', label='Texture Slices')
|
458 |
gr.DataFrame()
|
459 |
],
|
460 |
title="Image Processing for Mask Visualization and Color Extraction",
|