JohanDL commited on
Commit
567ff97
·
1 Parent(s): f8dfd1f

Adding initial eval code

Browse files
Files changed (1) hide show
  1. app.py +19 -13
app.py CHANGED
@@ -50,26 +50,31 @@ def fused_sim(a:Image.Image,b:Image.Image,α=.5):
50
  lp_sim = 1 - _LP(ta,tb,normalize=True).item()
51
  return α*clip_sim + (1-α)*lp_sim
52
 
 
 
53
  # ---------- load models once at startup ---------------------
54
  @spaces.GPU
55
  def load_models():
56
  from unsloth import FastLanguageModel
57
  global base, tok, lora
58
- bnb_cfg = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True)
59
- print("Loading BASE …")
60
- base, tok = FastLanguageModel.from_pretrained(
61
- BASE_MODEL, max_seq_length=2048,
62
- load_in_4bit=True, quantization_config=bnb_cfg, device_map="auto")
63
- tok.pad_token = tok.eos_token
64
-
65
- print("Loading LoRA …")
66
- lora, _ = FastLanguageModel.from_pretrained(
67
- ADAPTER_DIR, max_seq_length=2048,
68
- load_in_4bit=True, quantization_config=bnb_cfg, device_map="auto")
69
- print("✔ models loaded")
70
 
 
 
 
 
71
 
72
- load_models()
73
 
74
  def build_prompt(desc:str):
75
  msgs=[{"role":"system","content":"You are an SVG illustrator."},
@@ -80,6 +85,7 @@ def build_prompt(desc:str):
80
  @spaces.GPU
81
  @torch.no_grad()
82
  def draw(model, desc:str):
 
83
  prompt = build_prompt(desc)
84
  ids = tok(prompt, return_tensors="pt").to(DEVICE)
85
  out = model.generate(**ids, max_new_tokens=MAX_NEW,
 
50
  lp_sim = 1 - _LP(ta,tb,normalize=True).item()
51
  return α*clip_sim + (1-α)*lp_sim
52
 
53
+ bnb_cfg = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True)
54
+
55
  # ---------- load models once at startup ---------------------
56
  @spaces.GPU
57
  def load_models():
58
  from unsloth import FastLanguageModel
59
  global base, tok, lora
60
+ if base is None:
61
+ print("Loading BASE …")
62
+ base, tok = FastLanguageModel.from_pretrained(
63
+ BASE_MODEL, max_seq_length=2048,
64
+ load_in_4bit=True, quantization_config=bnb_cfg, device_map="auto")
65
+ tok.pad_token = tok.eos_token
66
+
67
+ print("Loading LoRA …")
68
+ lora, _ = FastLanguageModel.from_pretrained(
69
+ ADAPTER_DIR, max_seq_length=2048,
70
+ load_in_4bit=True, quantization_config=bnb_cfg, device_map="auto")
71
+ print("✔ models loaded")
72
 
73
+ @spaces.GPU
74
+ def ensure_models():
75
+ load_models()
76
+ return True # small, pickle-able sentinel
77
 
 
78
 
79
  def build_prompt(desc:str):
80
  msgs=[{"role":"system","content":"You are an SVG illustrator."},
 
85
  @spaces.GPU
86
  @torch.no_grad()
87
  def draw(model, desc:str):
88
+ ensure_models()
89
  prompt = build_prompt(desc)
90
  ids = tok(prompt, return_tensors="pt").to(DEVICE)
91
  out = model.generate(**ids, max_new_tokens=MAX_NEW,