Spaces:
Runtime error
Runtime error
import os | |
import cv2 | |
import numpy as np | |
import pandas as pd | |
import tensorflow as tf | |
from tensorflow.keras.models import load_model | |
import gradio as gr | |
import biosppy.signals.ecg as ecg | |
from PIL import Image | |
import traceback | |
# Create uploads directory | |
UPLOAD_FOLDER = "/tmp/uploads" | |
if not os.path.exists(UPLOAD_FOLDER): | |
os.makedirs(UPLOAD_FOLDER) | |
# Load the pre-trained model (assumes ecgScratchEpoch2.hdf5 is in the root directory) | |
try: | |
model = load_model("ecgScratchEpoch2.hdf5") | |
except Exception as e: | |
raise Exception(f"Failed to load model: {str(e)}") | |
def image_to_signal(image): | |
"""Convert an ECG image to a 1D signal and save as CSV.""" | |
try: | |
# Convert Gradio image (PIL) to OpenCV format | |
img = np.array(image) | |
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) | |
# Resize to a standard size | |
img = cv2.resize(img, (1000, 500)) | |
# Apply thresholding to isolate waveform | |
_, binary = cv2.threshold(img, 200, 255, cv2.THRESH_BINARY_INV) | |
# Find contours | |
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
if not contours: | |
raise ValueError("No waveform detected in the image") | |
# Use the largest contour | |
contour = max(contours, key=cv2.contourArea) | |
# Extract y-coordinates along x-axis | |
signal = [] | |
width = img.shape[1] | |
for x in range(width): | |
column = contour[contour[:, :, 0] == x] | |
if len(column) > 0: | |
y = np.mean(column[:, :, 1]) | |
signal.append(y) | |
else: | |
signal.append(signal[-1] if signal else 0) | |
# Normalize signal | |
signal = np.array(signal) | |
signal = (signal - np.min(signal)) / (np.max(signal) - np.min(signal)) * 1000 | |
# Save to CSV | |
csv_path = os.path.join(UPLOAD_FOLDER, "converted_signal.csv") | |
df = pd.DataFrame(signal, columns=[" Sample Value"]) | |
df.to_csv(csv_path, index=False) | |
return csv_path | |
except Exception as e: | |
raise Exception(f"Image processing error: {str(e)}") | |
def model_predict(csv_path): | |
"""Predict ECG arrhythmia classes from a CSV file.""" | |
try: | |
output = [] | |
APC, NORMAL, LBB, PVC, PAB, RBB, VEB = [], [], [], [], [], [], [] | |
result = {"APC": APC, "Normal": NORMAL, "LBB": LBB, "PAB": PAB, "PVC": PVC, "RBB": RBB, "VEB": VEB} | |
kernel = np.ones((4, 4), np.uint8) | |
csv = pd.read_csv(csv_path) | |
csv_data = csv[" Sample Value"] | |
data = np.array(csv_data) | |
signals = [] | |
count = 1 | |
peaks = ecg.christov_segmenter(signal=data, sampling_rate=200)[0] | |
indices = [] | |
for i in peaks[1:-1]: | |
diff1 = abs(peaks[count - 1] - i) | |
diff2 = abs(peaks[count + 1] - i) | |
x = peaks[count - 1] + diff1 // 2 | |
y = peaks[count + 1] - diff2 // 2 | |
signal = data[x:y] | |
signals.append(signal) | |
count += 1 | |
indices.append((x, y)) | |
for signal, index in zip(signals, indices): | |
if len(signal) > 10: | |
img = np.zeros((128, 128)) | |
for i in range(len(signal)): | |
img[i, int(signal[i] / 10)] = 255 | |
img = cv2.dilate(img, kernel, iterations=1) | |
img = img.reshape(128, 128, 1) | |
prediction = model.predict(np.array([img]), verbose=0).argmax() | |
classes = ["Normal", "APC", "LBB", "PAB", "PVC", "RBB", "VEB"] | |
result[classes[prediction]].append(index) | |
output.append({"file": csv_path, "results": result}) | |
return output | |
except Exception as e: | |
raise Exception(f"Prediction error: {str(e)}") | |
def classify_ecg(file): | |
"""Main function to handle file uploads (CSV or image).""" | |
try: | |
if file is None: | |
return "No file uploaded." | |
# Save uploaded file | |
file_path = os.path.join(UPLOAD_FOLDER, "uploaded_file") | |
if isinstance(file, str): # CSV file path | |
file_path += ".csv" | |
with open(file_path, "wb") as f: | |
with open(file, "rb") as src: | |
f.write(src.read()) | |
else: # Image file (PIL Image from Gradio) | |
file_path += ".png" | |
file.save(file_path) | |
# Check file type | |
ext = file_path.rsplit(".", 1)[1].lower() | |
if ext in ["png", "jpg", "jpeg"]: | |
csv_path = image_to_signal(file) | |
elif ext == "csv": | |
csv_path = file_path | |
else: | |
return "Unsupported file type. Use CSV, PNG, or JPG." | |
# Run prediction | |
results = model_predict(csv_path) | |
# Format output | |
output = "" | |
for result in results: | |
output += f"File: {result['file']}\n" | |
for key, value in result["results"].items(): | |
if value: | |
output += f"{key}: {value}\n" | |
return output | |
except Exception as e: | |
return f"Error: {str(e)}\n{traceback.format_exc()}" | |
# Gradio interface | |
iface = gr.Interface( | |
fn=classify_ecg, | |
inputs=gr.File(label="Upload ECG Image (PNG/JPG) or CSV"), | |
outputs=gr.Textbox(label="Classification Results"), | |
title="ECG Arrhythmia Classification", | |
description="Upload an ECG image (PNG/JPG) or CSV file to classify arrhythmias. Images will be converted to CSV before processing.", | |
) | |
if __name__ == "__main__": | |
iface.launch(server_name="0.0.0.0", server_port=7860) |