akukkapa commited on
Commit
239f8ba
·
verified ·
1 Parent(s): dab977e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -13
app.py CHANGED
@@ -1,35 +1,36 @@
1
  from transformers import Blip2Processor, Blip2ForConditionalGeneration, AutoModelForCausalLM, AutoTokenizer
2
- from accelerate import init_empty_weights, load_checkpoint_and_dispatch
3
  import torch
4
  from PIL import Image, ImageDraw, ImageFont
5
  import gradio as gr
6
  import os
 
 
7
  os.makedirs("./offload", exist_ok=True)
8
- from accelerate import infer_auto_device_map
9
 
 
 
 
10
  torch.backends.cuda.matmul.allow_tf32 = True
11
  torch.backends.cudnn.allow_tf32 = True
12
 
13
-
14
- # For BLIP-2
15
  blip_model = Blip2ForConditionalGeneration.from_pretrained(
16
  "Salesforce/blip2-opt-2.7b",
17
  torch_dtype=torch.float16,
18
- device_map="auto",
19
- offload_folder="./offload",
20
- no_split_module_classes=["Blip2QFormerModel"]
21
- )
22
 
23
- # For Phi-3
24
  phi_model = AutoModelForCausalLM.from_pretrained(
25
  "microsoft/Phi-3-mini-4k-instruct",
26
  trust_remote_code=True,
27
  device_map="auto",
28
  torch_dtype=torch.float16,
29
- offload_folder="./offload",
30
- no_split_module_classes=["PhiDecoderLayer"],
31
- load_in_4bit=True # Add 4-bit quantization
32
- )
33
  phi_tokenizer = AutoTokenizer.from_pretrained(
34
  "microsoft/Phi-3-mini-4k-instruct",
35
  token=HF_TOKEN
 
1
  from transformers import Blip2Processor, Blip2ForConditionalGeneration, AutoModelForCausalLM, AutoTokenizer
 
2
  import torch
3
  from PIL import Image, ImageDraw, ImageFont
4
  import gradio as gr
5
  import os
6
+
7
+ # Initialize environment
8
  os.makedirs("./offload", exist_ok=True)
9
+ HF_TOKEN = os.environ.get("HF_TOKEN")
10
 
11
+ # Memory optimization
12
+ torch.cuda.empty_cache()
13
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
14
  torch.backends.cuda.matmul.allow_tf32 = True
15
  torch.backends.cudnn.allow_tf32 = True
16
 
17
+ # Load BLIP-2
18
+ blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
19
  blip_model = Blip2ForConditionalGeneration.from_pretrained(
20
  "Salesforce/blip2-opt-2.7b",
21
  torch_dtype=torch.float16,
22
+ device_map="auto"
23
+ ).eval()
 
 
24
 
25
+ # Load Phi-3
26
  phi_model = AutoModelForCausalLM.from_pretrained(
27
  "microsoft/Phi-3-mini-4k-instruct",
28
  trust_remote_code=True,
29
  device_map="auto",
30
  torch_dtype=torch.float16,
31
+ load_in_4bit=True,
32
+ token=HF_TOKEN
33
+ ).eval()
 
34
  phi_tokenizer = AutoTokenizer.from_pretrained(
35
  "microsoft/Phi-3-mini-4k-instruct",
36
  token=HF_TOKEN