CHEST-XRAY / app.py
Snigs98's picture
Update app.py
a6b95e7 verified
import os
import gradio as gr
import tensorflow as tf
import numpy as np
import cv2
# Define the model path
MODEL_PATH = "chest_xray_model.h5"
# Check if the model file exists
if not os.path.exists(MODEL_PATH):
raise FileNotFoundError(
f"Model file '{MODEL_PATH}' not found. Please upload it to your Hugging Face Space."
)
# Load the trained model
model = tf.keras.models.load_model(MODEL_PATH)
# Get class labels from the trained model
class_labels = ["COVID-19", "NORMAL", "PNEUMONIA"] # Update if needed
# Function to preprocess the input image
def preprocess_image(img):
"""Prepares the image for model prediction."""
img = cv2.resize(img, (150, 150)) # Resize to match model input shape
img = img.astype(np.float32) / 255.0 # Normalize pixel values
img = np.expand_dims(img, axis=0) # Add batch dimension
return img
# Function to make predictions
def predict_chest_xray(img):
"""Runs inference on an uploaded X-ray image."""
try:
processed_img = preprocess_image(img)
prediction = model.predict(processed_img)[0]
predicted_class = class_labels[np.argmax(prediction)]
confidence = round(100 * np.max(prediction), 2)
return f"Prediction: {predicted_class} (Confidence: {confidence}%)"
except Exception as e:
return f"Error: {str(e)}"
# Create Gradio interface
interface = gr.Interface(
fn=predict_chest_xray,
inputs=gr.Image(type="numpy"),
outputs="text",
title="Chest X-Ray Diagnosis",
description="Upload a chest X-ray image to get a diagnosis prediction.",
)
# Run the Gradio app
if __name__ == "__main__":
interface.launch()