Spaces:
Running
on
Zero
Running
on
Zero
Adding initial eval code
Browse files
app.py
CHANGED
@@ -53,28 +53,46 @@ def fused_sim(a:Image.Image,b:Image.Image,Ξ±=.5):
|
|
53 |
bnb_cfg = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True)
|
54 |
|
55 |
# ---------- load models once at startup ---------------------
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
@spaces.GPU
|
58 |
-
def
|
|
|
59 |
from unsloth import FastLanguageModel
|
60 |
-
global
|
61 |
-
if
|
62 |
-
|
63 |
-
base, tok = FastLanguageModel.from_pretrained(
|
64 |
BASE_MODEL, max_seq_length=2048,
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
print("Loading LoRA β¦")
|
69 |
-
lora, _ = FastLanguageModel.from_pretrained(
|
70 |
ADAPTER_DIR, max_seq_length=2048,
|
71 |
-
|
72 |
-
|
73 |
|
74 |
-
@spaces.GPU
|
75 |
-
def ensure_models():
|
76 |
-
|
77 |
-
|
78 |
|
79 |
|
80 |
def build_prompt(desc:str):
|
@@ -85,26 +103,37 @@ def build_prompt(desc:str):
|
|
85 |
|
86 |
@spaces.GPU
|
87 |
@torch.no_grad()
|
88 |
-
def draw(
|
89 |
ensure_models()
|
90 |
-
|
91 |
-
|
|
|
|
|
|
|
|
|
|
|
92 |
out = model.generate(**ids, max_new_tokens=MAX_NEW,
|
93 |
do_sample=True, temperature=.7, top_p=.8)
|
94 |
-
|
95 |
-
svg = extract_svg(txt)
|
96 |
img = svg2pil(svg) if svg else None
|
97 |
return img, svg or "(no SVG found)"
|
98 |
|
99 |
# ---------- gradio interface --------------------------------
|
|
|
100 |
def compare(desc):
|
101 |
-
|
102 |
-
|
103 |
-
img_lora, svg_lora = draw(lora, desc)
|
104 |
-
# sim = (fused_sim(img_lora, img_base) if img_base and img_lora else float("nan"))
|
105 |
-
|
106 |
caption = "Thanks for trying our model π\nIf you don't see an image for the base or GRPO model that means it didn't generate a valid SVG!"
|
107 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
|
109 |
with gr.Blocks(css="body{background:#111;color:#eee}") as demo:
|
110 |
gr.Markdown("## ποΈ Qwen-2.5 SVG Generator β base vs GRPO-LoRA")
|
|
|
53 |
bnb_cfg = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True)
|
54 |
|
55 |
# ---------- load models once at startup ---------------------
|
56 |
+
_base = None
|
57 |
+
# @spaces.GPU
|
58 |
+
# def load_models():
|
59 |
+
# from unsloth import FastLanguageModel
|
60 |
+
# global base, tok, lora
|
61 |
+
# if base is None:
|
62 |
+
# print("Loading BASE β¦")
|
63 |
+
# base, tok = FastLanguageModel.from_pretrained(
|
64 |
+
# BASE_MODEL, max_seq_length=2048,
|
65 |
+
# load_in_4bit=True, quantization_config=bnb_cfg, device_map="auto")
|
66 |
+
# tok.pad_token = tok.eos_token
|
67 |
+
|
68 |
+
# print("Loading LoRA β¦")
|
69 |
+
# lora, _ = FastLanguageModel.from_pretrained(
|
70 |
+
# ADAPTER_DIR, max_seq_length=2048,
|
71 |
+
# load_in_4bit=True, quantization_config=bnb_cfg, device_map="auto")
|
72 |
+
# print("β models loaded")
|
73 |
+
|
74 |
+
_base = _lora = _tok = None
|
75 |
+
_CLIP = _PREP = _LP = None
|
76 |
+
|
77 |
@spaces.GPU
|
78 |
+
def ensure_models():
|
79 |
+
"""Create base, lora, tok **once per worker**."""
|
80 |
from unsloth import FastLanguageModel
|
81 |
+
global _base, _lora, _tok
|
82 |
+
if _base is None:
|
83 |
+
_base, _tok = FastLanguageModel.from_pretrained(
|
|
|
84 |
BASE_MODEL, max_seq_length=2048,
|
85 |
+
quantization_config=bnb_cfg, device_map="auto")
|
86 |
+
_tok.pad_token = _tok.eos_token
|
87 |
+
_lora, _ = FastLanguageModel.from_pretrained(
|
|
|
|
|
88 |
ADAPTER_DIR, max_seq_length=2048,
|
89 |
+
quantization_config=bnb_cfg, device_map="auto")
|
90 |
+
return True
|
91 |
|
92 |
+
# @spaces.GPU
|
93 |
+
# def ensure_models():
|
94 |
+
# load_models()
|
95 |
+
# return True # small, pickle-able sentinel
|
96 |
|
97 |
|
98 |
def build_prompt(desc:str):
|
|
|
103 |
|
104 |
@spaces.GPU
|
105 |
@torch.no_grad()
|
106 |
+
def draw(model_flag, desc):
|
107 |
ensure_models()
|
108 |
+
model = _base if model_flag == "base" else _lora
|
109 |
+
prompt = _tok.apply_chat_template(
|
110 |
+
[{"role":"system","content":"You are an SVG illustrator."},
|
111 |
+
{"role":"user",
|
112 |
+
"content":f"ONLY reply with a valid, complete <svg>β¦</svg> file that depicts: {desc}"}],
|
113 |
+
tokenize=False, add_generation_prompt=True)
|
114 |
+
ids = _tok(prompt, return_tensors="pt").to(DEVICE)
|
115 |
out = model.generate(**ids, max_new_tokens=MAX_NEW,
|
116 |
do_sample=True, temperature=.7, top_p=.8)
|
117 |
+
svg = extract_svg(_tok.decode(out[0], skip_special_tokens=True))
|
|
|
118 |
img = svg2pil(svg) if svg else None
|
119 |
return img, svg or "(no SVG found)"
|
120 |
|
121 |
# ---------- gradio interface --------------------------------
|
122 |
+
#
|
123 |
def compare(desc):
|
124 |
+
img_b, svg_b = draw("base", desc)
|
125 |
+
img_l, svg_l = draw("lora", desc)
|
|
|
|
|
|
|
126 |
caption = "Thanks for trying our model π\nIf you don't see an image for the base or GRPO model that means it didn't generate a valid SVG!"
|
127 |
+
return img_b, img_l, caption, svg_b, svg_l
|
128 |
+
|
129 |
+
# def compare(desc):
|
130 |
+
# ensure_models()
|
131 |
+
# img_base, svg_base = draw(base, desc)
|
132 |
+
# img_lora, svg_lora = draw(lora, desc)
|
133 |
+
# # sim = (fused_sim(img_lora, img_base) if img_base and img_lora else float("nan"))
|
134 |
+
|
135 |
+
# caption = "Thanks for trying our model π\nIf you don't see an image for the base or GRPO model that means it didn't generate a valid SVG!"
|
136 |
+
# return img_base, img_lora, caption, svg_base, svg_lora
|
137 |
|
138 |
with gr.Blocks(css="body{background:#111;color:#eee}") as demo:
|
139 |
gr.Markdown("## ποΈ Qwen-2.5 SVG Generator β base vs GRPO-LoRA")
|