JohanDL commited on
Commit
d71891a
Β·
1 Parent(s): eaa86c1

Adding initial eval code

Browse files
Files changed (1) hide show
  1. app.py +57 -28
app.py CHANGED
@@ -53,28 +53,46 @@ def fused_sim(a:Image.Image,b:Image.Image,Ξ±=.5):
53
  bnb_cfg = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True)
54
 
55
  # ---------- load models once at startup ---------------------
56
- base = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  @spaces.GPU
58
- def load_models():
 
59
  from unsloth import FastLanguageModel
60
- global base, tok, lora
61
- if base is None:
62
- print("Loading BASE …")
63
- base, tok = FastLanguageModel.from_pretrained(
64
  BASE_MODEL, max_seq_length=2048,
65
- load_in_4bit=True, quantization_config=bnb_cfg, device_map="auto")
66
- tok.pad_token = tok.eos_token
67
-
68
- print("Loading LoRA …")
69
- lora, _ = FastLanguageModel.from_pretrained(
70
  ADAPTER_DIR, max_seq_length=2048,
71
- load_in_4bit=True, quantization_config=bnb_cfg, device_map="auto")
72
- print("βœ” models loaded")
73
 
74
- @spaces.GPU
75
- def ensure_models():
76
- load_models()
77
- return True # small, pickle-able sentinel
78
 
79
 
80
  def build_prompt(desc:str):
@@ -85,26 +103,37 @@ def build_prompt(desc:str):
85
 
86
  @spaces.GPU
87
  @torch.no_grad()
88
- def draw(model, desc:str):
89
  ensure_models()
90
- prompt = build_prompt(desc)
91
- ids = tok(prompt, return_tensors="pt").to(DEVICE)
 
 
 
 
 
92
  out = model.generate(**ids, max_new_tokens=MAX_NEW,
93
  do_sample=True, temperature=.7, top_p=.8)
94
- txt = tok.decode(out[0], skip_special_tokens=True)
95
- svg = extract_svg(txt)
96
  img = svg2pil(svg) if svg else None
97
  return img, svg or "(no SVG found)"
98
 
99
  # ---------- gradio interface --------------------------------
 
100
  def compare(desc):
101
- ensure_models()
102
- img_base, svg_base = draw(base, desc)
103
- img_lora, svg_lora = draw(lora, desc)
104
- # sim = (fused_sim(img_lora, img_base) if img_base and img_lora else float("nan"))
105
-
106
  caption = "Thanks for trying our model 😊\nIf you don't see an image for the base or GRPO model that means it didn't generate a valid SVG!"
107
- return img_base, img_lora, caption, svg_base, svg_lora
 
 
 
 
 
 
 
 
 
108
 
109
  with gr.Blocks(css="body{background:#111;color:#eee}") as demo:
110
  gr.Markdown("## πŸ–ŒοΈ Qwen-2.5 SVG Generator β€” base vs GRPO-LoRA")
 
53
  bnb_cfg = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True)
54
 
55
  # ---------- load models once at startup ---------------------
56
+ _base = None
57
+ # @spaces.GPU
58
+ # def load_models():
59
+ # from unsloth import FastLanguageModel
60
+ # global base, tok, lora
61
+ # if base is None:
62
+ # print("Loading BASE …")
63
+ # base, tok = FastLanguageModel.from_pretrained(
64
+ # BASE_MODEL, max_seq_length=2048,
65
+ # load_in_4bit=True, quantization_config=bnb_cfg, device_map="auto")
66
+ # tok.pad_token = tok.eos_token
67
+
68
+ # print("Loading LoRA …")
69
+ # lora, _ = FastLanguageModel.from_pretrained(
70
+ # ADAPTER_DIR, max_seq_length=2048,
71
+ # load_in_4bit=True, quantization_config=bnb_cfg, device_map="auto")
72
+ # print("βœ” models loaded")
73
+
74
+ _base = _lora = _tok = None
75
+ _CLIP = _PREP = _LP = None
76
+
77
  @spaces.GPU
78
+ def ensure_models():
79
+ """Create base, lora, tok **once per worker**."""
80
  from unsloth import FastLanguageModel
81
+ global _base, _lora, _tok
82
+ if _base is None:
83
+ _base, _tok = FastLanguageModel.from_pretrained(
 
84
  BASE_MODEL, max_seq_length=2048,
85
+ quantization_config=bnb_cfg, device_map="auto")
86
+ _tok.pad_token = _tok.eos_token
87
+ _lora, _ = FastLanguageModel.from_pretrained(
 
 
88
  ADAPTER_DIR, max_seq_length=2048,
89
+ quantization_config=bnb_cfg, device_map="auto")
90
+ return True
91
 
92
+ # @spaces.GPU
93
+ # def ensure_models():
94
+ # load_models()
95
+ # return True # small, pickle-able sentinel
96
 
97
 
98
  def build_prompt(desc:str):
 
103
 
104
  @spaces.GPU
105
  @torch.no_grad()
106
+ def draw(model_flag, desc):
107
  ensure_models()
108
+ model = _base if model_flag == "base" else _lora
109
+ prompt = _tok.apply_chat_template(
110
+ [{"role":"system","content":"You are an SVG illustrator."},
111
+ {"role":"user",
112
+ "content":f"ONLY reply with a valid, complete <svg>…</svg> file that depicts: {desc}"}],
113
+ tokenize=False, add_generation_prompt=True)
114
+ ids = _tok(prompt, return_tensors="pt").to(DEVICE)
115
  out = model.generate(**ids, max_new_tokens=MAX_NEW,
116
  do_sample=True, temperature=.7, top_p=.8)
117
+ svg = extract_svg(_tok.decode(out[0], skip_special_tokens=True))
 
118
  img = svg2pil(svg) if svg else None
119
  return img, svg or "(no SVG found)"
120
 
121
  # ---------- gradio interface --------------------------------
122
+ #
123
  def compare(desc):
124
+ img_b, svg_b = draw("base", desc)
125
+ img_l, svg_l = draw("lora", desc)
 
 
 
126
  caption = "Thanks for trying our model 😊\nIf you don't see an image for the base or GRPO model that means it didn't generate a valid SVG!"
127
+ return img_b, img_l, caption, svg_b, svg_l
128
+
129
+ # def compare(desc):
130
+ # ensure_models()
131
+ # img_base, svg_base = draw(base, desc)
132
+ # img_lora, svg_lora = draw(lora, desc)
133
+ # # sim = (fused_sim(img_lora, img_base) if img_base and img_lora else float("nan"))
134
+
135
+ # caption = "Thanks for trying our model 😊\nIf you don't see an image for the base or GRPO model that means it didn't generate a valid SVG!"
136
+ # return img_base, img_lora, caption, svg_base, svg_lora
137
 
138
  with gr.Blocks(css="body{background:#111;color:#eee}") as demo:
139
  gr.Markdown("## πŸ–ŒοΈ Qwen-2.5 SVG Generator β€” base vs GRPO-LoRA")