patel18 commited on
Commit
f4ec5ac
·
1 Parent(s): a17db14

Upload Gradio Examples.ipynb

Browse files
Files changed (1) hide show
  1. Gradio Examples.ipynb +762 -0
Gradio Examples.ipynb ADDED
@@ -0,0 +1,762 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "09eb5ef2",
6
+ "metadata": {},
7
+ "source": [
8
+ "#### Gradio Comparing Transfer Learning Models"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": 1,
14
+ "id": "f3f83569",
15
+ "metadata": {},
16
+ "outputs": [
17
+ {
18
+ "name": "stdout",
19
+ "output_type": "stream",
20
+ "text": [
21
+ "2.12.0\n"
22
+ ]
23
+ }
24
+ ],
25
+ "source": [
26
+ "import tensorflow as tf\n",
27
+ "print(tf.__version__)"
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "code",
32
+ "execution_count": 2,
33
+ "id": "c1ca8b20",
34
+ "metadata": {},
35
+ "outputs": [
36
+ {
37
+ "name": "stdout",
38
+ "output_type": "stream",
39
+ "text": [
40
+ "Collecting gradio==1.6.0\n",
41
+ " Downloading gradio-1.6.0-py3-none-any.whl (1.1 MB)\n",
42
+ " 0.0/1.1 MB ? eta -:--:--\n",
43
+ " ---- 0.1/1.1 MB 3.2 MB/s eta 0:00:01\n",
44
+ " ------------ 0.3/1.1 MB 4.2 MB/s eta 0:00:01\n",
45
+ " ---------------------- 0.6/1.1 MB 5.6 MB/s eta 0:00:01\n",
46
+ " ------------------------------- 0.9/1.1 MB 5.6 MB/s eta 0:00:01\n",
47
+ " ------------------------------------- 1.0/1.1 MB 5.1 MB/s eta 0:00:01\n",
48
+ " ------------------------------------- 1.0/1.1 MB 5.1 MB/s eta 0:00:01\n",
49
+ " ---------------------------------------- 1.1/1.1 MB 3.8 MB/s eta 0:00:00\n",
50
+ "Requirement already satisfied: numpy in c:\\users\\user\\anaconda3\\lib\\site-packages (from gradio==1.6.0) (1.24.3)\n",
51
+ "Requirement already satisfied: requests in c:\\users\\user\\anaconda3\\lib\\site-packages (from gradio==1.6.0) (2.28.1)\n",
52
+ "Requirement already satisfied: Flask>=1.1.1 in c:\\users\\user\\anaconda3\\lib\\site-packages (from gradio==1.6.0) (2.1.2)\n",
53
+ "Requirement already satisfied: Flask-Cors>=3.0.8 in c:\\users\\user\\anaconda3\\lib\\site-packages (from gradio==1.6.0) (3.0.10)\n",
54
+ "Collecting flask-cachebuster (from gradio==1.6.0)\n",
55
+ " Downloading Flask-CacheBuster-1.0.0.tar.gz (3.1 kB)\n",
56
+ " Preparing metadata (setup.py): started\n",
57
+ " Preparing metadata (setup.py): finished with status 'done'\n",
58
+ "Requirement already satisfied: Flask-BasicAuth in c:\\users\\user\\anaconda3\\lib\\site-packages (from gradio==1.6.0) (0.2.0)\n",
59
+ "Requirement already satisfied: paramiko in c:\\users\\user\\anaconda3\\lib\\site-packages (from gradio==1.6.0) (2.8.1)\n",
60
+ "Requirement already satisfied: scipy in c:\\users\\user\\anaconda3\\lib\\site-packages (from gradio==1.6.0) (1.10.0)\n",
61
+ "Requirement already satisfied: IPython in c:\\users\\user\\anaconda3\\lib\\site-packages (from gradio==1.6.0) (8.10.0)\n",
62
+ "Requirement already satisfied: scikit-image in c:\\users\\user\\anaconda3\\lib\\site-packages (from gradio==1.6.0) (0.19.3)\n",
63
+ "Collecting analytics-python (from gradio==1.6.0)\n",
64
+ " Downloading analytics_python-1.4.post1-py2.py3-none-any.whl (23 kB)\n",
65
+ "Requirement already satisfied: pandas in c:\\users\\user\\anaconda3\\lib\\site-packages (from gradio==1.6.0) (1.5.3)\n",
66
+ "Requirement already satisfied: ffmpy in c:\\users\\user\\anaconda3\\lib\\site-packages (from gradio==1.6.0) (0.3.0)\n",
67
+ "Collecting markdown2 (from gradio==1.6.0)\n",
68
+ " Downloading markdown2-2.4.8-py2.py3-none-any.whl (38 kB)\n",
69
+ "Requirement already satisfied: Werkzeug>=2.0 in c:\\users\\user\\anaconda3\\lib\\site-packages (from Flask>=1.1.1->gradio==1.6.0) (2.3.4)\n",
70
+ "Requirement already satisfied: Jinja2>=3.0 in c:\\users\\user\\anaconda3\\lib\\site-packages (from Flask>=1.1.1->gradio==1.6.0) (3.1.2)\n",
71
+ "Requirement already satisfied: itsdangerous>=2.0 in c:\\users\\user\\anaconda3\\lib\\site-packages (from Flask>=1.1.1->gradio==1.6.0) (2.1.2)\n",
72
+ "Requirement already satisfied: click>=8.0 in c:\\users\\user\\anaconda3\\lib\\site-packages (from Flask>=1.1.1->gradio==1.6.0) (8.1.3)\n",
73
+ "Requirement already satisfied: Six in c:\\users\\user\\anaconda3\\lib\\site-packages (from Flask-Cors>=3.0.8->gradio==1.6.0) (1.16.0)\n",
74
+ "Collecting monotonic>=1.5 (from analytics-python->gradio==1.6.0)\n",
75
+ " Downloading monotonic-1.6-py2.py3-none-any.whl (8.2 kB)\n",
76
+ "Collecting backoff==1.10.0 (from analytics-python->gradio==1.6.0)\n",
77
+ " Downloading backoff-1.10.0-py2.py3-none-any.whl (31 kB)\n",
78
+ "Requirement already satisfied: python-dateutil>2.1 in c:\\users\\user\\anaconda3\\lib\\site-packages (from analytics-python->gradio==1.6.0) (2.8.2)\n",
79
+ "Requirement already satisfied: charset-normalizer<3,>=2 in c:\\users\\user\\anaconda3\\lib\\site-packages (from requests->gradio==1.6.0) (2.0.4)\n",
80
+ "Requirement already satisfied: idna<4,>=2.5 in c:\\users\\user\\anaconda3\\lib\\site-packages (from requests->gradio==1.6.0) (3.4)\n",
81
+ "Requirement already satisfied: urllib3<1.27,>=1.21.1 in c:\\users\\user\\anaconda3\\lib\\site-packages (from requests->gradio==1.6.0) (1.26.14)\n",
82
+ "Requirement already satisfied: certifi>=2017.4.17 in c:\\users\\user\\anaconda3\\lib\\site-packages (from requests->gradio==1.6.0) (2022.12.7)\n",
83
+ "Requirement already satisfied: backcall in c:\\users\\user\\anaconda3\\lib\\site-packages (from IPython->gradio==1.6.0) (0.2.0)\n",
84
+ "Requirement already satisfied: decorator in c:\\users\\user\\anaconda3\\lib\\site-packages (from IPython->gradio==1.6.0) (5.1.1)\n",
85
+ "Requirement already satisfied: jedi>=0.16 in c:\\users\\user\\anaconda3\\lib\\site-packages (from IPython->gradio==1.6.0) (0.18.1)\n",
86
+ "Requirement already satisfied: matplotlib-inline in c:\\users\\user\\anaconda3\\lib\\site-packages (from IPython->gradio==1.6.0) (0.1.6)\n",
87
+ "Requirement already satisfied: pickleshare in c:\\users\\user\\anaconda3\\lib\\site-packages (from IPython->gradio==1.6.0) (0.7.5)\n",
88
+ "Requirement already satisfied: prompt-toolkit<3.1.0,>=3.0.30 in c:\\users\\user\\anaconda3\\lib\\site-packages (from IPython->gradio==1.6.0) (3.0.36)\n",
89
+ "Requirement already satisfied: pygments>=2.4.0 in c:\\users\\user\\anaconda3\\lib\\site-packages (from IPython->gradio==1.6.0) (2.15.1)\n",
90
+ "Requirement already satisfied: stack-data in c:\\users\\user\\anaconda3\\lib\\site-packages (from IPython->gradio==1.6.0) (0.2.0)\n",
91
+ "Requirement already satisfied: traitlets>=5 in c:\\users\\user\\anaconda3\\lib\\site-packages (from IPython->gradio==1.6.0) (5.7.1)\n",
92
+ "Requirement already satisfied: colorama in c:\\users\\user\\anaconda3\\lib\\site-packages (from IPython->gradio==1.6.0) (0.4.6)\n",
93
+ "Requirement already satisfied: pytz>=2020.1 in c:\\users\\user\\anaconda3\\lib\\site-packages (from pandas->gradio==1.6.0) (2022.7)\n",
94
+ "Requirement already satisfied: bcrypt>=3.1.3 in c:\\users\\user\\anaconda3\\lib\\site-packages (from paramiko->gradio==1.6.0) (3.2.0)\n",
95
+ "Requirement already satisfied: cryptography>=2.5 in c:\\users\\user\\anaconda3\\lib\\site-packages (from paramiko->gradio==1.6.0) (39.0.1)\n",
96
+ "Requirement already satisfied: pynacl>=1.0.1 in c:\\users\\user\\anaconda3\\lib\\site-packages (from paramiko->gradio==1.6.0) (1.5.0)\n",
97
+ "Requirement already satisfied: networkx>=2.2 in c:\\users\\user\\anaconda3\\lib\\site-packages (from scikit-image->gradio==1.6.0) (2.8.4)\n",
98
+ "Requirement already satisfied: pillow!=7.1.0,!=7.1.1,!=8.3.0,>=6.1.0 in c:\\users\\user\\anaconda3\\lib\\site-packages (from scikit-image->gradio==1.6.0) (9.4.0)\n",
99
+ "Requirement already satisfied: imageio>=2.4.1 in c:\\users\\user\\anaconda3\\lib\\site-packages (from scikit-image->gradio==1.6.0) (2.26.0)\n",
100
+ "Requirement already satisfied: tifffile>=2019.7.26 in c:\\users\\user\\anaconda3\\lib\\site-packages (from scikit-image->gradio==1.6.0) (2021.7.2)\n",
101
+ "Requirement already satisfied: PyWavelets>=1.1.1 in c:\\users\\user\\anaconda3\\lib\\site-packages (from scikit-image->gradio==1.6.0) (1.4.1)\n",
102
+ "Requirement already satisfied: packaging>=20.0 in c:\\users\\user\\anaconda3\\lib\\site-packages (from scikit-image->gradio==1.6.0) (22.0)\n",
103
+ "Requirement already satisfied: cffi>=1.1 in c:\\users\\user\\anaconda3\\lib\\site-packages (from bcrypt>=3.1.3->paramiko->gradio==1.6.0) (1.15.1)\n",
104
+ "Requirement already satisfied: parso<0.9.0,>=0.8.0 in c:\\users\\user\\anaconda3\\lib\\site-packages (from jedi>=0.16->IPython->gradio==1.6.0) (0.8.3)\n",
105
+ "Requirement already satisfied: MarkupSafe>=2.0 in c:\\users\\user\\anaconda3\\lib\\site-packages (from Jinja2>=3.0->Flask>=1.1.1->gradio==1.6.0) (2.1.1)\n",
106
+ "Requirement already satisfied: wcwidth in c:\\users\\user\\anaconda3\\lib\\site-packages (from prompt-toolkit<3.1.0,>=3.0.30->IPython->gradio==1.6.0) (0.2.5)\n",
107
+ "Requirement already satisfied: executing in c:\\users\\user\\anaconda3\\lib\\site-packages (from stack-data->IPython->gradio==1.6.0) (0.8.3)\n",
108
+ "Requirement already satisfied: asttokens in c:\\users\\user\\anaconda3\\lib\\site-packages (from stack-data->IPython->gradio==1.6.0) (2.0.5)\n",
109
+ "Requirement already satisfied: pure-eval in c:\\users\\user\\anaconda3\\lib\\site-packages (from stack-data->IPython->gradio==1.6.0) (0.2.2)\n",
110
+ "Requirement already satisfied: pycparser in c:\\users\\user\\anaconda3\\lib\\site-packages (from cffi>=1.1->bcrypt>=3.1.3->paramiko->gradio==1.6.0) (2.21)\n",
111
+ "Building wheels for collected packages: flask-cachebuster\n",
112
+ " Building wheel for flask-cachebuster (setup.py): started\n",
113
+ " Building wheel for flask-cachebuster (setup.py): finished with status 'done'\n",
114
+ " Created wheel for flask-cachebuster: filename=Flask_CacheBuster-1.0.0-py3-none-any.whl size=3372 sha256=c1b85a8b017ca7784ce61eec4714ca9dd7e500dc251835ef6f9820731268b2c5\n",
115
+ " Stored in directory: c:\\users\\user\\appdata\\local\\pip\\cache\\wheels\\22\\35\\5e\\088242cb16f309a4ff4e94ce97f1ef8a469983fdde92b45f50\n",
116
+ "Successfully built flask-cachebuster\n",
117
+ "Installing collected packages: monotonic, markdown2, backoff, analytics-python, flask-cachebuster, gradio\n",
118
+ " Attempting uninstall: gradio\n",
119
+ " Found existing installation: gradio 3.33.1\n",
120
+ " Uninstalling gradio-3.33.1:\n",
121
+ " Successfully uninstalled gradio-3.33.1\n",
122
+ "Successfully installed analytics-python-1.4.post1 backoff-1.10.0 flask-cachebuster-1.0.0 gradio-1.6.0 markdown2-2.4.8 monotonic-1.6\n",
123
+ "Note: you may need to restart the kernel to use updated packages.\n"
124
+ ]
125
+ },
126
+ {
127
+ "name": "stderr",
128
+ "output_type": "stream",
129
+ "text": [
130
+ "WARNING: Ignoring invalid distribution -orch (c:\\users\\user\\anaconda3\\lib\\site-packages)\n",
131
+ "WARNING: Ignoring invalid distribution -rotobuf (c:\\users\\user\\anaconda3\\lib\\site-packages)\n",
132
+ "WARNING: Ignoring invalid distribution -orch (c:\\users\\user\\anaconda3\\lib\\site-packages)\n",
133
+ "WARNING: Ignoring invalid distribution -rotobuf (c:\\users\\user\\anaconda3\\lib\\site-packages)\n"
134
+ ]
135
+ }
136
+ ],
137
+ "source": [
138
+ "pip install gradio==1.6.0"
139
+ ]
140
+ },
141
+ {
142
+ "cell_type": "code",
143
+ "execution_count": 3,
144
+ "id": "70ec40a3",
145
+ "metadata": {},
146
+ "outputs": [
147
+ {
148
+ "name": "stdout",
149
+ "output_type": "stream",
150
+ "text": [
151
+ "Collecting MarkupSafe==2.1.1\n",
152
+ " Downloading MarkupSafe-2.1.1-cp310-cp310-win_amd64.whl (17 kB)\n",
153
+ "Installing collected packages: MarkupSafe\n",
154
+ " Attempting uninstall: MarkupSafe\n",
155
+ " Found existing installation: MarkupSafe 2.0.1\n",
156
+ " Uninstalling MarkupSafe-2.0.1:\n",
157
+ " Successfully uninstalled MarkupSafe-2.0.1\n",
158
+ "Note: you may need to restart the kernel to use updated packages.\n"
159
+ ]
160
+ },
161
+ {
162
+ "name": "stderr",
163
+ "output_type": "stream",
164
+ "text": [
165
+ "WARNING: Ignoring invalid distribution -orch (c:\\users\\user\\anaconda3\\lib\\site-packages)\n",
166
+ "WARNING: Ignoring invalid distribution -rotobuf (c:\\users\\user\\anaconda3\\lib\\site-packages)\n",
167
+ "WARNING: Ignoring invalid distribution -orch (c:\\users\\user\\anaconda3\\lib\\site-packages)\n",
168
+ "WARNING: Ignoring invalid distribution -rotobuf (c:\\users\\user\\anaconda3\\lib\\site-packages)\n",
169
+ "ERROR: Could not install packages due to an OSError: [WinError 5] Access is denied: 'C:\\\\Users\\\\User\\\\anaconda3\\\\Lib\\\\site-packages\\\\~arkupsafe\\\\_speedups.cp310-win_amd64.pyd'\n",
170
+ "Consider using the `--user` option or check the permissions.\n",
171
+ "\n"
172
+ ]
173
+ }
174
+ ],
175
+ "source": [
176
+ "pip install MarkupSafe==2.1.1"
177
+ ]
178
+ },
179
+ {
180
+ "cell_type": "code",
181
+ "execution_count": 1,
182
+ "id": "961a6510",
183
+ "metadata": {},
184
+ "outputs": [
185
+ {
186
+ "name": "stderr",
187
+ "output_type": "stream",
188
+ "text": [
189
+ "C:\\Users\\User\\anaconda3\\lib\\site-packages\\paramiko\\transport.py:219: CryptographyDeprecationWarning: Blowfish has been deprecated\n",
190
+ " \"class\": algorithms.Blowfish,\n"
191
+ ]
192
+ }
193
+ ],
194
+ "source": [
195
+ "import gradio as gr\n",
196
+ "import tensorflow as tf\n",
197
+ "import numpy as np\n",
198
+ "from PIL import Image\n",
199
+ "import requests\n",
200
+ "\n",
201
+ "\n",
202
+ "# Download human-readable labels for ImageNet.\n",
203
+ "response = requests.get(\"https://git.io/JJkYN\")\n",
204
+ "labels = response.text.split(\"\\n\")\n",
205
+ "\n",
206
+ "mobile_net = tf.keras.applications.MobileNetV2()\n",
207
+ "inception_net = tf.keras.applications.InceptionV3()\n"
208
+ ]
209
+ },
210
+ {
211
+ "cell_type": "code",
212
+ "execution_count": 2,
213
+ "id": "44d83e7d",
214
+ "metadata": {},
215
+ "outputs": [],
216
+ "source": [
217
+ "def classify_image_with_mobile_net(im):\n",
218
+ " im = Image.fromarray(im.astype('uint8'), 'RGB')\n",
219
+ " im = im.resize((224, 224))\n",
220
+ " arr = np.array(im).reshape((-1, 224, 224, 3))\n",
221
+ " arr = tf.keras.applications.mobilenet.preprocess_input(arr)\n",
222
+ " prediction = mobile_net.predict(arr).flatten()\n",
223
+ " return {labels[i]: float(prediction[i]) for i in range(1000)}\n",
224
+ " "
225
+ ]
226
+ },
227
+ {
228
+ "cell_type": "code",
229
+ "execution_count": 3,
230
+ "id": "5e77912e",
231
+ "metadata": {},
232
+ "outputs": [],
233
+ "source": [
234
+ "def classify_image_with_inception_net(im):\n",
235
+ " # Resize the image to\n",
236
+ " im = Image.fromarray(im.astype('uint8'), 'RGB')\n",
237
+ " im = im.resize((299, 299))\n",
238
+ " arr = np.array(im).reshape((-1, 299, 299, 3))\n",
239
+ " arr = tf.keras.applications.inception_v3.preprocess_input(arr)\n",
240
+ " prediction = inception_net.predict(arr).flatten()\n",
241
+ " return {labels[i]: float(prediction[i]) for i in range(1000)}"
242
+ ]
243
+ },
244
+ {
245
+ "cell_type": "code",
246
+ "execution_count": 4,
247
+ "id": "f6a9e5fe",
248
+ "metadata": {},
249
+ "outputs": [],
250
+ "source": [
251
+ "imagein = gr.inputs.Image()\n",
252
+ "label = gr.outputs.Label(num_top_classes=3)\n",
253
+ "sample_images = [\n",
254
+ " [\"monkey.jpg\"],\n",
255
+ " [\"sailboat.jpg\"],\n",
256
+ " [\"bicycle.jpg\"],\n",
257
+ " [\"download.jpg\"],\n",
258
+ "]"
259
+ ]
260
+ },
261
+ {
262
+ "cell_type": "code",
263
+ "execution_count": 6,
264
+ "id": "61c325d0",
265
+ "metadata": {},
266
+ "outputs": [
267
+ {
268
+ "name": "stdout",
269
+ "output_type": "stream",
270
+ "text": [
271
+ "IMPORTANT: You are using gradio version 1.6.0, however version 3.14.0 is available, please upgrade.\n",
272
+ "--------\n",
273
+ "Running locally at: http://127.0.0.1:7861/\n",
274
+ "To create a public link, set `share=True` in `launch()`.\n",
275
+ "Interface loading below...\n"
276
+ ]
277
+ },
278
+ {
279
+ "data": {
280
+ "text/html": [
281
+ "\n",
282
+ " <iframe\n",
283
+ " width=\"1000\"\n",
284
+ " height=\"500\"\n",
285
+ " src=\"http://127.0.0.1:7861/\"\n",
286
+ " frameborder=\"0\"\n",
287
+ " allowfullscreen\n",
288
+ " \n",
289
+ " ></iframe>\n",
290
+ " "
291
+ ],
292
+ "text/plain": [
293
+ "<IPython.lib.display.IFrame at 0x1f4f8fe2830>"
294
+ ]
295
+ },
296
+ "metadata": {},
297
+ "output_type": "display_data"
298
+ },
299
+ {
300
+ "name": "stderr",
301
+ "output_type": "stream",
302
+ "text": [
303
+ "[2023-06-05 22:40:39,347] ERROR in app: Exception on /file/monkey.jpg [GET]\n",
304
+ "Traceback (most recent call last):\n",
305
+ " File \"C:\\Users\\User\\anaconda3\\lib\\site-packages\\flask\\app.py\", line 2077, in wsgi_app\n",
306
+ " response = self.full_dispatch_request()\n",
307
+ " File \"C:\\Users\\User\\anaconda3\\lib\\site-packages\\flask\\app.py\", line 1525, in full_dispatch_request\n",
308
+ " rv = self.handle_user_exception(e)\n",
309
+ " File \"C:\\Users\\User\\anaconda3\\lib\\site-packages\\flask_cors\\extension.py\", line 165, in wrapped_function\n",
310
+ " return cors_after_request(app.make_response(f(*args, **kwargs)))\n",
311
+ " File \"C:\\Users\\User\\anaconda3\\lib\\site-packages\\flask\\app.py\", line 1523, in full_dispatch_request\n",
312
+ " rv = self.dispatch_request()\n",
313
+ " File \"C:\\Users\\User\\anaconda3\\lib\\site-packages\\flask\\app.py\", line 1509, in dispatch_request\n",
314
+ " return self.ensure_sync(self.view_functions[rule.endpoint])(**req.view_args)\n",
315
+ " File \"C:\\Users\\User\\anaconda3\\lib\\site-packages\\gradio\\networking.py\", line 269, in file\n",
316
+ " return send_file(os.path.join(app.cwd, path))\n",
317
+ " File \"C:\\Users\\User\\anaconda3\\lib\\site-packages\\flask\\helpers.py\", line 610, in send_file\n",
318
+ " return werkzeug.utils.send_file(\n",
319
+ " File \"C:\\Users\\User\\anaconda3\\lib\\site-packages\\werkzeug\\utils.py\", line 427, in send_file\n",
320
+ " stat = os.stat(path)\n",
321
+ "FileNotFoundError: [WinError 2] The system cannot find the file specified: 'C:\\\\Users\\\\User\\\\Downloads\\\\monkey.jpg'\n",
322
+ "[2023-06-05 22:40:39,356] ERROR in app: Exception on /file/sailboat.jpg [GET]\n",
323
+ "Traceback (most recent call last):\n",
324
+ " File \"C:\\Users\\User\\anaconda3\\lib\\site-packages\\flask\\app.py\", line 2077, in wsgi_app\n",
325
+ " response = self.full_dispatch_request()\n",
326
+ " File \"C:\\Users\\User\\anaconda3\\lib\\site-packages\\flask\\app.py\", line 1525, in full_dispatch_request\n",
327
+ " rv = self.handle_user_exception(e)\n",
328
+ " File \"C:\\Users\\User\\anaconda3\\lib\\site-packages\\flask_cors\\extension.py\", line 165, in wrapped_function\n",
329
+ " return cors_after_request(app.make_response(f(*args, **kwargs)))\n",
330
+ " File \"C:\\Users\\User\\anaconda3\\lib\\site-packages\\flask\\app.py\", line 1523, in full_dispatch_request\n",
331
+ " rv = self.dispatch_request()\n",
332
+ " File \"C:\\Users\\User\\anaconda3\\lib\\site-packages\\flask\\app.py\", line 1509, in dispatch_request\n",
333
+ " return self.ensure_sync(self.view_functions[rule.endpoint])(**req.view_args)\n",
334
+ " File \"C:\\Users\\User\\anaconda3\\lib\\site-packages\\gradio\\networking.py\", line 269, in file\n",
335
+ " return send_file(os.path.join(app.cwd, path))\n",
336
+ " File \"C:\\Users\\User\\anaconda3\\lib\\site-packages\\flask\\helpers.py\", line 610, in send_file\n",
337
+ " return werkzeug.utils.send_file(\n",
338
+ " File \"C:\\Users\\User\\anaconda3\\lib\\site-packages\\werkzeug\\utils.py\", line 427, in send_file\n",
339
+ " stat = os.stat(path)\n",
340
+ "FileNotFoundError: [WinError 2] The system cannot find the file specified: 'C:\\\\Users\\\\User\\\\Downloads\\\\sailboat.jpg'\n",
341
+ "[2023-06-05 22:40:39,357] ERROR in app: Exception on /file/bicycle.jpg [GET]\n",
342
+ "Traceback (most recent call last):\n",
343
+ " File \"C:\\Users\\User\\anaconda3\\lib\\site-packages\\flask\\app.py\", line 2077, in wsgi_app\n",
344
+ " response = self.full_dispatch_request()\n",
345
+ " File \"C:\\Users\\User\\anaconda3\\lib\\site-packages\\flask\\app.py\", line 1525, in full_dispatch_request\n",
346
+ " rv = self.handle_user_exception(e)\n",
347
+ " File \"C:\\Users\\User\\anaconda3\\lib\\site-packages\\flask_cors\\extension.py\", line 165, in wrapped_function\n",
348
+ " return cors_after_request(app.make_response(f(*args, **kwargs)))\n",
349
+ " File \"C:\\Users\\User\\anaconda3\\lib\\site-packages\\flask\\app.py\", line 1523, in full_dispatch_request\n",
350
+ " rv = self.dispatch_request()\n",
351
+ " File \"C:\\Users\\User\\anaconda3\\lib\\site-packages\\flask\\app.py\", line 1509, in dispatch_request\n",
352
+ " return self.ensure_sync(self.view_functions[rule.endpoint])(**req.view_args)\n",
353
+ " File \"C:\\Users\\User\\anaconda3\\lib\\site-packages\\gradio\\networking.py\", line 269, in file\n",
354
+ " return send_file(os.path.join(app.cwd, path))\n",
355
+ " File \"C:\\Users\\User\\anaconda3\\lib\\site-packages\\flask\\helpers.py\", line 610, in send_file\n",
356
+ " return werkzeug.utils.send_file(\n",
357
+ " File \"C:\\Users\\User\\anaconda3\\lib\\site-packages\\werkzeug\\utils.py\", line 427, in send_file\n",
358
+ " stat = os.stat(path)\n",
359
+ "FileNotFoundError: [WinError 2] The system cannot find the file specified: 'C:\\\\Users\\\\User\\\\Downloads\\\\bicycle.jpg'\n"
360
+ ]
361
+ },
362
+ {
363
+ "data": {
364
+ "text/plain": [
365
+ "(<Flask 'gradio.networking'>, 'http://127.0.0.1:7861/', None)"
366
+ ]
367
+ },
368
+ "execution_count": 6,
369
+ "metadata": {},
370
+ "output_type": "execute_result"
371
+ }
372
+ ],
373
+ "source": [
374
+ "gr.Interface(\n",
375
+ " [classify_image_with_mobile_net, classify_image_with_inception_net],\n",
376
+ " imagein,\n",
377
+ " label,\n",
378
+ " title=\"MobileNet vs. InceptionNet\",\n",
379
+ " description=\"\"\"Let's compare 2 state-of-the-art machine learning models that classify images into one of 1,000 categories: MobileNet (top),\n",
380
+ " a lightweight model that has an accuracy of 0.704, vs. InceptionNet\n",
381
+ " (bottom), a much heavier model that has an accuracy of 0.779.\"\"\",\n",
382
+ " examples=sample_images).launch()"
383
+ ]
384
+ },
385
+ {
386
+ "cell_type": "code",
387
+ "execution_count": 6,
388
+ "id": "3dbc1cab",
389
+ "metadata": {},
390
+ "outputs": [
391
+ {
392
+ "name": "stdout",
393
+ "output_type": "stream",
394
+ "text": [
395
+ "Requirement already satisfied: transformers in c:\\users\\user\\anaconda3\\lib\\site-packages (4.27.0)\n",
396
+ "Requirement already satisfied: filelock in c:\\users\\user\\anaconda3\\lib\\site-packages (from transformers) (3.12.0)\n",
397
+ "Requirement already satisfied: huggingface-hub<1.0,>=0.11.0 in c:\\users\\user\\anaconda3\\lib\\site-packages (from transformers) (0.14.1)\n",
398
+ "Requirement already satisfied: numpy>=1.17 in c:\\users\\user\\anaconda3\\lib\\site-packages (from transformers) (1.24.3)\n",
399
+ "Requirement already satisfied: packaging>=20.0 in c:\\users\\user\\anaconda3\\lib\\site-packages (from transformers) (22.0)\n",
400
+ "Requirement already satisfied: pyyaml>=5.1 in c:\\users\\user\\anaconda3\\lib\\site-packages (from transformers) (6.0)\n",
401
+ "Requirement already satisfied: regex!=2019.12.17 in c:\\users\\user\\anaconda3\\lib\\site-packages (from transformers) (2022.7.9)\n",
402
+ "Requirement already satisfied: requests in c:\\users\\user\\anaconda3\\lib\\site-packages (from transformers) (2.28.1)\n",
403
+ "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in c:\\users\\user\\anaconda3\\lib\\site-packages (from transformers) (0.11.4)\n",
404
+ "Requirement already satisfied: tqdm>=4.27 in c:\\users\\user\\anaconda3\\lib\\site-packages (from transformers) (4.64.1)\n",
405
+ "Requirement already satisfied: fsspec in c:\\users\\user\\anaconda3\\lib\\site-packages (from huggingface-hub<1.0,>=0.11.0->transformers) (2022.11.0)\n",
406
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in c:\\users\\user\\anaconda3\\lib\\site-packages (from huggingface-hub<1.0,>=0.11.0->transformers) (4.4.0)\n",
407
+ "Requirement already satisfied: colorama in c:\\users\\user\\anaconda3\\lib\\site-packages (from tqdm>=4.27->transformers) (0.4.6)\n",
408
+ "Requirement already satisfied: charset-normalizer<3,>=2 in c:\\users\\user\\anaconda3\\lib\\site-packages (from requests->transformers) (2.0.4)\n",
409
+ "Requirement already satisfied: idna<4,>=2.5 in c:\\users\\user\\anaconda3\\lib\\site-packages (from requests->transformers) (3.4)\n",
410
+ "Requirement already satisfied: urllib3<1.27,>=1.21.1 in c:\\users\\user\\anaconda3\\lib\\site-packages (from requests->transformers) (1.26.14)\n",
411
+ "Requirement already satisfied: certifi>=2017.4.17 in c:\\users\\user\\anaconda3\\lib\\site-packages (from requests->transformers) (2022.12.7)\n",
412
+ "Note: you may need to restart the kernel to use updated packages.\n"
413
+ ]
414
+ },
415
+ {
416
+ "name": "stderr",
417
+ "output_type": "stream",
418
+ "text": [
419
+ "WARNING: Ignoring invalid distribution -orch (c:\\users\\user\\anaconda3\\lib\\site-packages)\n",
420
+ "WARNING: Ignoring invalid distribution -rotobuf (c:\\users\\user\\anaconda3\\lib\\site-packages)\n",
421
+ "WARNING: Ignoring invalid distribution -orch (c:\\users\\user\\anaconda3\\lib\\site-packages)\n",
422
+ "WARNING: Ignoring invalid distribution -rotobuf (c:\\users\\user\\anaconda3\\lib\\site-packages)\n"
423
+ ]
424
+ }
425
+ ],
426
+ "source": [
427
+ "pip install transformers"
428
+ ]
429
+ },
430
+ {
431
+ "cell_type": "code",
432
+ "execution_count": 6,
433
+ "id": "7deaaac7",
434
+ "metadata": {},
435
+ "outputs": [
436
+ {
437
+ "name": "stdout",
438
+ "output_type": "stream",
439
+ "text": [
440
+ "IMPORTANT: You are using gradio version 1.6.0, however version 3.14.0 is available, please upgrade.\n",
441
+ "--------\n",
442
+ "Running locally at: http://127.0.0.1:7861/\n",
443
+ "To create a public link, set `share=True` in `launch()`.\n",
444
+ "Interface loading below...\n"
445
+ ]
446
+ },
447
+ {
448
+ "data": {
449
+ "text/html": [
450
+ "\n",
451
+ " <iframe\n",
452
+ " width=\"1000\"\n",
453
+ " height=\"500\"\n",
454
+ " src=\"http://127.0.0.1:7861/\"\n",
455
+ " frameborder=\"0\"\n",
456
+ " allowfullscreen\n",
457
+ " \n",
458
+ " ></iframe>\n",
459
+ " "
460
+ ],
461
+ "text/plain": [
462
+ "<IPython.lib.display.IFrame at 0x22af27eb940>"
463
+ ]
464
+ },
465
+ "metadata": {},
466
+ "output_type": "display_data"
467
+ },
468
+ {
469
+ "data": {
470
+ "text/plain": [
471
+ "(<Flask 'gradio.networking'>, 'http://127.0.0.1:7861/', None)"
472
+ ]
473
+ },
474
+ "execution_count": 6,
475
+ "metadata": {},
476
+ "output_type": "execute_result"
477
+ }
478
+ ],
479
+ "source": [
480
+ "import gradio as gr\n",
481
+ "from transformers import AutoTokenizer, AutoModelForSequenceClassification\n",
482
+ "\n",
483
+ "# Load the models and tokenizers\n",
484
+ "from transformers import AutoTokenizer, AutoModelForSequenceClassification\n",
485
+ "\n",
486
+ "tokenizer1 = AutoTokenizer.from_pretrained(\"textattack/bert-base-uncased-imdb\")\n",
487
+ "tokenizer2 = AutoTokenizer.from_pretrained(\"nlptown/bert-base-multilingual-uncased-sentiment\")\n",
488
+ "model1 = AutoModelForSequenceClassification.from_pretrained(\"textattack/bert-base-uncased-imdb\")\n",
489
+ "model2 = AutoModelForSequenceClassification.from_pretrained(\"nlptown/bert-base-multilingual-uncased-sentiment\")\n",
490
+ "\n",
491
+ "\n",
492
+ "\n",
493
+ "\n",
494
+ "# Define the sentiment prediction functions\n",
495
+ "def predict_sentiment(text):\n",
496
+ " # Predict sentiment using model 1\n",
497
+ " inputs1 = tokenizer1.encode_plus(text, padding=\"longest\", truncation=True, return_tensors=\"pt\")\n",
498
+ " outputs1 = model1(**inputs1)\n",
499
+ " predicted_label1 = outputs1.logits.argmax().item()\n",
500
+ " sentiment1 = \"Positive\" if predicted_label1 == 1 else \"Negative\" if predicted_label1 == 0 else \"Neutral\"\n",
501
+ "\n",
502
+ " # Predict sentiment using model 2\n",
503
+ " inputs2 = tokenizer2.encode_plus(text, padding=\"longest\", truncation=True, return_tensors=\"pt\")\n",
504
+ " outputs2 = model2(**inputs2)\n",
505
+ " predicted_label2 = outputs2.logits.argmax().item()\n",
506
+ " sentiment2 = \"Positive\" if predicted_label2 == 1 else \"Negative\" if predicted_label2 == 0 else \"Neutral\"\n",
507
+ "\n",
508
+ " return sentiment1, sentiment2\n",
509
+ "\n",
510
+ "# Create the Gradio interface\n",
511
+ "iface = gr.Interface(\n",
512
+ " fn=predict_sentiment,\n",
513
+ " inputs=\"text\",\n",
514
+ " outputs=[\"text\", \"text\"],\n",
515
+ " title=\"Sentiment Analysis (Model 1 vs Model 2)\",\n",
516
+ " description=\"Compare sentiment predictions from two models.\",\n",
517
+ ")\n",
518
+ "\n",
519
+ "# Launch the interface\n",
520
+ "iface.launch()\n"
521
+ ]
522
+ },
523
+ {
524
+ "cell_type": "code",
525
+ "execution_count": 17,
526
+ "id": "bd93f2a5",
527
+ "metadata": {},
528
+ "outputs": [
529
+ {
530
+ "name": "stderr",
531
+ "output_type": "stream",
532
+ "text": [
533
+ "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight']\n",
534
+ "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
535
+ "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
536
+ "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
537
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
538
+ "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.bias', 'vocab_projector.weight']\n",
539
+ "- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
540
+ "- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
541
+ "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']\n",
542
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
543
+ ]
544
+ },
545
+ {
546
+ "name": "stdout",
547
+ "output_type": "stream",
548
+ "text": [
549
+ "IMPORTANT: You are using gradio version 1.6.0, however version 3.14.0 is available, please upgrade.\n",
550
+ "--------\n",
551
+ "Running locally at: http://127.0.0.1:7871/\n",
552
+ "To create a public link, set `share=True` in `launch()`.\n",
553
+ "Interface loading below...\n"
554
+ ]
555
+ },
556
+ {
557
+ "data": {
558
+ "text/html": [
559
+ "\n",
560
+ " <iframe\n",
561
+ " width=\"1000\"\n",
562
+ " height=\"500\"\n",
563
+ " src=\"http://127.0.0.1:7871/\"\n",
564
+ " frameborder=\"0\"\n",
565
+ " allowfullscreen\n",
566
+ " \n",
567
+ " ></iframe>\n",
568
+ " "
569
+ ],
570
+ "text/plain": [
571
+ "<IPython.lib.display.IFrame at 0x22a82329c30>"
572
+ ]
573
+ },
574
+ "metadata": {},
575
+ "output_type": "display_data"
576
+ },
577
+ {
578
+ "data": {
579
+ "text/plain": [
580
+ "(<Flask 'gradio.networking'>, 'http://127.0.0.1:7871/', None)"
581
+ ]
582
+ },
583
+ "execution_count": 17,
584
+ "metadata": {},
585
+ "output_type": "execute_result"
586
+ },
587
+ {
588
+ "name": "stderr",
589
+ "output_type": "stream",
590
+ "text": [
591
+ "[2023-06-05 21:25:10,327] ERROR in app: Exception on /api/predict/ [POST]\n",
592
+ "Traceback (most recent call last):\n",
593
+ " File \"C:\\Users\\User\\anaconda3\\lib\\site-packages\\flask\\app.py\", line 2077, in wsgi_app\n",
594
+ " response = self.full_dispatch_request()\n",
595
+ " File \"C:\\Users\\User\\anaconda3\\lib\\site-packages\\flask\\app.py\", line 1525, in full_dispatch_request\n",
596
+ " rv = self.handle_user_exception(e)\n",
597
+ " File \"C:\\Users\\User\\anaconda3\\lib\\site-packages\\flask_cors\\extension.py\", line 165, in wrapped_function\n",
598
+ " return cors_after_request(app.make_response(f(*args, **kwargs)))\n",
599
+ " File \"C:\\Users\\User\\anaconda3\\lib\\site-packages\\flask\\app.py\", line 1523, in full_dispatch_request\n",
600
+ " rv = self.dispatch_request()\n",
601
+ " File \"C:\\Users\\User\\anaconda3\\lib\\site-packages\\flask\\app.py\", line 1509, in dispatch_request\n",
602
+ " return self.ensure_sync(self.view_functions[rule.endpoint])(**req.view_args)\n",
603
+ " File \"C:\\Users\\User\\anaconda3\\lib\\site-packages\\gradio\\networking.py\", line 133, in predict\n",
604
+ " prediction, durations = app.interface.process(raw_input)\n",
605
+ " File \"C:\\Users\\User\\anaconda3\\lib\\site-packages\\gradio\\interface.py\", line 272, in process\n",
606
+ " predictions, durations = self.run_prediction(processed_input, return_duration=True)\n",
607
+ " File \"C:\\Users\\User\\anaconda3\\lib\\site-packages\\gradio\\interface.py\", line 246, in run_prediction\n",
608
+ " prediction = predict_fn(*processed_input)\n",
609
+ " File \"C:\\Users\\User\\AppData\\Local\\Temp\\ipykernel_9376\\3704131587.py\", line 80, in classify_image\n",
610
+ " prediction = predict(image_file=image_file, model_key=model_key)\n",
611
+ " File \"C:\\Users\\User\\AppData\\Local\\Temp\\ipykernel_9376\\3704131587.py\", line 67, in predict\n",
612
+ " image = preprocess(image_file)\n",
613
+ " File \"C:\\Users\\User\\AppData\\Local\\Temp\\ipykernel_9376\\3704131587.py\", line 51, in preprocess\n",
614
+ " image = Image.open(BytesIO(image_file.read())).convert(\"RGB\")\n",
615
+ " File \"C:\\Users\\User\\anaconda3\\lib\\site-packages\\PIL\\Image.py\", line 3283, in open\n",
616
+ " raise UnidentifiedImageError(msg)\n",
617
+ "PIL.UnidentifiedImageError: cannot identify image file <_io.BytesIO object at 0x0000022ABA1C10D0>\n",
618
+ "[2023-06-05 21:39:36,773] ERROR in app: Exception on /api/predict/ [POST]\n",
619
+ "Traceback (most recent call last):\n",
620
+ " File \"C:\\Users\\User\\anaconda3\\lib\\site-packages\\flask\\app.py\", line 2077, in wsgi_app\n",
621
+ " response = self.full_dispatch_request()\n",
622
+ " File \"C:\\Users\\User\\anaconda3\\lib\\site-packages\\flask\\app.py\", line 1525, in full_dispatch_request\n",
623
+ " rv = self.handle_user_exception(e)\n",
624
+ " File \"C:\\Users\\User\\anaconda3\\lib\\site-packages\\flask_cors\\extension.py\", line 165, in wrapped_function\n",
625
+ " return cors_after_request(app.make_response(f(*args, **kwargs)))\n",
626
+ " File \"C:\\Users\\User\\anaconda3\\lib\\site-packages\\flask\\app.py\", line 1523, in full_dispatch_request\n",
627
+ " rv = self.dispatch_request()\n",
628
+ " File \"C:\\Users\\User\\anaconda3\\lib\\site-packages\\flask\\app.py\", line 1509, in dispatch_request\n",
629
+ " return self.ensure_sync(self.view_functions[rule.endpoint])(**req.view_args)\n",
630
+ " File \"C:\\Users\\User\\anaconda3\\lib\\site-packages\\gradio\\networking.py\", line 133, in predict\n",
631
+ " prediction, durations = app.interface.process(raw_input)\n",
632
+ " File \"C:\\Users\\User\\anaconda3\\lib\\site-packages\\gradio\\interface.py\", line 270, in process\n",
633
+ " processed_input = [input_interface.preprocess(raw_input[i])\n",
634
+ " File \"C:\\Users\\User\\anaconda3\\lib\\site-packages\\gradio\\interface.py\", line 270, in <listcomp>\n",
635
+ " processed_input = [input_interface.preprocess(raw_input[i])\n",
636
+ "IndexError: list index out of range\n"
637
+ ]
638
+ }
639
+ ],
640
+ "source": [
641
+ "import gradio as gr\n",
642
+ "from transformers import AutoTokenizer, AutoModelForSequenceClassification\n",
643
+ "import torch\n",
644
+ "from torchvision import transforms\n",
645
+ "from io import BytesIO\n",
646
+ "from PIL import Image\n",
647
+ "\n",
648
+ "# Define the available models and datasets\n",
649
+ "models = {\n",
650
+ " \"Model 1\": {\n",
651
+ " \"model_name\": \"bert-base-uncased\",\n",
652
+ " \"tokenizer\": None,\n",
653
+ " \"model\": None\n",
654
+ " },\n",
655
+ " \"Model 2\": {\n",
656
+ " \"model_name\": \"distilbert-base-uncased\",\n",
657
+ " \"tokenizer\": None,\n",
658
+ " \"model\": None\n",
659
+ " },\n",
660
+ " # Add more models as needed\n",
661
+ "}\n",
662
+ "\n",
663
+ "datasets = {\n",
664
+ " \"Dataset 1\": {\n",
665
+ " \"name\": \"imdb\",\n",
666
+ " \"split\": \"test\",\n",
667
+ " \"features\": [\"text\"],\n",
668
+ " },\n",
669
+ " \"Dataset 2\": {\n",
670
+ " \"name\": \"ag_news\",\n",
671
+ " \"split\": \"test\",\n",
672
+ " \"features\": [\"text\"],\n",
673
+ " },\n",
674
+ " # Add more datasets as needed\n",
675
+ "}\n",
676
+ "\n",
677
+ "# Load models\n",
678
+ "for model_key, model_info in models.items():\n",
679
+ " tokenizer = AutoTokenizer.from_pretrained(model_info[\"model_name\"])\n",
680
+ " model = AutoModelForSequenceClassification.from_pretrained(model_info[\"model_name\"])\n",
681
+ " model_info[\"tokenizer\"] = tokenizer\n",
682
+ " model_info[\"model\"] = model\n",
683
+ "\n",
684
+ "# Set the device to GPU if available\n",
685
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
686
+ "for model_info in models.values():\n",
687
+ " model_info[\"model\"].to(device)\n",
688
+ "\n",
689
+ "# Define the preprocessing function\n",
690
+ "def preprocess(image_file):\n",
691
+ " image = Image.open(BytesIO(image_file.read())).convert(\"RGB\")\n",
692
+ " preprocess_transform = transforms.Compose([\n",
693
+ " transforms.Resize((224, 224)),\n",
694
+ " transforms.ToTensor(),\n",
695
+ " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n",
696
+ " ])\n",
697
+ " image = preprocess_transform(image)\n",
698
+ " image = image.unsqueeze(0)\n",
699
+ " return image.to(device)\n",
700
+ "\n",
701
+ "# Define the prediction function\n",
702
+ "def predict(image_file, model_key):\n",
703
+ " model_info = models[model_key]\n",
704
+ " tokenizer = model_info[\"tokenizer\"]\n",
705
+ " model = model_info[\"model\"]\n",
706
+ "\n",
707
+ " image = preprocess(image_file)\n",
708
+ "\n",
709
+ " with torch.no_grad():\n",
710
+ " outputs = model(image)\n",
711
+ "\n",
712
+ " predictions = outputs.logits.argmax(dim=1)\n",
713
+ "\n",
714
+ " return predictions.item()\n",
715
+ "\n",
716
+ "def classify_image(image, model_key):\n",
717
+ " image = Image.fromarray(image.astype('uint8'), 'RGB')\n",
718
+ " image_file = BytesIO()\n",
719
+ " image.save(image_file, format=\"JPEG\")\n",
720
+ " prediction = predict(image_file=image_file, model_key=model_key)\n",
721
+ " return prediction\n",
722
+ "\n",
723
+ "iface = gr.Interface(fn=classify_image,\n",
724
+ " inputs=[\"image\", gr.inputs.Dropdown(list(models.keys()), label=\"Model\")],\n",
725
+ " outputs=\"text\",\n",
726
+ " title=\"Image Classification\",\n",
727
+ " description=\"Classify images using Hugging Face models\")\n",
728
+ "\n",
729
+ "iface.launch()\n"
730
+ ]
731
+ },
732
+ {
733
+ "cell_type": "code",
734
+ "execution_count": null,
735
+ "id": "c25d39b5",
736
+ "metadata": {},
737
+ "outputs": [],
738
+ "source": []
739
+ }
740
+ ],
741
+ "metadata": {
742
+ "kernelspec": {
743
+ "display_name": "Python 3 (ipykernel)",
744
+ "language": "python",
745
+ "name": "python3"
746
+ },
747
+ "language_info": {
748
+ "codemirror_mode": {
749
+ "name": "ipython",
750
+ "version": 3
751
+ },
752
+ "file_extension": ".py",
753
+ "mimetype": "text/x-python",
754
+ "name": "python",
755
+ "nbconvert_exporter": "python",
756
+ "pygments_lexer": "ipython3",
757
+ "version": "3.10.9"
758
+ }
759
+ },
760
+ "nbformat": 4,
761
+ "nbformat_minor": 5
762
+ }