Upload sd_token_similarity_calculator.ipynb
Browse files- sd_token_similarity_calculator.ipynb +539 -505
sd_token_similarity_calculator.ipynb
CHANGED
@@ -29,143 +29,9 @@
|
|
29 |
"cell_type": "code",
|
30 |
"source": [
|
31 |
"# @title ✳️ Load/initialize values\n",
|
32 |
-
"# Load the tokens into the colab\n",
|
33 |
-
"!git clone https://huggingface.co/datasets/codeShare/sd_tokens\n",
|
34 |
-
"import torch\n",
|
35 |
-
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
36 |
-
"from torch import linalg as LA\n",
|
37 |
-
"%cd /content/sd_tokens\n",
|
38 |
-
"token = torch.load('sd15_tensors.pt', map_location= torch.device('cpu'), weights_only=True)\n",
|
39 |
-
"#-----#\n",
|
40 |
-
"VOCAB_FILENAME = 'tokens_most_similiar_to_girl'\n",
|
41 |
-
"ACTIVE_IMG = ''\n",
|
42 |
-
"#-----#\n",
|
43 |
-
"\n",
|
44 |
-
"# Define functions/constants\n",
|
45 |
-
"NUM_TOKENS = 49407\n",
|
46 |
-
"NUM_PREFIX = 13662\n",
|
47 |
-
"NUM_SUFFIX = 32901\n",
|
48 |
-
"\n",
|
49 |
-
"PREFIX_ENC_VOCAB = ['encoded_prefix_to_girl',]\n",
|
50 |
-
"SUFFIX_ENC_VOCAB = ['a_-_encoded_suffix' ,]\n",
|
51 |
-
" #'from_-encoded_suffix',\n",
|
52 |
-
" #'by_-encoded_suffix' ,\n",
|
53 |
-
" #'encoded_suffix-_like']\n",
|
54 |
-
"\n",
|
55 |
-
"# Make sure these match above results\n",
|
56 |
-
"NUM_PREFIX_LISTS = len(PREFIX_ENC_VOCAB)\n",
|
57 |
-
"NUM_SUFFIX_LISTS = len(SUFFIX_ENC_VOCAB)\n",
|
58 |
-
"#-----#\n",
|
59 |
-
"\n",
|
60 |
-
"\n",
|
61 |
-
"#Import the vocab.json\n",
|
62 |
-
"import json\n",
|
63 |
-
"import pandas as pd\n",
|
64 |
-
"\n",
|
65 |
-
"# Read suffix.json\n",
|
66 |
-
"with open('suffix.json', 'r') as f:\n",
|
67 |
-
" data = json.load(f)\n",
|
68 |
-
"_df = pd.DataFrame({'count': data})['count']\n",
|
69 |
-
"suffix = {\n",
|
70 |
-
" key : value for key, value in _df.items()\n",
|
71 |
-
"}\n",
|
72 |
-
"# Read prefix json\n",
|
73 |
-
"with open('prefix.json', 'r') as f:\n",
|
74 |
-
" data = json.load(f)\n",
|
75 |
-
"_df = pd.DataFrame({'count': data})['count']\n",
|
76 |
-
"prefix = {\n",
|
77 |
-
" key : value for key, value in _df.items()\n",
|
78 |
-
"}\n",
|
79 |
-
"\n",
|
80 |
-
"# Read to_suffix.json\n",
|
81 |
-
"with open('to_suffix.json', 'r') as f:\n",
|
82 |
-
" data = json.load(f)\n",
|
83 |
-
"_df = pd.DataFrame({'count': data})['count']\n",
|
84 |
-
"suffix_to_vocab = {\n",
|
85 |
-
" key : value for key, value in _df.items()\n",
|
86 |
-
"}\n",
|
87 |
-
"\n",
|
88 |
-
"# Read to_prefix.json\n",
|
89 |
-
"with open('to_prefix.json', 'r') as f:\n",
|
90 |
-
" data = json.load(f)\n",
|
91 |
-
"_df = pd.DataFrame({'count': data})['count']\n",
|
92 |
-
"prefix_to_vocab = {\n",
|
93 |
-
" key : value for key, value in _df.items()\n",
|
94 |
-
"}\n",
|
95 |
-
"\n",
|
96 |
-
"#-----#\n",
|
97 |
-
"\n",
|
98 |
-
"\n",
|
99 |
-
"# Read to_suffix.json (reversing key and value)\n",
|
100 |
-
"with open('to_suffix.json', 'r') as f:\n",
|
101 |
-
" data = json.load(f)\n",
|
102 |
-
"_df = pd.DataFrame({'count': data})['count']\n",
|
103 |
-
"vocab_to_suffix = {\n",
|
104 |
-
" value : key for key, value in _df.items()\n",
|
105 |
-
"}\n",
|
106 |
-
"\n",
|
107 |
-
"# Read to_prefix.json (reversing key and value)\n",
|
108 |
-
"with open('to_prefix.json', 'r') as f:\n",
|
109 |
-
" data = json.load(f)\n",
|
110 |
-
"_df = pd.DataFrame({'count': data})['count']\n",
|
111 |
-
"vocab_to_prefix = {\n",
|
112 |
-
" value : key for key, value in _df.items()\n",
|
113 |
-
"}\n",
|
114 |
-
"\n",
|
115 |
-
"\n",
|
116 |
-
"#-----#\n",
|
117 |
-
"\n",
|
118 |
-
"#get token from id (excluding tokens with special symbols)\n",
|
119 |
-
"def vocab(id):\n",
|
120 |
-
" _id = f'{id}'\n",
|
121 |
-
" if _id in vocab_to_suffix:\n",
|
122 |
-
" _id = vocab_to_suffix[_id]\n",
|
123 |
-
" return suffix[_id]\n",
|
124 |
-
" if _id in vocab_to_prefix:\n",
|
125 |
-
" _id = vocab_to_prefix[_id]\n",
|
126 |
-
" return prefix[_id]\n",
|
127 |
-
" return ' ' #<---- return whitespace if other id like emojis etc.\n",
|
128 |
-
"#--------#\n",
|
129 |
-
"\n",
|
130 |
-
"#get token from id (excluding tokens with special symbols)\n",
|
131 |
-
"def get_suffix(id):\n",
|
132 |
-
" _id = f'{id}'\n",
|
133 |
-
" if int(id) <= NUM_SUFFIX:\n",
|
134 |
-
" return suffix[_id]\n",
|
135 |
-
" return ' ' #<---- return whitespace if out of bounds\n",
|
136 |
-
"#--------#\n",
|
137 |
-
"\n",
|
138 |
-
"#get token from id (excluding tokens with special symbols)\n",
|
139 |
-
"def get_prefix(id):\n",
|
140 |
-
" _id = f'{id}'\n",
|
141 |
-
" if int(id) <= NUM_PREFIX:\n",
|
142 |
-
" return prefix[_id]\n",
|
143 |
-
" return ' ' #<---- return whitespace if out of bounds\n",
|
144 |
-
"#--------#\n",
|
145 |
-
"\n",
|
146 |
-
"\n",
|
147 |
-
"def _modulus(_id,id_max):\n",
|
148 |
-
" id = _id\n",
|
149 |
-
" while(id>id_max):\n",
|
150 |
-
" id = id-id_max\n",
|
151 |
-
" return id\n",
|
152 |
-
"\n",
|
153 |
-
"#print(get_token(35894))\n"
|
154 |
-
],
|
155 |
-
"metadata": {
|
156 |
-
"id": "w8O0TX7PBh5m"
|
157 |
-
},
|
158 |
-
"execution_count": null,
|
159 |
-
"outputs": []
|
160 |
-
},
|
161 |
-
{
|
162 |
-
"cell_type": "code",
|
163 |
-
"source": [
|
164 |
-
"# @title Load/initialize values (new version - ignore this cell)\n",
|
165 |
"#Imports\n",
|
166 |
"!pip install safetensors\n",
|
167 |
"from safetensors.torch import load_file\n",
|
168 |
-
"\n",
|
169 |
"import json , os , shelve , torch\n",
|
170 |
"import pandas as pd\n",
|
171 |
"#----#\n",
|
@@ -229,7 +95,7 @@
|
|
229 |
" continue\n",
|
230 |
" #-------#\n",
|
231 |
" #--------#\n",
|
232 |
-
" _text_encodings.close() #close the text_encodings file\n",
|
233 |
" file_index = file_index + 1\n",
|
234 |
" #----------#\n",
|
235 |
" NUM_ITEMS = index\n",
|
@@ -245,175 +111,25 @@
|
|
245 |
"metadata": {
|
246 |
"id": "rUXQ73IbonHY"
|
247 |
},
|
248 |
-
"execution_count":
|
249 |
"outputs": []
|
250 |
},
|
251 |
{
|
252 |
"cell_type": "code",
|
253 |
"source": [
|
254 |
-
"# @title
|
|
|
|
|
255 |
"%cd /content/\n",
|
256 |
"!git clone https://huggingface.co/datasets/codeShare/text-to-image-prompts\n",
|
257 |
"#------#\n",
|
258 |
"path = '/content/text-to-image-prompts/civitai-prompts/green'\n",
|
259 |
-
"prompts , text_encodings,
|
260 |
-
],
|
261 |
-
"metadata": {
|
262 |
-
"id": "ZMG4CThUAmwW"
|
263 |
-
},
|
264 |
-
"execution_count": null,
|
265 |
-
"outputs": []
|
266 |
-
},
|
267 |
-
{
|
268 |
-
"cell_type": "code",
|
269 |
-
"source": [
|
270 |
-
"# @title ⚡ Get similiar tokens\n",
|
271 |
-
"import torch\n",
|
272 |
-
"from transformers import AutoTokenizer\n",
|
273 |
-
"tokenizer = AutoTokenizer.from_pretrained(\"openai/clip-vit-large-patch14\", clean_up_tokenization_spaces = False)\n",
|
274 |
-
"\n",
|
275 |
-
"# @markdown Write name of token to match against\n",
|
276 |
-
"token_name = \"banana \" # @param {type:'string',\"placeholder\":\"leave empty for random value token\"}\n",
|
277 |
-
"\n",
|
278 |
-
"prompt = token_name\n",
|
279 |
-
"# @markdown (optional) Mix the token with something else\n",
|
280 |
-
"mix_with = \"\" # @param {\"type\":\"string\",\"placeholder\":\"leave empty for random value token\"}\n",
|
281 |
-
"mix_method = \"None\" # @param [\"None\" , \"Average\", \"Subtract\"] {allow-input: true}\n",
|
282 |
-
"w = 0.5 # @param {type:\"slider\", min:0, max:1, step:0.01}\n",
|
283 |
-
"# @markdown Limit char size of included token\n",
|
284 |
-
"\n",
|
285 |
-
"min_char_size = 0 # param {type:\"slider\", min:0, max: 50, step:1}\n",
|
286 |
-
"char_range = 50 # param {type:\"slider\", min:0, max: 50, step:1}\n",
|
287 |
-
"\n",
|
288 |
-
"tokenizer_output = tokenizer(text = prompt)\n",
|
289 |
-
"input_ids = tokenizer_output['input_ids']\n",
|
290 |
-
"id_A = input_ids[1]\n",
|
291 |
-
"A = torch.tensor(token[id_A])\n",
|
292 |
-
"A = A/A.norm(p=2, dim=-1, keepdim=True)\n",
|
293 |
-
"#-----#\n",
|
294 |
-
"tokenizer_output = tokenizer(text = mix_with)\n",
|
295 |
-
"input_ids = tokenizer_output['input_ids']\n",
|
296 |
-
"id_C = input_ids[1]\n",
|
297 |
-
"C = torch.tensor(token[id_C])\n",
|
298 |
-
"C = C/C.norm(p=2, dim=-1, keepdim=True)\n",
|
299 |
-
"#-----#\n",
|
300 |
-
"sim_AC = torch.dot(A,C)\n",
|
301 |
-
"#-----#\n",
|
302 |
-
"print(input_ids)\n",
|
303 |
-
"#-----#\n",
|
304 |
-
"\n",
|
305 |
-
"#if no imput exists we just randomize the entire thing\n",
|
306 |
-
"if (prompt == \"\"):\n",
|
307 |
-
" id_A = -1\n",
|
308 |
-
" print(\"Tokenized prompt tensor A is a random valued tensor with no ID\")\n",
|
309 |
-
" R = torch.rand(A.shape)\n",
|
310 |
-
" R = R/R.norm(p=2, dim=-1, keepdim=True)\n",
|
311 |
-
" A = R\n",
|
312 |
-
" name_A = 'random_A'\n",
|
313 |
-
"\n",
|
314 |
-
"#if no imput exists we just randomize the entire thing\n",
|
315 |
-
"if (mix_with == \"\"):\n",
|
316 |
-
" id_C = -1\n",
|
317 |
-
" print(\"Tokenized prompt 'mix_with' tensor C is a random valued tensor with no ID\")\n",
|
318 |
-
" R = torch.rand(A.shape)\n",
|
319 |
-
" R = R/R.norm(p=2, dim=-1, keepdim=True)\n",
|
320 |
-
" C = R\n",
|
321 |
-
" name_C = 'random_C'\n",
|
322 |
-
"\n",
|
323 |
-
"name_A = \"A of random type\"\n",
|
324 |
-
"if (id_A>-1):\n",
|
325 |
-
" name_A = vocab(id_A)\n",
|
326 |
-
"\n",
|
327 |
-
"name_C = \"token C of random type\"\n",
|
328 |
-
"if (id_C>-1):\n",
|
329 |
-
" name_C = vocab(id_C)\n",
|
330 |
-
"\n",
|
331 |
-
"print(f\"The similarity between A '{name_A}' and C '{name_C}' is {round(sim_AC.item()*100,2)} %\")\n",
|
332 |
-
"\n",
|
333 |
-
"if (mix_method == \"None\"):\n",
|
334 |
-
" print(\"No operation\")\n",
|
335 |
-
"\n",
|
336 |
-
"if (mix_method == \"Average\"):\n",
|
337 |
-
" A = w*A + (1-w)*C\n",
|
338 |
-
" _A = LA.vector_norm(A, ord=2)\n",
|
339 |
-
" print(f\"Tokenized prompt tensor A '{name_A}' token has been recalculated as A = w*A + (1-w)*C , where C is '{name_C}' token , for w = {w} \")\n",
|
340 |
-
"\n",
|
341 |
-
"if (mix_method == \"Subtract\"):\n",
|
342 |
-
" tmp = w*A - (1-w)*C\n",
|
343 |
-
" tmp = tmp/tmp.norm(p=2, dim=-1, keepdim=True)\n",
|
344 |
-
" A = tmp\n",
|
345 |
-
" #//---//\n",
|
346 |
-
" print(f\"Tokenized prompt tensor A '{name_A}' token has been recalculated as A = _A*norm(w*A - (1-w)*C) , where C is '{name_C}' token , for w = {w} \")\n",
|
347 |
-
"\n",
|
348 |
-
"#OPTIONAL : Add/subtract + normalize above result with another token. Leave field empty to get a random value tensor\n",
|
349 |
-
"\n",
|
350 |
-
"dots = torch.zeros(NUM_TOKENS)\n",
|
351 |
-
"for index in range(NUM_TOKENS):\n",
|
352 |
-
" id_B = index\n",
|
353 |
-
" B = torch.tensor(token[id_B])\n",
|
354 |
-
" B = B/B.norm(p=2, dim=-1, keepdim=True)\n",
|
355 |
-
" sim_AB = torch.dot(A,B)\n",
|
356 |
-
" dots[index] = sim_AB\n",
|
357 |
-
"\n",
|
358 |
-
"\n",
|
359 |
-
"sorted, indices = torch.sort(dots,dim=0 , descending=True)\n",
|
360 |
-
"#----#\n",
|
361 |
-
"if (mix_method == \"Average\"):\n",
|
362 |
-
" print(f'Calculated all cosine-similarities between the average of token {name_A} and {name_C} with Id_A = {id_A} and mixed Id_C = {id_C} as a 1x{sorted.shape[0]} tensor')\n",
|
363 |
-
"if (mix_method == \"Subtract\"):\n",
|
364 |
-
" print(f'Calculated all cosine-similarities between the subtract of token {name_A} and {name_C} with Id_A = {id_A} and mixed Id_C = {id_C} as a 1x{sorted.shape[0]} tensor')\n",
|
365 |
-
"if (mix_method == \"None\"):\n",
|
366 |
-
" print(f'Calculated all cosine-similarities between the token {name_A} with Id_A = {id_A} with the the rest of the {NUM_TOKENS} tokens as a 1x{sorted.shape[0]} tensor')\n",
|
367 |
-
"\n",
|
368 |
-
"#Produce a list id IDs that are most similiar to the prompt ID at positiion 1 based on above result\n",
|
369 |
-
"\n",
|
370 |
-
"# @markdown Set print options\n",
|
371 |
-
"list_size = 100 # @param {type:'number'}\n",
|
372 |
-
"print_ID = False # @param {type:\"boolean\"}\n",
|
373 |
-
"print_Similarity = True # @param {type:\"boolean\"}\n",
|
374 |
-
"print_Name = True # @param {type:\"boolean\"}\n",
|
375 |
-
"print_Divider = True # @param {type:\"boolean\"}\n",
|
376 |
-
"\n",
|
377 |
-
"\n",
|
378 |
-
"if (print_Divider):\n",
|
379 |
-
" print('//---//')\n",
|
380 |
-
"\n",
|
381 |
-
"print('')\n",
|
382 |
-
"print('Here is the result : ')\n",
|
383 |
-
"print('')\n",
|
384 |
-
"\n",
|
385 |
-
"for index in range(list_size):\n",
|
386 |
-
" id = indices[index].item()\n",
|
387 |
-
" if (print_Name):\n",
|
388 |
-
" print(f'{vocab(id)}') # vocab item\n",
|
389 |
-
" if (print_ID):\n",
|
390 |
-
" print(f'ID = {id}') # IDs\n",
|
391 |
-
" if (print_Similarity):\n",
|
392 |
-
" print(f'similiarity = {round(sorted[index].item()*100,2)} %')\n",
|
393 |
-
" if (print_Divider):\n",
|
394 |
-
" print('--------')\n",
|
395 |
-
"\n",
|
396 |
-
"#Print the sorted list from above result\n",
|
397 |
-
"\n",
|
398 |
-
"#The prompt will be enclosed with the <|start-of-text|> and <|end-of-text|> tokens, which is why output will be [49406, ... , 49407].\n",
|
399 |
"\n",
|
400 |
-
"
|
401 |
-
"\n",
|
402 |
-
"# Save results as .db file\n",
|
403 |
-
"import shelve\n",
|
404 |
-
"VOCAB_FILENAME = 'tokens_most_similiar_to_' + name_A.replace('</w>','').strip()\n",
|
405 |
-
"d = shelve.open(VOCAB_FILENAME)\n",
|
406 |
-
"#NUM TOKENS == 49407\n",
|
407 |
-
"for index in range(NUM_TOKENS):\n",
|
408 |
-
" #print(d[f'{index}']) #<-----Use this to read values from the .db file\n",
|
409 |
-
" d[f'{index}']= vocab(indices[index].item()) #<---- write values to .db file\n",
|
410 |
-
"#----#\n",
|
411 |
-
"d.close() #close the file\n",
|
412 |
-
"# See this link for additional stuff to do with shelve: https://docs.python.org/3/library/shelve.html"
|
413 |
],
|
414 |
"metadata": {
|
415 |
-
"id": "
|
416 |
-
"cellView": "form"
|
417 |
},
|
418 |
"execution_count": null,
|
419 |
"outputs": []
|
@@ -424,7 +140,6 @@
|
|
424 |
"# @title 📝 Get Prompt text_encoding similarity to the pre-calc. text_encodings\n",
|
425 |
"prompt = \" a fast car on the road \" # @param {\"type\":\"string\",\"placeholder\":\"Write a prompt\"}\n",
|
426 |
"\n",
|
427 |
-
"\n",
|
428 |
"from transformers import AutoTokenizer\n",
|
429 |
"tokenizer = AutoTokenizer.from_pretrained(\"openai/clip-vit-large-patch14\", clean_up_tokenization_spaces = False)\n",
|
430 |
"from transformers import CLIPProcessor, CLIPModel\n",
|
@@ -438,47 +153,13 @@
|
|
438 |
"name_A = prompt\n",
|
439 |
"#------#\n",
|
440 |
"\n",
|
441 |
-
"
|
442 |
-
"
|
443 |
-
"
|
444 |
-
"
|
445 |
-
"NUM_PREFIX_LISTS = 1\n",
|
446 |
-
"dots = results_sim = torch.zeros(RANGE*NUM_PREFIX_LISTS)\n",
|
447 |
-
"for _PREFIX_ENC_VOCAB in PREFIX_ENC_VOCAB:\n",
|
448 |
-
" _iters = _iters + 1\n",
|
449 |
-
" d = shelve.open(_PREFIX_ENC_VOCAB)\n",
|
450 |
-
" for _index in range(RANGE):\n",
|
451 |
-
" index = _iters*RANGE + _index\n",
|
452 |
-
" text_features = d[f'{_index}']\n",
|
453 |
-
" text_features = text_features/text_features.norm(p=2, dim=-1, keepdim=True)\n",
|
454 |
-
" sim = torch.nn.functional.cosine_similarity(text_features, text_features_A)\n",
|
455 |
-
" dots[index] = sim\n",
|
456 |
-
" #----#\n",
|
457 |
-
" d.close() #close the file\n",
|
458 |
-
"#------#\n",
|
459 |
-
"prefix_sorted, prefix_indices = torch.sort(dots,dim=0 , descending=True)\n",
|
460 |
"#------#\n",
|
461 |
"\n",
|
462 |
-
"
|
463 |
-
"import shelve\n",
|
464 |
-
"_iters = -1\n",
|
465 |
-
"RANGE = NUM_SUFFIX\n",
|
466 |
-
"dots = results_sim = torch.zeros(RANGE*NUM_SUFFIX_LISTS)\n",
|
467 |
-
"for _SUFFIX_ENC_VOCAB in SUFFIX_ENC_VOCAB:\n",
|
468 |
-
" _iters = _iters + 1\n",
|
469 |
-
" d = shelve.open(_SUFFIX_ENC_VOCAB)\n",
|
470 |
-
" for _index in range(RANGE):\n",
|
471 |
-
" index = _iters*RANGE + _index\n",
|
472 |
-
" text_features = d[f'{_index}']\n",
|
473 |
-
" text_features = text_features/text_features.norm(p=2, dim=-1, keepdim=True)\n",
|
474 |
-
" sim = torch.nn.functional.cosine_similarity(text_features, text_features_A)\n",
|
475 |
-
" dots[index] = sim\n",
|
476 |
-
" #----#\n",
|
477 |
-
" d.close() #close the file\n",
|
478 |
-
"#------#\n",
|
479 |
-
"suffix_sorted, suffix_indices = torch.sort(dots,dim=0 , descending=True)\n",
|
480 |
-
"#------#\n",
|
481 |
-
"\n"
|
482 |
],
|
483 |
"metadata": {
|
484 |
"id": "xc-PbIYF428y"
|
@@ -493,75 +174,43 @@
|
|
493 |
"list_size = 100 # @param {type:'number'}\n",
|
494 |
"start_at_index = 0 # @param {type:'number'}\n",
|
495 |
"print_Similarity = True # @param {type:\"boolean\"}\n",
|
496 |
-
"
|
497 |
"print_Prefix = True # @param {type:\"boolean\"}\n",
|
498 |
"print_Descriptions = True # @param {type:\"boolean\"}\n",
|
499 |
-
"compact_Output =
|
|
|
500 |
"\n",
|
501 |
"# title Show the 100 most similiar suffix and prefix text-encodings to the text encoding\n",
|
502 |
"RANGE = list_size\n",
|
503 |
-
"
|
|
|
|
|
|
|
|
|
504 |
"_sims = '{'\n",
|
505 |
-
"for
|
506 |
-
" if
|
507 |
-
"
|
508 |
-
"
|
509 |
-
"
|
510 |
-
" if(id>NUM_SUFFIX*1):\n",
|
511 |
-
" ahead = \"a \"\n",
|
512 |
-
" if(id>NUM_SUFFIX*2):\n",
|
513 |
-
" ahead = \"by \"\n",
|
514 |
-
" if(id>NUM_SUFFIX*3):\n",
|
515 |
-
" ahead = \"\"\n",
|
516 |
-
" behind = \"like\"\n",
|
517 |
-
" id = _modulus(id,NUM_SUFFIX)\n",
|
518 |
-
" #------#\n",
|
519 |
-
" sim = suffix_sorted[index].item()\n",
|
520 |
-
" name = ahead + get_suffix(id) + behind\n",
|
521 |
-
" if(get_suffix(id) == ' '): name = ahead + f'{id}' + behind\n",
|
522 |
-
" _suffixes = _suffixes + name + '|'\n",
|
523 |
-
" _sims = _sims + f'{round(sim*100,2)} %' + '|'\n",
|
524 |
"#------#\n",
|
525 |
-
"
|
526 |
-
"
|
527 |
"#------#\n",
|
528 |
"\n",
|
529 |
-
"\n",
|
530 |
-
"
|
531 |
-
"sims = _sims\n",
|
532 |
-
"if(not print_Suffix): suffixes = ''\n",
|
533 |
-
"if(not print_Similarity): sims = ''\n",
|
534 |
"\n",
|
535 |
"if(not compact_Output):\n",
|
536 |
" if(print_Descriptions):\n",
|
537 |
-
" print(f'The {start_at_index}-{start_at_index + RANGE} most similiar
|
538 |
-
" print(f'The {start_at_index}-{start_at_index + RANGE} similarity % for
|
539 |
" print('')\n",
|
540 |
" else:\n",
|
541 |
-
" print(
|
542 |
-
"#-------#\n",
|
543 |
-
"\n",
|
544 |
-
"_prefixes = '{'\n",
|
545 |
-
"for index in range(start_at_index + RANGE):\n",
|
546 |
-
" if index < start_at_index : continue\n",
|
547 |
-
" id = f'{prefix_indices[index]}'\n",
|
548 |
-
" #sim = prefix_sorted[index]\n",
|
549 |
-
" name = get_prefix(id)\n",
|
550 |
-
" _prefixes = _prefixes + name + '|'\n",
|
551 |
-
"#------#\n",
|
552 |
-
"_prefixes = (_prefixes + '}').replace('|}', '}')\n",
|
553 |
-
"\n",
|
554 |
-
"\n",
|
555 |
-
"prefixes = _prefixes\n",
|
556 |
-
"if(not print_Prefix): prefixes = ''\n",
|
557 |
-
"\n",
|
558 |
-
"if(print_Descriptions):\n",
|
559 |
-
" print(f'The {start_at_index}-{start_at_index + RANGE} most similiar prefixes to prompt : ' + prefixes)\n",
|
560 |
"else:\n",
|
561 |
-
"
|
562 |
-
"
|
563 |
-
" else:\n",
|
564 |
-
" print(prefixes)"
|
565 |
],
|
566 |
"metadata": {
|
567 |
"id": "_vnVbxcFf7WV"
|
@@ -589,6 +238,9 @@
|
|
589 |
" for k, v in uploaded.items():\n",
|
590 |
" open(k, 'wb').write(v)\n",
|
591 |
" return list(uploaded.keys())\n",
|
|
|
|
|
|
|
592 |
"#Get image\n",
|
593 |
"# You can use \"http://images.cocodataset.org/val2017/000000039769.jpg\" for testing\n",
|
594 |
"image_url = \"\" # @param {\"type\":\"string\",\"placeholder\":\"leave empty for local upload (scroll down to see it)\"}\n",
|
@@ -608,11 +260,11 @@
|
|
608 |
" if colab_image_path == \"\":\n",
|
609 |
" keys = upload_files()\n",
|
610 |
" for key in keys:\n",
|
611 |
-
" image_A = cv2.imread(
|
612 |
-
" colab_image_path =
|
613 |
-
" image_path =
|
614 |
" else:\n",
|
615 |
-
" image_A = cv2.imread(
|
616 |
"else:\n",
|
617 |
" image_A = Image.open(requests.get(image_url, stream=True).raw)\n",
|
618 |
"#------#\n",
|
@@ -622,13 +274,13 @@
|
|
622 |
],
|
623 |
"metadata": {
|
624 |
"id": "ke6mZ1RZDOeB",
|
625 |
-
"outputId": "
|
626 |
"colab": {
|
627 |
"base_uri": "https://localhost:8080/",
|
628 |
"height": 1000
|
629 |
}
|
630 |
},
|
631 |
-
"execution_count":
|
632 |
"outputs": [
|
633 |
{
|
634 |
"output_type": "display_data",
|
@@ -647,14 +299,6 @@
|
|
647 |
"source": [
|
648 |
"# @title 🖼️ Get image_encoding similarity to the pre-calc. text_encodings\n",
|
649 |
"\n",
|
650 |
-
"list_size = 100 # @param {type:'number'}\n",
|
651 |
-
"start_at_index = 0 # @param {type:'number'}\n",
|
652 |
-
"print_Similarity = True # @param {type:\"boolean\"}\n",
|
653 |
-
"print_Suffix = True # @param {type:\"boolean\"}\n",
|
654 |
-
"print_Prefix = True # @param {type:\"boolean\"}\n",
|
655 |
-
"print_Descriptions = True # @param {type:\"boolean\"}\n",
|
656 |
-
"compact_Output = False # @param {type:\"boolean\"}\n",
|
657 |
-
"\n",
|
658 |
"from transformers import AutoTokenizer\n",
|
659 |
"tokenizer = AutoTokenizer.from_pretrained(\"openai/clip-vit-large-patch14\", clean_up_tokenization_spaces = False)\n",
|
660 |
"from transformers import CLIPProcessor, CLIPModel\n",
|
@@ -668,48 +312,14 @@
|
|
668 |
"name_A = \"the image\"\n",
|
669 |
"#-----#\n",
|
670 |
"\n",
|
671 |
-
"
|
672 |
-
"
|
673 |
-
"
|
674 |
-
"
|
675 |
-
"
|
676 |
-
"
|
677 |
-
"
|
678 |
-
"
|
679 |
-
" d = shelve.open(_PREFIX_ENC_VOCAB)\n",
|
680 |
-
" for _index in range(RANGE):\n",
|
681 |
-
" index = _iters*RANGE + _index\n",
|
682 |
-
" text_features = d[f'{_index}']\n",
|
683 |
-
" logit_scale = model.logit_scale.exp()\n",
|
684 |
-
" torch.matmul(text_features, image_features.t()) * logit_scale\n",
|
685 |
-
" sim = torch.nn.functional.cosine_similarity(text_features, image_features) * logit_scale\n",
|
686 |
-
" dots[index] = sim\n",
|
687 |
-
" #----#\n",
|
688 |
-
" d.close() #close the file\n",
|
689 |
-
"#------#\n",
|
690 |
-
"prefix_sorted, prefix_indices = torch.sort(dots,dim=0 , descending=True)\n",
|
691 |
-
"#------#\n",
|
692 |
-
"\n",
|
693 |
-
"# Load the .db file for prefix encodings\n",
|
694 |
-
"import shelve\n",
|
695 |
-
"_iters = -1\n",
|
696 |
-
"RANGE = NUM_SUFFIX\n",
|
697 |
-
"dots = results_sim = torch.zeros(RANGE*NUM_SUFFIX_LISTS)\n",
|
698 |
-
"for _SUFFIX_ENC_VOCAB in SUFFIX_ENC_VOCAB:\n",
|
699 |
-
" _iters = _iters + 1\n",
|
700 |
-
" d = shelve.open(_SUFFIX_ENC_VOCAB)\n",
|
701 |
-
" for _index in range(RANGE):\n",
|
702 |
-
" index = _iters*RANGE + _index\n",
|
703 |
-
" text_features = d[f'{_index}']\n",
|
704 |
-
" logit_scale = model.logit_scale.exp()\n",
|
705 |
-
" torch.matmul(text_features, image_features.t()) * logit_scale\n",
|
706 |
-
" sim = torch.nn.functional.cosine_similarity(text_features, image_features) * logit_scale\n",
|
707 |
-
" dots[index] = sim\n",
|
708 |
-
" #----#\n",
|
709 |
-
" d.close() #close the file\n",
|
710 |
-
"#------#\n",
|
711 |
-
"suffix_sorted, suffix_indices = torch.sort(dots,dim=0 , descending=True)\n",
|
712 |
-
"#------#"
|
713 |
],
|
714 |
"metadata": {
|
715 |
"id": "rebogpoyOG8k"
|
@@ -722,80 +332,312 @@
|
|
722 |
"source": [
|
723 |
"# @title 🖼️ Print the results\n",
|
724 |
"list_size = 100 # @param {type:'number'}\n",
|
725 |
-
"start_at_index = 0 # @param {type:'number'}\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
726 |
"print_Similarity = True # @param {type:\"boolean\"}\n",
|
727 |
-
"
|
728 |
-
"
|
729 |
-
"print_Descriptions = True # @param {type:\"boolean\"}\n",
|
730 |
-
"compact_Output = False # @param {type:\"boolean\"}\n",
|
731 |
"\n",
|
732 |
-
"# title Show the 100 most similiar suffix and prefix text-encodings to the text encoding\n",
|
733 |
-
"RANGE = list_size\n",
|
734 |
-
"_suffixes = '{'\n",
|
735 |
-
"_sims = '{'\n",
|
736 |
-
"for index in range(start_at_index + RANGE):\n",
|
737 |
-
" if index < start_at_index : continue\n",
|
738 |
-
" id = int(suffix_indices[index])\n",
|
739 |
-
" ahead = \"from \"\n",
|
740 |
-
" behind = \"\"\n",
|
741 |
-
" if(id>NUM_SUFFIX*1):\n",
|
742 |
-
" ahead = \"a \"\n",
|
743 |
-
" if(id>NUM_SUFFIX*2):\n",
|
744 |
-
" ahead = \"by \"\n",
|
745 |
-
" if(id>NUM_SUFFIX*3):\n",
|
746 |
-
" ahead = \"\"\n",
|
747 |
-
" behind = \"like\"\n",
|
748 |
-
" id = _modulus(id,NUM_SUFFIX)\n",
|
749 |
-
" #------#\n",
|
750 |
-
" sim = suffix_sorted[index].item()\n",
|
751 |
-
" name = ahead + get_suffix(id) + behind\n",
|
752 |
-
" if(get_suffix(id) == ' '): name = ahead + f'{id}' + behind\n",
|
753 |
-
" _suffixes = _suffixes + name + '|'\n",
|
754 |
-
" _sims = _sims + f'{round(sim*100,2)} %' + '|'\n",
|
755 |
-
"#------#\n",
|
756 |
-
"_suffixes = (_suffixes + '}').replace('|}', '}')\n",
|
757 |
-
"_sims = (_sims + '}').replace('|}', '}')\n",
|
758 |
-
"#------#\n",
|
759 |
"\n",
|
|
|
|
|
760 |
"\n",
|
761 |
-
"
|
762 |
-
"
|
763 |
-
"
|
764 |
-
"if(not print_Similarity): sims = ''\n",
|
765 |
"\n",
|
766 |
-
"
|
767 |
-
"
|
768 |
-
"
|
769 |
-
" print(f'
|
770 |
-
"
|
771 |
-
"
|
772 |
-
"
|
773 |
-
"
|
|
|
|
|
774 |
"\n",
|
775 |
-
"
|
776 |
-
"for index in range(start_at_index + RANGE):\n",
|
777 |
-
" if index < start_at_index : continue\n",
|
778 |
-
" id = f'{prefix_indices[index]}'\n",
|
779 |
-
" #sim = prefix_sorted[index]\n",
|
780 |
-
" name = get_prefix(id)\n",
|
781 |
-
" _prefixes = _prefixes + name + '|'\n",
|
782 |
-
"#------#\n",
|
783 |
-
"_prefixes = (_prefixes + '}').replace('|}', '}')\n",
|
784 |
"\n",
|
|
|
785 |
"\n",
|
786 |
-
"
|
787 |
-
"if(not print_Prefix): prefixes = ''\n",
|
788 |
"\n",
|
789 |
-
"
|
790 |
-
"
|
791 |
-
"
|
792 |
-
"
|
793 |
-
"
|
794 |
-
"
|
795 |
-
"
|
|
|
|
|
|
|
|
|
796 |
],
|
797 |
"metadata": {
|
798 |
-
"id": "
|
799 |
},
|
800 |
"execution_count": null,
|
801 |
"outputs": []
|
@@ -1490,6 +1332,198 @@
|
|
1490 |
"metadata": {
|
1491 |
"id": "njeJx_nSSA8H"
|
1492 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1493 |
}
|
1494 |
]
|
1495 |
}
|
|
|
29 |
"cell_type": "code",
|
30 |
"source": [
|
31 |
"# @title ✳️ Load/initialize values\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
"#Imports\n",
|
33 |
"!pip install safetensors\n",
|
34 |
"from safetensors.torch import load_file\n",
|
|
|
35 |
"import json , os , shelve , torch\n",
|
36 |
"import pandas as pd\n",
|
37 |
"#----#\n",
|
|
|
95 |
" continue\n",
|
96 |
" #-------#\n",
|
97 |
" #--------#\n",
|
98 |
+
" #_text_encodings.close() #close the text_encodings file\n",
|
99 |
" file_index = file_index + 1\n",
|
100 |
" #----------#\n",
|
101 |
" NUM_ITEMS = index\n",
|
|
|
111 |
"metadata": {
|
112 |
"id": "rUXQ73IbonHY"
|
113 |
},
|
114 |
+
"execution_count": null,
|
115 |
"outputs": []
|
116 |
},
|
117 |
{
|
118 |
"cell_type": "code",
|
119 |
"source": [
|
120 |
+
"# @title 📝 Choose which vocab to load\n",
|
121 |
+
"use_vocab = '🦜 fusion-t2i-prompt-features' # @param ['🦜 fusion-t2i-prompt-features']\n",
|
122 |
+
"\n",
|
123 |
"%cd /content/\n",
|
124 |
"!git clone https://huggingface.co/datasets/codeShare/text-to-image-prompts\n",
|
125 |
"#------#\n",
|
126 |
"path = '/content/text-to-image-prompts/civitai-prompts/green'\n",
|
127 |
+
"prompts , text_encodings, NUM_VOCAB_ITEMS = getPrompts(path)\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
"\n",
|
129 |
+
"append_Whitespace = True\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
],
|
131 |
"metadata": {
|
132 |
+
"id": "ZMG4CThUAmwW"
|
|
|
133 |
},
|
134 |
"execution_count": null,
|
135 |
"outputs": []
|
|
|
140 |
"# @title 📝 Get Prompt text_encoding similarity to the pre-calc. text_encodings\n",
|
141 |
"prompt = \" a fast car on the road \" # @param {\"type\":\"string\",\"placeholder\":\"Write a prompt\"}\n",
|
142 |
"\n",
|
|
|
143 |
"from transformers import AutoTokenizer\n",
|
144 |
"tokenizer = AutoTokenizer.from_pretrained(\"openai/clip-vit-large-patch14\", clean_up_tokenization_spaces = False)\n",
|
145 |
"from transformers import CLIPProcessor, CLIPModel\n",
|
|
|
153 |
"name_A = prompt\n",
|
154 |
"#------#\n",
|
155 |
"\n",
|
156 |
+
"sims = torch.zeros(NUM_VOCAB_ITEMS)\n",
|
157 |
+
"for index in range(NUM_VOCAB_ITEMS):\n",
|
158 |
+
" text_features = text_encodings[f'{index}']\n",
|
159 |
+
" sims[index] = torch.nn.functional.cosine_similarity(text_features, text_features_A)\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
"#------#\n",
|
161 |
"\n",
|
162 |
+
"sorted , indices = torch.sort(sims,dim=0 , descending=True)"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
],
|
164 |
"metadata": {
|
165 |
"id": "xc-PbIYF428y"
|
|
|
174 |
"list_size = 100 # @param {type:'number'}\n",
|
175 |
"start_at_index = 0 # @param {type:'number'}\n",
|
176 |
"print_Similarity = True # @param {type:\"boolean\"}\n",
|
177 |
+
"print_Prompts = True # @param {type:\"boolean\"}\n",
|
178 |
"print_Prefix = True # @param {type:\"boolean\"}\n",
|
179 |
"print_Descriptions = True # @param {type:\"boolean\"}\n",
|
180 |
+
"compact_Output = True # @param {type:\"boolean\"}\n",
|
181 |
+
"newline_Separator = True # @param {type:\"boolean\"}\n",
|
182 |
"\n",
|
183 |
"# title Show the 100 most similiar suffix and prefix text-encodings to the text encoding\n",
|
184 |
"RANGE = list_size\n",
|
185 |
+
"separator = '|'\n",
|
186 |
+
"if append_Whitespace : separator = ' ' + separator\n",
|
187 |
+
"if newline_Separator : separator = separator + '\\n'\n",
|
188 |
+
"\n",
|
189 |
+
"_prompts = '{'\n",
|
190 |
"_sims = '{'\n",
|
191 |
+
"for _index in range(start_at_index + RANGE):\n",
|
192 |
+
" if _index < start_at_index : continue\n",
|
193 |
+
" index = indices[_index]\n",
|
194 |
+
" _prompts = _prompts + prompts[f'{index}'] + separator\n",
|
195 |
+
" _sims = _sims + f'{round(100*sims[index].item(), 2)} %' + separator\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
196 |
"#------#\n",
|
197 |
+
"__prompts = (_prompts + '}').replace(separator + '}', '}')\n",
|
198 |
+
"__sims = (_sims + '}').replace(separator + '}', '}')\n",
|
199 |
"#------#\n",
|
200 |
"\n",
|
201 |
+
"if(not print_Prompts): __prompts = ''\n",
|
202 |
+
"if(not print_Similarity): __sims = ''\n",
|
|
|
|
|
|
|
203 |
"\n",
|
204 |
"if(not compact_Output):\n",
|
205 |
" if(print_Descriptions):\n",
|
206 |
+
" print(f'The {start_at_index}-{start_at_index + RANGE} most similiar items to prompt : \\n\\n ' + __prompts)\n",
|
207 |
+
" print(f'The {start_at_index}-{start_at_index + RANGE} similarity % for items : \\n\\n' + __sims)\n",
|
208 |
" print('')\n",
|
209 |
" else:\n",
|
210 |
+
" print(__prompts)\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
211 |
"else:\n",
|
212 |
+
" print(__prompts)\n",
|
213 |
+
"#-------#"
|
|
|
|
|
214 |
],
|
215 |
"metadata": {
|
216 |
"id": "_vnVbxcFf7WV"
|
|
|
238 |
" for k, v in uploaded.items():\n",
|
239 |
" open(k, 'wb').write(v)\n",
|
240 |
" return list(uploaded.keys())\n",
|
241 |
+
"\n",
|
242 |
+
"\n",
|
243 |
+
"colab_image_folder = '/content/text-to-image-prompts/images/'\n",
|
244 |
"#Get image\n",
|
245 |
"# You can use \"http://images.cocodataset.org/val2017/000000039769.jpg\" for testing\n",
|
246 |
"image_url = \"\" # @param {\"type\":\"string\",\"placeholder\":\"leave empty for local upload (scroll down to see it)\"}\n",
|
|
|
260 |
" if colab_image_path == \"\":\n",
|
261 |
" keys = upload_files()\n",
|
262 |
" for key in keys:\n",
|
263 |
+
" image_A = cv2.imread(colab_image_folder + key)\n",
|
264 |
+
" colab_image_path = colab_image_folder + key\n",
|
265 |
+
" image_path = colab_image_folder + key\n",
|
266 |
" else:\n",
|
267 |
+
" image_A = cv2.imread(colab_image_folder + colab_image_path)\n",
|
268 |
"else:\n",
|
269 |
" image_A = Image.open(requests.get(image_url, stream=True).raw)\n",
|
270 |
"#------#\n",
|
|
|
274 |
],
|
275 |
"metadata": {
|
276 |
"id": "ke6mZ1RZDOeB",
|
277 |
+
"outputId": "8ced884a-bf07-4fcb-c108-0f873d71a73c",
|
278 |
"colab": {
|
279 |
"base_uri": "https://localhost:8080/",
|
280 |
"height": 1000
|
281 |
}
|
282 |
},
|
283 |
+
"execution_count": 4,
|
284 |
"outputs": [
|
285 |
{
|
286 |
"output_type": "display_data",
|
|
|
299 |
"source": [
|
300 |
"# @title 🖼️ Get image_encoding similarity to the pre-calc. text_encodings\n",
|
301 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
302 |
"from transformers import AutoTokenizer\n",
|
303 |
"tokenizer = AutoTokenizer.from_pretrained(\"openai/clip-vit-large-patch14\", clean_up_tokenization_spaces = False)\n",
|
304 |
"from transformers import CLIPProcessor, CLIPModel\n",
|
|
|
312 |
"name_A = \"the image\"\n",
|
313 |
"#-----#\n",
|
314 |
"\n",
|
315 |
+
"sims = torch.zeros(NUM_VOCAB_ITEMS)\n",
|
316 |
+
"for index in range(NUM_VOCAB_ITEMS):\n",
|
317 |
+
" text_features = text_encodings[f'{index}']\n",
|
318 |
+
" logit_scale = model.logit_scale.exp()\n",
|
319 |
+
" torch.matmul(text_features, image_features.t()) * logit_scale\n",
|
320 |
+
" sims[index] = torch.nn.functional.cosine_similarity(text_features, image_features)\n",
|
321 |
+
"#-------#\n",
|
322 |
+
"sorted , indices = torch.sort(sims,dim=0 , descending=True)"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
323 |
],
|
324 |
"metadata": {
|
325 |
"id": "rebogpoyOG8k"
|
|
|
332 |
"source": [
|
333 |
"# @title 🖼️ Print the results\n",
|
334 |
"list_size = 100 # @param {type:'number'}\n",
|
335 |
+
"start_at_index = 0 # @param {type:'number'}\n",
|
336 |
+
"print_Similarity = True # @param {type:\"boolean\"}\n",
|
337 |
+
"print_Prompts = True # @param {type:\"boolean\"}\n",
|
338 |
+
"print_Prefix = True # @param {type:\"boolean\"}\n",
|
339 |
+
"print_Descriptions = True # @param {type:\"boolean\"}\n",
|
340 |
+
"compact_Output = True # @param {type:\"boolean\"}\n",
|
341 |
+
"newline_Separator = True # @param {type:\"boolean\"}\n",
|
342 |
+
"\n",
|
343 |
+
"# title Show the 100 most similiar suffix and prefix text-encodings to the text encoding\n",
|
344 |
+
"RANGE = list_size\n",
|
345 |
+
"separator = '|'\n",
|
346 |
+
"if append_Whitespace : separator = ' ' + separator\n",
|
347 |
+
"if newline_Separator : separator = separator + '\\n'\n",
|
348 |
+
"\n",
|
349 |
+
"_prompts = '{'\n",
|
350 |
+
"_sims = '{'\n",
|
351 |
+
"for _index in range(start_at_index + RANGE):\n",
|
352 |
+
" if _index < start_at_index : continue\n",
|
353 |
+
" index = indices[_index]\n",
|
354 |
+
" _prompts = _prompts + prompts[f'{index}'] + separator\n",
|
355 |
+
" _sims = _sims + f'{round(100*sims[index].item(), 2)} %' + separator\n",
|
356 |
+
"#------#\n",
|
357 |
+
"__prompts = (_prompts + '}').replace(separator + '}', '}')\n",
|
358 |
+
"__sims = (_sims + '}').replace(separator + '}', '}')\n",
|
359 |
+
"#------#\n",
|
360 |
+
"\n",
|
361 |
+
"if(not print_Prompts): __prompts = ''\n",
|
362 |
+
"if(not print_Similarity): __sims = ''\n",
|
363 |
+
"\n",
|
364 |
+
"if(not compact_Output):\n",
|
365 |
+
" if(print_Descriptions):\n",
|
366 |
+
" print(f'The {start_at_index}-{start_at_index + RANGE} most similiar items to prompt : \\n\\n ' + __prompts)\n",
|
367 |
+
" print(f'The {start_at_index}-{start_at_index + RANGE} similarity % for items : \\n\\n' + __sims)\n",
|
368 |
+
" print('')\n",
|
369 |
+
" else:\n",
|
370 |
+
" print(__prompts)\n",
|
371 |
+
"else:\n",
|
372 |
+
" print(__prompts)\n",
|
373 |
+
"#-------#"
|
374 |
+
],
|
375 |
+
"metadata": {
|
376 |
+
"id": "JkzncP8SgKtS",
|
377 |
+
"outputId": "37351bed-c5e2-4554-c5e0-a9dc84da700b",
|
378 |
+
"colab": {
|
379 |
+
"base_uri": "https://localhost:8080/"
|
380 |
+
}
|
381 |
+
},
|
382 |
+
"execution_count": 6,
|
383 |
+
"outputs": [
|
384 |
+
{
|
385 |
+
"output_type": "stream",
|
386 |
+
"name": "stdout",
|
387 |
+
"text": [
|
388 |
+
"{beautiful avatar pictures |\n",
|
389 |
+
"purple hair crowned standing in storm background |\n",
|
390 |
+
"beautiful celebrity futuristic sci-fi |\n",
|
391 |
+
"by magali villeneuve |\n",
|
392 |
+
"visually striking spectacle inspired by the works |\n",
|
393 |
+
"visually striking spectacle inspired by the works |\n",
|
394 |
+
"a beautiful female warrior |\n",
|
395 |
+
"a sexy scifi warrior |\n",
|
396 |
+
"a sexy scifi warrior |\n",
|
397 |
+
"film still from halo live action adaptation |\n",
|
398 |
+
"cinematic film still from captain marvel |\n",
|
399 |
+
"beautiful female warrior |\n",
|
400 |
+
"beautiful female warrior |\n",
|
401 |
+
"film still from halo live-action movie adaptation |\n",
|
402 |
+
"outlandish costume design |\n",
|
403 |
+
"beautiful indian warrior queen |\n",
|
404 |
+
"of female space soldier |\n",
|
405 |
+
"blue light on her face she appears calm |\n",
|
406 |
+
"a female scifi warrior |\n",
|
407 |
+
"a female scifi warrior |\n",
|
408 |
+
"nebula in her streak hair |\n",
|
409 |
+
"of brown skinned indian warrior queen |\n",
|
410 |
+
"played by young dove cameron |\n",
|
411 |
+
"has runes on her body |\n",
|
412 |
+
"has runes on her body |\n",
|
413 |
+
"beautiful light makeup female sorceress |\n",
|
414 |
+
"a gorgeous female void thrall |\n",
|
415 |
+
"a gorgeous female void thrall |\n",
|
416 |
+
"beautiful female elf queen |\n",
|
417 |
+
"captivating mystique |\n",
|
418 |
+
"captivating mystique |\n",
|
419 |
+
"symbolizing her role as the goddess |\n",
|
420 |
+
"character integrated into the background |\n",
|
421 |
+
"swirling black light around the character |\n",
|
422 |
+
"very beautiful jean grey wearing |\n",
|
423 |
+
"lightly blued metal armor |\n",
|
424 |
+
"multiple different characters in the background |\n",
|
425 |
+
"cinematic still from conan |\n",
|
426 |
+
"yo person as dark elf queen |\n",
|
427 |
+
"trending at cgsociety |\n",
|
428 |
+
"femaleastronaut exalted human futuristic warrior |\n",
|
429 |
+
"femaleastronaut exalted human futuristic warrior |\n",
|
430 |
+
"epic fantasy greek priestess |\n",
|
431 |
+
"epic fantasy greek priestess |\n",
|
432 |
+
"female draenei world |\n",
|
433 |
+
"genetically engineered soldiers |\n",
|
434 |
+
"genetically engineered soldiers |\n",
|
435 |
+
"visually striking scene the lighting |\n",
|
436 |
+
"the female soldier marches in formation |\n",
|
437 |
+
"the female soldier marches in formation |\n",
|
438 |
+
"revealing costume design |\n",
|
439 |
+
"pandora_smith_magister |\n",
|
440 |
+
"pandora_smith_magister |\n",
|
441 |
+
"his sorceress in the back ground |\n",
|
442 |
+
"gorgeous muscular elven ukrainian |\n",
|
443 |
+
"gorgeous muscular elven ukrainian |\n",
|
444 |
+
"water elemental officer jenny |\n",
|
445 |
+
"norse female goddess |\n",
|
446 |
+
"matte fantasy painting |\n",
|
447 |
+
"periwinkle purple skin |\n",
|
448 |
+
"of norse female goddess |\n",
|
449 |
+
"mujer de ojos rojos y pelo azulado |\n",
|
450 |
+
"strength the battle scene around her |\n",
|
451 |
+
"a beautiful young redhead warrior |\n",
|
452 |
+
"a beautiful young redhead warrior |\n",
|
453 |
+
"moody cinematic epic concept art |\n",
|
454 |
+
"hypnotically beautiful wood elf in |\n",
|
455 |
+
"hypnotically beautiful wood elf in |\n",
|
456 |
+
"loraemmawatsonlora_v |\n",
|
457 |
+
"alphonse mucha cinematic epic + rule |\n",
|
458 |
+
"alphonse mucha cinematic epic + rule |\n",
|
459 |
+
"an actress standing behind |\n",
|
460 |
+
"fking_scifi_v amazing |\n",
|
461 |
+
"shine like sapphires |\n",
|
462 |
+
"female elemental water wizard |\n",
|
463 |
+
"female elemental water wizard |\n",
|
464 |
+
"beautiful character design |\n",
|
465 |
+
"female warriors protecting an underwater temple |\n",
|
466 |
+
"indian mary_winstead |\n",
|
467 |
+
"sarah kerrigan queen |\n",
|
468 |
+
"sarah kerrigan queen |\n",
|
469 |
+
"the female barbarian stands tall |\n",
|
470 |
+
"widowmaker from overwatchscarlett johannson |\n",
|
471 |
+
"intricate costume design |\n",
|
472 |
+
"award winning character concept art of |\n",
|
473 |
+
"movie still from braveheart |\n",
|
474 |
+
"movie still from braveheart |\n",
|
475 |
+
"an extremely beautiful young female elf |\n",
|
476 |
+
"an extremely beautiful young female elf |\n",
|
477 |
+
"cinematic character design d |\n",
|
478 |
+
"dark purple bodypaint |\n",
|
479 |
+
"dark purple bodypaint |\n",
|
480 |
+
"a daring reimagining |\n",
|
481 |
+
"captivating film still |\n",
|
482 |
+
"glowing dark blue skin |\n",
|
483 |
+
"of sexy cyberpunk female mage |\n",
|
484 |
+
"military men gaze at her longingly |\n",
|
485 |
+
"military men gaze at her longingly |\n",
|
486 |
+
"female space soldier |\n",
|
487 |
+
"norse goddess fighting in}\n"
|
488 |
+
]
|
489 |
+
}
|
490 |
+
]
|
491 |
+
},
|
492 |
+
{
|
493 |
+
"cell_type": "code",
|
494 |
+
"source": [
|
495 |
+
"# @title ⚡ Get similiar tokens (not updated yet)\n",
|
496 |
+
"import torch\n",
|
497 |
+
"from transformers import AutoTokenizer\n",
|
498 |
+
"tokenizer = AutoTokenizer.from_pretrained(\"openai/clip-vit-large-patch14\", clean_up_tokenization_spaces = False)\n",
|
499 |
+
"\n",
|
500 |
+
"# @markdown Write name of token to match against\n",
|
501 |
+
"token_name = \"banana \" # @param {type:'string',\"placeholder\":\"leave empty for random value token\"}\n",
|
502 |
+
"\n",
|
503 |
+
"prompt = token_name\n",
|
504 |
+
"# @markdown (optional) Mix the token with something else\n",
|
505 |
+
"mix_with = \"\" # @param {\"type\":\"string\",\"placeholder\":\"leave empty for random value token\"}\n",
|
506 |
+
"mix_method = \"None\" # @param [\"None\" , \"Average\", \"Subtract\"] {allow-input: true}\n",
|
507 |
+
"w = 0.5 # @param {type:\"slider\", min:0, max:1, step:0.01}\n",
|
508 |
+
"# @markdown Limit char size of included token\n",
|
509 |
+
"\n",
|
510 |
+
"min_char_size = 0 # param {type:\"slider\", min:0, max: 50, step:1}\n",
|
511 |
+
"char_range = 50 # param {type:\"slider\", min:0, max: 50, step:1}\n",
|
512 |
+
"\n",
|
513 |
+
"tokenizer_output = tokenizer(text = prompt)\n",
|
514 |
+
"input_ids = tokenizer_output['input_ids']\n",
|
515 |
+
"id_A = input_ids[1]\n",
|
516 |
+
"A = torch.tensor(token[id_A])\n",
|
517 |
+
"A = A/A.norm(p=2, dim=-1, keepdim=True)\n",
|
518 |
+
"#-----#\n",
|
519 |
+
"tokenizer_output = tokenizer(text = mix_with)\n",
|
520 |
+
"input_ids = tokenizer_output['input_ids']\n",
|
521 |
+
"id_C = input_ids[1]\n",
|
522 |
+
"C = torch.tensor(token[id_C])\n",
|
523 |
+
"C = C/C.norm(p=2, dim=-1, keepdim=True)\n",
|
524 |
+
"#-----#\n",
|
525 |
+
"sim_AC = torch.dot(A,C)\n",
|
526 |
+
"#-----#\n",
|
527 |
+
"print(input_ids)\n",
|
528 |
+
"#-----#\n",
|
529 |
+
"\n",
|
530 |
+
"#if no imput exists we just randomize the entire thing\n",
|
531 |
+
"if (prompt == \"\"):\n",
|
532 |
+
" id_A = -1\n",
|
533 |
+
" print(\"Tokenized prompt tensor A is a random valued tensor with no ID\")\n",
|
534 |
+
" R = torch.rand(A.shape)\n",
|
535 |
+
" R = R/R.norm(p=2, dim=-1, keepdim=True)\n",
|
536 |
+
" A = R\n",
|
537 |
+
" name_A = 'random_A'\n",
|
538 |
+
"\n",
|
539 |
+
"#if no imput exists we just randomize the entire thing\n",
|
540 |
+
"if (mix_with == \"\"):\n",
|
541 |
+
" id_C = -1\n",
|
542 |
+
" print(\"Tokenized prompt 'mix_with' tensor C is a random valued tensor with no ID\")\n",
|
543 |
+
" R = torch.rand(A.shape)\n",
|
544 |
+
" R = R/R.norm(p=2, dim=-1, keepdim=True)\n",
|
545 |
+
" C = R\n",
|
546 |
+
" name_C = 'random_C'\n",
|
547 |
+
"\n",
|
548 |
+
"name_A = \"A of random type\"\n",
|
549 |
+
"if (id_A>-1):\n",
|
550 |
+
" name_A = vocab(id_A)\n",
|
551 |
+
"\n",
|
552 |
+
"name_C = \"token C of random type\"\n",
|
553 |
+
"if (id_C>-1):\n",
|
554 |
+
" name_C = vocab(id_C)\n",
|
555 |
+
"\n",
|
556 |
+
"print(f\"The similarity between A '{name_A}' and C '{name_C}' is {round(sim_AC.item()*100,2)} %\")\n",
|
557 |
+
"\n",
|
558 |
+
"if (mix_method == \"None\"):\n",
|
559 |
+
" print(\"No operation\")\n",
|
560 |
+
"\n",
|
561 |
+
"if (mix_method == \"Average\"):\n",
|
562 |
+
" A = w*A + (1-w)*C\n",
|
563 |
+
" _A = LA.vector_norm(A, ord=2)\n",
|
564 |
+
" print(f\"Tokenized prompt tensor A '{name_A}' token has been recalculated as A = w*A + (1-w)*C , where C is '{name_C}' token , for w = {w} \")\n",
|
565 |
+
"\n",
|
566 |
+
"if (mix_method == \"Subtract\"):\n",
|
567 |
+
" tmp = w*A - (1-w)*C\n",
|
568 |
+
" tmp = tmp/tmp.norm(p=2, dim=-1, keepdim=True)\n",
|
569 |
+
" A = tmp\n",
|
570 |
+
" #//---//\n",
|
571 |
+
" print(f\"Tokenized prompt tensor A '{name_A}' token has been recalculated as A = _A*norm(w*A - (1-w)*C) , where C is '{name_C}' token , for w = {w} \")\n",
|
572 |
+
"\n",
|
573 |
+
"#OPTIONAL : Add/subtract + normalize above result with another token. Leave field empty to get a random value tensor\n",
|
574 |
+
"\n",
|
575 |
+
"dots = torch.zeros(NUM_TOKENS)\n",
|
576 |
+
"for index in range(NUM_TOKENS):\n",
|
577 |
+
" id_B = index\n",
|
578 |
+
" B = torch.tensor(token[id_B])\n",
|
579 |
+
" B = B/B.norm(p=2, dim=-1, keepdim=True)\n",
|
580 |
+
" sim_AB = torch.dot(A,B)\n",
|
581 |
+
" dots[index] = sim_AB\n",
|
582 |
+
"\n",
|
583 |
+
"\n",
|
584 |
+
"sorted, indices = torch.sort(dots,dim=0 , descending=True)\n",
|
585 |
+
"#----#\n",
|
586 |
+
"if (mix_method == \"Average\"):\n",
|
587 |
+
" print(f'Calculated all cosine-similarities between the average of token {name_A} and {name_C} with Id_A = {id_A} and mixed Id_C = {id_C} as a 1x{sorted.shape[0]} tensor')\n",
|
588 |
+
"if (mix_method == \"Subtract\"):\n",
|
589 |
+
" print(f'Calculated all cosine-similarities between the subtract of token {name_A} and {name_C} with Id_A = {id_A} and mixed Id_C = {id_C} as a 1x{sorted.shape[0]} tensor')\n",
|
590 |
+
"if (mix_method == \"None\"):\n",
|
591 |
+
" print(f'Calculated all cosine-similarities between the token {name_A} with Id_A = {id_A} with the the rest of the {NUM_TOKENS} tokens as a 1x{sorted.shape[0]} tensor')\n",
|
592 |
+
"\n",
|
593 |
+
"#Produce a list id IDs that are most similiar to the prompt ID at positiion 1 based on above result\n",
|
594 |
+
"\n",
|
595 |
+
"# @markdown Set print options\n",
|
596 |
+
"list_size = 100 # @param {type:'number'}\n",
|
597 |
+
"print_ID = False # @param {type:\"boolean\"}\n",
|
598 |
"print_Similarity = True # @param {type:\"boolean\"}\n",
|
599 |
+
"print_Name = True # @param {type:\"boolean\"}\n",
|
600 |
+
"print_Divider = True # @param {type:\"boolean\"}\n",
|
|
|
|
|
601 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
602 |
"\n",
|
603 |
+
"if (print_Divider):\n",
|
604 |
+
" print('//---//')\n",
|
605 |
"\n",
|
606 |
+
"print('')\n",
|
607 |
+
"print('Here is the result : ')\n",
|
608 |
+
"print('')\n",
|
|
|
609 |
"\n",
|
610 |
+
"for index in range(list_size):\n",
|
611 |
+
" id = indices[index].item()\n",
|
612 |
+
" if (print_Name):\n",
|
613 |
+
" print(f'{vocab(id)}') # vocab item\n",
|
614 |
+
" if (print_ID):\n",
|
615 |
+
" print(f'ID = {id}') # IDs\n",
|
616 |
+
" if (print_Similarity):\n",
|
617 |
+
" print(f'similiarity = {round(sorted[index].item()*100,2)} %')\n",
|
618 |
+
" if (print_Divider):\n",
|
619 |
+
" print('--------')\n",
|
620 |
"\n",
|
621 |
+
"#Print the sorted list from above result\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
622 |
"\n",
|
623 |
+
"#The prompt will be enclosed with the <|start-of-text|> and <|end-of-text|> tokens, which is why output will be [49406, ... , 49407].\n",
|
624 |
"\n",
|
625 |
+
"#You can leave the 'prompt' field empty to get a random value tensor. Since the tensor is random value, it will not correspond to any tensor in the vocab.json list , and this it will have no ID.\n",
|
|
|
626 |
"\n",
|
627 |
+
"# Save results as .db file\n",
|
628 |
+
"import shelve\n",
|
629 |
+
"VOCAB_FILENAME = 'tokens_most_similiar_to_' + name_A.replace('</w>','').strip()\n",
|
630 |
+
"d = shelve.open(VOCAB_FILENAME)\n",
|
631 |
+
"#NUM TOKENS == 49407\n",
|
632 |
+
"for index in range(NUM_TOKENS):\n",
|
633 |
+
" #print(d[f'{index}']) #<-----Use this to read values from the .db file\n",
|
634 |
+
" d[f'{index}']= vocab(indices[index].item()) #<---- write values to .db file\n",
|
635 |
+
"#----#\n",
|
636 |
+
"d.close() #close the file\n",
|
637 |
+
"# See this link for additional stuff to do with shelve: https://docs.python.org/3/library/shelve.html"
|
638 |
],
|
639 |
"metadata": {
|
640 |
+
"id": "iWeFnT1gAx6A"
|
641 |
},
|
642 |
"execution_count": null,
|
643 |
"outputs": []
|
|
|
1332 |
"metadata": {
|
1333 |
"id": "njeJx_nSSA8H"
|
1334 |
}
|
1335 |
+
},
|
1336 |
+
{
|
1337 |
+
"cell_type": "code",
|
1338 |
+
"source": [
|
1339 |
+
"# @title Deprecated\n",
|
1340 |
+
"prompt = \" a fast car on the road \" # @param {\"type\":\"string\",\"placeholder\":\"Write a prompt\"}\n",
|
1341 |
+
"\n",
|
1342 |
+
"from transformers import AutoTokenizer\n",
|
1343 |
+
"tokenizer = AutoTokenizer.from_pretrained(\"openai/clip-vit-large-patch14\", clean_up_tokenization_spaces = False)\n",
|
1344 |
+
"from transformers import CLIPProcessor, CLIPModel\n",
|
1345 |
+
"processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-large-patch14\" , clean_up_tokenization_spaces = True)\n",
|
1346 |
+
"model = CLIPModel.from_pretrained(\"openai/clip-vit-large-patch14\")\n",
|
1347 |
+
"\n",
|
1348 |
+
"# Get text features for user input\n",
|
1349 |
+
"inputs = tokenizer(text = prompt, padding=True, return_tensors=\"pt\")\n",
|
1350 |
+
"text_features_A = model.get_text_features(**inputs)\n",
|
1351 |
+
"text_features_A = text_features_A/text_features_A.norm(p=2, dim=-1, keepdim=True)\n",
|
1352 |
+
"name_A = prompt\n",
|
1353 |
+
"#------#\n",
|
1354 |
+
"\n",
|
1355 |
+
"# Load the .db file for prefix encodings\n",
|
1356 |
+
"import shelve\n",
|
1357 |
+
"_iters = -1\n",
|
1358 |
+
"RANGE = NUM_PREFIX\n",
|
1359 |
+
"NUM_PREFIX_LISTS = 1\n",
|
1360 |
+
"dots = results_sim = torch.zeros(RANGE*NUM_PREFIX_LISTS)\n",
|
1361 |
+
"for _PREFIX_ENC_VOCAB in PREFIX_ENC_VOCAB:\n",
|
1362 |
+
" _iters = _iters + 1\n",
|
1363 |
+
" d = shelve.open(_PREFIX_ENC_VOCAB)\n",
|
1364 |
+
" for _index in range(RANGE):\n",
|
1365 |
+
" index = _iters*RANGE + _index\n",
|
1366 |
+
" text_features = text_encodings[f'{_index}']\n",
|
1367 |
+
" sim = torch.nn.functional.cosine_similarity(text_features, text_features_A)\n",
|
1368 |
+
" dots[index] = sim\n",
|
1369 |
+
" #----#\n",
|
1370 |
+
" d.close() #close the file\n",
|
1371 |
+
"#------#\n",
|
1372 |
+
"prefix_sorted, prefix_indices = torch.sort(dots,dim=0 , descending=True)\n",
|
1373 |
+
"#------#\n",
|
1374 |
+
"\n",
|
1375 |
+
"\n",
|
1376 |
+
"_prefixes = '{'\n",
|
1377 |
+
"for index in range(start_at_index + RANGE):\n",
|
1378 |
+
" if index < start_at_index : continue\n",
|
1379 |
+
" id = f'{prefix_indices[index]}'\n",
|
1380 |
+
" #sim = prefix_sorted[index]\n",
|
1381 |
+
" name = get_prefix(id)\n",
|
1382 |
+
" _prefixes = _prefixes + name + '|'\n",
|
1383 |
+
"#------#\n",
|
1384 |
+
"_prefixes = (_prefixes + '}').replace('|}', '}')\n",
|
1385 |
+
"\n",
|
1386 |
+
"\n",
|
1387 |
+
"prefixes = _prefixes\n",
|
1388 |
+
"if(not print_Prefix): prefixes = ''\n",
|
1389 |
+
"\n",
|
1390 |
+
"if(print_Descriptions):\n",
|
1391 |
+
" print(f'The {start_at_index}-{start_at_index + RANGE} most similiar prefixes to prompt : ' + prefixes)\n",
|
1392 |
+
"else:\n",
|
1393 |
+
" if(compact_Output):\n",
|
1394 |
+
" print((prefixes + _suffixes).replace('}{', '|'))\n",
|
1395 |
+
" else:\n",
|
1396 |
+
" print(prefixes)\n",
|
1397 |
+
"\n",
|
1398 |
+
"# @title ✳️ Load/initialize values\n",
|
1399 |
+
"# Load the tokens into the colab\n",
|
1400 |
+
"!git clone https://huggingface.co/datasets/codeShare/sd_tokens\n",
|
1401 |
+
"import torch\n",
|
1402 |
+
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
1403 |
+
"from torch import linalg as LA\n",
|
1404 |
+
"%cd /content/sd_tokens\n",
|
1405 |
+
"token = torch.load('sd15_tensors.pt', map_location= torch.device('cpu'), weights_only=True)\n",
|
1406 |
+
"#-----#\n",
|
1407 |
+
"VOCAB_FILENAME = 'tokens_most_similiar_to_girl'\n",
|
1408 |
+
"ACTIVE_IMG = ''\n",
|
1409 |
+
"#-----#\n",
|
1410 |
+
"\n",
|
1411 |
+
"# Define functions/constants\n",
|
1412 |
+
"NUM_TOKENS = 49407\n",
|
1413 |
+
"NUM_PREFIX = 13662\n",
|
1414 |
+
"NUM_SUFFIX = 32901\n",
|
1415 |
+
"\n",
|
1416 |
+
"PREFIX_ENC_VOCAB = ['encoded_prefix_to_girl',]\n",
|
1417 |
+
"SUFFIX_ENC_VOCAB = ['a_-_encoded_suffix' ,]\n",
|
1418 |
+
" #'from_-encoded_suffix',\n",
|
1419 |
+
" #'by_-encoded_suffix' ,\n",
|
1420 |
+
" #'encoded_suffix-_like']\n",
|
1421 |
+
"\n",
|
1422 |
+
"# Make sure these match above results\n",
|
1423 |
+
"NUM_PREFIX_LISTS = len(PREFIX_ENC_VOCAB)\n",
|
1424 |
+
"NUM_SUFFIX_LISTS = len(SUFFIX_ENC_VOCAB)\n",
|
1425 |
+
"#-----#\n",
|
1426 |
+
"\n",
|
1427 |
+
"\n",
|
1428 |
+
"#Import the vocab.json\n",
|
1429 |
+
"import json\n",
|
1430 |
+
"import pandas as pd\n",
|
1431 |
+
"\n",
|
1432 |
+
"# Read suffix.json\n",
|
1433 |
+
"with open('suffix.json', 'r') as f:\n",
|
1434 |
+
" data = json.load(f)\n",
|
1435 |
+
"_df = pd.DataFrame({'count': data})['count']\n",
|
1436 |
+
"suffix = {\n",
|
1437 |
+
" key : value for key, value in _df.items()\n",
|
1438 |
+
"}\n",
|
1439 |
+
"# Read prefix json\n",
|
1440 |
+
"with open('prefix.json', 'r') as f:\n",
|
1441 |
+
" data = json.load(f)\n",
|
1442 |
+
"_df = pd.DataFrame({'count': data})['count']\n",
|
1443 |
+
"prefix = {\n",
|
1444 |
+
" key : value for key, value in _df.items()\n",
|
1445 |
+
"}\n",
|
1446 |
+
"\n",
|
1447 |
+
"# Read to_suffix.json\n",
|
1448 |
+
"with open('to_suffix.json', 'r') as f:\n",
|
1449 |
+
" data = json.load(f)\n",
|
1450 |
+
"_df = pd.DataFrame({'count': data})['count']\n",
|
1451 |
+
"suffix_to_vocab = {\n",
|
1452 |
+
" key : value for key, value in _df.items()\n",
|
1453 |
+
"}\n",
|
1454 |
+
"\n",
|
1455 |
+
"# Read to_prefix.json\n",
|
1456 |
+
"with open('to_prefix.json', 'r') as f:\n",
|
1457 |
+
" data = json.load(f)\n",
|
1458 |
+
"_df = pd.DataFrame({'count': data})['count']\n",
|
1459 |
+
"prefix_to_vocab = {\n",
|
1460 |
+
" key : value for key, value in _df.items()\n",
|
1461 |
+
"}\n",
|
1462 |
+
"\n",
|
1463 |
+
"#-----#\n",
|
1464 |
+
"\n",
|
1465 |
+
"\n",
|
1466 |
+
"# Read to_suffix.json (reversing key and value)\n",
|
1467 |
+
"with open('to_suffix.json', 'r') as f:\n",
|
1468 |
+
" data = json.load(f)\n",
|
1469 |
+
"_df = pd.DataFrame({'count': data})['count']\n",
|
1470 |
+
"vocab_to_suffix = {\n",
|
1471 |
+
" value : key for key, value in _df.items()\n",
|
1472 |
+
"}\n",
|
1473 |
+
"\n",
|
1474 |
+
"# Read to_prefix.json (reversing key and value)\n",
|
1475 |
+
"with open('to_prefix.json', 'r') as f:\n",
|
1476 |
+
" data = json.load(f)\n",
|
1477 |
+
"_df = pd.DataFrame({'count': data})['count']\n",
|
1478 |
+
"vocab_to_prefix = {\n",
|
1479 |
+
" value : key for key, value in _df.items()\n",
|
1480 |
+
"}\n",
|
1481 |
+
"\n",
|
1482 |
+
"\n",
|
1483 |
+
"#-----#\n",
|
1484 |
+
"\n",
|
1485 |
+
"#get token from id (excluding tokens with special symbols)\n",
|
1486 |
+
"def vocab(id):\n",
|
1487 |
+
" _id = f'{id}'\n",
|
1488 |
+
" if _id in vocab_to_suffix:\n",
|
1489 |
+
" _id = vocab_to_suffix[_id]\n",
|
1490 |
+
" return suffix[_id]\n",
|
1491 |
+
" if _id in vocab_to_prefix:\n",
|
1492 |
+
" _id = vocab_to_prefix[_id]\n",
|
1493 |
+
" return prefix[_id]\n",
|
1494 |
+
" return ' ' #<---- return whitespace if other id like emojis etc.\n",
|
1495 |
+
"#--------#\n",
|
1496 |
+
"\n",
|
1497 |
+
"#get token from id (excluding tokens with special symbols)\n",
|
1498 |
+
"def get_suffix(id):\n",
|
1499 |
+
" _id = f'{id}'\n",
|
1500 |
+
" if int(id) <= NUM_SUFFIX:\n",
|
1501 |
+
" return suffix[_id]\n",
|
1502 |
+
" return ' ' #<---- return whitespace if out of bounds\n",
|
1503 |
+
"#--------#\n",
|
1504 |
+
"\n",
|
1505 |
+
"#get token from id (excluding tokens with special symbols)\n",
|
1506 |
+
"def get_prefix(id):\n",
|
1507 |
+
" _id = f'{id}'\n",
|
1508 |
+
" if int(id) <= NUM_PREFIX:\n",
|
1509 |
+
" return prefix[_id]\n",
|
1510 |
+
" return ' ' #<---- return whitespace if out of bounds\n",
|
1511 |
+
"#--------#\n",
|
1512 |
+
"\n",
|
1513 |
+
"\n",
|
1514 |
+
"def _modulus(_id,id_max):\n",
|
1515 |
+
" id = _id\n",
|
1516 |
+
" while(id>id_max):\n",
|
1517 |
+
" id = id-id_max\n",
|
1518 |
+
" return id\n",
|
1519 |
+
"\n",
|
1520 |
+
"#print(get_token(35894))\n"
|
1521 |
+
],
|
1522 |
+
"metadata": {
|
1523 |
+
"id": "8BWq7SY8mzKD"
|
1524 |
+
},
|
1525 |
+
"execution_count": null,
|
1526 |
+
"outputs": []
|
1527 |
}
|
1528 |
]
|
1529 |
}
|