taesiri commited on
Commit
defdeae
·
1 Parent(s): 8823991
Files changed (1) hide show
  1. app.py +267 -0
app.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import torch
4
+ from PIL import Image
5
+ from transformers import MllamaForConditionalGeneration, AutoProcessor
6
+ from peft import PeftModel
7
+ from huggingface_hub import login
8
+ import json
9
+ import matplotlib.pyplot as plt
10
+ import io
11
+ import base64
12
+
13
+
14
+ def check_environment():
15
+ required_vars = ["HF_TOKEN"]
16
+ missing_vars = [var for var in required_vars if var not in os.environ]
17
+
18
+ if missing_vars:
19
+ raise ValueError(
20
+ f"Missing required environment variables: {', '.join(missing_vars)}\n"
21
+ "Please set the HF_TOKEN environment variable with your Hugging Face token"
22
+ )
23
+
24
+
25
+ # # Login to Hugging Face
26
+ # check_environment()
27
+ # login(token=os.environ["HF_TOKEN"], add_to_git_credential=True)
28
+
29
+ # Load model and processor (do this outside the inference function to avoid reloading)
30
+ # base_model_path = (
31
+ # "taesiri/BugsBunny-LLama-3.2-11B-Vision-BaseCaptioner-Medium-FullModel"
32
+ # )
33
+
34
+ # processor = AutoProcessor.from_pretrained(base_model_path)
35
+ # model = MllamaForConditionalGeneration.from_pretrained(
36
+ # base_model_path,
37
+ # torch_dtype=torch.bfloat16,
38
+ # device_map="cuda",
39
+ # cache_dir="./",
40
+ # )
41
+ # #
42
+ # odel = PeftModel.from_pretrained(model, lora_weights_path)
43
+
44
+
45
+ from transformers import MllamaForConditionalGeneration, AutoProcessor
46
+ import torch
47
+
48
+
49
+ local_model_path = "../merged-llama-3.2-dummy"
50
+
51
+ # Load model and processor (do this outside the inference function to avoid reloading)
52
+ base_model_path = (
53
+ local_model_path
54
+ )
55
+ # lora_weights_path = "taesiri/BugsBunny-LLama-3.2-11B-Vision-Base-Medium-LoRA"
56
+
57
+ processor = AutoProcessor.from_pretrained(base_model_path)
58
+ model = MllamaForConditionalGeneration.from_pretrained(
59
+ base_model_path,
60
+ torch_dtype=torch.bfloat16,
61
+ device_map="cuda",
62
+ cache_dir="./"
63
+ )
64
+
65
+ model.tie_weights()
66
+
67
+
68
+ def create_color_palette_image(colors):
69
+ if not colors or not isinstance(colors, list):
70
+ return None
71
+
72
+ try:
73
+ # Validate color format
74
+ for color in colors:
75
+ if not isinstance(color, str) or not color.startswith("#"):
76
+ return None
77
+
78
+ # Create figure and axis
79
+ fig, ax = plt.subplots(figsize=(10, 2))
80
+
81
+ # Create rectangles for each color
82
+ for i, color in enumerate(colors):
83
+ ax.add_patch(plt.Rectangle((i, 0), 1, 1, facecolor=color))
84
+
85
+ # Set the view limits and aspect ratio
86
+ ax.set_xlim(0, len(colors))
87
+ ax.set_ylim(0, 1)
88
+ ax.set_xticks([])
89
+ ax.set_yticks([])
90
+
91
+ return fig # Return the matplotlib figure directly
92
+ except Exception as e:
93
+ print(f"Error creating color palette: {e}")
94
+ return None
95
+
96
+
97
+ def inference(image):
98
+ if image is None:
99
+ return ["Please provide an image"] * 4
100
+
101
+ if not isinstance(image, Image.Image):
102
+ try:
103
+ image = Image.fromarray(image)
104
+ except Exception as e:
105
+ print(f"Image conversion error: {e}")
106
+ return ["Invalid image format"] * 4
107
+
108
+ # Prepare input
109
+ messages = [
110
+ {
111
+ "role": "user",
112
+ "content": [
113
+ {"type": "image"},
114
+ {"type": "text", "text": "Analyze this image for fire, smoke, haze, or other related conditions."},
115
+ ],
116
+ }
117
+ ]
118
+ input_text = processor.apply_chat_template(messages, add_generation_prompt=True)
119
+ try:
120
+ # Move inputs to the correct device
121
+ inputs = processor(
122
+ image, input_text, add_special_tokens=False, return_tensors="pt"
123
+ ).to(model.device)
124
+
125
+ # Clear CUDA cache after inference
126
+ with torch.no_grad():
127
+ output = model.generate(**inputs, max_new_tokens=2048)
128
+ if torch.cuda.is_available():
129
+ torch.cuda.empty_cache()
130
+
131
+ except Exception as e:
132
+ print(f"Inference error: {e}")
133
+ return ["Error during inference"] * 4
134
+
135
+ # Decode output
136
+ result = processor.decode(output[0], skip_special_tokens=True)
137
+ print("DEBUG: Full decoded output:", result)
138
+
139
+ try:
140
+ json_str = result.strip().split("assistant\n")[1].strip()
141
+ parsed_json = json.loads(json_str)
142
+
143
+ # Create specific JSON subsets for each section
144
+ fire_analysis = {
145
+ "predictions": parsed_json.get("predictions", "N/A"),
146
+ "description": parsed_json.get("description", "No description available"),
147
+ "confidence_scores": parsed_json.get("confidence_score", {})
148
+ }
149
+
150
+ environment_analysis = {
151
+ "environmental_factors": parsed_json.get("environmental_factors", {})
152
+ }
153
+
154
+ detection_analysis = {
155
+ "detections": parsed_json.get("detections", []),
156
+ "detection_count": len(parsed_json.get("detections", []))
157
+ }
158
+
159
+ report_analysis = {
160
+ "uncertainty_factors": parsed_json.get("uncertainty_factors", []),
161
+ "false_positive_indicators": parsed_json.get("false_positive_indicators", [])
162
+ }
163
+
164
+ return (
165
+ json.dumps(fire_analysis, indent=2),
166
+ json.dumps(environment_analysis, indent=2),
167
+ json.dumps(detection_analysis, indent=2),
168
+ json.dumps(report_analysis, indent=2),
169
+ json_str,
170
+ "",
171
+ "Analysis complete",
172
+ parsed_json
173
+ )
174
+ except Exception as e:
175
+ print("DEBUG: Error processing response:", e)
176
+ return (
177
+ "Error processing response",
178
+ "",
179
+ "",
180
+ "",
181
+ str(result),
182
+ str(e),
183
+ "Error",
184
+ {}
185
+ )
186
+
187
+
188
+ # Update Gradio interface
189
+ with gr.Blocks() as demo:
190
+ gr.Markdown("# Fire Detection Demo")
191
+
192
+ with gr.Row():
193
+ with gr.Column(scale=1):
194
+ image_input = gr.Image(
195
+ type="pil",
196
+ label="Upload Image",
197
+ elem_id="large-image",
198
+ )
199
+ submit_btn = gr.Button("Analyze Image", variant="primary")
200
+
201
+ # Add examples here
202
+ gr.Examples(
203
+ examples=[
204
+ "examples/Birch MWF014-0001.png",
205
+ "examples/Birch MWF014-0006.png",
206
+ "examples/Blackstone PB-0010.png",
207
+ ],
208
+ inputs=image_input,
209
+ label="Example Images",
210
+ examples_per_page=4
211
+ )
212
+
213
+ with gr.Tabs() as tabs:
214
+ with gr.Tab("Analysis Results"):
215
+ with gr.Row():
216
+ with gr.Column():
217
+ fire_output = gr.JSON(
218
+ label="Fire Details",
219
+ lines=4,
220
+ )
221
+ with gr.Column():
222
+ environment_output = gr.JSON(
223
+ label="Environment Details",
224
+ lines=4,
225
+ )
226
+ with gr.Row():
227
+ with gr.Column():
228
+ detection_output = gr.JSON(
229
+ label="Detection Details",
230
+ lines=4,
231
+ )
232
+ with gr.Column():
233
+ report_output = gr.JSON(
234
+ label="Report Details",
235
+ lines=4,
236
+ )
237
+
238
+ with gr.Tab("JSON Output", id=0):
239
+ json_output = gr.JSON(
240
+ label="Detailed JSON Results",
241
+ )
242
+
243
+ with gr.Tab("Raw Output"):
244
+ raw_output = gr.Textbox(
245
+ label="Raw JSON Response",
246
+ lines=10,
247
+ )
248
+
249
+ error_box = gr.Textbox(label="Error Messages", visible=False)
250
+ status_text = gr.Textbox(label="Status", value="Ready", interactive=False)
251
+
252
+ submit_btn.click(
253
+ fn=inference,
254
+ inputs=[image_input],
255
+ outputs=[
256
+ fire_output,
257
+ environment_output,
258
+ detection_output,
259
+ report_output,
260
+ raw_output,
261
+ error_box,
262
+ status_text,
263
+ json_output,
264
+ ],
265
+ )
266
+
267
+ demo.launch(share=True)