Nayefleb commited on
Commit
795e3eb
·
verified ·
1 Parent(s): 7e42a78

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -103
app.py CHANGED
@@ -4,76 +4,77 @@ import numpy as np
4
  import pandas as pd
5
  import tensorflow as tf
6
  from tensorflow.keras.models import load_model
7
- from flask import Flask, request, render_template
8
- from werkzeug.utils import secure_filename
9
  import biosppy.signals.ecg as ecg
 
 
10
 
11
- app = Flask(__name__)
 
 
 
12
 
13
- UPLOAD_FOLDER = 'uploads'
14
- ALLOWED_EXTENSIONS = {'csv', 'png', 'jpg', 'jpeg'}
15
- app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
 
 
16
 
17
- # Load the pre-trained model
18
- model = load_model('ecgScratchEpoch2.hdf5')
19
-
20
- def allowed_file(filename):
21
- return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
22
-
23
- def image_to_signal(image_path):
24
- # Read and preprocess the image
25
- img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
26
- if img is None:
27
- raise ValueError("Failed to load image")
28
-
29
- # Resize to a standard height for consistency (e.g., 500 pixels)
30
- img = cv2.resize(img, (1000, 500))
31
-
32
- # Apply thresholding to isolate the waveform
33
- _, binary = cv2.threshold(img, 200, 255, cv2.THRESH_BINARY_INV)
34
-
35
- # Find contours of the waveform
36
- contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
37
- if not contours:
38
- raise ValueError("No waveform detected in the image")
39
-
40
- # Assume the largest contour is the ECG waveform
41
- contour = max(contours, key=cv2.contourArea)
42
-
43
- # Extract y-coordinates (signal amplitude) along x-axis
44
- signal = []
45
- width = img.shape[1]
46
- for x in range(width):
47
- column = contour[contour[:, :, 0] == x]
48
- if len(column) > 0:
49
- # Take the average y-coordinate if multiple points exist
50
- y = np.mean(column[:, :, 1])
51
- signal.append(y)
52
- else:
53
- # Interpolate if no point is found
54
- signal.append(signal[-1] if signal else 0)
55
-
56
- # Normalize signal to match expected amplitude range
57
- signal = np.array(signal)
58
- signal = (signal - np.min(signal)) / (np.max(signal) - np.min(signal)) * 1000
59
-
60
- # Save to CSV
61
- csv_path = os.path.join(app.config['UPLOAD_FOLDER'], 'converted_signal.csv')
62
- df = pd.DataFrame(signal, columns=[' Sample Value'])
63
- df.to_csv(csv_path, index=False)
64
-
65
- return csv_path
66
 
67
- def model_predict(uploaded_files, model):
68
- output = []
69
- for path in uploaded_files:
 
70
  APC, NORMAL, LBB, PVC, PAB, RBB, VEB = [], [], [], [], [], [], []
71
- output.append(str(path))
72
  result = {"APC": APC, "Normal": NORMAL, "LBB": LBB, "PAB": PAB, "PVC": PVC, "RBB": RBB, "VEB": VEB}
73
 
74
- kernel = np.ones((4,4), np.uint8)
75
- csv = pd.read_csv(path)
76
- csv_data = csv[' Sample Value']
77
  data = np.array(csv_data)
78
  signals = []
79
  count = 1
@@ -97,49 +98,64 @@ def model_predict(uploaded_files, model):
97
  img[i, int(signal[i] / 10)] = 255
98
  img = cv2.dilate(img, kernel, iterations=1)
99
  img = img.reshape(128, 128, 1)
100
- prediction = model.predict(np.array([img])).argmax()
101
- classes = ['Normal', 'APC', 'LBB', 'PAB', 'PVC', 'RBB', 'VEB']
102
  result[classes[prediction]].append(index)
103
 
104
- output.append(result)
105
-
106
- return output
107
-
108
- @app.route('/', methods=['GET'])
109
- def index():
110
- return render_template('index.html')
111
 
112
- @app.route('/', methods=['POST'])
113
- def upload_file():
114
- if 'files[]' not in request.files:
115
- return render_template('index.html', message='No file part')
116
-
117
- files = request.files.getlist('files[]')
118
- file_paths = []
119
-
120
- for file in files:
121
- if file and allowed_file(file.filename):
122
- filename = secure_filename(file.filename)
123
- file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
 
 
 
124
  file.save(file_path)
125
-
126
- # If the file is an image, convert to CSV
127
- if filename.rsplit('.', 1)[1].lower() in {'png', 'jpg', 'jpeg'}:
128
- try:
129
- csv_path = image_to_signal(file_path)
130
- file_paths.append(csv_path)
131
- except Exception as e:
132
- return render_template('index.html', message=f'Error processing image: {str(e)}')
133
- else:
134
- file_paths.append(file_path)
135
-
136
- if not file_paths:
137
- return render_template('index.html', message='No valid files uploaded')
138
-
139
- results = model_predict(file_paths, model)
140
- return render_template('index.html', prediction=results)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
- if __name__ == '__main__':
143
- if not os.path.exists(UPLOAD_FOLDER):
144
- os.makedirs(UPLOAD_FOLDER)
145
- app.run(debug=True, host='0.0.0.0', port=5000)
 
4
  import pandas as pd
5
  import tensorflow as tf
6
  from tensorflow.keras.models import load_model
7
+ import gradio as gr
 
8
  import biosppy.signals.ecg as ecg
9
+ from PIL import Image
10
+ import traceback
11
 
12
+ # Create uploads directory
13
+ UPLOAD_FOLDER = "/tmp/uploads"
14
+ if not os.path.exists(UPLOAD_FOLDER):
15
+ os.makedirs(UPLOAD_FOLDER)
16
 
17
+ # Load the pre-trained model (assumes ecgScratchEpoch2.hdf5 is in the root directory)
18
+ try:
19
+ model = load_model("ecgScratchEpoch2.hdf5")
20
+ except Exception as e:
21
+ raise Exception(f"Failed to load model: {str(e)}")
22
 
23
+ def image_to_signal(image):
24
+ """Convert an ECG image to a 1D signal and save as CSV."""
25
+ try:
26
+ # Convert Gradio image (PIL) to OpenCV format
27
+ img = np.array(image)
28
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
29
+
30
+ # Resize to a standard size
31
+ img = cv2.resize(img, (1000, 500))
32
+
33
+ # Apply thresholding to isolate waveform
34
+ _, binary = cv2.threshold(img, 200, 255, cv2.THRESH_BINARY_INV)
35
+
36
+ # Find contours
37
+ contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
38
+ if not contours:
39
+ raise ValueError("No waveform detected in the image")
40
+
41
+ # Use the largest contour
42
+ contour = max(contours, key=cv2.contourArea)
43
+
44
+ # Extract y-coordinates along x-axis
45
+ signal = []
46
+ width = img.shape[1]
47
+ for x in range(width):
48
+ column = contour[contour[:, :, 0] == x]
49
+ if len(column) > 0:
50
+ y = np.mean(column[:, :, 1])
51
+ signal.append(y)
52
+ else:
53
+ signal.append(signal[-1] if signal else 0)
54
+
55
+ # Normalize signal
56
+ signal = np.array(signal)
57
+ signal = (signal - np.min(signal)) / (np.max(signal) - np.min(signal)) * 1000
58
+
59
+ # Save to CSV
60
+ csv_path = os.path.join(UPLOAD_FOLDER, "converted_signal.csv")
61
+ df = pd.DataFrame(signal, columns=[" Sample Value"])
62
+ df.to_csv(csv_path, index=False)
63
+
64
+ return csv_path
65
+ except Exception as e:
66
+ raise Exception(f"Image processing error: {str(e)}")
 
 
 
 
 
67
 
68
+ def model_predict(csv_path):
69
+ """Predict ECG arrhythmia classes from a CSV file."""
70
+ try:
71
+ output = []
72
  APC, NORMAL, LBB, PVC, PAB, RBB, VEB = [], [], [], [], [], [], []
 
73
  result = {"APC": APC, "Normal": NORMAL, "LBB": LBB, "PAB": PAB, "PVC": PVC, "RBB": RBB, "VEB": VEB}
74
 
75
+ kernel = np.ones((4, 4), np.uint8)
76
+ csv = pd.read_csv(csv_path)
77
+ csv_data = csv[" Sample Value"]
78
  data = np.array(csv_data)
79
  signals = []
80
  count = 1
 
98
  img[i, int(signal[i] / 10)] = 255
99
  img = cv2.dilate(img, kernel, iterations=1)
100
  img = img.reshape(128, 128, 1)
101
+ prediction = model.predict(np.array([img]), verbose=0).argmax()
102
+ classes = ["Normal", "APC", "LBB", "PAB", "PVC", "RBB", "VEB"]
103
  result[classes[prediction]].append(index)
104
 
105
+ output.append({"file": csv_path, "results": result})
106
+ return output
107
+ except Exception as e:
108
+ raise Exception(f"Prediction error: {str(e)}")
 
 
 
109
 
110
+ def classify_ecg(file):
111
+ """Main function to handle file uploads (CSV or image)."""
112
+ try:
113
+ if file is None:
114
+ return "No file uploaded."
115
+
116
+ # Save uploaded file
117
+ file_path = os.path.join(UPLOAD_FOLDER, "uploaded_file")
118
+ if isinstance(file, str): # CSV file path
119
+ file_path += ".csv"
120
+ with open(file_path, "wb") as f:
121
+ with open(file, "rb") as src:
122
+ f.write(src.read())
123
+ else: # Image file (PIL Image from Gradio)
124
+ file_path += ".png"
125
  file.save(file_path)
126
+
127
+ # Check file type
128
+ ext = file_path.rsplit(".", 1)[1].lower()
129
+ if ext in ["png", "jpg", "jpeg"]:
130
+ csv_path = image_to_signal(file)
131
+ elif ext == "csv":
132
+ csv_path = file_path
133
+ else:
134
+ return "Unsupported file type. Use CSV, PNG, or JPG."
135
+
136
+ # Run prediction
137
+ results = model_predict(csv_path)
138
+
139
+ # Format output
140
+ output = ""
141
+ for result in results:
142
+ output += f"File: {result['file']}\n"
143
+ for key, value in result["results"].items():
144
+ if value:
145
+ output += f"{key}: {value}\n"
146
+
147
+ return output
148
+ except Exception as e:
149
+ return f"Error: {str(e)}\n{traceback.format_exc()}"
150
+
151
+ # Gradio interface
152
+ iface = gr.Interface(
153
+ fn=classify_ecg,
154
+ inputs=gr.File(label="Upload ECG Image (PNG/JPG) or CSV"),
155
+ outputs=gr.Textbox(label="Classification Results"),
156
+ title="ECG Arrhythmia Classification",
157
+ description="Upload an ECG image (PNG/JPG) or CSV file to classify arrhythmias. Images will be converted to CSV before processing.",
158
+ )
159
 
160
+ if __name__ == "__main__":
161
+ iface.launch(server_name="0.0.0.0", server_port=7860)