Snigs98 commited on
Commit
a6b95e7
·
verified ·
1 Parent(s): dc2240a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -19
app.py CHANGED
@@ -1,42 +1,53 @@
 
1
  import gradio as gr
2
  import tensorflow as tf
3
  import numpy as np
4
  import cv2
5
- import os
6
 
7
- # Ensure correct class labels (fix folder naming issue if needed)
8
- class_labels = ["COVID-19", "Normal", "Pneumonia"] # Ensure class order is correct
 
 
 
 
 
 
9
 
10
  # Load the trained model
11
- model_path = "chest_xray_model.h5"
12
- if not os.path.exists(model_path):
13
- raise FileNotFoundError(f"Model file '{model_path}' not found. Ensure it's uploaded to the Space.")
14
 
15
- model = tf.keras.models.load_model(model_path)
 
16
 
17
- # Preprocessing function for uploaded images
18
  def preprocess_image(img):
19
- img = cv2.resize(img, (150, 150)) # Resize
 
20
  img = img.astype(np.float32) / 255.0 # Normalize pixel values
21
  img = np.expand_dims(img, axis=0) # Add batch dimension
22
  return img
23
 
24
- # Prediction function
25
  def predict_chest_xray(img):
26
- processed_img = preprocess_image(img)
27
- prediction = model.predict(processed_img)[0]
28
- predicted_class = class_labels[np.argmax(prediction)]
29
- confidence = round(100 * np.max(prediction), 2)
30
- return f"Prediction: {predicted_class} (Confidence: {confidence}%)"
 
 
 
 
31
 
32
- # Create Gradio UI
33
  interface = gr.Interface(
34
  fn=predict_chest_xray,
35
- inputs=gr.Image(type="numpy"), # Expect NumPy array
36
  outputs="text",
37
  title="Chest X-Ray Diagnosis",
38
- description="Upload a chest X-ray image to get a diagnosis prediction."
39
  )
40
 
 
41
  if __name__ == "__main__":
42
- interface.launch()
 
1
+ import os
2
  import gradio as gr
3
  import tensorflow as tf
4
  import numpy as np
5
  import cv2
 
6
 
7
+ # Define the model path
8
+ MODEL_PATH = "chest_xray_model.h5"
9
+
10
+ # Check if the model file exists
11
+ if not os.path.exists(MODEL_PATH):
12
+ raise FileNotFoundError(
13
+ f"Model file '{MODEL_PATH}' not found. Please upload it to your Hugging Face Space."
14
+ )
15
 
16
  # Load the trained model
17
+ model = tf.keras.models.load_model(MODEL_PATH)
 
 
18
 
19
+ # Get class labels from the trained model
20
+ class_labels = ["COVID-19", "NORMAL", "PNEUMONIA"] # Update if needed
21
 
22
+ # Function to preprocess the input image
23
  def preprocess_image(img):
24
+ """Prepares the image for model prediction."""
25
+ img = cv2.resize(img, (150, 150)) # Resize to match model input shape
26
  img = img.astype(np.float32) / 255.0 # Normalize pixel values
27
  img = np.expand_dims(img, axis=0) # Add batch dimension
28
  return img
29
 
30
+ # Function to make predictions
31
  def predict_chest_xray(img):
32
+ """Runs inference on an uploaded X-ray image."""
33
+ try:
34
+ processed_img = preprocess_image(img)
35
+ prediction = model.predict(processed_img)[0]
36
+ predicted_class = class_labels[np.argmax(prediction)]
37
+ confidence = round(100 * np.max(prediction), 2)
38
+ return f"Prediction: {predicted_class} (Confidence: {confidence}%)"
39
+ except Exception as e:
40
+ return f"Error: {str(e)}"
41
 
42
+ # Create Gradio interface
43
  interface = gr.Interface(
44
  fn=predict_chest_xray,
45
+ inputs=gr.Image(type="numpy"),
46
  outputs="text",
47
  title="Chest X-Ray Diagnosis",
48
+ description="Upload a chest X-ray image to get a diagnosis prediction.",
49
  )
50
 
51
+ # Run the Gradio app
52
  if __name__ == "__main__":
53
+ interface.launch()