Spaces:
Runtime error
Runtime error
import torch | |
import clip | |
from PIL import Image | |
import glob | |
import os | |
from random import choice | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model, preprocess = clip.load("ViT-L/14@336px", device=device) | |
COCO = glob.glob(os.path.join(os.getcwd(), "images", "*")) | |
# with open('coco_paths.txt', 'r') as file: | |
# COCO = file.read().split('\n')[:-1] | |
# COCO = [c.replace("/media/bonilla/My Book/coco/train2017/", "D:\\coco\\train2017\\") for c in COCO] | |
def load_random_image(): | |
image_path = choice(COCO) | |
image = Image.open(image_path) | |
return image | |
def next_image(): | |
global image_org, image | |
image_org = load_random_image() | |
image = preprocess(Image.fromarray(image_org)).unsqueeze(0).to(device) | |
last = -1 | |
best = -1 | |
goal = 21 | |
image_org = load_random_image() | |
image = preprocess(image_org).unsqueeze(0).to(device) | |
def answer(message): | |
global last, best | |
text = clip.tokenize([message]).to(device) | |
with torch.no_grad(): | |
logits_per_image, _ = model(image, text) | |
logits = logits_per_image.cpu().numpy().flatten()[0] | |
if last == -1: | |
is_better = -1 | |
elif last > logits: | |
is_better = 0 | |
elif last < logits: | |
is_better = 1 | |
elif logits > goal: | |
is_better = 2 | |
else: | |
is_better = -1 | |
last = logits | |
if logits > best: | |
best = logits | |
is_better = 3 | |
return logits, is_better # logit2sentence(logits) + " " + is_better + " " + f"({logits})" | |
def reset_everything(): | |
global last, best, goal, image, image_org | |
last = -1 | |
best = -1 | |
goal = 21 | |
image_org = load_random_image() | |
image = preprocess(image_org).unsqueeze(0).to(device) | |