Nayefleb's picture
Update app.py
795e3eb verified
import os
import cv2
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras.models import load_model
import gradio as gr
import biosppy.signals.ecg as ecg
from PIL import Image
import traceback
# Create uploads directory
UPLOAD_FOLDER = "/tmp/uploads"
if not os.path.exists(UPLOAD_FOLDER):
os.makedirs(UPLOAD_FOLDER)
# Load the pre-trained model (assumes ecgScratchEpoch2.hdf5 is in the root directory)
try:
model = load_model("ecgScratchEpoch2.hdf5")
except Exception as e:
raise Exception(f"Failed to load model: {str(e)}")
def image_to_signal(image):
"""Convert an ECG image to a 1D signal and save as CSV."""
try:
# Convert Gradio image (PIL) to OpenCV format
img = np.array(image)
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
# Resize to a standard size
img = cv2.resize(img, (1000, 500))
# Apply thresholding to isolate waveform
_, binary = cv2.threshold(img, 200, 255, cv2.THRESH_BINARY_INV)
# Find contours
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
raise ValueError("No waveform detected in the image")
# Use the largest contour
contour = max(contours, key=cv2.contourArea)
# Extract y-coordinates along x-axis
signal = []
width = img.shape[1]
for x in range(width):
column = contour[contour[:, :, 0] == x]
if len(column) > 0:
y = np.mean(column[:, :, 1])
signal.append(y)
else:
signal.append(signal[-1] if signal else 0)
# Normalize signal
signal = np.array(signal)
signal = (signal - np.min(signal)) / (np.max(signal) - np.min(signal)) * 1000
# Save to CSV
csv_path = os.path.join(UPLOAD_FOLDER, "converted_signal.csv")
df = pd.DataFrame(signal, columns=[" Sample Value"])
df.to_csv(csv_path, index=False)
return csv_path
except Exception as e:
raise Exception(f"Image processing error: {str(e)}")
def model_predict(csv_path):
"""Predict ECG arrhythmia classes from a CSV file."""
try:
output = []
APC, NORMAL, LBB, PVC, PAB, RBB, VEB = [], [], [], [], [], [], []
result = {"APC": APC, "Normal": NORMAL, "LBB": LBB, "PAB": PAB, "PVC": PVC, "RBB": RBB, "VEB": VEB}
kernel = np.ones((4, 4), np.uint8)
csv = pd.read_csv(csv_path)
csv_data = csv[" Sample Value"]
data = np.array(csv_data)
signals = []
count = 1
peaks = ecg.christov_segmenter(signal=data, sampling_rate=200)[0]
indices = []
for i in peaks[1:-1]:
diff1 = abs(peaks[count - 1] - i)
diff2 = abs(peaks[count + 1] - i)
x = peaks[count - 1] + diff1 // 2
y = peaks[count + 1] - diff2 // 2
signal = data[x:y]
signals.append(signal)
count += 1
indices.append((x, y))
for signal, index in zip(signals, indices):
if len(signal) > 10:
img = np.zeros((128, 128))
for i in range(len(signal)):
img[i, int(signal[i] / 10)] = 255
img = cv2.dilate(img, kernel, iterations=1)
img = img.reshape(128, 128, 1)
prediction = model.predict(np.array([img]), verbose=0).argmax()
classes = ["Normal", "APC", "LBB", "PAB", "PVC", "RBB", "VEB"]
result[classes[prediction]].append(index)
output.append({"file": csv_path, "results": result})
return output
except Exception as e:
raise Exception(f"Prediction error: {str(e)}")
def classify_ecg(file):
"""Main function to handle file uploads (CSV or image)."""
try:
if file is None:
return "No file uploaded."
# Save uploaded file
file_path = os.path.join(UPLOAD_FOLDER, "uploaded_file")
if isinstance(file, str): # CSV file path
file_path += ".csv"
with open(file_path, "wb") as f:
with open(file, "rb") as src:
f.write(src.read())
else: # Image file (PIL Image from Gradio)
file_path += ".png"
file.save(file_path)
# Check file type
ext = file_path.rsplit(".", 1)[1].lower()
if ext in ["png", "jpg", "jpeg"]:
csv_path = image_to_signal(file)
elif ext == "csv":
csv_path = file_path
else:
return "Unsupported file type. Use CSV, PNG, or JPG."
# Run prediction
results = model_predict(csv_path)
# Format output
output = ""
for result in results:
output += f"File: {result['file']}\n"
for key, value in result["results"].items():
if value:
output += f"{key}: {value}\n"
return output
except Exception as e:
return f"Error: {str(e)}\n{traceback.format_exc()}"
# Gradio interface
iface = gr.Interface(
fn=classify_ecg,
inputs=gr.File(label="Upload ECG Image (PNG/JPG) or CSV"),
outputs=gr.Textbox(label="Classification Results"),
title="ECG Arrhythmia Classification",
description="Upload an ECG image (PNG/JPG) or CSV file to classify arrhythmias. Images will be converted to CSV before processing.",
)
if __name__ == "__main__":
iface.launch(server_name="0.0.0.0", server_port=7860)