|
import gradio as gr |
|
from fastai.vision.all import * |
|
from fastai.learner import load_learner |
|
from pathlib import Path |
|
import pandas as pd |
|
import os |
|
import time |
|
|
|
""" |
|
Warning Lamp Detector using FastAI |
|
This application allows users to upload images of warning lamps and get classification results. |
|
""" |
|
|
|
def get_labels(fname): |
|
""" |
|
Function required by the model to process labels |
|
Args: |
|
fname: Path to the image file |
|
Returns: |
|
list: List of active labels |
|
""" |
|
|
|
|
|
return [] |
|
|
|
|
|
try: |
|
model_path = Path("WarningLampClassifier.pkl") |
|
learn_inf = load_learner(model_path) |
|
print("Model loaded successfully") |
|
except Exception as e: |
|
print(f"Error loading model: {e}") |
|
raise |
|
|
|
def detect_warning_lamp(image, history: list[tuple[str, str]], system_message): |
|
""" |
|
Process the uploaded image and return detection results using FastAI model |
|
Args: |
|
image: PIL Image from Gradio |
|
history: Chat history |
|
system_message: System prompt |
|
Returns: |
|
Updated chat history with prediction results |
|
""" |
|
if image is None: |
|
history.append((None, "Please upload an image first.")) |
|
return history |
|
|
|
try: |
|
|
|
img = PILImage(image) |
|
|
|
|
|
pred_class, pred_idx, probs = learn_inf.predict(img) |
|
|
|
|
|
pred_class_str = str(pred_class) |
|
|
|
|
|
response = f"Detected Warning Lamp: {pred_class_str}" |
|
|
|
|
|
try: |
|
|
|
if isinstance(pred_idx, torch.Tensor): |
|
idx = pred_idx.item() |
|
else: |
|
idx = int(pred_idx) |
|
|
|
|
|
if isinstance(probs, torch.Tensor) and idx < len(probs): |
|
confidence = probs[idx].item() |
|
response += f"\nConfidence: {confidence:.2%}" |
|
except Exception as conf_error: |
|
print(f"Could not calculate confidence: {conf_error}") |
|
|
|
|
|
try: |
|
response += "\n\nProbabilities for all classes:" |
|
for i, cls in enumerate(learn_inf.dls.vocab): |
|
if i < len(probs): |
|
if isinstance(probs, torch.Tensor): |
|
prob_value = probs[i].item() |
|
else: |
|
prob_value = float(probs[i]) |
|
response += f"\n- {cls}: {prob_value:.2%}" |
|
except Exception as prob_error: |
|
print(f"Could not list all probabilities: {prob_error}") |
|
|
|
|
|
history.append((None, response)) |
|
return history |
|
except Exception as e: |
|
error_msg = f"Error processing image: {str(e)}" |
|
print(f"Exception in detect_warning_lamp: {e}") |
|
history.append((None, error_msg)) |
|
return history |
|
|
|
|
|
with gr.Blocks(title="Warning Lamp Detector", theme=gr.themes.Soft()) as demo: |
|
gr.Markdown(""" |
|
# π¨ Warning Lamp Detector |
|
Upload an image of a warning lamp to get its classification. |
|
|
|
### Instructions: |
|
1. Upload a clear image of the warning lamp |
|
2. Wait for the analysis |
|
3. View the detailed classification results |
|
|
|
### Supported Warning Lamps: |
|
""") |
|
|
|
|
|
if 'learn_inf' in locals(): |
|
gr.Markdown("\n".join([f"- {cls}" for cls in learn_inf.dls.vocab])) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
image_input = gr.Image( |
|
label="Upload Warning Lamp Image", |
|
type="pil", |
|
sources="upload" |
|
) |
|
system_message = gr.Textbox( |
|
value="You are an expert in warning lamp classification. Analyze the image and provide detailed information about the type, color, and status of the warning lamp.", |
|
label="System Message", |
|
lines=3, |
|
visible=False |
|
) |
|
|
|
with gr.Column(scale=1): |
|
chatbot = gr.Chatbot( |
|
[], |
|
elem_id="chatbot", |
|
bubble_full_width=False, |
|
avatar_images=(None, "π¨"), |
|
height=400 |
|
) |
|
|
|
|
|
submit_btn = gr.Button("Analyze Warning Lamp", variant="primary") |
|
submit_btn.click( |
|
detect_warning_lamp, |
|
inputs=[image_input, chatbot, system_message], |
|
outputs=chatbot |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|