import torch from transformers import AutoFeatureExtractor, AutoModelForImageClassification from PIL import Image import io import base64 import numpy as np from flask import Flask, request, jsonify class HuggingFaceClassifier: def __init__(self, model_name="microsoft/resnet-50"): """ Initialize Hugging Face model and feature extractor Args: model_name (str): Hugging Face model identifier """ try: # Load pre-trained model and feature extractor self.feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) self.model = AutoModelForImageClassification.from_pretrained(model_name) # Move to GPU if available self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model.to(self.device) self.model.eval() except Exception as e: raise ValueError(f"Model loading error: {e}") def preprocess_image(self, image): """ Preprocess image for model input Args: image (PIL.Image): Input image Returns: torch.Tensor: Preprocessed image tensor """ # Preprocess image using feature extractor inputs = self.feature_extractor(images=image, return_tensors="pt") return inputs.pixel_values.to(self.device) def predict(self, image): """ Predict image classification Args: image (PIL.Image): Input image Returns: list: Top prediction results """ try: # Preprocess image inputs = self.preprocess_image(image) # Perform prediction with torch.no_grad(): outputs = self.model(inputs) logits = outputs.logits probabilities = torch.softmax(logits, dim=-1) top_k = torch.topk(probabilities, k=5) # Process results predicted_classes = [ { "label": self.model.config.id2label[idx.item()], "score": prob.item() } for idx, prob in zip(top_k.indices[0], top_k.values[0]) ] return predicted_classes except Exception as e: raise RuntimeError(f"Prediction error: {e}") # Flask API Setup app = Flask(__name__) # Initialize classifier (can be changed to any model) classifier = HuggingFaceClassifier( model_name="microsoft/resnet-50" ) @app.route('/predict', methods=['POST']) def predict_image(): """ Image classification endpoint Supports base64 and file upload """ try: # Handle base64 encoded image if 'image' in request.json: image_data = base64.b64decode(request.json['image']) image = Image.open(io.BytesIO(image_data)) # Handle file upload elif 'file' in request.files: image = Image.open(request.files['file']) else: return jsonify({ 'error': 'No image provided', 'status': 'failed' }), 400 # Perform prediction predictions = classifier.predict(image) return jsonify({ 'predictions': predictions, 'status': 'success' }) except Exception as e: return jsonify({ 'error': str(e), 'status': 'failed' }), 500 @app.route('/models', methods=['GET']) def available_models(): """ List available pre-trained models """ models = [ "microsoft/resnet-50", "google/vit-base-patch16-224", "facebook/vit-mae-base", "microsoft/beit-base-patch16-224" ] return jsonify({ 'models': models, 'total_models': len(models) }) @app.route('/health', methods=['GET']) def health_check(): """ API health check endpoint """ return jsonify({ 'status': 'healthy', 'model': classifier.model.config.model_type, 'device': str(classifier.device) }) if __name__ == '__main__': app.run(host='0.0.0.0', port=5000, debug=True)