Spaces:
Runtime error
Runtime error
Update clip_chat.py
Browse files- clip_chat.py +11 -7
clip_chat.py
CHANGED
@@ -7,15 +7,14 @@ from random import choice
|
|
7 |
|
8 |
|
9 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
10 |
-
|
11 |
-
available_models = ['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px']
|
12 |
-
model, preprocess = clip.load(available_models[-1], device=device)
|
13 |
|
14 |
COCO = glob.glob(os.path.join(os.getcwd(), "images", "*"))
|
|
|
15 |
|
16 |
|
17 |
def load_random_image():
|
18 |
-
image_path = choice(COCO)
|
19 |
image = Image.open(image_path)
|
20 |
return image
|
21 |
|
@@ -26,6 +25,10 @@ def next_image():
|
|
26 |
image = preprocess(Image.fromarray(image_org)).unsqueeze(0).to(device)
|
27 |
|
28 |
|
|
|
|
|
|
|
|
|
29 |
def calculate_logits(image_features, text_features):
|
30 |
image_features = image_features / image_features.norm(dim=1, keepdim=True)
|
31 |
text_features = text_features / text_features.norm(dim=1, keepdim=True)
|
@@ -37,7 +40,7 @@ def calculate_logits(image_features, text_features):
|
|
37 |
last = -1
|
38 |
best = -1
|
39 |
|
40 |
-
goal =
|
41 |
|
42 |
image_org = load_random_image()
|
43 |
image = preprocess(image_org).unsqueeze(0).to(device)
|
@@ -52,8 +55,9 @@ def answer(message):
|
|
52 |
|
53 |
with torch.no_grad():
|
54 |
text_features = model.encode_text(text)
|
55 |
-
logits_per_image, _ = model(image, text)
|
56 |
logits = calculate_logits(image_features, text_features).cpu().numpy().flatten()[0]
|
|
|
57 |
|
58 |
if last == -1:
|
59 |
is_better = -1
|
@@ -78,6 +82,6 @@ def reset_everything():
|
|
78 |
global last, best, goal, image, image_org
|
79 |
last = -1
|
80 |
best = -1
|
81 |
-
goal =
|
82 |
image_org = load_random_image()
|
83 |
image = preprocess(image_org).unsqueeze(0).to(device)
|
|
|
7 |
|
8 |
|
9 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
10 |
+
model, preprocess = clip.load("ViT-L/14@336px", device=device)
|
|
|
|
|
11 |
|
12 |
COCO = glob.glob(os.path.join(os.getcwd(), "images", "*"))
|
13 |
+
available_models = ['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px']
|
14 |
|
15 |
|
16 |
def load_random_image():
|
17 |
+
image_path = COCO[0] # choice(COCO)
|
18 |
image = Image.open(image_path)
|
19 |
return image
|
20 |
|
|
|
25 |
image = preprocess(Image.fromarray(image_org)).unsqueeze(0).to(device)
|
26 |
|
27 |
|
28 |
+
# def calculate_logits(image, text):
|
29 |
+
# return model(image, text)[0]
|
30 |
+
|
31 |
+
|
32 |
def calculate_logits(image_features, text_features):
|
33 |
image_features = image_features / image_features.norm(dim=1, keepdim=True)
|
34 |
text_features = text_features / text_features.norm(dim=1, keepdim=True)
|
|
|
40 |
last = -1
|
41 |
best = -1
|
42 |
|
43 |
+
goal = 23
|
44 |
|
45 |
image_org = load_random_image()
|
46 |
image = preprocess(image_org).unsqueeze(0).to(device)
|
|
|
55 |
|
56 |
with torch.no_grad():
|
57 |
text_features = model.encode_text(text)
|
58 |
+
# logits_per_image, _ = model(image, text)
|
59 |
logits = calculate_logits(image_features, text_features).cpu().numpy().flatten()[0]
|
60 |
+
# logits = calculate_logits(image, text)
|
61 |
|
62 |
if last == -1:
|
63 |
is_better = -1
|
|
|
82 |
global last, best, goal, image, image_org
|
83 |
last = -1
|
84 |
best = -1
|
85 |
+
goal = 23
|
86 |
image_org = load_random_image()
|
87 |
image = preprocess(image_org).unsqueeze(0).to(device)
|