medtty commited on
Commit
067b419
·
1 Parent(s): 528ce19
Files changed (4) hide show
  1. .gitignore +2 -1
  2. app.py +287 -90
  3. examples/.DS_Store +0 -0
  4. requirements.txt +3 -1
.gitignore CHANGED
@@ -1 +1,2 @@
1
- training.py
 
 
1
+ trainig.py
2
+ *.bak
app.py CHANGED
@@ -3,14 +3,24 @@ import tensorflow as tf
3
  import numpy as np
4
  import json
5
  from PIL import Image
6
- from fastapi import FastAPI, UploadFile, File
 
7
  import uvicorn
8
- import cv2 # Import OpenCV
9
- import mediapipe as mp # Import MediaPipe
 
 
 
 
 
 
10
 
11
  # Initialize MediaPipe Hands
12
  mp_hands = mp.solutions.hands
13
- hands = mp_hands.Hands(static_image_mode=True, max_num_hands=1, min_detection_confidence=0.5)
 
 
 
14
  mp_drawing = mp.solutions.drawing_utils
15
 
16
  # Create both Gradio and FastAPI apps
@@ -28,129 +38,208 @@ with open('model/class_indices.json') as f:
28
 
29
  index_to_class = {int(k): v for k, v in class_indices.items()}
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  # Preprocess function now expects a PIL Image (already cropped)
32
  def preprocess_image(image):
33
  # Ensure image is RGB before resizing and converting
34
  if image.mode != 'RGB':
35
  image = image.convert('RGB')
36
- image = image.resize((224, 224))
37
  image_array = np.array(image) / 255.0
38
- # The input tensor is expected to be float32
39
  return np.expand_dims(image_array, axis=0).astype(np.float32)
40
 
41
- # Modified predict function to include hand detection and cropping
42
- def predict(image_pil):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  try:
44
- print(f"Original image mode: {image_pil.mode}, size: {image_pil.size}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
- # Convert PIL image to OpenCV format (NumPy array)
 
 
 
47
  image_cv = np.array(image_pil)
48
- # Convert RGB (from PIL) to BGR (for OpenCV display if needed) then back to RGB for MediaPipe
 
49
  image_rgb = cv2.cvtColor(image_cv, cv2.COLOR_RGB2BGR)
50
  image_rgb = cv2.cvtColor(image_rgb, cv2.COLOR_BGR2RGB)
51
-
52
- # Process the image with MediaPipe Hands
53
- results = hands.process(image_rgb)
54
-
55
- if not results.multi_hand_landmarks:
56
- print("No hand detected in the image.")
57
- return {"error": "No hand detected"}
58
-
59
- # Assuming only one hand is detected (max_num_hands=1)
60
- hand_landmarks = results.multi_hand_landmarks[0]
61
-
62
- # Calculate bounding box from landmarks
63
- h, w, _ = image_rgb.shape
64
- x_min, y_min = w, h
65
- x_max, y_max = 0, 0
66
- for landmark in hand_landmarks.landmark:
67
- x, y = int(landmark.x * w), int(landmark.y * h)
68
- if x < x_min: x_min = x
69
- if y < y_min: y_min = y
70
- if x > x_max: x_max = x
71
- if y > y_max: y_max = y
72
-
73
- # Add some padding to the bounding box
74
- padding = 30
75
- x_min = max(0, x_min - padding)
76
- y_min = max(0, y_min - padding)
77
- x_max = min(w, x_max + padding)
78
- y_max = min(h, y_max + padding)
79
-
80
- # Ensure the box has valid dimensions
81
- if x_min >= x_max or y_min >= y_max:
82
- print("Invalid bounding box calculated.")
83
- return {"error": "Could not calculate valid hand bounding box"}
84
-
85
- # Crop the original RGB image using the bounding box
86
- cropped_image_np = image_rgb[y_min:y_max, x_min:x_max]
87
-
88
- # Check if cropping resulted in an empty image
89
- if cropped_image_np.size == 0:
90
- print("Cropping resulted in an empty image.")
91
- return {"error": "Cropping failed, possibly invalid bounding box"}
92
-
93
- # Convert cropped NumPy array back to PIL Image
94
- cropped_image_pil = Image.fromarray(cropped_image_np)
95
- print(f"Cropped image size: {cropped_image_pil.size}")
96
-
97
- # Preprocess the cropped image
98
- processed_image = preprocess_image(cropped_image_pil)
99
- print(f"Processed image shape: {processed_image.shape}, dtype: {processed_image.dtype}")
100
-
101
- # --- Inference ---
102
  interpreter.set_tensor(input_details[0]['index'], processed_image)
103
  interpreter.invoke()
104
  output_data = interpreter.get_tensor(output_details[0]['index'])
105
  prediction = output_data[0]
106
- # --- End Inference ---
107
-
108
- print(f"Raw prediction output: {prediction}")
109
-
110
  predicted_class_idx = int(np.argmax(prediction))
111
  confidence = float(prediction[predicted_class_idx])
112
- # Use the correct class mapping loaded earlier
113
  predicted_class = index_to_class.get(predicted_class_idx, f"unknown_{predicted_class_idx}")
114
-
115
- print(f"Predicted class index: {predicted_class_idx}, Confidence: {confidence}, Class: {predicted_class}")
116
-
117
  return {
118
  "class": predicted_class,
119
  "confidence": confidence,
120
  "all_predictions": {
121
- # Use the correct class mapping here too
122
  index_to_class.get(i, f"class_{i}"): float(prediction[i])
123
  for i in range(len(prediction))
124
  }
125
  }
126
  except Exception as e:
127
- print(f"Error during prediction: {e}")
128
- # Also print traceback for detailed debugging
129
  import traceback
130
  traceback.print_exc()
131
  return {"error": str(e)}
132
 
133
- # Gradio Interface
134
  with gradio_app:
135
  gr.Markdown("# Hand Gesture Recognition")
136
- with gr.Row():
137
- input_image = gr.Image(type="pil", label="Upload Image")
138
- output_json = gr.JSON(label="Prediction Results")
139
- submit = gr.Button("Predict")
140
- submit.click(
141
- fn=predict,
142
- inputs=input_image,
143
- outputs=output_json
144
- )
145
- gr.Examples(
146
- examples=[["examples/two_up.jpg"], ["examples/call.jpg"], ["examples/stop.jpg"]],
147
- inputs=input_image
148
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
  # Mount Gradio app to FastAPI
151
  fastapi_app = gr.mount_gradio_app(fastapi_app, gradio_app, path="/")
152
 
153
- # API endpoint
154
  @fastapi_app.post("/api/predict")
155
  async def api_predict(file: UploadFile = File(...)):
156
  try:
@@ -164,13 +253,121 @@ async def api_predict(file: UploadFile = File(...)):
164
  # Convert BGR (OpenCV default) to RGB for PIL
165
  img_rgb = cv2.cvtColor(img_cv, cv2.COLOR_BGR2RGB)
166
  image_pil = Image.fromarray(img_rgb)
167
- return predict(image_pil) # Call the modified predict function
168
  except Exception as e:
169
- print(f"Error processing uploaded file: {e}")
170
  import traceback
171
  traceback.print_exc()
172
  return {"error": f"Failed to process image: {e}"}
173
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
 
175
  if __name__ == "__main__":
176
  # Modified for Hugging Face Spaces environment
 
3
  import numpy as np
4
  import json
5
  from PIL import Image
6
+ from fastapi import FastAPI, UploadFile, File, WebSocket, Request, Response
7
+ from fastapi.responses import StreamingResponse
8
  import uvicorn
9
+ import cv2
10
+ import mediapipe as mp
11
+ import io
12
+ import base64
13
+ import asyncio
14
+ import time
15
+ from typing import List, Dict, Any
16
+ from pydantic import BaseModel
17
 
18
  # Initialize MediaPipe Hands
19
  mp_hands = mp.solutions.hands
20
+ # For static images, we use static_image_mode=True
21
+ hands_static = mp_hands.Hands(static_image_mode=True, max_num_hands=1, min_detection_confidence=0.5)
22
+ # For video streams, we use static_image_mode=False for better performance
23
+ hands_video = mp_hands.Hands(static_image_mode=False, max_num_hands=1, min_detection_confidence=0.5, min_tracking_confidence=0.5)
24
  mp_drawing = mp.solutions.drawing_utils
25
 
26
  # Create both Gradio and FastAPI apps
 
38
 
39
  index_to_class = {int(k): v for k, v in class_indices.items()}
40
 
41
+ # Model and processing parameters
42
+ MODEL_INPUT_SIZE = (224, 224)
43
+ DETECTION_FREQUENCY = 5 # Process every Nth frame for performance
44
+ CONFIDENCE_THRESHOLD = 0.5 # Minimum confidence to report a gesture
45
+
46
+ # Data models for API
47
+ class GestureResponse(BaseModel):
48
+ class_name: str
49
+ confidence: float
50
+ timestamp: float
51
+ all_predictions: Dict[str, float] = None
52
+
53
+ class StreamRequest(BaseModel):
54
+ stream_id: str = None
55
+ width: int = 640
56
+ height: int = 480
57
+ fps: int = 15
58
+
59
+ # Cache to store most recent detection results
60
+ detection_cache = {}
61
+
62
  # Preprocess function now expects a PIL Image (already cropped)
63
  def preprocess_image(image):
64
  # Ensure image is RGB before resizing and converting
65
  if image.mode != 'RGB':
66
  image = image.convert('RGB')
67
+ image = image.resize(MODEL_INPUT_SIZE)
68
  image_array = np.array(image) / 255.0
 
69
  return np.expand_dims(image_array, axis=0).astype(np.float32)
70
 
71
+ def detect_and_crop_hand(image_rgb):
72
+ """Detect hand in the image and return cropped hand region if found"""
73
+ h, w = image_rgb.shape[:2]
74
+ results = hands_static.process(image_rgb)
75
+
76
+ if not results.multi_hand_landmarks:
77
+ return None, "No hand detected"
78
+
79
+ # Get the first hand detected
80
+ hand_landmarks = results.multi_hand_landmarks[0]
81
+
82
+ # Calculate bounding box from landmarks
83
+ x_min, y_min = w, h
84
+ x_max, y_max = 0, 0
85
+ for landmark in hand_landmarks.landmark:
86
+ x, y = int(landmark.x * w), int(landmark.y * h)
87
+ if x < x_min: x_min = x
88
+ if y < y_min: y_min = y
89
+ if x > x_max: x_max = x
90
+ if y > y_max: y_max = y
91
+
92
+ # Add padding to the bounding box
93
+ padding = 30
94
+ x_min = max(0, x_min - padding)
95
+ y_min = max(0, y_min - padding)
96
+ x_max = min(w, x_max + padding)
97
+ y_max = min(h, y_max + padding)
98
+
99
+ # Check for valid dimensions
100
+ if x_min >= x_max or y_min >= y_max:
101
+ return None, "Invalid bounding box"
102
+
103
+ # Crop the hand region
104
+ cropped_image = image_rgb[y_min:y_max, x_min:x_max]
105
+
106
+ if cropped_image.size == 0:
107
+ return None, "Empty cropped image"
108
+
109
+ return cropped_image, None
110
+
111
+ def process_frame_for_gesture(frame):
112
+ """Process a single frame for hand gesture recognition"""
113
  try:
114
+ # Convert to RGB for MediaPipe
115
+ if frame.shape[2] == 4: # RGBA
116
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
117
+ elif frame.shape[2] == 3 and frame.dtype == np.uint8:
118
+ # Assuming BGR from OpenCV
119
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
120
+
121
+ # Detect and crop hand
122
+ cropped_hand, error = detect_and_crop_hand(frame)
123
+ if error:
124
+ return {"error": error}
125
+
126
+ # Convert cropped NumPy array to PIL Image
127
+ cropped_pil = Image.fromarray(cropped_hand)
128
+
129
+ # Preprocess and predict
130
+ processed_image = preprocess_image(cropped_pil)
131
+ interpreter.set_tensor(input_details[0]['index'], processed_image)
132
+ interpreter.invoke()
133
+ output_data = interpreter.get_tensor(output_details[0]['index'])
134
+ prediction = output_data[0]
135
+
136
+ # Get the prediction result
137
+ predicted_class_idx = int(np.argmax(prediction))
138
+ confidence = float(prediction[predicted_class_idx])
139
+ predicted_class = index_to_class.get(predicted_class_idx, f"unknown_{predicted_class_idx}")
140
+
141
+ # Return prediction info
142
+ return {
143
+ "class": predicted_class,
144
+ "confidence": confidence,
145
+ "timestamp": time.time(),
146
+ "all_predictions": {
147
+ index_to_class.get(i, f"class_{i}"): float(prediction[i])
148
+ for i in range(len(prediction))
149
+ }
150
+ }
151
+ except Exception as e:
152
+ import traceback
153
+ traceback.print_exc()
154
+ return {"error": str(e)}
155
 
156
+ def predict(image_pil):
157
+ """Original prediction function for Gradio interface"""
158
+ try:
159
+ # Convert PIL image to OpenCV format
160
  image_cv = np.array(image_pil)
161
+
162
+ # Process the image with MediaPipe Hands
163
  image_rgb = cv2.cvtColor(image_cv, cv2.COLOR_RGB2BGR)
164
  image_rgb = cv2.cvtColor(image_rgb, cv2.COLOR_BGR2RGB)
165
+
166
+ # Detect hand and get cropped image
167
+ cropped_hand, error = detect_and_crop_hand(image_rgb)
168
+ if error:
169
+ return {"error": error}
170
+
171
+ # Convert cropped NumPy array to PIL Image
172
+ cropped_pil = Image.fromarray(cropped_hand)
173
+
174
+ # Preprocess and predict
175
+ processed_image = preprocess_image(cropped_pil)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  interpreter.set_tensor(input_details[0]['index'], processed_image)
177
  interpreter.invoke()
178
  output_data = interpreter.get_tensor(output_details[0]['index'])
179
  prediction = output_data[0]
180
+
181
+ # Get the prediction result
 
 
182
  predicted_class_idx = int(np.argmax(prediction))
183
  confidence = float(prediction[predicted_class_idx])
 
184
  predicted_class = index_to_class.get(predicted_class_idx, f"unknown_{predicted_class_idx}")
185
+
 
 
186
  return {
187
  "class": predicted_class,
188
  "confidence": confidence,
189
  "all_predictions": {
 
190
  index_to_class.get(i, f"class_{i}"): float(prediction[i])
191
  for i in range(len(prediction))
192
  }
193
  }
194
  except Exception as e:
 
 
195
  import traceback
196
  traceback.print_exc()
197
  return {"error": str(e)}
198
 
199
+ # Define the Gradio interface
200
  with gradio_app:
201
  gr.Markdown("# Hand Gesture Recognition")
202
+ with gr.Tabs():
203
+ with gr.TabItem("Image Upload"):
204
+ with gr.Row():
205
+ input_image = gr.Image(type="pil", label="Upload Image")
206
+ output_json = gr.JSON(label="Prediction Results")
207
+ submit = gr.Button("Predict")
208
+ submit.click(
209
+ fn=predict,
210
+ inputs=input_image,
211
+ outputs=output_json
212
+ )
213
+ gr.Examples(
214
+ examples=[["examples/two_up.jpg"], ["examples/stop.jpg"]],
215
+ inputs=input_image
216
+ )
217
+
218
+ with gr.TabItem("Live Demo"):
219
+ gr.Markdown("""
220
+ ## Live Demo
221
+ Try the live demo using your webcam!
222
+ - Please allow camera access when prompted
223
+ - Hold your hand gesture in front of the camera
224
+ """)
225
+ camera_input = gr.Image(source="webcam", streaming=True, label="Camera Input")
226
+ live_output = gr.JSON(label="Live Detection Results")
227
+
228
+ def process_camera_input(img):
229
+ if img is None:
230
+ return {"message": "No image received"}
231
+ return predict(img)
232
+
233
+ camera_input.change(
234
+ fn=process_camera_input,
235
+ inputs=camera_input,
236
+ outputs=live_output
237
+ )
238
 
239
  # Mount Gradio app to FastAPI
240
  fastapi_app = gr.mount_gradio_app(fastapi_app, gradio_app, path="/")
241
 
242
+ # API endpoint for single image prediction
243
  @fastapi_app.post("/api/predict")
244
  async def api_predict(file: UploadFile = File(...)):
245
  try:
 
253
  # Convert BGR (OpenCV default) to RGB for PIL
254
  img_rgb = cv2.cvtColor(img_cv, cv2.COLOR_BGR2RGB)
255
  image_pil = Image.fromarray(img_rgb)
256
+ return predict(image_pil)
257
  except Exception as e:
 
258
  import traceback
259
  traceback.print_exc()
260
  return {"error": f"Failed to process image: {e}"}
261
 
262
+ # WebSocket endpoint for video stream processing
263
+ @fastapi_app.websocket("/api/stream")
264
+ async def websocket_endpoint(websocket: WebSocket):
265
+ await websocket.accept()
266
+
267
+ try:
268
+ # Get stream configuration
269
+ config_data = await websocket.receive_text()
270
+ config = json.loads(config_data)
271
+ stream_id = config.get("stream_id", f"stream_{int(time.time())}")
272
+
273
+ frame_count = 0
274
+ last_detection_time = time.time()
275
+ processing_interval = 1.0 / DETECTION_FREQUENCY # Process every N frames
276
+
277
+ while True:
278
+ # Receive frame data
279
+ data = await websocket.receive_bytes()
280
+
281
+ # Decode the image
282
+ nparr = np.frombuffer(data, np.uint8)
283
+ frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
284
+
285
+ if frame is None:
286
+ await websocket.send_json({"error": "Invalid frame data"})
287
+ continue
288
+
289
+ frame_count += 1
290
+ current_time = time.time()
291
+
292
+ # Process every N frames for performance
293
+ if frame_count % DETECTION_FREQUENCY == 0 or (current_time - last_detection_time) >= processing_interval:
294
+ # Process the frame for gesture recognition
295
+ result = process_frame_for_gesture(frame)
296
+
297
+ if "error" not in result:
298
+ # Cache the result
299
+ detection_cache[stream_id] = result
300
+ last_detection_time = current_time
301
+ # Send results back to client
302
+ await websocket.send_json(result)
303
+
304
+ except Exception as e:
305
+ import traceback
306
+ traceback.print_exc()
307
+ print(f"WebSocket error: {e}")
308
+ finally:
309
+ print(f"WebSocket connection closed")
310
+
311
+ # REST API endpoints for mobile integration
312
+ @fastapi_app.post("/api/video/frame")
313
+ async def process_video_frame(request: Request):
314
+ """Process a single video frame sent from Android app"""
315
+ try:
316
+ # Get the raw bytes from the request
317
+ content = await request.body()
318
+
319
+ # Get stream ID from header if available
320
+ stream_id = request.headers.get("X-Stream-ID", f"stream_{int(time.time())}")
321
+
322
+ # Decode the image
323
+ nparr = np.frombuffer(content, np.uint8)
324
+ frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
325
+
326
+ if frame is None:
327
+ return {"error": "Could not decode image data"}
328
+
329
+ # Process the frame
330
+ result = process_frame_for_gesture(frame)
331
+
332
+ if "error" not in result:
333
+ # Cache the result for this stream
334
+ detection_cache[stream_id] = result
335
+ # Return the result
336
+ return result
337
+ else:
338
+ return result
339
+
340
+ except Exception as e:
341
+ import traceback
342
+ traceback.print_exc()
343
+ return {"error": f"Failed to process frame: {e}"}
344
+
345
+ @fastapi_app.get("/api/gestures")
346
+ def get_available_gestures():
347
+ """Return all available gesture classes the model can recognize"""
348
+ return {"gestures": list(index_to_class.values())}
349
+
350
+ @fastapi_app.get("/health")
351
+ def health_check():
352
+ """Simple health check endpoint"""
353
+ return {"status": "healthy", "timestamp": time.time()}
354
+
355
+ # Documentation for Android integration
356
+ @fastapi_app.get("/")
357
+ async def root():
358
+ return {
359
+ "app": "Hand Gesture Recognition API",
360
+ "usage": {
361
+ "image_prediction": "POST /api/predict with image file",
362
+ "video_streaming": "WebSocket /api/stream or POST frames to /api/video/frame",
363
+ "available_gestures": "GET /api/gestures"
364
+ },
365
+ "android_integration": {
366
+ "single_image": "Send image as multipart/form-data to /api/predict",
367
+ "video_stream": "Send individual frames to /api/video/frame with X-Stream-ID header",
368
+ "websocket": "Connect to /api/stream for bidirectional communication"
369
+ }
370
+ }
371
 
372
  if __name__ == "__main__":
373
  # Modified for Hugging Face Spaces environment
examples/.DS_Store ADDED
Binary file (6.15 kB). View file
 
requirements.txt CHANGED
@@ -6,4 +6,6 @@ numpy>=1.22.0
6
  pillow>=9.0.0
7
  python-multipart>=0.0.6
8
  mediapipe>=0.10.0
9
- opencv-python-headless>=4.5.0
 
 
 
6
  pillow>=9.0.0
7
  python-multipart>=0.0.6
8
  mediapipe>=0.10.0
9
+ opencv-python-headless>=4.5.0
10
+ websockets
11
+ pydantic