Boni98 commited on
Commit
b94b68d
·
1 Parent(s): 6e120a4

Update clip_chat.py

Browse files
Files changed (1) hide show
  1. clip_chat.py +11 -7
clip_chat.py CHANGED
@@ -7,15 +7,14 @@ from random import choice
7
 
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
-
11
- available_models = ['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px']
12
- model, preprocess = clip.load(available_models[-1], device=device)
13
 
14
  COCO = glob.glob(os.path.join(os.getcwd(), "images", "*"))
 
15
 
16
 
17
  def load_random_image():
18
- image_path = choice(COCO)
19
  image = Image.open(image_path)
20
  return image
21
 
@@ -26,6 +25,10 @@ def next_image():
26
  image = preprocess(Image.fromarray(image_org)).unsqueeze(0).to(device)
27
 
28
 
 
 
 
 
29
  def calculate_logits(image_features, text_features):
30
  image_features = image_features / image_features.norm(dim=1, keepdim=True)
31
  text_features = text_features / text_features.norm(dim=1, keepdim=True)
@@ -37,7 +40,7 @@ def calculate_logits(image_features, text_features):
37
  last = -1
38
  best = -1
39
 
40
- goal = 21
41
 
42
  image_org = load_random_image()
43
  image = preprocess(image_org).unsqueeze(0).to(device)
@@ -52,8 +55,9 @@ def answer(message):
52
 
53
  with torch.no_grad():
54
  text_features = model.encode_text(text)
55
- logits_per_image, _ = model(image, text)
56
  logits = calculate_logits(image_features, text_features).cpu().numpy().flatten()[0]
 
57
 
58
  if last == -1:
59
  is_better = -1
@@ -78,6 +82,6 @@ def reset_everything():
78
  global last, best, goal, image, image_org
79
  last = -1
80
  best = -1
81
- goal = 21
82
  image_org = load_random_image()
83
  image = preprocess(image_org).unsqueeze(0).to(device)
 
7
 
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ model, preprocess = clip.load("ViT-L/14@336px", device=device)
 
 
11
 
12
  COCO = glob.glob(os.path.join(os.getcwd(), "images", "*"))
13
+ available_models = ['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px']
14
 
15
 
16
  def load_random_image():
17
+ image_path = COCO[0] # choice(COCO)
18
  image = Image.open(image_path)
19
  return image
20
 
 
25
  image = preprocess(Image.fromarray(image_org)).unsqueeze(0).to(device)
26
 
27
 
28
+ # def calculate_logits(image, text):
29
+ # return model(image, text)[0]
30
+
31
+
32
  def calculate_logits(image_features, text_features):
33
  image_features = image_features / image_features.norm(dim=1, keepdim=True)
34
  text_features = text_features / text_features.norm(dim=1, keepdim=True)
 
40
  last = -1
41
  best = -1
42
 
43
+ goal = 23
44
 
45
  image_org = load_random_image()
46
  image = preprocess(image_org).unsqueeze(0).to(device)
 
55
 
56
  with torch.no_grad():
57
  text_features = model.encode_text(text)
58
+ # logits_per_image, _ = model(image, text)
59
  logits = calculate_logits(image_features, text_features).cpu().numpy().flatten()[0]
60
+ # logits = calculate_logits(image, text)
61
 
62
  if last == -1:
63
  is_better = -1
 
82
  global last, best, goal, image, image_org
83
  last = -1
84
  best = -1
85
+ goal = 23
86
  image_org = load_random_image()
87
  image = preprocess(image_org).unsqueeze(0).to(device)