ritiksh commited on
Commit
ceb4613
·
1 Parent(s): 74a94d0

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +218 -0
app.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """Text_to_Image_Demo.ipynb
3
+
4
+ Automatically generated by Colaboratory.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1mkGloXbrNHKFh99ryB6PQDyCJ3u4RqD5
8
+
9
+ ## Generate Images from Text
10
+ """
11
+ # Important installations
12
+ # pip install openai
13
+ # pip install python-dotenv
14
+ # pip install transformers datasets -q
15
+ # pip install streamlit
16
+
17
+ import os
18
+ import openai
19
+
20
+ # open_ai_key_file = "openai_api_key_llm_2023.txt"
21
+ # with open(open_ai_key_file, "r") as f:
22
+ # for line in f:
23
+ # OPENAI_KEY = line
24
+ # break
25
+
26
+ from dotenv import load_dotenv, find_dotenv
27
+ _ = load_dotenv(find_dotenv())
28
+
29
+ # Read 100 flower names from 100flowers.txt
30
+ # openai.api_key = OPENAI_KEY
31
+ file1 = open('./100flowers.txt', 'r')
32
+ Lines = file1.readlines()
33
+ Lines = [line.strip() for line in Lines]
34
+
35
+ from openai import OpenAI
36
+ from PIL import Image
37
+ import urllib.request
38
+ from io import BytesIO
39
+ from IPython.display import display
40
+
41
+ # client = OpenAI(api_key=OPENAI_KEY)
42
+
43
+ # Code to generate images from names in 100flowers.txt
44
+ # for prompt in Lines:
45
+ # response = client.images.generate(
46
+ # model="dall-e-3",
47
+ # prompt=prompt,
48
+ # size="1024x1024",
49
+ # quality="standard",
50
+ # n=1,
51
+ # )
52
+
53
+ # Code to save generated images as png in Flowers folder
54
+ # image_url = response.data[0].url
55
+ # with urllib.request.urlopen(image_url) as image_url:
56
+ # img = Image.open(BytesIO(image_url.read()))
57
+
58
+ # img.save(f'./Flowers/{prompt}.png')
59
+
60
+
61
+ # from transformers.utils import send_example_telemetry
62
+
63
+ # send_example_telemetry("image_similarity_notebook", framework="pytorch")
64
+
65
+
66
+ # Creates a list of flower names
67
+ directory = './Flowers'
68
+ png_files = [file[:-len('.png')].strip() for file in os.listdir(directory) if file.endswith(".png")]
69
+
70
+
71
+ from datasets import Dataset, Image
72
+
73
+ # Gets list of file paths
74
+ def get_paths_to_images(images_directory):
75
+
76
+ paths = []
77
+ for file in os.listdir(images_directory):
78
+ print(file)
79
+ paths.append(file)
80
+
81
+ return paths
82
+
83
+ # Creates dataset
84
+ def load_dataset(images_directory):
85
+
86
+ paths_images = get_paths_to_images(images_directory)
87
+ print(paths_images[0])
88
+ dataset = Dataset.from_dict({"image": paths_images})
89
+
90
+ return dataset
91
+
92
+ path_images = "./Flowers"
93
+ dataset = load_dataset(path_images)
94
+
95
+ from transformers import AutoFeatureExtractor, AutoModel
96
+
97
+ model_ckpt = "jafdxc/vit-base-patch16-224-finetuned-flower"
98
+ extractor = AutoFeatureExtractor.from_pretrained(model_ckpt)
99
+ model = AutoModel.from_pretrained(model_ckpt)
100
+
101
+ import torchvision.transforms as T
102
+ import torch
103
+ from PIL import Image
104
+
105
+
106
+ # Data transformation chain.
107
+ transformation_chain = T.Compose(
108
+ [
109
+ # We first resize the input image to 256x256 and then we take center crop.
110
+ T.Resize(int((256 / 224) * extractor.size["height"])),
111
+ T.CenterCrop(extractor.size["height"]),
112
+ T.ToTensor(),
113
+ T.Normalize(mean=extractor.image_mean, std=extractor.image_std),
114
+ ]
115
+ )
116
+ def extract_embeddings(model: torch.nn.Module):
117
+ """Utility to compute embeddings."""
118
+ device = model.device
119
+
120
+ def pp(batch):
121
+ images = batch["image"]
122
+ image_batch_transformed = torch.stack(
123
+ [transformation_chain(Image.open("./Flowers/" + image)) for image in images]
124
+ )
125
+ new_batch = {"pixel_values": image_batch_transformed.to(device)}
126
+ with torch.no_grad():
127
+ embeddings = model(**new_batch).last_hidden_state[:, 0].cpu()
128
+ return {"embeddings": embeddings}
129
+
130
+ return pp
131
+
132
+
133
+
134
+ import numpy as np
135
+ # Here, we map embedding extraction utility on our subset of candidate images.
136
+ batch_size = 1
137
+ device = "cuda" if torch.cuda.is_available() else "cpu"
138
+ extract_fn = extract_embeddings(model.to(device))
139
+ candidate_subset_emb = dataset.map(extract_fn, batched=True, batch_size=1)
140
+
141
+ all_candidate_embeddings = np.array(candidate_subset_emb["embeddings"])
142
+ all_candidate_embeddings = torch.from_numpy(all_candidate_embeddings)
143
+
144
+ print(all_candidate_embeddings.shape[0])
145
+
146
+ def compute_scores(emb_one, emb_two):
147
+ """Computes cosine similarity between two vectors."""
148
+ scores = torch.nn.functional.cosine_similarity(emb_one, emb_two)
149
+ return scores.numpy().tolist()
150
+
151
+
152
+ def fetch_similar(image, top_k=5):
153
+ """Fetches the `top_k` similar images with `image` as the query."""
154
+ # Prepare the input query image for embedding computation.
155
+ image_transformed = transformation_chain(image).unsqueeze(0)
156
+ new_batch = {"pixel_values": image_transformed.to(device)}
157
+
158
+ # Compute the embedding.
159
+ with torch.no_grad():
160
+ query_embeddings = model(**new_batch).last_hidden_state[:, 0].cpu()
161
+
162
+ # Compute similarity scores with all the candidate images at one go.
163
+ # We also create a mapping between the candidate image identifiers
164
+ # and their similarity scores with the query image.
165
+ sim_scores = compute_scores(all_candidate_embeddings, query_embeddings)
166
+ similarity_mapping = dict(zip([str(index) for index in range(all_candidate_embeddings.shape[0])], sim_scores))
167
+
168
+ # Sort the mapping dictionary and return `top_k` candidates.
169
+ similarity_mapping_sorted = dict(
170
+ sorted(similarity_mapping.items(), key=lambda x: x[1], reverse=True)
171
+ )
172
+ id_entries = list(similarity_mapping_sorted.keys())[:top_k]
173
+
174
+ ids = list(map(lambda x: int(x.split("_")[0]), id_entries))
175
+ return ids
176
+
177
+ import matplotlib.pyplot as plt
178
+
179
+
180
+ def plot_images(images):
181
+
182
+ for image, name in images:
183
+ if name == 'original':
184
+ count = 0
185
+ st.write("Showing the original image")
186
+ st.image (image, caption=name, width=None, use_column_width=None, clamp=False, channels='RGB', output_format='auto')
187
+
188
+ else:
189
+ count+=1
190
+ st.write(f"Showing similar image {count}")
191
+ img = Image.open(image)
192
+ st.image (img, caption=name, width=None, use_column_width=None, clamp=False, channels='RGB', output_format='auto')
193
+
194
+ # Streamlit webpage code
195
+ import streamlit as st
196
+ from io import StringIO
197
+
198
+ # Text Search
199
+ st.title("Flower Type Demo")
200
+ st.subheader("Upload an image of a Flower, you will get 5 flowers similar to it from our Dataset")
201
+
202
+ upload_file = st.file_uploader('Upload a Flower Image')
203
+
204
+ images = []
205
+
206
+ if upload_file:
207
+ test_sample = Image.open(upload_file)
208
+
209
+ sim_ids = fetch_similar(test_sample)
210
+
211
+ for id in sim_ids:
212
+ images.append(("./Flowers/" + candidate_subset_emb[id]["image"],candidate_subset_emb[id]["image"]))
213
+
214
+
215
+ images.insert(0, (test_sample,'original'))
216
+ print(images)
217
+ plot_images(images)
218
+ st.write("")