Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor | |
from PIL import Image | |
import os | |
import spaces | |
# Initial setup without loading model to device | |
print("Setting up the application...") | |
# We'll load the model in the GPU functions to avoid CPU memory issues | |
model = None | |
tokenizer = AutoTokenizer.from_pretrained("sagar007/Lava_phi") | |
processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
print("Tokenizer and processor loaded successfully!") | |
# For text-only generation with GPU on demand | |
def generate_text(prompt, max_length=128): | |
try: | |
global model | |
# Load model if not already loaded | |
if model is None: | |
print("Loading model on first request...") | |
model = AutoModelForCausalLM.from_pretrained( | |
"sagar007/Lava_phi", | |
torch_dtype=torch.float16, # Use float16 on GPU | |
device_map="auto" # This will put the model on GPU automatically | |
) | |
print("Model loaded successfully!") | |
inputs = tokenizer(f"human: {prompt}\ngpt:", return_tensors="pt").to(model.device) | |
# Generate with GPU | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=max_length, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.9, | |
) | |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Extract only the model's response | |
if "gpt:" in generated_text: | |
generated_text = generated_text.split("gpt:", 1)[1].strip() | |
return generated_text | |
except Exception as e: | |
# Capture and return any errors | |
return f"Error generating text: {str(e)}" | |
# For image and text processing with GPU on demand | |
def process_image_and_prompt(image, prompt, max_length=128): | |
try: | |
if image is None: | |
return "No image provided. Please upload an image." | |
global model | |
# Load model if not already loaded | |
if model is None: | |
print("Loading model on first request...") | |
model = AutoModelForCausalLM.from_pretrained( | |
"sagar007/Lava_phi", | |
torch_dtype=torch.float16, # Use float16 on GPU | |
device_map="auto" # This will put the model on GPU automatically | |
) | |
print("Model loaded successfully!") | |
# Process image | |
image_tensor = processor(images=image, return_tensors="pt").pixel_values.to(model.device) | |
# Tokenize input with image token | |
inputs = tokenizer(f"human: <image>\n{prompt}\ngpt:", return_tensors="pt").to(model.device) | |
# Generate with GPU | |
with torch.no_grad(): | |
outputs = model.generate( | |
input_ids=inputs["input_ids"], | |
attention_mask=inputs["attention_mask"], | |
images=image_tensor, | |
max_new_tokens=max_length, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.9, | |
) | |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Extract only the model's response | |
if "gpt:" in generated_text: | |
generated_text = generated_text.split("gpt:", 1)[1].strip() | |
return generated_text | |
except Exception as e: | |
# Capture and return any errors | |
return f"Error processing image: {str(e)}" | |
# Create Gradio Interface | |
with gr.Blocks(title="LLaVA-Phi: Vision-Language Model") as demo: | |
gr.Markdown("# LLaVA-Phi: Vision-Language Model") | |
gr.Markdown("This model uses ZeroGPU technology - GPU resources are allocated only when generating responses and released afterward.") | |
with gr.Tab("Text Generation"): | |
with gr.Row(): | |
with gr.Column(): | |
text_input = gr.Textbox(label="Enter your prompt", lines=3, placeholder="What is artificial intelligence?") | |
text_max_length = gr.Slider(minimum=16, maximum=512, value=128, step=8, label="Maximum response length") | |
text_button = gr.Button("Generate") | |
with gr.Column(): | |
text_output = gr.Textbox(label="Generated response", lines=8) | |
text_status = gr.Markdown("*Status: Ready*") | |
def text_fn(prompt, max_length): | |
text_status.update("*Status: Generating response...*") | |
try: | |
response = generate_text(prompt, max_length) | |
text_status.update("*Status: Complete*") | |
return response | |
except Exception as e: | |
text_status.update("*Status: Error*") | |
return f"Error: {str(e)}" | |
text_button.click( | |
fn=text_fn, | |
inputs=[text_input, text_max_length], | |
outputs=text_output | |
) | |
with gr.Tab("Image + Text Analysis"): | |
with gr.Row(): | |
with gr.Column(): | |
image_input = gr.Image(type="pil", label="Upload an image") | |
image_text_input = gr.Textbox(label="Enter your prompt about the image", | |
lines=2, | |
placeholder="Describe this image in detail.") | |
image_max_length = gr.Slider(minimum=16, maximum=512, value=128, step=8, label="Maximum response length") | |
image_button = gr.Button("Analyze") | |
with gr.Column(): | |
image_output = gr.Textbox(label="Model response", lines=8) | |
image_status = gr.Markdown("*Status: Ready*") | |
def image_fn(image, prompt, max_length): | |
image_status.update("*Status: Analyzing image...*") | |
try: | |
response = process_image_and_prompt(image, prompt, max_length) | |
image_status.update("*Status: Complete*") | |
return response | |
except Exception as e: | |
image_status.update("*Status: Error*") | |
return f"Error: {str(e)}" | |
image_button.click( | |
fn=image_fn, | |
inputs=[image_input, image_text_input, image_max_length], | |
outputs=image_output | |
) | |
# Example inputs for each tab | |
gr.Examples( | |
examples=["What is the advantage of vision-language models?", | |
"Explain how multimodal AI models work.", | |
"Tell me a short story about robots."], | |
inputs=text_input | |
) | |
# Status indicator | |
with gr.Row(): | |
gr.Markdown("*Note: When you click Generate or Analyze, a GPU will be temporarily allocated to process your request and then released. The first request may take longer as the model needs to be loaded.*") | |
# Launch the app | |
if __name__ == "__main__": | |
demo.launch( | |
enable_queue=True, | |
show_error=True | |
) |