import os import cv2 import numpy as np import pandas as pd import tensorflow as tf from tensorflow.keras.models import load_model import gradio as gr import biosppy.signals.ecg as ecg from PIL import Image import traceback # Create uploads directory UPLOAD_FOLDER = "/tmp/uploads" if not os.path.exists(UPLOAD_FOLDER): os.makedirs(UPLOAD_FOLDER) # Load the pre-trained model (assumes ecgScratchEpoch2.hdf5 is in the root directory) try: model = load_model("ecgScratchEpoch2.hdf5") except Exception as e: raise Exception(f"Failed to load model: {str(e)}") def image_to_signal(image): """Convert an ECG image to a 1D signal and save as CSV.""" try: # Convert Gradio image (PIL) to OpenCV format img = np.array(image) img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) # Resize to a standard size img = cv2.resize(img, (1000, 500)) # Apply thresholding to isolate waveform _, binary = cv2.threshold(img, 200, 255, cv2.THRESH_BINARY_INV) # Find contours contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if not contours: raise ValueError("No waveform detected in the image") # Use the largest contour contour = max(contours, key=cv2.contourArea) # Extract y-coordinates along x-axis signal = [] width = img.shape[1] for x in range(width): column = contour[contour[:, :, 0] == x] if len(column) > 0: y = np.mean(column[:, :, 1]) signal.append(y) else: signal.append(signal[-1] if signal else 0) # Normalize signal signal = np.array(signal) signal = (signal - np.min(signal)) / (np.max(signal) - np.min(signal)) * 1000 # Save to CSV csv_path = os.path.join(UPLOAD_FOLDER, "converted_signal.csv") df = pd.DataFrame(signal, columns=[" Sample Value"]) df.to_csv(csv_path, index=False) return csv_path except Exception as e: raise Exception(f"Image processing error: {str(e)}") def model_predict(csv_path): """Predict ECG arrhythmia classes from a CSV file.""" try: output = [] APC, NORMAL, LBB, PVC, PAB, RBB, VEB = [], [], [], [], [], [], [] result = {"APC": APC, "Normal": NORMAL, "LBB": LBB, "PAB": PAB, "PVC": PVC, "RBB": RBB, "VEB": VEB} kernel = np.ones((4, 4), np.uint8) csv = pd.read_csv(csv_path) csv_data = csv[" Sample Value"] data = np.array(csv_data) signals = [] count = 1 peaks = ecg.christov_segmenter(signal=data, sampling_rate=200)[0] indices = [] for i in peaks[1:-1]: diff1 = abs(peaks[count - 1] - i) diff2 = abs(peaks[count + 1] - i) x = peaks[count - 1] + diff1 // 2 y = peaks[count + 1] - diff2 // 2 signal = data[x:y] signals.append(signal) count += 1 indices.append((x, y)) for signal, index in zip(signals, indices): if len(signal) > 10: img = np.zeros((128, 128)) for i in range(len(signal)): img[i, int(signal[i] / 10)] = 255 img = cv2.dilate(img, kernel, iterations=1) img = img.reshape(128, 128, 1) prediction = model.predict(np.array([img]), verbose=0).argmax() classes = ["Normal", "APC", "LBB", "PAB", "PVC", "RBB", "VEB"] result[classes[prediction]].append(index) output.append({"file": csv_path, "results": result}) return output except Exception as e: raise Exception(f"Prediction error: {str(e)}") def classify_ecg(file): """Main function to handle file uploads (CSV or image).""" try: if file is None: return "No file uploaded." # Save uploaded file file_path = os.path.join(UPLOAD_FOLDER, "uploaded_file") if isinstance(file, str): # CSV file path file_path += ".csv" with open(file_path, "wb") as f: with open(file, "rb") as src: f.write(src.read()) else: # Image file (PIL Image from Gradio) file_path += ".png" file.save(file_path) # Check file type ext = file_path.rsplit(".", 1)[1].lower() if ext in ["png", "jpg", "jpeg"]: csv_path = image_to_signal(file) elif ext == "csv": csv_path = file_path else: return "Unsupported file type. Use CSV, PNG, or JPG." # Run prediction results = model_predict(csv_path) # Format output output = "" for result in results: output += f"File: {result['file']}\n" for key, value in result["results"].items(): if value: output += f"{key}: {value}\n" return output except Exception as e: return f"Error: {str(e)}\n{traceback.format_exc()}" # Gradio interface iface = gr.Interface( fn=classify_ecg, inputs=gr.File(label="Upload ECG Image (PNG/JPG) or CSV"), outputs=gr.Textbox(label="Classification Results"), title="ECG Arrhythmia Classification", description="Upload an ECG image (PNG/JPG) or CSV file to classify arrhythmias. Images will be converted to CSV before processing.", ) if __name__ == "__main__": iface.launch(server_name="0.0.0.0", server_port=7860)