File size: 2,087 Bytes
1f094d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from typing import  List, Callable

import numpy as np
from PIL import Image
import torch

def l2_normalize(embedding: np.ndarray) -> np.ndarray:
    """Normalize vector using L2 norm.
    
    Args:
        embedding (np.ndarray): Input vector to normalize.
    
    Returns:
        np.ndarray: Normalized vector.    
    """
    # Compute the L2 norm of the input vector
    norm = np.linalg.norm(embedding)

    # Return the normalized vector if norm is greater than 0;
    # otherwise, return the original vector
    return embedding / norm if norm > 0 else embedding

def encode_image(
        image: Image.Image,
        preprocess: Callable[[Image.Image], torch.Tensor],
        model: torch.nn.Module,
        device: torch.device,
    ) -> List[float]:
    """Preprocess and encode an image using input model.

    This function performs the following steps:
      1. Preprocess the image to create a tensor.
      2. Move the tensor to the specified device (CPU or GPU).
      3. Generate image features using the model.
      4. Normalize the resulting embedding.
    
    Args:
        image (Image.Image): Input image to encode.
        preprocess (Callable[[Image.Image], torch.Tensor]): 
            A callable function to preprocess the image.
        model (torch.nn.Module): The model used for encoding.
        device (torch.device): The device to which the image tensor is sent.

    Returns:
        List[float]: A list representing the normalized embedding.
    """
    # Preprocess the input image and add a batch dimension
    image_input = preprocess(image).unsqueeze(0).to(device)

    # Use the model to encode the image without computing gradients
    with torch.no_grad():
        image_features = model.encode_image(image_input)

    # Extract the first (and only) embedding from the batch and move it to CPU
    embedding = image_features[0].cpu().numpy()

    # Normalize the embedding using L2 normalization
    embedding_norm = l2_normalize(embedding)

    # Convert the normalized NumPy array to a list and return it
    return embedding_norm.tolist()