DSVmon commited on
Commit
427fced
·
verified ·
1 Parent(s): dc8ed73
Files changed (1) hide show
  1. app.py +199 -0
app.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # === Загрузка библиотек ===
2
+
3
+ from pypdf import PdfReader, PdfWriter
4
+ import gradio as gr
5
+ import fitz
6
+ from PIL import Image
7
+ import pandas as pd
8
+ import cv2
9
+ import numpy as np
10
+ import os
11
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
12
+ import torch
13
+ import difflib
14
+ from pdf2image import convert_from_path
15
+
16
+ # Загрузка TrOCR
17
+ processor = TrOCRProcessor.from_pretrained('kazars24/trocr-base-handwritten-ru')
18
+ model = VisionEncoderDecoderModel.from_pretrained('kazars24/trocr-base-handwritten-ru')
19
+ device = "cuda" if torch.cuda.is_available() else "cpu"
20
+ model.to(device)
21
+
22
+ # === 1. Функция поиска и группировки линий ===
23
+ def group_lines(contours, img_size, y_tolerance=10, is_horizontal=True):
24
+ line_groups = []
25
+ used = [False] * len(contours)
26
+
27
+ for i in range(len(contours)):
28
+ if used[i]:
29
+ continue
30
+
31
+ group = [contours[i]]
32
+ used[i] = True
33
+ x, y, w, h = cv2.boundingRect(contours[i])
34
+
35
+ for j in range(i + 1, len(contours)):
36
+ if used[j]:
37
+ continue
38
+ x2, y2, w2, h2 = cv2.boundingRect(contours[j])
39
+ if is_horizontal:
40
+ if abs(y2 - y) < y_tolerance:
41
+ group.append(contours[j])
42
+ used[j] = True
43
+ else:
44
+ if abs(x2 - x) < y_tolerance:
45
+ group.append(contours[j])
46
+ used[j] = True
47
+
48
+ line_groups.append(group)
49
+
50
+ return line_groups
51
+
52
+ # === 2. Основная функция отрисовки линий и сохранения координат ===
53
+ def detect_table_lines_and_cells(img, min_cell_size=15):
54
+ gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
55
+ thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
56
+
57
+ # Горизонтальные линии
58
+ horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (40, 1))
59
+ detect_horizontal = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, horizontal_kernel, iterations=2)
60
+ horizontal_contours = cv2.findContours(detect_horizontal, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
61
+ horizontal_contours = horizontal_contours[0] if len(horizontal_contours) == 2 else horizontal_contours[1]
62
+
63
+ horizontal_line_groups = group_lines(horizontal_contours, img.shape[0], is_horizontal=True)
64
+ horizontal_line_groups.sort(key=lambda g: np.mean([cv2.boundingRect(c)[1] for c in g]))
65
+
66
+ horizontal_coords = [int(np.mean([cv2.boundingRect(c)[1] + cv2.boundingRect(c)[3] / 2 for c in group])) for group in horizontal_line_groups[3:]] # от 4-й линии
67
+
68
+ # Вертикальные линии
69
+ vertical_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 40))
70
+ detect_vertical = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, vertical_kernel, iterations=2)
71
+ vertical_contours = cv2.findContours(detect_vertical, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
72
+ vertical_contours = vertical_contours[0] if len(vertical_contours) == 2 else vertical_contours[1]
73
+
74
+ vertical_line_groups = group_lines(vertical_contours, img.shape[1], is_horizontal=False)
75
+ vertical_coords = [int(np.mean([cv2.boundingRect(c)[0] + cv2.boundingRect(c)[2] / 2 for c in group])) for group in vertical_line_groups]
76
+
77
+ # Поиск ячеек
78
+ cells = []
79
+ horizontal_coords = sorted(horizontal_coords)
80
+ vertical_coords = sorted(vertical_coords)
81
+
82
+ for row_idx in range(len(horizontal_coords) - 1):
83
+ y1, y2 = horizontal_coords[row_idx], horizontal_coords[row_idx + 1]
84
+ for col_idx in range(len(vertical_coords) - 1):
85
+ x1, x2 = vertical_coords[col_idx], vertical_coords[col_idx + 1]
86
+ w = x2 - x1
87
+ h = y2 - y1
88
+ if w > min_cell_size and h > min_cell_size:
89
+ cells.append({'row': row_idx, 'col': col_idx, 'box': (x1, y1, w, h)})
90
+
91
+ return cells
92
+
93
+ # === 3. Функция распознавания текста в ячейках ===
94
+ def recognize_text(image, max_length=10):
95
+ if image is None:
96
+ return "Не удалось загрузить изображение"
97
+ try:
98
+ inputs = processor(images=image, return_tensors="pt").to(device)
99
+ generated_ids = model.generate(
100
+ **inputs,
101
+ max_length=max_length,
102
+ early_stopping=True,
103
+ num_beams=1,
104
+ use_cache=True
105
+ )
106
+ return processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
107
+ except Exception as e:
108
+ print(f"Ошибка распознавания: {e}")
109
+ return "Ошибка распознавания"
110
+
111
+ # === 4. Обрезка ячеек и OCR для таблицы ===
112
+ def crop_and_recognize_cells(image, cells):
113
+ allowed_words = ["труба", "врезка", "зкл", "отвод", "арм", "переход", "тройник", "заглушка", "зад-ка", "т-т", "комп-р"]
114
+
115
+ recognized = {}
116
+ pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
117
+
118
+ for cell in cells:
119
+ x, y, w, h = cell['box']
120
+ cropped = pil_image.crop((x, y, x + w, y + h))
121
+
122
+ text = recognize_text(cropped, max_length=10).strip().lower()
123
+
124
+ # Заменяем все запятые на точки
125
+ text = text.replace(',', '.')
126
+
127
+ if any(char.isdigit() for char in text):
128
+ # Если есть цифры - оставляем как есть (уже заменили запятые)
129
+ final_text = text
130
+ else:
131
+ # Если текст состоит только из букв
132
+ if len(text) <= 2:
133
+ # Если длина 1 или 2 символа - заменяем на пустую строку
134
+ final_text = ""
135
+ else:
136
+ # Ищем наиболее похожее слово из словаря
137
+ matches = difflib.get_close_matches(text, allowed_words, n=1, cutoff=0.5)
138
+ final_text = matches[0] if matches else text
139
+
140
+ recognized[(cell['row'], cell['col'])] = final_text
141
+
142
+ return recognized
143
+
144
+ # === 5. Полный процесс обработки изображения таблицы ===
145
+ def process_pdf_table(pdf_path, output_excel='results.xlsx'):
146
+ images = convert_from_path(pdf_path, dpi=300) # Загружаем первую страницу
147
+ if not images:
148
+ print("Ошибка: PDF пустой или не удалось сконвертировать.")
149
+ return
150
+
151
+ image = cv2.cvtColor(np.array(images[0]), cv2.COLOR_RGB2BGR) # Переводим PIL -> OpenCV
152
+
153
+ cells = detect_table_lines_and_cells(image, min_cell_size=15)
154
+ print(f"Найдено ячеек: {len(cells)}")
155
+
156
+ recognized = crop_and_recognize_cells(image, cells)
157
+
158
+ # Собираем в DataFrame
159
+ data = {}
160
+ for (row, col), text in recognized.items():
161
+ data.setdefault(row, {})[col] = text
162
+
163
+ max_cols = max((max(cols.keys()) for cols in data.values()), default=0) + 1
164
+ rows = []
165
+
166
+ for row_idx in range(max(data.keys()) + 1):
167
+ row = []
168
+ for col_idx in range(max_cols):
169
+ row.append(data.get(row_idx, {}).get(col_idx, ""))
170
+ rows.append(row)
171
+
172
+ df = pd.DataFrame(rows)
173
+ df.to_excel(output_excel, index=False, header=False)
174
+ print(f"Результат сохранён в {output_excel}")
175
+ return output_excel
176
+
177
+ # === Gradio приложение ===
178
+ def gradio_process(pdf_file, progress=gr.Progress()):
179
+ progress(0, desc="Чтение PDF...")
180
+
181
+ # Получаем имя без расширения и меняем его на .xlsx
182
+ base_name = os.path.splitext(os.path.basename(pdf_file.name))[0]
183
+ output_excel = f"{base_name}.xlsx"
184
+
185
+ progress(0.3, desc="Поиск ячеек таблицы...")
186
+ result_file = process_pdf_table(pdf_file.name, output_excel=output_excel)
187
+ progress(1.0, desc="Готово! Таблица сохранена.")
188
+ return result_file
189
+
190
+ app = gr.Interface(
191
+ fn=gradio_process,
192
+ inputs=gr.File(label="Загрузите PDF файл таблицы"),
193
+ outputs=gr.File(label="Скачайте Excel с распознанными ячейками"),
194
+ title="📄 PDF → Excel распознавание таблиц",
195
+ description="Загрузите PDF-файл с таблицей. Программа найдет ячейки, распознает текст и сохранит результат в Excel-файл.",
196
+ allow_flagging="never"
197
+ )
198
+
199
+ app.launch(share=True)