Spaces:
Runtime error
Runtime error
import os | |
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor | |
from PIL import Image | |
# Model ID | |
MODEL_ID = "0llheaven/Llama-3.2-11B-Vision-Radiology-mini" | |
# Load tokenizer and processor | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
processor = AutoProcessor.from_pretrained(MODEL_ID) | |
# Load the model with reduced precision and memory optimizations | |
print("Loading model with memory optimizations...") | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_ID, | |
torch_dtype=torch.float16, # Use half precision | |
device_map="auto", # Let the library decide how to map the model | |
low_cpu_mem_usage=True, # Optimize CPU memory usage | |
offload_folder="offload", # Offload weights to disk if needed | |
offload_state_dict=True, # Enable state dict offloading | |
trust_remote_code=True, | |
) | |
print("Model loaded!") | |
# Clear CUDA cache after loading | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
def generate_response(image_file, prompt, max_new_tokens=256, temperature=0.7, top_p=0.9): | |
try: | |
# Process image if provided | |
if image_file is not None: | |
image = Image.open(image_file).convert('RGB') | |
# Process inputs | |
inputs = processor( | |
text=prompt, | |
images=image, | |
return_tensors="pt" | |
) | |
# Move inputs to the same device as model | |
inputs = {k: v.to(model.device) for k, v in inputs.items() if isinstance(v, torch.Tensor)} | |
# For safer generation, extract only what's needed | |
input_ids = inputs.pop("input_ids", None) | |
attention_mask = inputs.pop("attention_mask", None) | |
# Generate response with conservative memory settings | |
with torch.no_grad(): | |
# Clear cache before generation | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
outputs = model.generate( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
max_new_tokens=max_new_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
do_sample=True | |
) | |
# Decode and return the response | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
else: | |
# Text-only input | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
# Generate response with conservative memory settings | |
with torch.no_grad(): | |
# Clear cache before generation | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=max_new_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
do_sample=True | |
) | |
# Decode and return the response | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Remove the input prompt from the response if present | |
if response.startswith(prompt): | |
response = response[len(prompt):].strip() | |
return response | |
except Exception as e: | |
return f"Error: {str(e)}" | |
# Define the Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Llama-3.2-11B Vision Radiology Model") | |
gr.Markdown("Upload a radiology image (X-ray, CT, MRI, etc.) and ask questions about it.") | |
with gr.Row(): | |
with gr.Column(): | |
image_input = gr.Image(type="filepath", label="Upload Radiology Image") | |
prompt_input = gr.Textbox(label="Question or Prompt", placeholder="Describe what you see in this image and identify any abnormalities.") | |
with gr.Row(): | |
max_tokens = gr.Slider(minimum=16, maximum=512, value=256, step=8, label="Max New Tokens") | |
temperature = gr.Slider(minimum=0.1, maximum=1.5, value=0.7, step=0.1, label="Temperature") | |
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p") | |
submit_btn = gr.Button("Generate Response") | |
with gr.Column(): | |
output = gr.Textbox(label="Model Response", lines=15) | |
submit_btn.click( | |
generate_response, | |
inputs=[image_input, prompt_input, max_tokens, temperature, top_p], | |
outputs=[output] | |
) | |
gr.Examples( | |
[ | |
["sample_xray.jpg", "What abnormalities do you see in this X-ray?"], | |
["sample_ct.jpg", "Describe this image and any findings."], | |
], | |
inputs=[image_input, prompt_input], | |
) | |
# Reduce maximum allowed concurrent users to conserve memory | |
demo.launch(max_threads=1) |