seung275 commited on
Commit
83c2e7f
·
verified ·
1 Parent(s): 11a6c62

Upload 9 files

Browse files
Files changed (10) hide show
  1. .gitattributes +2 -0
  2. README.md +3 -10
  3. app.py +245 -0
  4. capsule_crack.png +3 -0
  5. carpet_normal.jpg +0 -0
  6. ffffff.png +0 -0
  7. gitattributes +40 -0
  8. hazelnut_cut.png +3 -0
  9. header.py +35 -0
  10. requirements.txt +29 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ capsule_crack.png filter=lfs diff=lfs merge=lfs -text
37
+ hazelnut_cut.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,12 +1,5 @@
1
  ---
2
- title: AnomalyGPT1
3
- emoji: 😻
4
- colorFrom: yellow
5
- colorTo: purple
6
  sdk: gradio
7
- sdk_version: 5.25.2
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ license: cc-by-sa-4.0
3
+ title: AnomalyGPT
 
 
4
  sdk: gradio
5
+ ---
 
 
 
 
 
app.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.system("cp /home/user/.pyenv/versions/3.10.13/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda118.so /home/user/.pyenv/versions/3.10.13/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cpu.so")
4
+
5
+
6
+ import gradio as gr
7
+ import mdtex2html
8
+ from model.openllama import OpenLLAMAPEFTModel
9
+ import torch
10
+ from io import BytesIO
11
+ from PIL import Image as PILImage
12
+ import cv2
13
+ import numpy as np
14
+ from matplotlib import pyplot as plt
15
+ from torchvision import transforms
16
+
17
+ # init the model
18
+ args = {
19
+ 'model': 'openllama_peft',
20
+ 'imagebind_ckpt_path': './pretrained_ckpt/imagebind_ckpt/imagebind_huge.pth',
21
+ 'vicuna_ckpt_path': './pretrained_ckpt/vicuna_ckpt/7b_v0',
22
+ 'anomalygpt_ckpt_path': './ckpt/train_supervised/pytorch_model.pt',
23
+ 'delta_ckpt_path': './pretrained_ckpt/pandagpt_ckpt/7b/pytorch_model.pt',
24
+ 'stage': 2,
25
+ 'max_tgt_len': 128,
26
+ 'lora_r': 32,
27
+ 'lora_alpha': 32,
28
+ 'lora_dropout': 0.1
29
+ }
30
+
31
+ model = OpenLLAMAPEFTModel(**args)
32
+ delta_ckpt = torch.load(args['delta_ckpt_path'], map_location=torch.device('cpu'))
33
+ model.load_state_dict(delta_ckpt, strict=False)
34
+ delta_ckpt = torch.load(args['anomalygpt_ckpt_path'], map_location=torch.device('cpu'))
35
+ model.load_state_dict(delta_ckpt, strict=False)
36
+ model = model.eval()#.half()#.cuda()
37
+ # model.image_decoder = model.image_decoder.cuda()
38
+ # model.prompt_learner = model.prompt_learner.cuda()
39
+
40
+ """Override Chatbot.postprocess"""
41
+ def postprocess(self, y):
42
+ if y is None:
43
+ return []
44
+ for i, (message, response) in enumerate(y):
45
+ y[i] = (
46
+ None if message is None else mdtex2html.convert((message)),
47
+ None if response is None else mdtex2html.convert(response),
48
+ )
49
+ return y
50
+
51
+
52
+ gr.Chatbot.postprocess = postprocess
53
+
54
+
55
+ def parse_text(text):
56
+ """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
57
+ lines = text.split("\n")
58
+ lines = [line for line in lines if line != ""]
59
+ count = 0
60
+ for i, line in enumerate(lines):
61
+ if "```" in line:
62
+ count += 1
63
+ items = line.split('`')
64
+ if count % 2 == 1:
65
+ lines[i] = f'<pre><code class="language-{items[-1]}">'
66
+ else:
67
+ lines[i] = f'<br></code></pre>'
68
+ else:
69
+ if i > 0:
70
+ if count % 2 == 1:
71
+ line = line.replace("`", "\`")
72
+ line = line.replace("<", "&lt;")
73
+ line = line.replace(">", "&gt;")
74
+ line = line.replace(" ", "&nbsp;")
75
+ line = line.replace("*", "&ast;")
76
+ line = line.replace("_", "&lowbar;")
77
+ line = line.replace("-", "&#45;")
78
+ line = line.replace(".", "&#46;")
79
+ line = line.replace("!", "&#33;")
80
+ line = line.replace("(", "&#40;")
81
+ line = line.replace(")", "&#41;")
82
+ line = line.replace("$", "&#36;")
83
+ lines[i] = "<br>"+line
84
+ text = "".join(lines)
85
+ return text
86
+
87
+
88
+ def predict(
89
+ input,
90
+ image_path,
91
+ normal_img_path,
92
+ chatbot,
93
+ max_length,
94
+ top_p,
95
+ temperature,
96
+ history,
97
+ modality_cache,
98
+ ):
99
+
100
+ if image_path is None and normal_img_path is None:
101
+ return [(input, "There is no input data provided! Please upload your data and start the conversation.")]
102
+ else:
103
+ print(f'[!] image path: {image_path}\n[!] normal image path: {normal_img_path}\n')
104
+
105
+ # prepare the prompt
106
+ prompt_text = ''
107
+ for idx, (q, a) in enumerate(history):
108
+ if idx == 0:
109
+ prompt_text += f'{q}\n### Assistant: {a}\n###'
110
+ else:
111
+ prompt_text += f' Human: {q}\n### Assistant: {a}\n###'
112
+ if len(history) == 0:
113
+ prompt_text += f'{input}'
114
+ else:
115
+ prompt_text += f' Human: {input}'
116
+
117
+ response, pixel_output = model.generate({
118
+ 'prompt': prompt_text,
119
+ 'image_paths': [image_path] if image_path else [],
120
+ 'normal_img_paths': [normal_img_path] if normal_img_path else [],
121
+ 'audio_paths': [],
122
+ 'video_paths': [],
123
+ 'thermal_paths': [],
124
+ 'top_p': top_p,
125
+ 'temperature': temperature,
126
+ 'max_tgt_len': max_length,
127
+ 'modality_embeds': modality_cache
128
+ },web_demo=True)
129
+ chatbot.append((parse_text(input), parse_text(response)))
130
+ history.append((input, response))
131
+
132
+
133
+ plt.imshow(pixel_output.to(torch.float16).reshape(224,224).detach().cpu(), cmap='binary_r')
134
+ plt.axis('off')
135
+ plt.savefig('output.png',bbox_inches='tight',pad_inches = 0)
136
+
137
+ target_size = 435
138
+ original_width, original_height = PILImage.open(image_path).size
139
+ if original_width > original_height:
140
+ new_width = target_size
141
+ new_height = int(target_size * (original_height / original_width))
142
+ else:
143
+ new_height = target_size
144
+ new_width = int(target_size * (original_width / original_height))
145
+
146
+ new_image = PILImage.new('L', (target_size, target_size), 255) # 'L' mode for grayscale
147
+
148
+ paste_x = (target_size - new_width) // 2
149
+ paste_y = (target_size - new_height) // 2
150
+
151
+ pixel_output = PILImage.open('output.png').resize((new_width, new_height), PILImage.LANCZOS)
152
+
153
+ new_image.paste(pixel_output, (paste_x, paste_y))
154
+
155
+ new_image.save('output.png')
156
+
157
+ image = cv2.imread('output.png', cv2.IMREAD_GRAYSCALE)
158
+ kernel = np.ones((3, 3), np.uint8)
159
+ eroded_image = cv2.erode(image, kernel, iterations=1)
160
+ cv2.imwrite('output.png', eroded_image)
161
+
162
+ output = PILImage.open('output.png').convert('L')
163
+
164
+
165
+ return chatbot, history, modality_cache, output
166
+
167
+
168
+
169
+ def reset_user_input():
170
+ return gr.update(value='')
171
+
172
+
173
+ def reset_state():
174
+ return gr.update(value=''), None, None, [], [], [], PILImage.open('ffffff.png')
175
+
176
+ examples = ['hazelnut_cut.png','capsule_crack.png','carpet_normal.jpg']
177
+
178
+ with gr.Blocks() as demo:
179
+ gr.HTML("""<h1 align="center">Demo of AnomalyGPT</h1>""")
180
+
181
+ with gr.Row():
182
+ with gr.Column(scale=1):
183
+ with gr.Row():
184
+ image_path = gr.Image(type="filepath", label="Query Image", value=examples[0])
185
+ with gr.Row():
186
+ normal_img_path = gr.Image(type="filepath", label="Normal Image (optional)", value=None)
187
+ with gr.Row():
188
+ gr.Examples(examples=examples, inputs=[image_path])
189
+ with gr.Row():
190
+ max_length = gr.Slider(0, 512, value=512, step=1.0, label="Max length", interactive=True)
191
+ top_p = gr.Slider(0, 1, value=0.01, step=0.01, label="Top P", interactive=True)
192
+ temperature = gr.Slider(0, 1, value=1.0, step=0.01, label="Temperature", interactive=True)
193
+
194
+
195
+ with gr.Column(scale=3):
196
+ with gr.Row():
197
+ with gr.Column(scale=6):
198
+ chatbot = gr.Chatbot().style(height=440)
199
+ with gr.Column(scale=4):
200
+ # gr.Image(output)
201
+ image_output = gr.Image(interactive=False, label="Localization Output", type='pil',value=PILImage.open('ffffff.png'))
202
+ with gr.Row():
203
+ user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=12).style(container=False)
204
+ with gr.Row():
205
+ with gr.Column(scale=2):
206
+ submitBtn = gr.Button("Submit", variant="primary")
207
+ with gr.Column(scale=1):
208
+ emptyBtn = gr.Button("Clear History")
209
+
210
+ history = gr.State([])
211
+ modality_cache = gr.State([])
212
+
213
+ submitBtn.click(
214
+ predict, [
215
+ user_input,
216
+ image_path,
217
+ normal_img_path,
218
+ chatbot,
219
+ max_length,
220
+ top_p,
221
+ temperature,
222
+ history,
223
+ modality_cache,
224
+ ], [
225
+ chatbot,
226
+ history,
227
+ modality_cache,
228
+ image_output
229
+ ],
230
+ show_progress=True
231
+ )
232
+
233
+ submitBtn.click(reset_user_input, [], [user_input])
234
+ emptyBtn.click(reset_state, outputs=[
235
+ user_input,
236
+ image_path,
237
+ normal_img_path,
238
+ chatbot,
239
+ history,
240
+ modality_cache,
241
+ image_output
242
+ ], show_progress=True)
243
+
244
+
245
+ demo.queue().launch()
capsule_crack.png ADDED

Git LFS Details

  • SHA256: dd07c258e465acf0dc3770da851f3671fb4721df60bc460e053a95b9b21acccb
  • Pointer size: 132 Bytes
  • Size of remote file: 1.18 MB
carpet_normal.jpg ADDED
ffffff.png ADDED
gitattributes ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ .bin filter=lfs diff=lfs merge=lfs -text
37
+ .pt filter=lfs diff=lfs merge=lfs -text
38
+ .pth filter=lfs diff=lfs merge=lfs -text
39
+ hazelnut_cut.png filter=lfs diff=lfs merge=lfs -text
40
+ capsule_crack.png filter=lfs diff=lfs merge=lfs -text
hazelnut_cut.png ADDED

Git LFS Details

  • SHA256: cd5d45c2c2a12aa99dac4e084a91fa21948238f660a70578dd28c34f5bb7325c
  • Pointer size: 132 Bytes
  • Size of remote file: 1.25 MB
header.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import datetime
3
+ import types
4
+ import deepspeed
5
+ from transformers.deepspeed import HfDeepSpeedConfig
6
+ import transformers
7
+ import numpy as np
8
+ from collections import OrderedDict
9
+ from torch.utils.data import Dataset, DataLoader
10
+ from torch.nn.utils import clip_grad_norm_
11
+ from torch.cuda.amp import autocast, GradScaler
12
+ from torch.nn import DataParallel
13
+ from torch.optim import lr_scheduler
14
+ import torch.optim as optim
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from tqdm import tqdm
18
+ import os
19
+ import re
20
+ import math
21
+ import random
22
+ import json
23
+ import time
24
+ import logging
25
+ from copy import deepcopy
26
+ import ipdb
27
+ import argparse
28
+ from model.ImageBind import data
29
+ from transformers import LlamaTokenizer, LlamaForCausalLM, LlamaConfig
30
+ from torch.nn.utils.rnn import pad_sequence
31
+ from peft import LoraConfig, TaskType, get_peft_model
32
+
33
+ logging.getLogger("transformers").setLevel(logging.WARNING)
34
+ logging.getLogger("transformers.tokenization_utils").setLevel(logging.ERROR)
35
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
requirements.txt ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ deepspeed==0.9.2
2
+ easydict==1.10
3
+ einops==0.6.1
4
+ ftfy==6.1.1
5
+ gradio==3.41.2
6
+ h5py==3.9.0
7
+ iopath==0.1.10
8
+ ipdb==0.13.13
9
+ kornia==0.7.0
10
+ matplotlib==3.7.2
11
+ mdtex2html==1.2.0
12
+ numpy==1.24.3
13
+ open3d_python==0.3.0.0
14
+ opencv_python==4.8.0.74
15
+ peft==0.3.0
16
+ Pillow==10.0.0
17
+ pytorchvideo==0.1.5
18
+ PyYAML==6.0.1
19
+ regex==2022.10.31
20
+ timm==0.6.7
21
+ torch==1.13.1
22
+ torchaudio==0.13.1
23
+ torchvision==0.14.1
24
+ tqdm==4.64.1
25
+ transformers==4.30.2
26
+ sentencepiece
27
+ accelerate==0.21.0
28
+ bitsandbytes==0.41.1
29
+ scipy