Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -17,14 +17,14 @@ class SpriteGenerator(nn.Module):
|
|
17 |
|
18 |
# Proiezione dal testo al latent space
|
19 |
self.text_projection = nn.Sequential(
|
20 |
-
nn.Linear(768, latent_dim),
|
21 |
nn.LeakyReLU(0.2),
|
22 |
-
nn.Linear(latent_dim, latent_dim)
|
23 |
)
|
24 |
|
25 |
# Generator
|
26 |
self.generator = nn.Sequential(
|
27 |
-
#
|
28 |
nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False),
|
29 |
nn.BatchNorm2d(512),
|
30 |
nn.ReLU(True),
|
@@ -95,7 +95,6 @@ class SpriteGenerator(nn.Module):
|
|
95 |
# Generate frame
|
96 |
frame_latent_reshaped = frame_latent.unsqueeze(2).unsqueeze(3)
|
97 |
frame = self.generator(frame_latent_reshaped)
|
98 |
-
# Normalizzazione dell'output
|
99 |
frame = torch.tanh(frame)
|
100 |
all_frames.append(frame)
|
101 |
|
@@ -104,100 +103,77 @@ class SpriteGenerator(nn.Module):
|
|
104 |
|
105 |
return sprites
|
106 |
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
Carica il modello
|
114 |
-
"""
|
115 |
try:
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
filename="pytorch_model.bin",
|
121 |
-
cache_dir=CACHE_DIR
|
122 |
-
)
|
123 |
-
model.load_state_dict(torch.load(model_path, map_location='cpu'))
|
124 |
model.eval()
|
125 |
-
|
|
|
126 |
except Exception as e:
|
127 |
print(f"Errore nel caricamento del modello: {str(e)}")
|
128 |
-
|
129 |
-
|
130 |
-
# Inizializzazione globale
|
131 |
-
print("Caricamento del modello...")
|
132 |
-
model = load_model()
|
133 |
-
tokenizer = AutoTokenizer.from_pretrained("t5-base")
|
134 |
-
|
135 |
-
def generate_animated_sprite(character_description, num_frames, character_action, viewing_direction):
|
136 |
-
"""
|
137 |
-
Genera un'animazione sprite utilizzando il modello
|
138 |
-
"""
|
139 |
-
if model is None:
|
140 |
-
raise Exception("Il modello non è stato caricato correttamente")
|
141 |
-
|
142 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
143 |
-
model.to(device)
|
144 |
-
|
145 |
-
# Prepara l'input
|
146 |
-
text_input = f"""
|
147 |
-
Description: {character_description}
|
148 |
-
Action: {character_action}
|
149 |
-
Direction: {viewing_direction}
|
150 |
-
Number of frames: {num_frames}
|
151 |
-
"""
|
152 |
-
|
153 |
-
# Tokenizzazione
|
154 |
-
encoded_text = tokenizer(
|
155 |
-
text_input,
|
156 |
-
padding="max_length",
|
157 |
-
max_length=128,
|
158 |
-
truncation=True,
|
159 |
-
return_tensors="pt"
|
160 |
-
)
|
161 |
-
|
162 |
-
input_ids = encoded_text['input_ids'].to(device)
|
163 |
-
attention_mask = encoded_text['attention_mask'].to(device)
|
164 |
-
num_frames_tensor = torch.tensor([int(num_frames)], device=device)
|
165 |
|
|
|
166 |
try:
|
167 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
with torch.no_grad():
|
169 |
-
|
170 |
-
input_ids=input_ids,
|
171 |
-
attention_mask=attention_mask,
|
172 |
-
num_frames=
|
173 |
)
|
174 |
-
|
175 |
-
# Conversione in immagini
|
176 |
-
frames = []
|
177 |
-
for i in range(int(num_frames)):
|
178 |
-
frame = output_sprites[0, i].cpu()
|
179 |
-
frame = ((frame + 1) * 127.5).clamp(0, 255).to(torch.uint8)
|
180 |
-
frame = frame.permute(1, 2, 0).numpy()
|
181 |
-
frame_img = Image.fromarray(frame)
|
182 |
-
frames.append(frame_img)
|
183 |
-
|
184 |
-
# Salvataggio GIF
|
185 |
-
os.makedirs("tmp", exist_ok=True)
|
186 |
-
output_path = os.path.join("tmp", f"sprite_{hash(character_description)}.gif")
|
187 |
-
|
188 |
-
frames[0].save(
|
189 |
-
output_path,
|
190 |
-
format='GIF',
|
191 |
-
append_images=frames[1:],
|
192 |
-
save_all=True,
|
193 |
-
duration=200,
|
194 |
-
loop=0
|
195 |
-
)
|
196 |
|
197 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
except Exception as e:
|
199 |
print(f"Errore nella generazione: {str(e)}")
|
200 |
-
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
|
202 |
# Interfaccia Gradio
|
203 |
def create_interface():
|
|
|
17 |
|
18 |
# Proiezione dal testo al latent space
|
19 |
self.text_projection = nn.Sequential(
|
20 |
+
nn.Linear(768, latent_dim),
|
21 |
nn.LeakyReLU(0.2),
|
22 |
+
nn.Linear(latent_dim, latent_dim)
|
23 |
)
|
24 |
|
25 |
# Generator
|
26 |
self.generator = nn.Sequential(
|
27 |
+
# Input: latent_dim x 1 x 1 -> 512 x 4 x 4
|
28 |
nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False),
|
29 |
nn.BatchNorm2d(512),
|
30 |
nn.ReLU(True),
|
|
|
95 |
# Generate frame
|
96 |
frame_latent_reshaped = frame_latent.unsqueeze(2).unsqueeze(3)
|
97 |
frame = self.generator(frame_latent_reshaped)
|
|
|
98 |
frame = torch.tanh(frame)
|
99 |
all_frames.append(frame)
|
100 |
|
|
|
103 |
|
104 |
return sprites
|
105 |
|
106 |
+
def initialize_model():
|
107 |
+
print("Inizializzazione del modello...")
|
108 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
109 |
+
|
110 |
+
model = SpriteGenerator()
|
111 |
+
|
|
|
|
|
112 |
try:
|
113 |
+
# Carica il modello
|
114 |
+
state_dict = torch.load("Animator2D-v2.pth", map_location=device)
|
115 |
+
model.load_state_dict(state_dict)
|
116 |
+
model = model.to(device)
|
|
|
|
|
|
|
|
|
117 |
model.eval()
|
118 |
+
print("Modello caricato con successo!")
|
119 |
+
return model, device
|
120 |
except Exception as e:
|
121 |
print(f"Errore nel caricamento del modello: {str(e)}")
|
122 |
+
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
|
124 |
+
def generate_sprite(prompt, num_frames=8):
|
125 |
try:
|
126 |
+
# Usa il modello e il device globali
|
127 |
+
global model, device, tokenizer
|
128 |
+
|
129 |
+
# Tokenizza il testo
|
130 |
+
tokens = tokenizer(prompt, return_tensors="pt", padding=True)
|
131 |
+
tokens = {k: v.to(device) for k, v in tokens.items()}
|
132 |
+
|
133 |
+
# Genera l'immagine
|
134 |
with torch.no_grad():
|
135 |
+
frames = model(
|
136 |
+
input_ids=tokens["input_ids"],
|
137 |
+
attention_mask=tokens["attention_mask"],
|
138 |
+
num_frames=torch.tensor([num_frames], device=device)
|
139 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
|
141 |
+
# Converte il tensore in immagine
|
142 |
+
frames = (frames * 0.5 + 0.5).clamp(0, 1)
|
143 |
+
frames = frames.cpu().numpy()
|
144 |
+
|
145 |
+
# Ritorna il primo frame come esempio
|
146 |
+
frame = frames[0, 0] # Prende il primo frame del batch
|
147 |
+
frame = (frame * 255).astype('uint8').transpose(1, 2, 0)
|
148 |
+
|
149 |
+
return Image.fromarray(frame)
|
150 |
except Exception as e:
|
151 |
print(f"Errore nella generazione: {str(e)}")
|
152 |
+
raise
|
153 |
+
|
154 |
+
# Inizializzazione globale
|
155 |
+
print("Caricamento del modello...")
|
156 |
+
try:
|
157 |
+
model, device = initialize_model()
|
158 |
+
tokenizer = AutoTokenizer.from_pretrained("t5-base")
|
159 |
+
|
160 |
+
# Creazione dell'interfaccia Gradio
|
161 |
+
interface = gr.Interface(
|
162 |
+
fn=generate_sprite,
|
163 |
+
inputs=[
|
164 |
+
gr.Textbox(label="Descrivi lo sprite che vuoi generare"),
|
165 |
+
gr.Slider(minimum=1, maximum=16, value=8, step=1, label="Numero di frame")
|
166 |
+
],
|
167 |
+
outputs=gr.Image(label="Sprite generato"),
|
168 |
+
title="Animator2D-v2 Sprite Generator",
|
169 |
+
description="Genera sprite animati da descrizioni testuali"
|
170 |
+
)
|
171 |
+
|
172 |
+
# Avvio dell'interfaccia
|
173 |
+
interface.launch()
|
174 |
+
except Exception as e:
|
175 |
+
print(f"Errore nell'inizializzazione dell'applicazione: {str(e)}")
|
176 |
+
raise e
|
177 |
|
178 |
# Interfaccia Gradio
|
179 |
def create_interface():
|