File size: 3,518 Bytes
07ae9fd
b835b89
 
 
b426f35
07ae9fd
 
b835b89
b426f35
07ae9fd
 
b835b89
 
 
 
 
 
 
 
07ae9fd
b426f35
 
b835b89
 
 
 
 
 
 
b426f35
b835b89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
07ae9fd
b426f35
 
 
 
 
 
 
 
 
 
b835b89
 
b426f35
 
b835b89
 
 
 
b426f35
 
 
 
 
35ce90a
b426f35
 
 
 
b835b89
 
b426f35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
07ae9fd
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import gradio as gr
from fastai.vision.all import *
from fastai.learner import load_learner
from pathlib import Path
import os

"""
Warning Lamp Detector using FastAI
This application allows users to upload images of warning lamps and get classification results.
"""

# Load the FastAI model
try:
    model_path = Path("WarningLampClassifier.pkl")
    learn_inf = load_learner(model_path)
    print("Model loaded successfully")
except Exception as e:
    print(f"Error loading model: {e}")
    raise

def detect_warning_lamp(image, history: list[tuple[str, str]], system_message):
    """
    Process the uploaded image and return detection results using FastAI model
    Args:
        image: PIL Image from Gradio
        history: Chat history
        system_message: System prompt
    Returns:
        Updated chat history with prediction results
    """
    try:
        # Convert PIL image to FastAI compatible format
        img = PILImage(image)
        
        # Get model prediction
        pred_class, pred_idx, probs = learn_inf.predict(img)
        
        # Format the prediction results
        confidence = float(probs[pred_idx])  # Convert to float for better formatting
        response = f"Detected Warning Lamp: {pred_class}\nConfidence: {confidence:.2%}"
        
        # Add probabilities for all classes
        response += "\n\nProbabilities for all classes:"
        for i, (cls, prob) in enumerate(zip(learn_inf.dls.vocab, probs)):
            response += f"\n- {cls}: {float(prob):.2%}"
            
        # Update chat history
        history.append((None, response))
        return history
    except Exception as e:
        error_msg = f"Error processing image: {str(e)}"
        history.append((None, error_msg))
        return history

# Create a custom interface with image upload
with gr.Blocks(title="Warning Lamp Detector", theme=gr.themes.Soft()) as demo:
    gr.Markdown("""
    # 🚨 Warning Lamp Detector
    Upload an image of a warning lamp to get its classification.
    
    ### Instructions:
    1. Upload a clear image of the warning lamp
    2. Wait for the analysis
    3. View the detailed classification results
    
    ### Supported Warning Lamps:
    """)
    
    # Display supported classes if available
    if 'learn_inf' in locals():
        gr.Markdown("\n".join([f"- {cls}" for cls in learn_inf.dls.vocab]))
    
    with gr.Row():
        with gr.Column(scale=1):
            image_input = gr.Image(
                label="Upload Warning Lamp Image",
                type="pil",
                sources="upload"
            )
            system_message = gr.Textbox(
                value="You are an expert in warning lamp classification. Analyze the image and provide detailed information about the type, color, and status of the warning lamp.",
                label="System Message",
                lines=3,
                visible=False  # Hide this since we're using direct model inference
            )
        
        with gr.Column(scale=1):
            chatbot = gr.Chatbot(
                [],
                elem_id="chatbot",
                bubble_full_width=False,
                avatar_images=(None, "🚨"),
                height=400
            )
    
    # Add a submit button
    submit_btn = gr.Button("Analyze Warning Lamp", variant="primary")
    submit_btn.click(
        detect_warning_lamp,
        inputs=[image_input, chatbot, system_message],
        outputs=chatbot
    )

if __name__ == "__main__":
    demo.launch()