Nayefleb commited on
Commit
aa8be0a
·
verified ·
1 Parent(s): 3365d41

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +138 -75
app.py CHANGED
@@ -1,82 +1,145 @@
1
- import gradio as gr
2
- import matplotlib.pyplot as plt
3
- import pandas as pd
4
- import neurokit2 as nk
5
- import cv2
6
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- try:
9
- from ecg_image_kit.digitizer import digitize_ecg # Adjust based on actual module/function
10
- except ImportError as e:
11
- raise ImportError(f"Failed to import ecg_image_kit: {str(e)}. Ensure the ecg_image_kit directory is included.")
12
 
13
- def load_and_digitize_ecg(image_path):
14
- try:
15
- image = cv2.imread(image_path)
16
- if image is None:
17
- raise ValueError("Failed to load image")
18
- time_series = digitize_ecg(image) # Adjust based on actual function
19
- return time_series
20
- except Exception as e:
21
- return f"Error digitizing image: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- def analyze_ecg(time_series, sampling_rate=250):
24
- try:
25
- results = {}
26
- for lead_idx, ecg_signal in enumerate(time_series):
27
- signals, info = nk.ecg_process(ecg_signal, sampling_rate=sampling_rate)
28
- analysis = nk.ecg_analyze(signals, sampling_rate=sampling_rate)
29
- results[f"Lead_{lead_idx+1}"] = {
30
- "signals": signals,
31
- "info": info,
32
- "analysis": analysis
33
- }
34
- return results
35
- except Exception as e:
36
- return f"Error analyzing ECG: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- def process_ecg_image(image):
39
- try:
40
- image_path = "temp_ecg_image.png"
41
- image.save(image_path)
42
- time_series = load_and_digitize_ecg(image_path)
43
- if isinstance(time_series, str):
44
- return time_series, None
45
- results = analyze_ecg(time_series, sampling_rate=250)
46
- if isinstance(results, str):
47
- return results, None
48
- plots = []
49
- tables = []
50
- for lead, data in results.items():
51
- fig, ax = plt.subplots()
52
- signals = data["signals"]
53
- info = data["info"]
54
- ax.plot(signals["ECG_Clean"], label="Clean ECG")
55
- ax.plot(info["ECG_R_Peaks"], signals["ECG_Clean"][info["ECG_R_Peaks"]], "ro", label="R-peaks")
56
- ax.set_title(f"{lead} ECG")
57
- ax.legend()
58
- plots.append(fig)
59
- analysis = data["analysis"]
60
- table = pd.DataFrame({
61
- "Feature": ["Heart Rate (Mean)", "ECG Quality"],
62
- "Value": [analysis.get("ECG_Rate_Mean", "N/A"), analysis.get("ECG_Quality_Mean", "N/A")]
63
- })
64
- tables.append(table)
65
- os.remove(image_path)
66
- return plots, tables
67
- except Exception as e:
68
- return f"Error processing image: {str(e)}", None
69
 
70
- iface = gr.Interface(
71
- fn=process_ecg_image,
72
- inputs=gr.Image(type="pil", label="Upload ECG Image"),
73
- outputs=[
74
- gr.Gallery(label="ECG Plots"),
75
- gr.Dataframe(label="Analysis Results")
76
- ],
77
- title="ECG Image Analysis",
78
- description="Upload a 12-lead ECG image to digitize and analyze it."
79
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
- if __name__ == "__main__":
82
- iface.launch()
 
 
 
 
 
 
 
 
1
  import os
2
+ import cv2
3
+ import numpy as np
4
+ import pandas as pd
5
+ import tensorflow as tf
6
+ from tensorflow.keras.models import load_model
7
+ from flask import Flask, request, render_template
8
+ from werkzeug.utils import secure_filename
9
+ import biosppy.signals.ecg as ecg
10
+
11
+ app = Flask(__name__)
12
+
13
+ UPLOAD_FOLDER = 'uploads'
14
+ ALLOWED_EXTENSIONS = {'csv', 'png', 'jpg', 'jpeg'}
15
+ app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
16
+
17
+ # Load the pre-trained model
18
+ model = load_model('ecgScratchEpoch2.hdf5')
19
 
20
+ def allowed_file(filename):
21
+ return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
 
 
22
 
23
+ def image_to_signal(image_path):
24
+ # Read and preprocess the image
25
+ img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
26
+ if img is None:
27
+ raise ValueError("Failed to load image")
28
+
29
+ # Resize to a standard height for consistency (e.g., 500 pixels)
30
+ img = cv2.resize(img, (1000, 500))
31
+
32
+ # Apply thresholding to isolate the waveform
33
+ _, binary = cv2.threshold(img, 200, 255, cv2.THRESH_BINARY_INV)
34
+
35
+ # Find contours of the waveform
36
+ contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
37
+ if not contours:
38
+ raise ValueError("No waveform detected in the image")
39
+
40
+ # Assume the largest contour is the ECG waveform
41
+ contour = max(contours, key=cv2.contourArea)
42
+
43
+ # Extract y-coordinates (signal amplitude) along x-axis
44
+ signal = []
45
+ width = img.shape[1]
46
+ for x in range(width):
47
+ column = contour[contour[:, :, 0] == x]
48
+ if len(column) > 0:
49
+ # Take the average y-coordinate if multiple points exist
50
+ y = np.mean(column[:, :, 1])
51
+ signal.append(y)
52
+ else:
53
+ # Interpolate if no point is found
54
+ signal.append(signal[-1] if signal else 0)
55
+
56
+ # Normalize signal to match expected amplitude range
57
+ signal = np.array(signal)
58
+ signal = (signal - np.min(signal)) / (np.max(signal) - np.min(signal)) * 1000
59
+
60
+ # Save to CSV
61
+ csv_path = os.path.join(app.config['UPLOAD_FOLDER'], 'converted_signal.csv')
62
+ df = pd.DataFrame(signal, columns=[' Sample Value'])
63
+ df.to_csv(csv_path, index=False)
64
+
65
+ return csv_path
66
 
67
+ def model_predict(uploaded_files, model):
68
+ output = []
69
+ for path in uploaded_files:
70
+ APC, NORMAL, LBB, PVC, PAB, RBB, VEB = [], [], [], [], [], [], []
71
+ output.append(str(path))
72
+ result = {"APC": APC, "Normal": NORMAL, "LBB": LBB, "PAB": PAB, "PVC": PVC, "RBB": RBB, "VEB": VEB}
73
+
74
+ kernel = np.ones((4,4), np.uint8)
75
+ csv = pd.read_csv(path)
76
+ csv_data = csv[' Sample Value']
77
+ data = np.array(csv_data)
78
+ signals = []
79
+ count = 1
80
+ peaks = ecg.christov_segmenter(signal=data, sampling_rate=200)[0]
81
+ indices = []
82
+
83
+ for i in peaks[1:-1]:
84
+ diff1 = abs(peaks[count - 1] - i)
85
+ diff2 = abs(peaks[count + 1] - i)
86
+ x = peaks[count - 1] + diff1 // 2
87
+ y = peaks[count + 1] - diff2 // 2
88
+ signal = data[x:y]
89
+ signals.append(signal)
90
+ count += 1
91
+ indices.append((x, y))
92
+
93
+ for signal, index in zip(signals, indices):
94
+ if len(signal) > 10:
95
+ img = np.zeros((128, 128))
96
+ for i in range(len(signal)):
97
+ img[i, int(signal[i] / 10)] = 255
98
+ img = cv2.dilate(img, kernel, iterations=1)
99
+ img = img.reshape(128, 128, 1)
100
+ prediction = model.predict(np.array([img])).argmax()
101
+ classes = ['Normal', 'APC', 'LBB', 'PAB', 'PVC', 'RBB', 'VEB']
102
+ result[classes[prediction]].append(index)
103
+
104
+ output.append(result)
105
+
106
+ return output
107
 
108
+ @app.route('/', methods=['GET'])
109
+ def index():
110
+ return render_template('index.html')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
+ @app.route('/', methods=['POST'])
113
+ def upload_file():
114
+ if 'files[]' not in request.files:
115
+ return render_template('index.html', message='No file part')
116
+
117
+ files = request.files.getlist('files[]')
118
+ file_paths = []
119
+
120
+ for file in files:
121
+ if file and allowed_file(file.filename):
122
+ filename = secure_filename(file.filename)
123
+ file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
124
+ file.save(file_path)
125
+
126
+ # If the file is an image, convert to CSV
127
+ if filename.rsplit('.', 1)[1].lower() in {'png', 'jpg', 'jpeg'}:
128
+ try:
129
+ csv_path = image_to_signal(file_path)
130
+ file_paths.append(csv_path)
131
+ except Exception as e:
132
+ return render_template('index.html', message=f'Error processing image: {str(e)}')
133
+ else:
134
+ file_paths.append(file_path)
135
+
136
+ if not file_paths:
137
+ return render_template('index.html', message='No valid files uploaded')
138
+
139
+ results = model_predict(file_paths, model)
140
+ return render_template('index.html', prediction=results)
141
 
142
+ if __name__ == '__main__':
143
+ if not os.path.exists(UPLOAD_FOLDER):
144
+ os.makedirs(UPLOAD_FOLDER)
145
+ app.run(debug=True, host='0.0.0.0', port=5000)