File size: 4,350 Bytes
8a9662a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import torch
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
from PIL import Image
import io
import base64
import numpy as np
from flask import Flask, request, jsonify

class HuggingFaceClassifier:
    def __init__(self, model_name="microsoft/resnet-50"):
        """
        Initialize Hugging Face model and feature extractor
        
        Args:
            model_name (str): Hugging Face model identifier
        """
        try:
            # Load pre-trained model and feature extractor
            self.feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
            self.model = AutoModelForImageClassification.from_pretrained(model_name)
            
            # Move to GPU if available
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            self.model.to(self.device)
            self.model.eval()
        
        except Exception as e:
            raise ValueError(f"Model loading error: {e}")
    
    def preprocess_image(self, image):
        """
        Preprocess image for model input
        
        Args:
            image (PIL.Image): Input image
        
        Returns:
            torch.Tensor: Preprocessed image tensor
        """
        # Preprocess image using feature extractor
        inputs = self.feature_extractor(images=image, return_tensors="pt")
        return inputs.pixel_values.to(self.device)
    
    def predict(self, image):
        """
        Predict image classification
        
        Args:
            image (PIL.Image): Input image
        
        Returns:
            list: Top prediction results
        """
        try:
            # Preprocess image
            inputs = self.preprocess_image(image)
            
            # Perform prediction
            with torch.no_grad():
                outputs = self.model(inputs)
                logits = outputs.logits
                probabilities = torch.softmax(logits, dim=-1)
                top_k = torch.topk(probabilities, k=5)
            
            # Process results
            predicted_classes = [
                {
                    "label": self.model.config.id2label[idx.item()],
                    "score": prob.item()
                }
                for idx, prob in zip(top_k.indices[0], top_k.values[0])
            ]
            
            return predicted_classes
        
        except Exception as e:
            raise RuntimeError(f"Prediction error: {e}")

# Flask API Setup
app = Flask(__name__)

# Initialize classifier (can be changed to any model)
classifier = HuggingFaceClassifier(
    model_name="microsoft/resnet-50"
)

@app.route('/predict', methods=['POST'])
def predict_image():
    """
    Image classification endpoint
    Supports base64 and file upload
    """
    try:
        # Handle base64 encoded image
        if 'image' in request.json:
            image_data = base64.b64decode(request.json['image'])
            image = Image.open(io.BytesIO(image_data))
        
        # Handle file upload
        elif 'file' in request.files:
            image = Image.open(request.files['file'])
        
        else:
            return jsonify({
                'error': 'No image provided',
                'status': 'failed'
            }), 400
        
        # Perform prediction
        predictions = classifier.predict(image)
        
        return jsonify({
            'predictions': predictions,
            'status': 'success'
        })
    
    except Exception as e:
        return jsonify({
            'error': str(e),
            'status': 'failed'
        }), 500

@app.route('/models', methods=['GET'])
def available_models():
    """
    List available pre-trained models
    """
    models = [
        "microsoft/resnet-50",
        "google/vit-base-patch16-224",
        "facebook/vit-mae-base",
        "microsoft/beit-base-patch16-224"
    ]
    
    return jsonify({
        'models': models,
        'total_models': len(models)
    })

@app.route('/health', methods=['GET'])
def health_check():
    """
    API health check endpoint
    """
    return jsonify({
        'status': 'healthy',
        'model': classifier.model.config.model_type,
        'device': str(classifier.device)
    })

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=5000, debug=True)