|
import gradio as gr |
|
import numpy as np |
|
import pandas as pd |
|
import wfdb |
|
import tensorflow as tf |
|
from scipy import signal |
|
import os |
|
import subprocess |
|
import shutil |
|
import requests |
|
import zipfile |
|
|
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" |
|
|
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
if not HF_TOKEN: |
|
raise ValueError("HF_TOKEN not found. Please set it in the Space's environment variables.") |
|
|
|
|
|
REPO_URL = "https://github.com/AutoECG/Automated-ECG-Interpretation.git" |
|
REPO_DIR = "Automated-ECG-Interpretation" |
|
DATASET_URL = "https://physionet.org/static/published-projects/ptb-xl/ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.3.zip" |
|
DATASET_DIR = "ptb-xl" |
|
PERSISTENT_DIR = "/data" |
|
|
|
|
|
def ensure_persistent_dir(): |
|
if not os.path.exists(PERSISTENT_DIR): |
|
os.makedirs(PERSISTENT_DIR, exist_ok=True) |
|
|
|
|
|
def clone_repository(): |
|
if not os.path.exists(REPO_DIR): |
|
print("Cloning repository...") |
|
try: |
|
subprocess.run(["git", "clone", REPO_URL, REPO_DIR], check=True) |
|
print("Repository cloned successfully.") |
|
except subprocess.CalledProcessError as e: |
|
print(f"Error cloning repository: {e}") |
|
else: |
|
print("Repository already cloned.") |
|
|
|
|
|
def download_and_extract_dataset(): |
|
ensure_persistent_dir() |
|
zip_path = os.path.join(PERSISTENT_DIR, "ptb-xl.zip") |
|
extract_path = os.path.join(PERSISTENT_DIR, DATASET_DIR) |
|
if not os.path.exists(extract_path): |
|
print("Downloading PTB-XL dataset...") |
|
response = requests.get(DATASET_URL, stream=True) |
|
with open(zip_path, "wb") as f: |
|
for chunk in response.iter_content(chunk_size=8192): |
|
f.write(chunk) |
|
print("Extracting dataset...") |
|
with zipfile.ZipFile(zip_path, "r") as zip_ref: |
|
zip_ref.extractall(PERSISTENT_DIR) |
|
os.remove(zip_path) |
|
print("Dataset extracted successfully.") |
|
else: |
|
print("Dataset already extracted.") |
|
|
|
|
|
clone_repository() |
|
download_and_extract_dataset() |
|
|
|
|
|
MODEL_FILENAME = "model.h5" |
|
MODEL_PATH = os.path.join(REPO_DIR, MODEL_FILENAME) |
|
PERSISTENT_MODEL_PATH = os.path.join(PERSISTENT_DIR, MODEL_FILENAME) |
|
|
|
if not os.path.exists(PERSISTENT_MODEL_PATH): |
|
if os.path.exists(MODEL_PATH): |
|
shutil.copy(MODEL_PATH, PERSISTENT_MODEL_PATH) |
|
else: |
|
raise FileNotFoundError( |
|
f"Model file not found at {MODEL_PATH}. Please ensure it's in the repository or upload it manually." |
|
) |
|
|
|
model = tf.keras.models.load_model(PERSISTENT_MODEL_PATH) |
|
|
|
|
|
def preprocess_ecg(file_path): |
|
record = wfdb.rdrecord(file_path.replace(".dat", "")) |
|
ecg_signal = record.p_signal[:, 0] |
|
target_fs = 360 |
|
num_samples = int(len(ecg_signal) * target_fs / record.fs) |
|
ecg_resampled = signal.resample(ecg_signal, num_samples) |
|
ecg_normalized = (ecg_resampled - np.mean(ecg_resampled)) / np.std(ecg_resampled) |
|
if len(ecg_normalized) < 3600: |
|
ecg_normalized = np.pad(ecg_normalized, (0, 3600 - len(ecg_normalized)), "constant") |
|
else: |
|
ecg_normalized = ecg_normalized[:3600] |
|
ecg_input = ecg_normalized.reshape(1, 3600, 1) |
|
return ecg_input |
|
|
|
|
|
def predict_ecg(file=None, dataset_file=None): |
|
if file: |
|
file_path = file.name |
|
elif dataset_file: |
|
file_path = os.path.join(PERSISTENT_DIR, DATASET_DIR, "records500", dataset_file) |
|
else: |
|
return "Please upload a file or select a dataset sample." |
|
|
|
ecg_data = preprocess_ecg(file_path) |
|
prediction = model.predict(ecg_data) |
|
label = "Abnormal" if prediction[0][0] > 0.5 else "Normal" |
|
confidence = float(prediction[0][0]) if label == "Abnormal" else float(1 - prediction[0][0]) |
|
return f"Prediction: {label}\nConfidence: {confidence:.2%}" |
|
|
|
|
|
dataset_files = [] |
|
if os.path.exists(os.path.join(PERSISTENT_DIR, DATASET_DIR, "records500")): |
|
for root, _, files in os.walk(os.path.join(PERSISTENT_DIR, DATASET_DIR, "records500")): |
|
for file in files: |
|
if file.endswith(".dat"): |
|
dataset_files.append(os.path.relpath(os.path.join(root, file), os.path.join(PERSISTENT_DIR, DATASET_DIR, "records500"))) |
|
|
|
|
|
interface = gr.Interface( |
|
fn=predict_ecg, |
|
inputs=[ |
|
gr.File(label="Upload ECG File (.dat format)"), |
|
gr.Dropdown(choices=dataset_files, label="Or Select a PTB-XL Sample") |
|
], |
|
outputs=gr.Textbox(label="ECG Interpretation"), |
|
title="Automated ECG Interpretation", |
|
description="Upload an ECG file (.dat) or select a sample from the PTB-XL dataset for automated interpretation." |
|
) |
|
|
|
|
|
interface.launch() |