Kilos1 commited on
Commit
fdb58e3
·
verified ·
1 Parent(s): 573e67a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -19
app.py CHANGED
@@ -1,22 +1,17 @@
 
 
1
  import torch
2
  import gradio as gr
3
  from PIL import Image
4
- from transformers import AutoProcessor, AutoModel
5
 
6
  # Load the model and processor
7
- model_id = "OpenGVLab/InternVL2_5-78B"
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
 
10
  # Initialize the model and processor
11
- model = AutoModel.from_pretrained(
12
- model_id,
13
- torch_dtype=torch.bfloat16,
14
- low_cpu_mem_usage=True,
15
- use_flash_attn=True,
16
- trust_remote_code=True
17
- ).eval().to(device)
18
-
19
- processor = AutoProcessor.from_pretrained(model_id)
20
 
21
  def generate_model_response(image_file, user_query):
22
  """
@@ -34,18 +29,14 @@ def generate_model_response(image_file, user_query):
34
  raw_image = Image.open(image_file).convert("RGB")
35
 
36
  # Prepare inputs for the model using the processor
37
- inputs = processor(
38
- text=user_query,
39
- images=raw_image,
40
- return_tensors="pt"
41
- ).to(device)
42
 
43
  # Generate response from the model
44
- outputs = model.generate(**inputs, max_new_tokens=50)
45
 
46
  # Decode and return the response
47
- response_text = processor.decode(outputs[0], skip_special_tokens=True)
48
- return response_text
49
 
50
  except Exception as e:
51
  print(f"Error in generating response: {e}")
 
1
+ import re
2
+ import io
3
  import torch
4
  import gradio as gr
5
  from PIL import Image
6
+ from transformers import OwlViTProcessor, OwlViTForImageClassification
7
 
8
  # Load the model and processor
9
+ model_id = "google/owlvit-base-patch16"
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
 
12
  # Initialize the model and processor
13
+ model = OwlViTForImageClassification.from_pretrained(model_id).to(device)
14
+ processor = OwlViTProcessor.from_pretrained(model_id)
 
 
 
 
 
 
 
15
 
16
  def generate_model_response(image_file, user_query):
17
  """
 
29
  raw_image = Image.open(image_file).convert("RGB")
30
 
31
  # Prepare inputs for the model using the processor
32
+ inputs = processor(images=raw_image, text=user_query, return_tensors="pt").to(device)
 
 
 
 
33
 
34
  # Generate response from the model
35
+ outputs = model(**inputs)
36
 
37
  # Decode and return the response
38
+ response_text = outputs.logits.argmax(dim=-1) # Example of how to process output
39
+ return f"Detected class ID: {response_text.item()}"
40
 
41
  except Exception as e:
42
  print(f"Error in generating response: {e}")