Lod34 commited on
Commit
95b77dc
·
verified ·
1 Parent(s): 91263a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -87
app.py CHANGED
@@ -17,14 +17,14 @@ class SpriteGenerator(nn.Module):
17
 
18
  # Proiezione dal testo al latent space
19
  self.text_projection = nn.Sequential(
20
- nn.Linear(768, latent_dim), # 768 -> 512
21
  nn.LeakyReLU(0.2),
22
- nn.Linear(latent_dim, latent_dim) # 512 -> 512
23
  )
24
 
25
  # Generator
26
  self.generator = nn.Sequential(
27
- # Blocco iniziale: latent_dim x 1 x 1 -> 512 x 4 x 4
28
  nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False),
29
  nn.BatchNorm2d(512),
30
  nn.ReLU(True),
@@ -95,7 +95,6 @@ class SpriteGenerator(nn.Module):
95
  # Generate frame
96
  frame_latent_reshaped = frame_latent.unsqueeze(2).unsqueeze(3)
97
  frame = self.generator(frame_latent_reshaped)
98
- # Normalizzazione dell'output
99
  frame = torch.tanh(frame)
100
  all_frames.append(frame)
101
 
@@ -104,100 +103,77 @@ class SpriteGenerator(nn.Module):
104
 
105
  return sprites
106
 
107
- # Costanti
108
- MODEL_ID = "Lod34/Animator2D-v2"
109
- CACHE_DIR = "model_cache"
110
-
111
- def load_model():
112
- """
113
- Carica il modello
114
- """
115
  try:
116
- model = SpriteGenerator()
117
- # Carica i pesi del modello
118
- model_path = hf_hub_download(
119
- repo_id=MODEL_ID,
120
- filename="pytorch_model.bin",
121
- cache_dir=CACHE_DIR
122
- )
123
- model.load_state_dict(torch.load(model_path, map_location='cpu'))
124
  model.eval()
125
- return model
 
126
  except Exception as e:
127
  print(f"Errore nel caricamento del modello: {str(e)}")
128
- return None
129
-
130
- # Inizializzazione globale
131
- print("Caricamento del modello...")
132
- model = load_model()
133
- tokenizer = AutoTokenizer.from_pretrained("t5-base")
134
-
135
- def generate_animated_sprite(character_description, num_frames, character_action, viewing_direction):
136
- """
137
- Genera un'animazione sprite utilizzando il modello
138
- """
139
- if model is None:
140
- raise Exception("Il modello non è stato caricato correttamente")
141
-
142
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
143
- model.to(device)
144
-
145
- # Prepara l'input
146
- text_input = f"""
147
- Description: {character_description}
148
- Action: {character_action}
149
- Direction: {viewing_direction}
150
- Number of frames: {num_frames}
151
- """
152
-
153
- # Tokenizzazione
154
- encoded_text = tokenizer(
155
- text_input,
156
- padding="max_length",
157
- max_length=128,
158
- truncation=True,
159
- return_tensors="pt"
160
- )
161
-
162
- input_ids = encoded_text['input_ids'].to(device)
163
- attention_mask = encoded_text['attention_mask'].to(device)
164
- num_frames_tensor = torch.tensor([int(num_frames)], device=device)
165
 
 
166
  try:
167
- # Generazione frames
 
 
 
 
 
 
 
168
  with torch.no_grad():
169
- output_sprites = model(
170
- input_ids=input_ids,
171
- attention_mask=attention_mask,
172
- num_frames=num_frames_tensor
173
  )
174
-
175
- # Conversione in immagini
176
- frames = []
177
- for i in range(int(num_frames)):
178
- frame = output_sprites[0, i].cpu()
179
- frame = ((frame + 1) * 127.5).clamp(0, 255).to(torch.uint8)
180
- frame = frame.permute(1, 2, 0).numpy()
181
- frame_img = Image.fromarray(frame)
182
- frames.append(frame_img)
183
-
184
- # Salvataggio GIF
185
- os.makedirs("tmp", exist_ok=True)
186
- output_path = os.path.join("tmp", f"sprite_{hash(character_description)}.gif")
187
-
188
- frames[0].save(
189
- output_path,
190
- format='GIF',
191
- append_images=frames[1:],
192
- save_all=True,
193
- duration=200,
194
- loop=0
195
- )
196
 
197
- return output_path
 
 
 
 
 
 
 
 
198
  except Exception as e:
199
  print(f"Errore nella generazione: {str(e)}")
200
- raise e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
  # Interfaccia Gradio
203
  def create_interface():
 
17
 
18
  # Proiezione dal testo al latent space
19
  self.text_projection = nn.Sequential(
20
+ nn.Linear(768, latent_dim),
21
  nn.LeakyReLU(0.2),
22
+ nn.Linear(latent_dim, latent_dim)
23
  )
24
 
25
  # Generator
26
  self.generator = nn.Sequential(
27
+ # Input: latent_dim x 1 x 1 -> 512 x 4 x 4
28
  nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False),
29
  nn.BatchNorm2d(512),
30
  nn.ReLU(True),
 
95
  # Generate frame
96
  frame_latent_reshaped = frame_latent.unsqueeze(2).unsqueeze(3)
97
  frame = self.generator(frame_latent_reshaped)
 
98
  frame = torch.tanh(frame)
99
  all_frames.append(frame)
100
 
 
103
 
104
  return sprites
105
 
106
+ def initialize_model():
107
+ print("Inizializzazione del modello...")
108
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
109
+
110
+ model = SpriteGenerator()
111
+
 
 
112
  try:
113
+ # Carica il modello
114
+ state_dict = torch.load("Animator2D-v2.pth", map_location=device)
115
+ model.load_state_dict(state_dict)
116
+ model = model.to(device)
 
 
 
 
117
  model.eval()
118
+ print("Modello caricato con successo!")
119
+ return model, device
120
  except Exception as e:
121
  print(f"Errore nel caricamento del modello: {str(e)}")
122
+ raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
+ def generate_sprite(prompt, num_frames=8):
125
  try:
126
+ # Usa il modello e il device globali
127
+ global model, device, tokenizer
128
+
129
+ # Tokenizza il testo
130
+ tokens = tokenizer(prompt, return_tensors="pt", padding=True)
131
+ tokens = {k: v.to(device) for k, v in tokens.items()}
132
+
133
+ # Genera l'immagine
134
  with torch.no_grad():
135
+ frames = model(
136
+ input_ids=tokens["input_ids"],
137
+ attention_mask=tokens["attention_mask"],
138
+ num_frames=torch.tensor([num_frames], device=device)
139
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
+ # Converte il tensore in immagine
142
+ frames = (frames * 0.5 + 0.5).clamp(0, 1)
143
+ frames = frames.cpu().numpy()
144
+
145
+ # Ritorna il primo frame come esempio
146
+ frame = frames[0, 0] # Prende il primo frame del batch
147
+ frame = (frame * 255).astype('uint8').transpose(1, 2, 0)
148
+
149
+ return Image.fromarray(frame)
150
  except Exception as e:
151
  print(f"Errore nella generazione: {str(e)}")
152
+ raise
153
+
154
+ # Inizializzazione globale
155
+ print("Caricamento del modello...")
156
+ try:
157
+ model, device = initialize_model()
158
+ tokenizer = AutoTokenizer.from_pretrained("t5-base")
159
+
160
+ # Creazione dell'interfaccia Gradio
161
+ interface = gr.Interface(
162
+ fn=generate_sprite,
163
+ inputs=[
164
+ gr.Textbox(label="Descrivi lo sprite che vuoi generare"),
165
+ gr.Slider(minimum=1, maximum=16, value=8, step=1, label="Numero di frame")
166
+ ],
167
+ outputs=gr.Image(label="Sprite generato"),
168
+ title="Animator2D-v2 Sprite Generator",
169
+ description="Genera sprite animati da descrizioni testuali"
170
+ )
171
+
172
+ # Avvio dell'interfaccia
173
+ interface.launch()
174
+ except Exception as e:
175
+ print(f"Errore nell'inizializzazione dell'applicazione: {str(e)}")
176
+ raise e
177
 
178
  # Interfaccia Gradio
179
  def create_interface():