nehakothari commited on
Commit
06fb7e9
·
verified ·
1 Parent(s): 6e742cd

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +197 -0
app.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import sys
3
+ import os
4
+
5
+ # Install dependencies before importing
6
+ def install_dependencies():
7
+ packages = [
8
+ "pip==23.3.1",
9
+ "setuptools",
10
+ "wheel",
11
+ "pytesseract",
12
+ "torch==2.1.0",
13
+ "torchvision==0.16.0",
14
+ "torchaudio==2.1.0",
15
+ "transformers==4.38.2",
16
+ "auto-gptq==0.7.1",
17
+ "autoawq==0.2.8",
18
+ "qwen_vl_utils==0.0.8",
19
+ "gradio==4.27.0",
20
+ "pyodbc",
21
+ "sqlalchemy",
22
+ "azure-storage-blob",
23
+ "pymssql",
24
+ "pandas",
25
+ "opencv-python"
26
+ ]
27
+ for package in packages:
28
+ subprocess.check_call([sys.executable, "-m", "pip", "install", package], stdout=sys.stdout, stderr=sys.stderr)
29
+ print(f"Installed {package}")
30
+
31
+ install_dependencies()
32
+
33
+ # Install system dependencies
34
+ def install_system_dependencies():
35
+ commands = [
36
+ "apt-get update",
37
+ "apt-get install -y unixodbc-dev tesseract-ocr",
38
+ "ACCEPT_EULA=Y apt-get install -y msodbcsql17"
39
+ ]
40
+ for command in commands:
41
+ subprocess.run(command, shell=True, check=True)
42
+ print(f"Executed: {command}")
43
+
44
+ install_system_dependencies()
45
+
46
+ # Now import the libraries after installation
47
+ import gradio as gr
48
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
49
+ from qwen_vl_utils import process_vision_info
50
+ import torch
51
+ import pandas as pd
52
+ import pytesseract
53
+ import cv2
54
+ import pymssql
55
+
56
+ # Hardcoded Hugging Face token and SQL server IP address
57
+
58
+ SERVER_IP = "35.227.148.156"
59
+
60
+ # Initialize model and processor
61
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
62
+ "Qwen/Qwen2-VL-2B-Instruct-AWQ",
63
+ torch_dtype="auto",
64
+ use_auth_token=HUGGINGFACE_API_KEY
65
+ )
66
+
67
+ processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct-AWQ", use_auth_token=HUGGINGFACE_API_KEY)
68
+
69
+ pytesseract.pytesseract_cmd = r'/usr/bin/tesseract'
70
+
71
+ # Function to preprocess the image for OCR
72
+ def preprocess_image(image_path):
73
+ image = cv2.imread(image_path)
74
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
75
+ _, binary = cv2.threshold(gray, 150, 255, cv2.THRESH_BINARY)
76
+ return binary
77
+
78
+ # Function to extract text using OCR
79
+ def ocr_extract_text(image_path):
80
+ preprocessed_image = preprocess_image(image_path)
81
+ return pytesseract.image_to_string(preprocessed_image)
82
+
83
+ # Function to process image and extract details
84
+ def process_image(image_path):
85
+ try:
86
+ messages = [{
87
+ "role": "user",
88
+ "content": [
89
+ {"type": "image", "image": image_path},
90
+ {"type": "text", "text": (
91
+ "Extract the following details from the invoice:\n"
92
+ "- 'invoice_number'\n"
93
+ "- 'date'\n"
94
+ "- 'place'\n"
95
+ "- 'amount' (monetary value in the relevant currency)\n"
96
+ "- 'category' (based on the invoice type)"
97
+ )}
98
+ ]
99
+ }]
100
+
101
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
102
+ image_inputs, video_inputs = process_vision_info(messages)
103
+ inputs = processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt")
104
+ inputs = inputs.to(model.device)
105
+
106
+ generated_ids = model.generate(**inputs, max_new_tokens=128)
107
+ output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
108
+
109
+ return parse_details(output_text[0])
110
+
111
+ except Exception as e:
112
+ print(f"Model failed, falling back to OCR: {e}")
113
+ ocr_text = ocr_extract_text(image_path)
114
+ return parse_details(ocr_text)
115
+
116
+ # Function to parse details from extracted text
117
+ def parse_details(details):
118
+ parsed_data = {
119
+ "Invoice Number": None,
120
+ "Date": None,
121
+ "Place": None,
122
+ "Amount": None,
123
+ "Category": None
124
+ }
125
+
126
+ lines = details.split("\n")
127
+ for line in lines:
128
+ lower_line = line.lower()
129
+ if "invoice" in lower_line:
130
+ parsed_data["Invoice Number"] = line.split(":")[-1].strip()
131
+ elif "date" in lower_line:
132
+ parsed_data["Date"] = line.split(":")[-1].strip()
133
+ elif "place" in lower_line:
134
+ parsed_data["Place"] = line.split(":")[-1].strip()
135
+ elif any(keyword in lower_line for keyword in ["total", "amount", "cost"]):
136
+ parsed_data["Amount"] = line.split(":")[-1].strip()
137
+ else:
138
+ parsed_data["Category"] = "General"
139
+
140
+ return parsed_data
141
+
142
+ # Store extracted data in Azure SQL Database
143
+ def store_to_azure_sql(dataframe):
144
+ conn_str = (
145
+ f"Driver={{ODBC Driver 17 for SQL Server}};"
146
+ f"Server={SERVER_IP};"
147
+ "Database=Invoices;"
148
+ "UID=pio-admin;"
149
+ "PWD=Poctest123#;"
150
+ )
151
+ try:
152
+ with pymssql.connect(conn_str) as conn:
153
+ cursor = conn.cursor()
154
+ create_table_query = """
155
+ IF NOT EXISTS (SELECT * FROM sysobjects WHERE name='Invoices' AND xtype='U')
156
+ CREATE TABLE Invoices (
157
+ InvoiceNumber NVARCHAR(255),
158
+ Date NVARCHAR(255),
159
+ Place NVARCHAR(255),
160
+ Amount NVARCHAR(255),
161
+ Category NVARCHAR(255)
162
+ )
163
+ """
164
+ cursor.execute(create_table_query)
165
+
166
+ for _, row in dataframe.iterrows():
167
+ insert_query = """
168
+ INSERT INTO Invoices (InvoiceNumber, Date, Place, Amount, Category)
169
+ VALUES (%s, %s, %s, %s, %s)
170
+ """
171
+ cursor.execute(insert_query, row['Invoice Number'], row['Date'], row['Place'], row['Amount'], row['Category'])
172
+ conn.commit()
173
+ print("Data successfully stored in Azure SQL Database.")
174
+ except Exception as e:
175
+ print(f"Error storing data to database: {e}")
176
+
177
+ # Gradio interface for invoice processing
178
+ def gradio_interface(image_files):
179
+ results = []
180
+ for image_file in image_files:
181
+ details = process_image(image_file)
182
+ results.append(details)
183
+
184
+ df = pd.DataFrame(results)
185
+ store_to_azure_sql(df)
186
+ return df
187
+
188
+ # Launch Gradio interface
189
+ grpc_interface = gr.Interface(
190
+ fn=gradio_interface,
191
+ inputs=gr.Files(label="Upload Invoice Images"),
192
+ outputs=gr.Dataframe(interactive=True),
193
+ title="Invoice Extraction System",
194
+ )
195
+
196
+ if __name__ == "__main__":
197
+ grpc_interface.launch(share=True)