Spaces:
Sleeping
Sleeping
File size: 4,720 Bytes
4878460 e120414 330b634 e120414 330b634 e120414 330b634 e120414 f04a169 330b634 f04a169 330b634 e6b49c3 330b634 e6b49c3 330b634 e120414 330b634 dd68e80 330b634 e120414 330b634 4878460 330b634 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import gradio as gr
import torch
from transformers.generation.streamers import TextIteratorStreamer
from PIL import Image
import requests
from io import BytesIO
from threading import Thread
import os
# 导入 LLaVA 相关模块
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token
# **确保 Hugging Face 缓存目录正确**
os.environ["HUGGINGFACE_HUB_CACHE"] = os.getcwd() + "/weights"
# **加载 LLaVA-1.5-13B**
disable_torch_init()
model_id = "Yanqing0327/LLaVA-project" # 替换为你的 Hugging Face 模型仓库
tokenizer, model, image_processor, context_len = load_pretrained_model(
model_id, model_name="llava-v1.5-13b", model_base=None, load_8bit=False, load_4bit=False
)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
def load_image(image_file):
"""确保 image 是 `PIL.Image`"""
if isinstance(image_file, Image.Image):
return image_file.convert("RGB") # 直接返回 `PIL.Image`
elif isinstance(image_file, str) and (image_file.startswith('http') or image_file.startswith('https')):
response = requests.get(image_file)
return Image.open(BytesIO(response.content)).convert('RGB')
else: # 这里如果 `image_file` 是路径
return Image.open(image_file).convert("RGB")
def llava_infer(image, text, temperature, top_p, max_tokens):
"""LLaVA 模型推理"""
if image is None or text.strip() == "":
return "请提供图片和文本输入"
# 预处理图像
image_data = load_image(image)
image_tensor = image_processor.preprocess(image_data, return_tensors='pt')['pixel_values']
# **确保数据在正确的设备上**
image_tensor = image_tensor.to(device)
if torch.cuda.is_available():
image_tensor = image_tensor.half() # GPU: 用 float16
else:
image_tensor = image_tensor.float() # CPU: 用 float32
# **处理对话**
conv_mode = "llava_v1"
conv = conv_templates[conv_mode].copy()
# 生成输入文本,添加特殊 token
inp = DEFAULT_IMAGE_TOKEN + '\n' + text
conv.append_message(conv.roles[0], inp)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(device)
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, timeout=20.0)
# **执行推理**
with torch.inference_mode():
thread = Thread(target=model.generate, kwargs=dict(
inputs=input_ids,
images=image_tensor,
do_sample=True,
temperature=temperature,
top_p=top_p,
max_new_tokens=max_tokens,
streamer=streamer,
use_cache=True
))
thread.start()
response = ""
prepend_space = False
for new_text in streamer:
if new_text == " ":
prepend_space = True
continue
if new_text.endswith(stop_str):
new_text = new_text[:-len(stop_str)].strip()
prepend_space = False
elif prepend_space:
new_text = " " + new_text
prepend_space = False
response += new_text
if prepend_space:
response += " "
thread.join()
return response
# **创建 Gradio Web 界面**
with gr.Blocks(title="LLaVA 1.5-13B Web UI") as demo:
gr.Markdown("# 🌋 LLaVA-1.5-13B Web Interface")
gr.Markdown("上传图片并输入文本,LLaVA 将返回回答")
with gr.Row():
with gr.Column(scale=3):
image_input = gr.Image(type="pil", label="上传图片")
text_input = gr.Textbox(placeholder="输入文本...", label="输入文本")
temperature = gr.Slider(0.0, 1.0, value=0.2, step=0.05, label="Temperature")
top_p = gr.Slider(0.0, 1.0, value=1.0, step=0.05, label="Top P")
max_tokens = gr.Slider(10, 1024, value=512, step=10, label="Max Tokens")
submit_button = gr.Button("提交")
with gr.Column(scale=7):
chatbot_output = gr.Textbox(label="LLaVA 输出", interactive=False)
submit_button.click(fn=llava_infer, inputs=[image_input, text_input, temperature, top_p, max_tokens], outputs=chatbot_output)
# **启动 Gradio Web 界面**
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
|