Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import LlavaNextVideoForConditionalGeneration, LlavaNextVideoProcessor | |
# BitsAndBytesConfig, | |
import torch | |
import av | |
import numpy as np | |
from huggingface_hub import hf_hub_download | |
from PIL import Image | |
import tempfile | |
# Configuration du modèle | |
# quantization_config = BitsAndBytesConfig( | |
# load_in_4bit=True, | |
# bnb_4bit_compute_dtype=torch.float16, | |
# llm_int8_enable_fp32_cpu_offload=True # Enable CPU offloading for unsupported layers | |
# ) | |
# Configuration du modèle | |
# quantization_config = BitsAndBytesConfig( | |
# load_in_4bit=True, | |
# bnb_4bit_compute_dtype=torch.float16 | |
# ) | |
processor = LlavaNextVideoProcessor.from_pretrained("llava-hf/LLaVA-NeXT-Video-7B-hf") | |
model = LlavaNextVideoForConditionalGeneration.from_pretrained( | |
"llava-hf/LLaVA-NeXT-Video-7B-hf", | |
# quantization_config=quantization_config, | |
device_map='auto' | |
) | |
def read_video_pyav(container, indices): | |
frames = [] | |
container.seek(0) | |
start_index = indices[0] | |
end_index = indices[-1] | |
for i, frame in enumerate(container.decode(video=0)): | |
if i > end_index: | |
break | |
if i >= start_index and i in indices: | |
frames.append(frame) | |
return np.stack([x.to_ndarray(format="rgb24") for x in frames]) | |
def process_input(message, file): | |
# Vérifier le type de fichier | |
if file is None: | |
return "Veuillez uploader une image ou une vidéo" | |
if file.name.endswith(('.mp4', '.avi', '.mov')): # Traitement vidéo | |
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_video: | |
temp_video.write(open(file.name, "rb").read()) | |
temp_video_path = temp_video.name | |
container = av.open(temp_video_path) | |
total_frames = container.streams.video[0].frames | |
indices = np.arange(0, total_frames, total_frames / 8).astype(int) | |
video_clip = read_video_pyav(container, indices) | |
conversation = [{ | |
"role": "user", | |
"content": [ | |
{"type": "text", "text": message}, | |
{"type": "video"}, | |
], | |
}] | |
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) | |
inputs = processor([prompt], videos=[video_clip], padding=True, return_tensors="pt").to(model.device) | |
elif file.name.endswith(('.jpg', '.jpeg', '.png')): # Traitement image | |
image = Image.open(file.name) | |
conversation = [{ | |
"role": "user", | |
"content": [ | |
{"type": "text", "text": message}, | |
{"type": "image"}, | |
], | |
}] | |
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) | |
inputs = processor(text=[prompt], images=[image], padding=True, return_tensors="pt").to(model.device) | |
else: | |
return "Format de fichier non supporté. Veuillez uploader une image ou une vidéo." | |
# Génération de la réponse | |
generate_kwargs = {"max_new_tokens": 1024, "do_sample": True, "top_p": 0.9} | |
output = model.generate(**inputs, **generate_kwargs) | |
generated_text = processor.batch_decode(output, skip_special_tokens=True) | |
return generated_text[0] | |
# Interface Gradio | |
with gr.Blocks(title="Chatbot Multimodal LLaVA") as demo: | |
gr.Markdown("# Chatbot Multimodal LLaVA") | |
gr.Markdown("Parlez avec un modèle IA capable de comprendre à la fois les images et les vidéos") | |
with gr.Row(): | |
with gr.Column(): | |
input_file = gr.File(label="Uploader une image ou une vidéo") | |
input_text = gr.Textbox(label="Votre message", placeholder="Posez votre question ici...") | |
submit_btn = gr.Button("Envoyer") | |
with gr.Column(): | |
output_text = gr.Textbox(label="Réponse de l'IA", interactive=False) | |
submit_btn.click( | |
fn=process_input, | |
inputs=[input_text, input_file], | |
outputs=output_text | |
) | |
examples = [ | |
["Décris cette image en détail.", "/content/Psoriasis (1).jpg"], | |
["Que se passe-t-il dans cette vidéo?", "/content/karate.mp4"], | |
] | |
gr.Examples( | |
examples=examples, | |
inputs=[input_text, input_file], | |
outputs=output_text, | |
fn=process_input, | |
cache_examples=False | |
) | |
# Démarrer l'interface | |
demo.launch() |