ECG / appmain
Nayefleb's picture
Rename app.py to appmain
9ebf059 verified
import gradio as gr
import numpy as np
import pandas as pd
import wfdb
import tensorflow as tf
from scipy import signal
import os
import subprocess
import shutil
import requests
import zipfile
# Disable GPU usage to avoid CUDA warnings
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # Force TensorFlow to use CPU only
# Get HF_TOKEN from environment variables
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
raise ValueError("HF_TOKEN not found. Please set it in the Space's environment variables.")
# Define repository and dataset details
REPO_URL = "https://github.com/AutoECG/Automated-ECG-Interpretation.git"
REPO_DIR = "Automated-ECG-Interpretation"
DATASET_URL = "https://physionet.org/static/published-projects/ptb-xl/ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.3.zip"
DATASET_DIR = "ptb-xl"
PERSISTENT_DIR = "/data"
# Ensure persistent directory exists
def ensure_persistent_dir():
if not os.path.exists(PERSISTENT_DIR):
os.makedirs(PERSISTENT_DIR, exist_ok=True)
# Function to clone the repository
def clone_repository():
if not os.path.exists(REPO_DIR):
print("Cloning repository...")
try:
subprocess.run(["git", "clone", REPO_URL, REPO_DIR], check=True)
print("Repository cloned successfully.")
except subprocess.CalledProcessError as e:
print(f"Error cloning repository: {e}")
else:
print("Repository already cloned.")
# Function to download and extract PTB-XL dataset
def download_and_extract_dataset():
ensure_persistent_dir() # Ensure /data exists
zip_path = os.path.join(PERSISTENT_DIR, "ptb-xl.zip")
extract_path = os.path.join(PERSISTENT_DIR, DATASET_DIR)
if not os.path.exists(extract_path):
print("Downloading PTB-XL dataset...")
response = requests.get(DATASET_URL, stream=True)
with open(zip_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
print("Extracting dataset...")
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(PERSISTENT_DIR)
os.remove(zip_path) # Clean up zip file
print("Dataset extracted successfully.")
else:
print("Dataset already extracted.")
# Clone repo and download dataset on startup
clone_repository()
download_and_extract_dataset()
# Load the pre-trained model
MODEL_FILENAME = "model.h5"
MODEL_PATH = os.path.join(REPO_DIR, MODEL_FILENAME)
PERSISTENT_MODEL_PATH = os.path.join(PERSISTENT_DIR, MODEL_FILENAME)
if not os.path.exists(PERSISTENT_MODEL_PATH):
if os.path.exists(MODEL_PATH):
shutil.copy(MODEL_PATH, PERSISTENT_MODEL_PATH)
else:
raise FileNotFoundError(
f"Model file not found at {MODEL_PATH}. Please ensure it's in the repository or upload it manually."
)
model = tf.keras.models.load_model(PERSISTENT_MODEL_PATH)
# Function to preprocess ECG data
def preprocess_ecg(file_path):
record = wfdb.rdrecord(file_path.replace(".dat", ""))
ecg_signal = record.p_signal[:, 0] # First lead
target_fs = 360
num_samples = int(len(ecg_signal) * target_fs / record.fs)
ecg_resampled = signal.resample(ecg_signal, num_samples)
ecg_normalized = (ecg_resampled - np.mean(ecg_resampled)) / np.std(ecg_resampled)
if len(ecg_normalized) < 3600:
ecg_normalized = np.pad(ecg_normalized, (0, 3600 - len(ecg_normalized)), "constant")
else:
ecg_normalized = ecg_normalized[:3600]
ecg_input = ecg_normalized.reshape(1, 3600, 1)
return ecg_input
# Prediction function
def predict_ecg(file=None, dataset_file=None):
if file:
file_path = file.name
elif dataset_file:
file_path = os.path.join(PERSISTENT_DIR, DATASET_DIR, "records500", dataset_file)
else:
return "Please upload a file or select a dataset sample."
ecg_data = preprocess_ecg(file_path)
prediction = model.predict(ecg_data)
label = "Abnormal" if prediction[0][0] > 0.5 else "Normal"
confidence = float(prediction[0][0]) if label == "Abnormal" else float(1 - prediction[0][0])
return f"Prediction: {label}\nConfidence: {confidence:.2%}"
# Get list of dataset files for dropdown
dataset_files = []
if os.path.exists(os.path.join(PERSISTENT_DIR, DATASET_DIR, "records500")):
for root, _, files in os.walk(os.path.join(PERSISTENT_DIR, DATASET_DIR, "records500")):
for file in files:
if file.endswith(".dat"):
dataset_files.append(os.path.relpath(os.path.join(root, file), os.path.join(PERSISTENT_DIR, DATASET_DIR, "records500")))
# Gradio interface
interface = gr.Interface(
fn=predict_ecg,
inputs=[
gr.File(label="Upload ECG File (.dat format)"),
gr.Dropdown(choices=dataset_files, label="Or Select a PTB-XL Sample")
],
outputs=gr.Textbox(label="ECG Interpretation"),
title="Automated ECG Interpretation",
description="Upload an ECG file (.dat) or select a sample from the PTB-XL dataset for automated interpretation."
)
# Launch the app
interface.launch()