Upload app.py
Browse files
app.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""Text_to_Image_Demo.ipynb
|
3 |
+
|
4 |
+
Automatically generated by Colaboratory.
|
5 |
+
|
6 |
+
Original file is located at
|
7 |
+
https://colab.research.google.com/drive/1mkGloXbrNHKFh99ryB6PQDyCJ3u4RqD5
|
8 |
+
|
9 |
+
## Generate Images from Text
|
10 |
+
"""
|
11 |
+
# Important installations
|
12 |
+
# pip install openai
|
13 |
+
# pip install python-dotenv
|
14 |
+
# pip install transformers datasets -q
|
15 |
+
# pip install streamlit
|
16 |
+
|
17 |
+
import os
|
18 |
+
import openai
|
19 |
+
|
20 |
+
# open_ai_key_file = "openai_api_key_llm_2023.txt"
|
21 |
+
# with open(open_ai_key_file, "r") as f:
|
22 |
+
# for line in f:
|
23 |
+
# OPENAI_KEY = line
|
24 |
+
# break
|
25 |
+
|
26 |
+
from dotenv import load_dotenv, find_dotenv
|
27 |
+
_ = load_dotenv(find_dotenv())
|
28 |
+
|
29 |
+
# Read 100 flower names from 100flowers.txt
|
30 |
+
# openai.api_key = OPENAI_KEY
|
31 |
+
file1 = open('./100flowers.txt', 'r')
|
32 |
+
Lines = file1.readlines()
|
33 |
+
Lines = [line.strip() for line in Lines]
|
34 |
+
|
35 |
+
from openai import OpenAI
|
36 |
+
from PIL import Image
|
37 |
+
import urllib.request
|
38 |
+
from io import BytesIO
|
39 |
+
from IPython.display import display
|
40 |
+
|
41 |
+
# client = OpenAI(api_key=OPENAI_KEY)
|
42 |
+
|
43 |
+
# Code to generate images from names in 100flowers.txt
|
44 |
+
# for prompt in Lines:
|
45 |
+
# response = client.images.generate(
|
46 |
+
# model="dall-e-3",
|
47 |
+
# prompt=prompt,
|
48 |
+
# size="1024x1024",
|
49 |
+
# quality="standard",
|
50 |
+
# n=1,
|
51 |
+
# )
|
52 |
+
|
53 |
+
# Code to save generated images as png in Flowers folder
|
54 |
+
# image_url = response.data[0].url
|
55 |
+
# with urllib.request.urlopen(image_url) as image_url:
|
56 |
+
# img = Image.open(BytesIO(image_url.read()))
|
57 |
+
|
58 |
+
# img.save(f'./Flowers/{prompt}.png')
|
59 |
+
|
60 |
+
|
61 |
+
# from transformers.utils import send_example_telemetry
|
62 |
+
|
63 |
+
# send_example_telemetry("image_similarity_notebook", framework="pytorch")
|
64 |
+
|
65 |
+
|
66 |
+
# Creates a list of flower names
|
67 |
+
directory = './Flowers'
|
68 |
+
png_files = [file[:-len('.png')].strip() for file in os.listdir(directory) if file.endswith(".png")]
|
69 |
+
|
70 |
+
|
71 |
+
from datasets import Dataset, Image
|
72 |
+
|
73 |
+
# Gets list of file paths
|
74 |
+
def get_paths_to_images(images_directory):
|
75 |
+
|
76 |
+
paths = []
|
77 |
+
for file in os.listdir(images_directory):
|
78 |
+
print(file)
|
79 |
+
paths.append(file)
|
80 |
+
|
81 |
+
return paths
|
82 |
+
|
83 |
+
# Creates dataset
|
84 |
+
def load_dataset(images_directory):
|
85 |
+
|
86 |
+
paths_images = get_paths_to_images(images_directory)
|
87 |
+
print(paths_images[0])
|
88 |
+
dataset = Dataset.from_dict({"image": paths_images})
|
89 |
+
|
90 |
+
return dataset
|
91 |
+
|
92 |
+
path_images = "./Flowers"
|
93 |
+
dataset = load_dataset(path_images)
|
94 |
+
|
95 |
+
from transformers import AutoFeatureExtractor, AutoModel
|
96 |
+
|
97 |
+
model_ckpt = "jafdxc/vit-base-patch16-224-finetuned-flower"
|
98 |
+
extractor = AutoFeatureExtractor.from_pretrained(model_ckpt)
|
99 |
+
model = AutoModel.from_pretrained(model_ckpt)
|
100 |
+
|
101 |
+
import torchvision.transforms as T
|
102 |
+
import torch
|
103 |
+
from PIL import Image
|
104 |
+
|
105 |
+
|
106 |
+
# Data transformation chain.
|
107 |
+
transformation_chain = T.Compose(
|
108 |
+
[
|
109 |
+
# We first resize the input image to 256x256 and then we take center crop.
|
110 |
+
T.Resize(int((256 / 224) * extractor.size["height"])),
|
111 |
+
T.CenterCrop(extractor.size["height"]),
|
112 |
+
T.ToTensor(),
|
113 |
+
T.Normalize(mean=extractor.image_mean, std=extractor.image_std),
|
114 |
+
]
|
115 |
+
)
|
116 |
+
def extract_embeddings(model: torch.nn.Module):
|
117 |
+
"""Utility to compute embeddings."""
|
118 |
+
device = model.device
|
119 |
+
|
120 |
+
def pp(batch):
|
121 |
+
images = batch["image"]
|
122 |
+
image_batch_transformed = torch.stack(
|
123 |
+
[transformation_chain(Image.open("./Flowers/" + image)) for image in images]
|
124 |
+
)
|
125 |
+
new_batch = {"pixel_values": image_batch_transformed.to(device)}
|
126 |
+
with torch.no_grad():
|
127 |
+
embeddings = model(**new_batch).last_hidden_state[:, 0].cpu()
|
128 |
+
return {"embeddings": embeddings}
|
129 |
+
|
130 |
+
return pp
|
131 |
+
|
132 |
+
|
133 |
+
|
134 |
+
import numpy as np
|
135 |
+
# Here, we map embedding extraction utility on our subset of candidate images.
|
136 |
+
batch_size = 1
|
137 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
138 |
+
extract_fn = extract_embeddings(model.to(device))
|
139 |
+
candidate_subset_emb = dataset.map(extract_fn, batched=True, batch_size=1)
|
140 |
+
|
141 |
+
all_candidate_embeddings = np.array(candidate_subset_emb["embeddings"])
|
142 |
+
all_candidate_embeddings = torch.from_numpy(all_candidate_embeddings)
|
143 |
+
|
144 |
+
print(all_candidate_embeddings.shape[0])
|
145 |
+
|
146 |
+
def compute_scores(emb_one, emb_two):
|
147 |
+
"""Computes cosine similarity between two vectors."""
|
148 |
+
scores = torch.nn.functional.cosine_similarity(emb_one, emb_two)
|
149 |
+
return scores.numpy().tolist()
|
150 |
+
|
151 |
+
|
152 |
+
def fetch_similar(image, top_k=5):
|
153 |
+
"""Fetches the `top_k` similar images with `image` as the query."""
|
154 |
+
# Prepare the input query image for embedding computation.
|
155 |
+
image_transformed = transformation_chain(image).unsqueeze(0)
|
156 |
+
new_batch = {"pixel_values": image_transformed.to(device)}
|
157 |
+
|
158 |
+
# Compute the embedding.
|
159 |
+
with torch.no_grad():
|
160 |
+
query_embeddings = model(**new_batch).last_hidden_state[:, 0].cpu()
|
161 |
+
|
162 |
+
# Compute similarity scores with all the candidate images at one go.
|
163 |
+
# We also create a mapping between the candidate image identifiers
|
164 |
+
# and their similarity scores with the query image.
|
165 |
+
sim_scores = compute_scores(all_candidate_embeddings, query_embeddings)
|
166 |
+
similarity_mapping = dict(zip([str(index) for index in range(all_candidate_embeddings.shape[0])], sim_scores))
|
167 |
+
|
168 |
+
# Sort the mapping dictionary and return `top_k` candidates.
|
169 |
+
similarity_mapping_sorted = dict(
|
170 |
+
sorted(similarity_mapping.items(), key=lambda x: x[1], reverse=True)
|
171 |
+
)
|
172 |
+
id_entries = list(similarity_mapping_sorted.keys())[:top_k]
|
173 |
+
|
174 |
+
ids = list(map(lambda x: int(x.split("_")[0]), id_entries))
|
175 |
+
return ids
|
176 |
+
|
177 |
+
import matplotlib.pyplot as plt
|
178 |
+
|
179 |
+
|
180 |
+
def plot_images(images):
|
181 |
+
|
182 |
+
for image, name in images:
|
183 |
+
if name == 'original':
|
184 |
+
count = 0
|
185 |
+
st.write("Showing the original image")
|
186 |
+
st.image (image, caption=name, width=None, use_column_width=None, clamp=False, channels='RGB', output_format='auto')
|
187 |
+
|
188 |
+
else:
|
189 |
+
count+=1
|
190 |
+
st.write(f"Showing similar image {count}")
|
191 |
+
img = Image.open(image)
|
192 |
+
st.image (img, caption=name, width=None, use_column_width=None, clamp=False, channels='RGB', output_format='auto')
|
193 |
+
|
194 |
+
# Streamlit webpage code
|
195 |
+
import streamlit as st
|
196 |
+
from io import StringIO
|
197 |
+
|
198 |
+
# Text Search
|
199 |
+
st.title("Flower Type Demo")
|
200 |
+
st.subheader("Upload an image of a Flower, you will get 5 flowers similar to it from our Dataset")
|
201 |
+
|
202 |
+
upload_file = st.file_uploader('Upload a Flower Image')
|
203 |
+
|
204 |
+
images = []
|
205 |
+
|
206 |
+
if upload_file:
|
207 |
+
test_sample = Image.open(upload_file)
|
208 |
+
|
209 |
+
sim_ids = fetch_similar(test_sample)
|
210 |
+
|
211 |
+
for id in sim_ids:
|
212 |
+
images.append(("./Flowers/" + candidate_subset_emb[id]["image"],candidate_subset_emb[id]["image"]))
|
213 |
+
|
214 |
+
|
215 |
+
images.insert(0, (test_sample,'original'))
|
216 |
+
print(images)
|
217 |
+
plot_images(images)
|
218 |
+
st.write("")
|