Spaces:
Running
Running
from PIL import Image, ImageDraw, ImageFont | |
from skimage.measure import label, regionprops | |
import gradio as gr | |
import tensorflow as tf | |
import numpy as np | |
from PIL import Image | |
from tensorflow.keras.preprocessing.image import array_to_img | |
import json | |
import os | |
from transformers import AutoModel | |
from transformers import TFSegformerForSemanticSegmentation | |
import matplotlib.pyplot as plt | |
import matplotlib | |
import matplotlib.font_manager as fm | |
from sklearn.cluster import KMeans | |
from skimage import color | |
import io | |
import pandas as pd | |
# Set the font to support Chinese characters | |
#font_path = 'simhei.ttf' | |
#font_prop = fm.FontProperties(fname=font_path) | |
#matplotlib.rcParams['font.family'] = font_prop.get_name() | |
#matplotlib.rcParams['font.family'] = 'Droid Sans Fallback' | |
id2color= {1: [209, 35, 69], | |
2: [216, 208, 246], | |
3: [172, 196, 170], | |
4: [178, 80, 80], | |
6: [89, 89, 89], | |
7: [160, 146, 229], | |
8: [18, 17, 20], | |
10: [190, 209, 189], | |
13: [37, 12, 156], | |
15: [250, 50, 83], | |
16: [61, 245, 61], | |
17: [230, 203, 104], | |
18: [125, 104, 227], | |
19: [228, 225, 249], | |
20: [51, 221, 255], | |
21: [95, 95, 95], | |
23: [156, 239, 255], | |
24: [153, 102, 51], | |
26: [0, 0, 226], | |
27: [254, 242, 208], | |
29: [89, 134, 179], | |
32: [255, 0, 204], | |
33: [170, 240, 209], | |
34: [140, 120, 240], | |
35: [118, 255, 166], | |
36: [250, 250, 55], | |
37: [243, 232, 208], | |
38: [1, 118, 141], | |
39: [243, 241, 255], | |
41: [158, 108, 4], | |
43: [132, 0, 0], | |
44: [245, 147, 49], | |
46: [240, 120, 240], | |
47: [149, 83, 203], | |
48: [52, 209, 183], | |
49: [200, 101, 0], | |
50: [65, 112, 192], | |
52: [255, 204, 51], | |
53: [36, 179, 83], | |
56: [90, 98, 89], | |
57: [255, 191, 0], | |
58: [204, 153, 51], | |
59: [31, 73, 125], | |
60: [155, 149, 205], | |
61: [154, 150, 169], | |
62: [128, 128, 128], | |
63: [163, 160, 172], | |
64: [255, 106, 77], | |
65: [115, 51, 128], | |
0: [10, 9, 10]} | |
id2label= {1: '动物皮', | |
2: '骨/牙/角', | |
3: '砖块', | |
4: '纸板/纸', | |
6: '天花板瓦片', | |
7: '瓷', | |
8: '黑板', | |
10: '混凝土', | |
13: '织物/布/地毯', | |
15: '火', | |
16: '树叶', | |
17: '食物', | |
18: '毛皮', | |
19: '宝石/石英', | |
20: '玻璃', | |
21: '毛发', | |
23: '冰', | |
24: '皮革', | |
26: '金属', | |
27: '镜子', | |
29: '油漆/抹灰/石膏', | |
32: '照片/绘画/布面招牌', | |
33: '透明塑料', | |
34: '非透明塑料', | |
35: '橡胶/乳胶', | |
36: '沙', | |
37: '皮肤/嘴唇', | |
38: '天空', | |
39: '雪', | |
41: '土壤/泥土', | |
43: '天然石材', | |
44: '抛光石材', | |
46: '片状地砖/石地砖/瓷地砖', | |
47: '壁纸', | |
48: '水', | |
49: '蜡', | |
50: '白板', | |
52: '木材', | |
53: '树木', | |
56: '沥青', | |
57: '珐琅/琉璃', | |
58: '夯土', | |
59: '塑钢复合装饰板', | |
60: '水泥', | |
61: '陶', | |
62: '屋顶防水卷材', | |
63: '金属网窗(远景)', | |
64: '砖雕', | |
65: '纱窗', | |
0: '背景/未知'} | |
id2material={1: 'Animal skin', | |
2: 'Bone/teeth/horn', | |
3: 'Brickwork', | |
4: 'Cardboard/Paper', | |
6: 'Ceiling tile', | |
7: 'Ceramic', | |
8: 'Chalkboard/blackboard', | |
10: 'Concrete', | |
13: 'Fabric/cloth', | |
15: 'Fire', | |
16: 'Foliage', | |
17: 'Food', | |
18: 'Fur', | |
19: 'Gemstone/quartz', | |
20: 'Glass', | |
21: 'Hair', | |
23: 'Ice', | |
24: 'Leather', | |
26: 'Metal', | |
27: 'Mirror', | |
29: 'Paint/plaster', | |
32: 'Photograph/painting', | |
33: 'Plastic, clear', | |
34: 'Plastic, non-clear', | |
35: 'Rubber/latex', | |
36: 'Sand', | |
37: 'Skin/lips', | |
38: 'Sky', | |
39: 'Snow', | |
41: 'Soil/mud', | |
43: 'natural stone', | |
44: 'polished stone & engineered stone', | |
46: 'Tile', | |
47: 'Wallpaper', | |
48: 'Water', | |
49: 'Wax', | |
50: 'Whiteboard', | |
52: 'Wood', | |
53: 'tree', | |
56: 'Asphalt', | |
57: 'enamel', | |
58: 'Rammed earth', | |
59: 'composite decorative board', | |
60: 'Cement', | |
61: 'Pottery', | |
62: 'Roofing waterproof material', | |
63: 'Metal mesh window (perspective)', | |
64: 'carved brick', | |
65: 'window screen', | |
0: 'background'} | |
model_save_path ='jinfengxie/BFM_segformer0821' | |
model = TFSegformerForSemanticSegmentation.from_pretrained(model_save_path) | |
def predict_and_visualize(image): | |
#image = Image.open(image_path) | |
image_np = np.array(image) | |
height,width,_=image_np.shape | |
maxhl=max(height,width) | |
image = tf.convert_to_tensor(image_np, dtype=tf.float32) | |
if maxhl>1500: | |
if maxhl==height: | |
image=tf.image.resize(image,(1500,int(1500*width/height))) | |
if maxhl==width: | |
image=tf.image.resize(image,(int(1500*height/width),1500)) | |
#image = tf.image.resize_with_pad(image, 1500, 1500) | |
image = tf.cast(image, tf.float32) / 255.0 | |
image = tf.transpose(image, perm=[2, 0, 1]) | |
images= tf.expand_dims(image, axis=0) | |
# 进行预测 | |
preds = model.predict(images).logits | |
pred_mask = tf.argmax(preds, axis=1) | |
pred_mask = tf.expand_dims(pred_mask, axis=-1) | |
pred_mask = pred_mask[0] # 取出批处理的第一个结果 | |
pred_mask=tf.image.resize(pred_mask,(height,width),method='nearest') | |
pred_mask=tf.squeeze(pred_mask) | |
print(pred_mask.shape) | |
#pred_mask = pred_mask[:,:,0] .numpy() # 取出批处理的第一个结果 | |
#print(pred_mask.shape) | |
unique, counts = np.unique(pred_mask, return_counts=True) | |
counts_dict = dict(zip(unique, counts)) | |
# 转换预测掩码为颜色图像 | |
color_mask = np.zeros((height,width, 3)) | |
label_positions = {} | |
for key, value in id2color.items(): | |
#print("mask shape",mask.shape) | |
color_mask[pred_mask == key] = np.array(value) # 颜色值需要被标准化到[0,1] | |
indices = np.where(pred_mask == key) | |
if indices[0].size > 0: | |
# 计算标签的位置为当前类别像素的中心点 | |
label_positions[key] = (np.mean(indices[1]), np.mean(indices[0])) | |
color_mask = color_mask.astype(np.uint8) | |
result_image = Image.fromarray(color_mask) | |
draw = ImageDraw.Draw(result_image) | |
font = ImageFont.truetype("arial.ttf", int(height/30)) # 尝试加载Arial字体,大小为12 | |
for key, position in label_positions.items(): | |
if key in id2label: | |
# 绘制文本,您可能需要调整文本位置和字体大小 | |
material=id2material[key] | |
draw.text((position[0], position[1]), str(material), font=font, fill='white') | |
return pred_mask,result_image,counts_dict | |
def ext_colors(image_path,mask,n_clusters=4): | |
#image = Image.open(image_path) | |
# 将图像和掩码转换为numpy数组 | |
image_np = np.array(image_path) | |
mask_np = np.array(mask) | |
# 获取掩码中的唯一类别 | |
unique_classes = np.unique(mask_np) | |
# 为每个类别提取颜色 | |
colors_per_class = {} | |
for cls in unique_classes: | |
# 提取当前类别的像素点 | |
indices = np.where(mask_np == cls) | |
#print(indices) | |
pixels = image_np[indices] | |
# 使用K-means聚类来找到主要颜色 | |
kmeans = KMeans(n_clusters=n_clusters,n_init=10) | |
kmeans.fit(pixels) | |
dominant_colors = kmeans.cluster_centers_ | |
# 将颜色存储为整数值 | |
dominant_colors = dominant_colors.astype(int) | |
# 保存颜色 | |
colors_per_class[cls] = dominant_colors | |
return colors_per_class | |
def plot_material_color_palette_grid(material_dict, materials_per_row=4): | |
# Calculate total number of color rows and header rows needed | |
total_rows = sum((len(colors) + 1) for colors in material_dict.values()) # +1 for the header row per material | |
num_materials = len(material_dict) | |
grid_rows = (num_materials + materials_per_row - 1) // materials_per_row | |
total_grid_rows = 0 | |
for i in range(grid_rows): | |
row_materials = list(material_dict.keys())[i * materials_per_row:(i + 1) * materials_per_row] | |
row_height = max(len(material_dict[mat]) for mat in row_materials if mat in material_dict) + 1 | |
total_grid_rows += row_height | |
# Set dimensions and spacing | |
block_width = 1 | |
block_height = 0.5 | |
text_gap = 0.2 | |
row_gap = 0.2 | |
column_gap = 1.5 # Gap between material columns within the same row | |
# Calculate figure width and height dynamically | |
fig_width = materials_per_row * (block_width + text_gap + column_gap) | |
fig_height = total_grid_rows * (block_height + row_gap) | |
# Create a figure and a set of subplots | |
fig, ax = plt.subplots(figsize=(fig_width, fig_height)) | |
# Set the title of the figure | |
#ax.set_title('Material Color Palette Grid') | |
# Remove axes | |
ax.axis('off') | |
# Reverse the Y-axis to top-align the origin | |
ax.invert_yaxis() | |
current_row = 0 # Tracker for the current row position in the grid | |
for i in range(grid_rows): | |
row_materials = list(material_dict.keys())[i * materials_per_row:(i + 1) * materials_per_row] | |
max_row_height = max(len(material_dict[mat]) for mat in row_materials if mat in material_dict) + 1 | |
for j, material in enumerate(row_materials): | |
if material not in material_dict: | |
continue | |
colors = material_dict[material] | |
# Add a header for each material class | |
ax.text(j * (block_width + text_gap + column_gap), current_row * (block_height + row_gap)+0.5, | |
material, va='center', fontsize=12, fontweight='bold', ha='left') | |
material_row_start = current_row | |
for k, color in enumerate(colors): | |
# Normalize the RGB values to [0, 1] for Matplotlib | |
normalized_color = np.array(color) / 255.0 | |
y_pos = (material_row_start + 1 + k) * (block_height + row_gap) | |
# Draw a rectangle for each color | |
ax.add_patch(plt.Rectangle((j * (block_width + text_gap + column_gap), y_pos), | |
block_width, block_height, color=normalized_color)) | |
# Annotate the RGB values to the right of each color block | |
ax.text(j * (block_width + text_gap + column_gap) + block_width + text_gap, y_pos + block_height / 2, | |
str(color), va='center', fontsize=10) | |
current_row += max_row_height | |
# Adjust plot limits | |
ax.set_xlim(0, fig_width) | |
ax.set_ylim(current_row * (block_height + row_gap), 0) | |
# 保存到内存,而不是显示图像 | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png') | |
plt.close() | |
buf.seek(0) | |
img = Image.open(buf) | |
return img | |
# 将matplotlib图转换为图像 | |
def plt_to_image(): | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png',dpi=300) | |
plt.close() | |
buf.seek(0) | |
img = Image.open(buf) | |
return img | |
def calculate_slice_statistics(one_mask, slice_size=256): | |
"""计算每个切片的材质占比""" | |
num_rows, num_cols = one_mask.shape[0] // slice_size, one_mask.shape[1] // slice_size | |
slice_stats = {} | |
for i in range(num_rows): | |
for j in range(num_cols): | |
slice_mask = one_mask[i*slice_size:(i+1)*slice_size, j*slice_size:(j+1)*slice_size] | |
unique, counts = np.unique(slice_mask, return_counts=True) | |
total_pixels = counts.sum() | |
slice_stats[(i, j)] = {k: v / total_pixels for k, v in zip(unique, counts)} | |
return slice_stats | |
def find_top_slices(slice_stats, exclusion_list, min_percent=0.7, min_slices=1, top_k=3): | |
"""找出每个类材质占比最高的前三个切片,加入新的筛选条件""" | |
from collections import defaultdict | |
import heapq | |
top_slices = defaultdict(list) | |
for slice_pos, stats in slice_stats.items(): | |
for material_id, percent in stats.items(): | |
# 第一个判断:材质是否在排除列表中 | |
if material_id in exclusion_list: | |
continue | |
# 第二个判断:材质占比是否至少为70% | |
if percent < min_percent: | |
continue | |
# 将符合条件的切片添加到堆中 | |
if len(top_slices[material_id]) < top_k: | |
heapq.heappush(top_slices[material_id], (percent, slice_pos)) | |
else: | |
heapq.heappushpop(top_slices[material_id], (percent, slice_pos)) | |
# 过滤出符合第三个条件的材质 | |
valid_top_slices = {} | |
for material_id, slices in top_slices.items(): | |
if len(slices) > min_slices: # 至少有超过一个切片 | |
valid_top_slices[material_id] = sorted(slices, reverse=True) | |
return valid_top_slices | |
def extract_and_visualize_top_slices(image, top_slices, slice_size=256): | |
fig, axs = plt.subplots(nrows=len(top_slices), ncols=3, figsize=(15, 5 * len(top_slices))) | |
image=Image.fromarray(image) | |
if len(top_slices) == 1: | |
axs = [axs] | |
for idx, (material_id, slices) in enumerate(top_slices.items()): | |
for col, (_, pos) in enumerate(slices): | |
i, j = pos | |
img_slice = image.crop((j * slice_size, i * slice_size, (j + 1) * slice_size, (i + 1) * slice_size)) | |
axs[idx][col].imshow(img_slice) | |
axs[idx][col].set_title(f'Material {id2material[material_id]} - Slice {pos}') | |
axs[idx][col].axis('off') | |
plt.tight_layout() | |
# 保存到内存,而不是显示图像 | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png') | |
plt.close() | |
buf.seek(0) | |
img = Image.open(buf) | |
return img | |
# main program | |
def process_image(image_path): | |
#image = Image.open(image_path) | |
one_mask,color_mask, counts_dict = predict_and_visualize(image_path) | |
colors_per_class=ext_colors(image_path,one_mask,n_clusters=4) | |
colors_per_label = {id2material[key]: value for key, value in colors_per_class.items()} | |
# 定义一个列表,包含需要从字典中删除的键 | |
labels_to_remove = ['Sky', 'background','Glass','tree','water','Plastic, clear'] | |
# 使用字典推导式删除列表中的键 | |
colors_per_label = {key: value for key, value in colors_per_label.items() if key not in labels_to_remove} | |
palette_image = plot_material_color_palette_grid(colors_per_label) | |
# 将结果转化为图片展示 | |
plt.figure(figsize=(5, 5)) | |
plt.imshow(color_mask) | |
plt.tight_layout() | |
plt.axis('off') | |
color_mask_img = plt_to_image() | |
counts_dict2={id2label[key]: value for key, value in counts_dict.items()} | |
counts_df = pd.DataFrame(list(counts_dict2.items()), columns=['类别', '计数']) | |
# 计算总计数 | |
total_count = counts_df['计数'].sum() | |
# 计算每个类别的百分比 | |
counts_df['百分比'] = (counts_df['计数'] / total_count * 100).round(2) | |
# 重新命名 DataFrame 为 percentage_df 以清楚表达其内容 | |
percentage_df = counts_df.rename(columns={'计数': 'pixels', '百分比': 'percentage (%)'}) | |
slice_size = 128 | |
exclusion_list = [38] | |
slice_stats = calculate_slice_statistics(one_mask, slice_size=slice_size) | |
top_slices = find_top_slices(slice_stats, exclusion_list=exclusion_list, min_percent=0.5, min_slices=1) | |
slice_image=extract_and_visualize_top_slices(image_path, top_slices, slice_size=slice_size) | |
return color_mask_img, palette_image, slice_image, percentage_df | |
iface = gr.Interface( | |
fn=process_image, | |
inputs=gr.Image(), | |
outputs=[ | |
gr.Image(type="pil", label="Color Mask"), | |
gr.Image(type="pil", label="Color Palette"), | |
gr.Image(type='pil', label='Texture Slices'), | |
gr.DataFrame() | |
], | |
title="Building Facade Material Segmentation", | |
description="Upload an image to segment material masks, and get color palettes." | |
) | |
iface.launch() |