File size: 2,334 Bytes
c025a3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import os
import sys
import torch
import numpy as np
import csv
import argparse
import open_clip

def load_descriptions(file_path):
    """Load descriptions from a CSV file."""
    descriptions = []
    with open(file_path, 'r') as file:
        csv_reader = csv.reader(file)
        next(csv_reader)  # Skip the header
        for row in csv_reader:
            descriptions.append(row[0])
    return descriptions

def generate_embeddings(descriptions, model, tokenizer, device, batch_size):
    """Generate text embeddings in batches."""
    final_embeddings = []
    for i in range(0, len(descriptions), batch_size):
        batch_desc = descriptions[i:i + batch_size]
        texts = tokenizer(batch_desc).to(device)
        batch_embeddings = model.encode_text(texts)
        batch_embeddings = batch_embeddings.detach().cpu().numpy()
        final_embeddings.append(batch_embeddings)
        del texts, batch_embeddings
        torch.cuda.empty_cache()
    return np.vstack(final_embeddings)

def save_embeddings(output_file, embeddings):
    """Save embeddings to a .npy file."""
    np.save(output_file, embeddings)

def main():
    parser = argparse.ArgumentParser(description="Generate text embeddings using CLIP.")
    parser.add_argument("--input_csv", type=str, required=True, help="Path to the input CSV file containing text descriptions.")
    parser.add_argument("--output_file", type=str, required=True, help="Path to save the output .npy file.")
    parser.add_argument("--batch_size", type=int, default=100, help="Batch size for processing embeddings.")
    parser.add_argument("--device", type=str, default="cuda:0", help="Device to run the model on (e.g., 'cuda:0' or 'cpu').")
    
    args = parser.parse_args()

    # Load the CLIP model and tokenizer
    model, _, _ = open_clip.create_model_and_transforms('hf-hub:laion/CLIP-ViT-H-14-laion2B-s32B-b79K')
    model.to(args.device)
    tokenizer = open_clip.get_tokenizer('hf-hub:laion/CLIP-ViT-H-14-laion2B-s32B-b79K')

    # Load descriptions from CSV
    descriptions = load_descriptions(args.input_csv)

    # Generate embeddings
    embeddings = generate_embeddings(descriptions, model, tokenizer, args.device, args.batch_size)

    # Save embeddings to output file
    save_embeddings(args.output_file, embeddings)

if __name__ == "__main__":
    main()