JohanDL commited on
Commit
9b179e0
·
1 Parent(s): 231c6ed

Unsloth needs a gpu

Browse files
Files changed (1) hide show
  1. app.py +17 -11
app.py CHANGED
@@ -5,7 +5,6 @@ import re, os, torch, cairosvg, lpips, clip, gradio as gr
5
  from io import BytesIO
6
  from pathlib import Path
7
  from PIL import Image
8
- from unsloth import FastLanguageModel
9
  from transformers import BitsAndBytesConfig, AutoTokenizer
10
  import gradio as gr
11
  import spaces
@@ -51,17 +50,24 @@ def fused_sim(a:Image.Image,b:Image.Image,α=.5):
51
  return α*clip_sim + (1-α)*lp_sim
52
 
53
  # ---------- load models once at startup ---------------------
54
- bnb_cfg = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True)
55
- print("Loading BASE …")
56
- base, tok = FastLanguageModel.from_pretrained(
57
- BASE_MODEL, max_seq_length=2048,
58
- load_in_4bit=True, quantization_config=bnb_cfg, device_map="auto")
59
- tok.pad_token = tok.eos_token
 
 
 
 
 
 
 
 
 
 
60
 
61
- print("Loading LoRA …")
62
- lora, _ = FastLanguageModel.from_pretrained(
63
- ADAPTER_DIR, max_seq_length=2048,
64
- load_in_4bit=True, quantization_config=bnb_cfg, device_map="auto")
65
 
66
  def build_prompt(desc:str):
67
  msgs=[{"role":"system","content":"You are an SVG illustrator."},
 
5
  from io import BytesIO
6
  from pathlib import Path
7
  from PIL import Image
 
8
  from transformers import BitsAndBytesConfig, AutoTokenizer
9
  import gradio as gr
10
  import spaces
 
50
  return α*clip_sim + (1-α)*lp_sim
51
 
52
  # ---------- load models once at startup ---------------------
53
+ @spaces.GPU
54
+ def load_models():
55
+ from unsloth import FastLanguageModel
56
+ bnb_cfg = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True)
57
+ print("Loading BASE ")
58
+ base, tok = FastLanguageModel.from_pretrained(
59
+ BASE_MODEL, max_seq_length=2048,
60
+ load_in_4bit=True, quantization_config=bnb_cfg, device_map="auto")
61
+ tok.pad_token = tok.eos_token
62
+
63
+ print("Loading LoRA …")
64
+ lora, _ = FastLanguageModel.from_pretrained(
65
+ ADAPTER_DIR, max_seq_length=2048,
66
+ load_in_4bit=True, quantization_config=bnb_cfg, device_map="auto")
67
+ return base, tok, lora
68
+
69
 
70
+ base, tok, lora = load_models()
 
 
 
71
 
72
  def build_prompt(desc:str):
73
  msgs=[{"role":"system","content":"You are an SVG illustrator."},