File size: 5,026 Bytes
07ae9fd
b835b89
 
 
76a82c9
b426f35
f438e63
07ae9fd
 
b835b89
b426f35
07ae9fd
 
76a82c9
 
 
 
 
 
 
 
 
 
 
 
b835b89
 
 
 
 
 
 
 
07ae9fd
b426f35
 
b835b89
 
 
 
 
 
 
b426f35
aefec06
 
 
f438e63
eceb545
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f438e63
 
eceb545
 
 
 
 
 
f438e63
eceb545
 
f438e63
eceb545
 
 
 
 
 
 
 
07ae9fd
b426f35
 
 
 
 
 
 
 
 
 
b835b89
 
b426f35
 
b835b89
 
 
 
b426f35
 
 
 
 
35ce90a
b426f35
 
 
 
b835b89
 
b426f35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
07ae9fd
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import gradio as gr
from fastai.vision.all import *
from fastai.learner import load_learner
from pathlib import Path
import pandas as pd
import os
import time

"""
Warning Lamp Detector using FastAI
This application allows users to upload images of warning lamps and get classification results.
"""

def get_labels(fname):
    """
    Function required by the model to process labels
    Args:
        fname: Path to the image file
    Returns:
        list: List of active labels
    """
    # Since we're only doing inference, we can return an empty list
    # This function is only needed because the model was saved with it
    return []

# Load the FastAI model
try:
    model_path = Path("WarningLampClassifier.pkl")
    learn_inf = load_learner(model_path)
    print("Model loaded successfully")
except Exception as e:
    print(f"Error loading model: {e}")
    raise

def detect_warning_lamp(image, history: list[tuple[str, str]], system_message):
    """
    Process the uploaded image and return detection results using FastAI model
    Args:
        image: PIL Image from Gradio
        history: Chat history
        system_message: System prompt
    Returns:
        Updated chat history with prediction results
    """
    if image is None:
        history.append((None, "Please upload an image first."))
        return history
    
    try:
        # Convert PIL image to FastAI compatible format
        img = PILImage(image)
        
        # Get model prediction
        pred_class, pred_idx, probs = learn_inf.predict(img)
        
        # Convert tensors to Python types safely
        pred_class_str = str(pred_class)  # Convert class name to string
        
        # Format the prediction results
        response = f"Detected Warning Lamp: {pred_class_str}"
        
        # Try to add confidence if possible
        try:
            # Get the index as an integer
            if isinstance(pred_idx, torch.Tensor):
                idx = pred_idx.item()
            else:
                idx = int(pred_idx)
                
            # Get the confidence value
            if isinstance(probs, torch.Tensor) and idx < len(probs):
                confidence = probs[idx].item()
                response += f"\nConfidence: {confidence:.2%}"
        except Exception as conf_error:
            print(f"Could not calculate confidence: {conf_error}")
        
        # Add probabilities for all classes if possible
        try:
            response += "\n\nProbabilities for all classes:"
            for i, cls in enumerate(learn_inf.dls.vocab):
                if i < len(probs):
                    if isinstance(probs, torch.Tensor):
                        prob_value = probs[i].item()
                    else:
                        prob_value = float(probs[i])
                    response += f"\n- {cls}: {prob_value:.2%}"
        except Exception as prob_error:
            print(f"Could not list all probabilities: {prob_error}")
            
        # Update chat history
        history.append((None, response))
        return history
    except Exception as e:
        error_msg = f"Error processing image: {str(e)}"
        print(f"Exception in detect_warning_lamp: {e}")
        history.append((None, error_msg))
        return history

# Create a custom interface with image upload
with gr.Blocks(title="Warning Lamp Detector", theme=gr.themes.Soft()) as demo:
    gr.Markdown("""
    # 🚨 Warning Lamp Detector
    Upload an image of a warning lamp to get its classification.
    
    ### Instructions:
    1. Upload a clear image of the warning lamp
    2. Wait for the analysis
    3. View the detailed classification results
    
    ### Supported Warning Lamps:
    """)
    
    # Display supported classes if available
    if 'learn_inf' in locals():
        gr.Markdown("\n".join([f"- {cls}" for cls in learn_inf.dls.vocab]))
    
    with gr.Row():
        with gr.Column(scale=1):
            image_input = gr.Image(
                label="Upload Warning Lamp Image",
                type="pil",
                sources="upload"
            )
            system_message = gr.Textbox(
                value="You are an expert in warning lamp classification. Analyze the image and provide detailed information about the type, color, and status of the warning lamp.",
                label="System Message",
                lines=3,
                visible=False  # Hide this since we're using direct model inference
            )
        
        with gr.Column(scale=1):
            chatbot = gr.Chatbot(
                [],
                elem_id="chatbot",
                bubble_full_width=False,
                avatar_images=(None, "🚨"),
                height=400
            )
    
    # Add a submit button
    submit_btn = gr.Button("Analyze Warning Lamp", variant="primary")
    submit_btn.click(
        detect_warning_lamp,
        inputs=[image_input, chatbot, system_message],
        outputs=chatbot
    )

if __name__ == "__main__":
    demo.launch()