File size: 7,443 Bytes
846540c
2274519
 
846540c
902125a
2274519
 
 
 
57ee356
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70a5d06
95b77dc
 
 
 
 
 
2274519
be2a526
 
 
 
 
 
 
95b77dc
be2a526
95b77dc
 
2274519
57ee356
95b77dc
2274519
 
95b77dc
2274519
95b77dc
2274519
95b77dc
 
 
 
 
 
 
 
2274519
95b77dc
 
 
 
2274519
846540c
95b77dc
 
 
 
 
 
 
846540c
95b77dc
2274519
 
95b77dc
 
 
57ee356
95b77dc
57ee356
95b77dc
 
 
57ee356
95b77dc
 
 
57ee356
 
 
 
 
 
 
 
 
 
 
 
95b77dc
 
57ee356
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95b77dc
 
 
 
57ee356
95b77dc
 
57ee356
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
import gradio as gr
import torch
from PIL import Image
import os
from transformers import AutoTokenizer, AutoModel, T5ForConditionalGeneration
from huggingface_hub import hf_hub_download
import torch.nn as nn

class SpriteGenerator(nn.Module):
    def __init__(self, text_encoder_name="t5-base", latent_dim=512):
        super(SpriteGenerator, self).__init__()
        
        # Text encoder (T5 with lm_head)
        self.text_encoder = T5ForConditionalGeneration.from_pretrained(text_encoder_name)
        for param in self.text_encoder.parameters():
            param.requires_grad = False
            
        # Proiezione dal testo al latent space
        self.text_projection = nn.Sequential(
            nn.Linear(768, latent_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(latent_dim, latent_dim)
        )
        
        # Generator
        self.generator = nn.Sequential(
            # Input: latent_dim x 1 x 1 -> 512 x 4 x 4
            nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),

            # 512 x 4 x 4 -> 256 x 8 x 8
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),

            # 256 x 8 x 8 -> 128 x 16 x 16
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),

            # 128 x 16 x 16 -> 64 x 32 x 32
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),

            # 64 x 32 x 32 -> 32 x 64 x 64
            nn.ConvTranspose2d(64, 32, 4, 2, 1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(True),

            # 32 x 64 x 64 -> 16 x 128 x 128
            nn.ConvTranspose2d(32, 16, 4, 2, 1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(True),

            # 16 x 128 x 128 -> 3 x 256 x 256
            nn.ConvTranspose2d(16, 3, 4, 2, 1, bias=False),
        )
        
        # Frame interpolator
        self.frame_interpolator = nn.Sequential(
            nn.Linear(latent_dim + 1, latent_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(latent_dim, latent_dim),
            nn.LeakyReLU(0.2)
        )

    def forward(self, input_ids, attention_mask, num_frames=1):
        batch_size = input_ids.shape[0]
        
        # Encode text usando il T5 completo
        text_outputs = self.text_encoder.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True
        )
        
        # Get text features
        text_features = text_outputs.last_hidden_state.mean(dim=1)
        
        # Project to latent space
        latent_vector = self.text_projection(text_features)
        
        # Generate multiple frames if needed
        all_frames = []
        for frame_idx in range(max(num_frames.max().item(), 1)):
            frame_info = torch.ones((batch_size, 1), device=latent_vector.device) * frame_idx / max(num_frames.max().item(), 1)
            
            # Combine latent vector with frame info
            frame_latent = self.frame_interpolator(
                torch.cat([latent_vector, frame_info], dim=1)
            )
            
            # Generate frame
            frame_latent_reshaped = frame_latent.unsqueeze(2).unsqueeze(3)
            frame = self.generator(frame_latent_reshaped)
            frame = torch.tanh(frame)
            all_frames.append(frame)
        
        # Stack all frames
        sprites = torch.stack(all_frames, dim=1)
        
        return sprites

def initialize_model():
    print("Inizializzazione del modello...")
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    model = SpriteGenerator()
    
    try:
        # Scarica il modello da Hugging Face Hub
        model_path = hf_hub_download(
            repo_id="Lod34/Animator2D-v2",
            filename="pytorch_model.bin",
            repo_type="model"
        )
        
        # Carica il modello
        state_dict = torch.load(model_path, map_location=device)
        model.load_state_dict(state_dict)
        model = model.to(device)
        model.eval()
        print(f"Modello caricato con successo su {device}!")
        return model, device
    except Exception as e:
        print(f"Errore nel caricamento del modello: {str(e)}")
        raise

def generate_sprite(prompt, num_frames=8):
    try:
        # Usa il modello e il device globali
        global model, device, tokenizer
        
        # Tokenizza il testo
        tokens = tokenizer(prompt, return_tensors="pt", padding=True)
        tokens = {k: v.to(device) for k, v in tokens.items()}
        
        # Genera l'immagine
        with torch.no_grad():
            frames = model(
                input_ids=tokens["input_ids"],
                attention_mask=tokens["attention_mask"],
                num_frames=torch.tensor([num_frames], device=device)
            )
        
        # Converte il tensore in immagine
        frames = (frames * 0.5 + 0.5).clamp(0, 1)
        frames = frames.cpu().numpy()
        
        # Ritorna il primo frame come esempio
        frame = frames[0, 0]  # Prende il primo frame del batch
        frame = (frame * 255).astype('uint8').transpose(1, 2, 0)
        
        return Image.fromarray(frame)
    except Exception as e:
        print(f"Errore nella generazione: {str(e)}")
        raise

# Inizializzazione globale
print("Caricamento del modello e configurazione dell'interfaccia...")
try:
    # Inizializzazione del modello e del tokenizer
    model, device = initialize_model()
    tokenizer = AutoTokenizer.from_pretrained("t5-base")
    
    # Configurazione dell'interfaccia Gradio
    interface = gr.Interface(
        fn=generate_sprite,
        inputs=[
            gr.Textbox(
                label="Descrivi lo sprite che vuoi generare",
                placeholder="Esempio: un personaggio pixel art che cammina"
            ),
            gr.Slider(
                minimum=1,
                maximum=16,
                value=8,
                step=1,
                label="Numero di frame",
                info="Più frame = animazione più fluida ma generazione più lenta"
            )
        ],
        outputs=gr.Image(label="Sprite generato"),
        title="🎮 Animator2D-v2 Sprite Generator",
        description="""
        ## Generatore di Sprite Animati
        Questo strumento genera sprite pixel art da descrizioni testuali.
        
        ### Come usare:
        1. Inserisci una descrizione dello sprite che vuoi generare
        2. Regola il numero di frame desiderati
        3. Clicca su Submit e attendi la generazione
        
        ### Tips:
        - Sii specifico nella descrizione
        - Prova diversi numeri di frame per risultati diversi
        - Le descrizioni in inglese potrebbero funzionare meglio
        """,
        article="""
        ### Note:
        - La generazione può richiedere alcuni secondi
        - Vengono mostrati solo i primi frame dell'animazione
        - Per risultati migliori, usa descrizioni dettagliate
        
        Creato da [Lod34](https://huggingface.co/Lod34)
        """
    )
    
    # Avvio dell'interfaccia
    interface.launch()
    
except Exception as e:
    print(f"Errore nell'inizializzazione dell'applicazione: {str(e)}")
    raise