Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,128 +1,48 @@
|
|
1 |
import gradio as gr
|
2 |
-
import torch
|
3 |
-
from transformers.generation.streamers import TextIteratorStreamer
|
4 |
-
from PIL import Image
|
5 |
import requests
|
6 |
-
|
7 |
-
from
|
8 |
-
import
|
9 |
-
|
10 |
-
# 导入 LLaVA 相关模块
|
11 |
-
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
|
12 |
-
from llava.conversation import conv_templates, SeparatorStyle
|
13 |
-
from llava.model.builder import load_pretrained_model
|
14 |
-
from llava.utils import disable_torch_init
|
15 |
-
from llava.mm_utils import tokenizer_image_token
|
16 |
-
|
17 |
-
# **确保 Hugging Face 缓存目录正确**
|
18 |
-
os.environ["HUGGINGFACE_HUB_CACHE"] = os.getcwd() + "/weights"
|
19 |
-
|
20 |
-
# **加载 LLaVA-1.5-13B**
|
21 |
-
disable_torch_init()
|
22 |
-
model_id = "Yanqing0327/LLaVA-project" # 替换为你的 Hugging Face 模型仓库
|
23 |
-
tokenizer, model, image_processor, context_len = load_pretrained_model(
|
24 |
-
model_id, model_name="llava-v1.5-13b", model_base=None, load_8bit=False, load_4bit=False, device_map="auto"
|
25 |
-
)
|
26 |
-
|
27 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
28 |
-
model = model.to(device)
|
29 |
-
|
30 |
-
def load_image(image_file):
|
31 |
-
"""确保 image 是 `PIL.Image`"""
|
32 |
-
if isinstance(image_file, Image.Image):
|
33 |
-
return image_file.convert("RGB") # 直接返回 `PIL.Image`
|
34 |
-
|
35 |
-
elif isinstance(image_file, str) and (image_file.startswith('http') or image_file.startswith('https')):
|
36 |
-
response = requests.get(image_file)
|
37 |
-
return Image.open(BytesIO(response.content)).convert('RGB')
|
38 |
|
39 |
-
|
40 |
-
|
41 |
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
-
def llava_infer(image, text
|
44 |
-
"""
|
45 |
if image is None or text.strip() == "":
|
46 |
return "请提供图片和文本输入"
|
47 |
|
48 |
-
|
49 |
-
|
50 |
-
image_tensor = image_processor.preprocess(image_data, return_tensors='pt')['pixel_values']
|
51 |
-
|
52 |
-
# **确保数据在正确的设备上**
|
53 |
-
image_tensor = image_tensor.to(device)
|
54 |
-
if torch.cuda.is_available():
|
55 |
-
image_tensor = image_tensor.half() # GPU: 用 float16
|
56 |
-
else:
|
57 |
-
image_tensor = image_tensor.float() # CPU: 用 float32
|
58 |
-
|
59 |
-
# **处理对话**
|
60 |
-
conv_mode = "llava_v1"
|
61 |
-
conv = conv_templates[conv_mode].copy()
|
62 |
-
|
63 |
-
# 生成输入文本,添加特殊 token
|
64 |
-
inp = DEFAULT_IMAGE_TOKEN + '\n' + text
|
65 |
-
conv.append_message(conv.roles[0], inp)
|
66 |
-
conv.append_message(conv.roles[1], None)
|
67 |
-
|
68 |
-
prompt = conv.get_prompt()
|
69 |
-
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(device)
|
70 |
-
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
71 |
-
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, timeout=20.0)
|
72 |
-
|
73 |
-
# **执行推理**
|
74 |
-
with torch.inference_mode():
|
75 |
-
thread = Thread(target=model.generate, kwargs=dict(
|
76 |
-
inputs=input_ids,
|
77 |
-
images=image_tensor,
|
78 |
-
do_sample=True,
|
79 |
-
temperature=temperature,
|
80 |
-
top_p=top_p,
|
81 |
-
max_new_tokens=max_tokens,
|
82 |
-
streamer=streamer,
|
83 |
-
use_cache=True
|
84 |
-
))
|
85 |
-
thread.start()
|
86 |
-
|
87 |
-
response = ""
|
88 |
-
prepend_space = False
|
89 |
-
for new_text in streamer:
|
90 |
-
if new_text == " ":
|
91 |
-
prepend_space = True
|
92 |
-
continue
|
93 |
-
if new_text.endswith(stop_str):
|
94 |
-
new_text = new_text[:-len(stop_str)].strip()
|
95 |
-
prepend_space = False
|
96 |
-
elif prepend_space:
|
97 |
-
new_text = " " + new_text
|
98 |
-
prepend_space = False
|
99 |
-
response += new_text
|
100 |
-
if prepend_space:
|
101 |
-
response += " "
|
102 |
-
|
103 |
-
thread.join()
|
104 |
-
|
105 |
-
return response
|
106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
|
108 |
-
#
|
109 |
-
with gr.Blocks(title="LLaVA
|
110 |
-
gr.Markdown("# 🌋 LLaVA
|
111 |
-
gr.Markdown("上传图片并输入文本,LLaVA
|
112 |
|
113 |
with gr.Row():
|
114 |
with gr.Column(scale=3):
|
115 |
image_input = gr.Image(type="pil", label="上传图片")
|
116 |
text_input = gr.Textbox(placeholder="输入文本...", label="输入文本")
|
117 |
-
temperature = gr.Slider(0.0, 1.0, value=0.2, step=0.05, label="Temperature")
|
118 |
-
top_p = gr.Slider(0.0, 1.0, value=1.0, step=0.05, label="Top P")
|
119 |
-
max_tokens = gr.Slider(10, 1024, value=512, step=10, label="Max Tokens")
|
120 |
submit_button = gr.Button("提交")
|
121 |
|
122 |
with gr.Column(scale=7):
|
123 |
chatbot_output = gr.Textbox(label="LLaVA 输出", interactive=False)
|
124 |
|
125 |
-
submit_button.click(fn=llava_infer, inputs=[image_input, text_input
|
126 |
|
127 |
-
# **启动
|
128 |
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
|
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
2 |
import requests
|
3 |
+
import base64
|
4 |
+
from PIL import Image
|
5 |
+
import io
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
+
# **本地 GPU 服务器 API**
|
8 |
+
LOCAL_SERVER_URL = "http://169.233.7.2:5000/infer"
|
9 |
|
10 |
+
def image_to_base64(image):
|
11 |
+
"""PIL Image -> Base64"""
|
12 |
+
buffer = io.BytesIO()
|
13 |
+
image.save(buffer, format="PNG")
|
14 |
+
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
15 |
|
16 |
+
def llava_infer(image, text):
|
17 |
+
"""把用户输入的图片+文本发送到本地服务器"""
|
18 |
if image is None or text.strip() == "":
|
19 |
return "请提供图片和文本输入"
|
20 |
|
21 |
+
image_base64 = image_to_base64(image)
|
22 |
+
payload = {"image": image_base64, "text": text}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
+
try:
|
25 |
+
response = requests.post(LOCAL_SERVER_URL, json=payload)
|
26 |
+
response_data = response.json()
|
27 |
+
return response_data["response"]
|
28 |
+
except Exception as e:
|
29 |
+
return f"服务器错误: {e}"
|
30 |
|
31 |
+
# **Gradio Web UI**
|
32 |
+
with gr.Blocks(title="LLaVA Remote Web UI") as demo:
|
33 |
+
gr.Markdown("# 🌋 LLaVA Web Interface (Remote Inference)")
|
34 |
+
gr.Markdown("上传图片并输入文本,LLaVA 将在远程 GPU 服务器推理")
|
35 |
|
36 |
with gr.Row():
|
37 |
with gr.Column(scale=3):
|
38 |
image_input = gr.Image(type="pil", label="上传图片")
|
39 |
text_input = gr.Textbox(placeholder="输入文本...", label="输入文本")
|
|
|
|
|
|
|
40 |
submit_button = gr.Button("提交")
|
41 |
|
42 |
with gr.Column(scale=7):
|
43 |
chatbot_output = gr.Textbox(label="LLaVA 输出", interactive=False)
|
44 |
|
45 |
+
submit_button.click(fn=llava_infer, inputs=[image_input, text_input], outputs=chatbot_output)
|
46 |
|
47 |
+
# **启动 Hugging Face Web UI**
|
48 |
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
|