Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import sys | |
import torch | |
import numpy as np | |
import csv | |
import argparse | |
import open_clip | |
def load_descriptions(file_path): | |
"""Load descriptions from a CSV file.""" | |
descriptions = [] | |
with open(file_path, 'r') as file: | |
csv_reader = csv.reader(file) | |
next(csv_reader) # Skip the header | |
for row in csv_reader: | |
descriptions.append(row[0]) | |
return descriptions | |
def generate_embeddings(descriptions, model, tokenizer, device, batch_size): | |
"""Generate text embeddings in batches.""" | |
final_embeddings = [] | |
for i in range(0, len(descriptions), batch_size): | |
batch_desc = descriptions[i:i + batch_size] | |
texts = tokenizer(batch_desc).to(device) | |
batch_embeddings = model.encode_text(texts) | |
batch_embeddings = batch_embeddings.detach().cpu().numpy() | |
final_embeddings.append(batch_embeddings) | |
del texts, batch_embeddings | |
torch.cuda.empty_cache() | |
return np.vstack(final_embeddings) | |
def save_embeddings(output_file, embeddings): | |
"""Save embeddings to a .npy file.""" | |
np.save(output_file, embeddings) | |
def main(): | |
parser = argparse.ArgumentParser(description="Generate text embeddings using CLIP.") | |
parser.add_argument("--input_csv", type=str, required=True, help="Path to the input CSV file containing text descriptions.") | |
parser.add_argument("--output_file", type=str, required=True, help="Path to save the output .npy file.") | |
parser.add_argument("--batch_size", type=int, default=100, help="Batch size for processing embeddings.") | |
parser.add_argument("--device", type=str, default="cuda:0", help="Device to run the model on (e.g., 'cuda:0' or 'cpu').") | |
args = parser.parse_args() | |
# Load the CLIP model and tokenizer | |
model, _, _ = open_clip.create_model_and_transforms('hf-hub:laion/CLIP-ViT-H-14-laion2B-s32B-b79K') | |
model.to(args.device) | |
tokenizer = open_clip.get_tokenizer('hf-hub:laion/CLIP-ViT-H-14-laion2B-s32B-b79K') | |
# Load descriptions from CSV | |
descriptions = load_descriptions(args.input_csv) | |
# Generate embeddings | |
embeddings = generate_embeddings(descriptions, model, tokenizer, args.device, args.batch_size) | |
# Save embeddings to output file | |
save_embeddings(args.output_file, embeddings) | |
if __name__ == "__main__": | |
main() | |