Spaces:
Running
Running
File size: 7,443 Bytes
846540c 2274519 846540c 902125a 2274519 57ee356 70a5d06 95b77dc 2274519 be2a526 95b77dc be2a526 95b77dc 2274519 57ee356 95b77dc 2274519 95b77dc 2274519 95b77dc 2274519 95b77dc 2274519 95b77dc 2274519 846540c 95b77dc 846540c 95b77dc 2274519 95b77dc 57ee356 95b77dc 57ee356 95b77dc 57ee356 95b77dc 57ee356 95b77dc 57ee356 95b77dc 57ee356 95b77dc 57ee356 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 |
import gradio as gr
import torch
from PIL import Image
import os
from transformers import AutoTokenizer, AutoModel, T5ForConditionalGeneration
from huggingface_hub import hf_hub_download
import torch.nn as nn
class SpriteGenerator(nn.Module):
def __init__(self, text_encoder_name="t5-base", latent_dim=512):
super(SpriteGenerator, self).__init__()
# Text encoder (T5 with lm_head)
self.text_encoder = T5ForConditionalGeneration.from_pretrained(text_encoder_name)
for param in self.text_encoder.parameters():
param.requires_grad = False
# Proiezione dal testo al latent space
self.text_projection = nn.Sequential(
nn.Linear(768, latent_dim),
nn.LeakyReLU(0.2),
nn.Linear(latent_dim, latent_dim)
)
# Generator
self.generator = nn.Sequential(
# Input: latent_dim x 1 x 1 -> 512 x 4 x 4
nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(True),
# 512 x 4 x 4 -> 256 x 8 x 8
nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(True),
# 256 x 8 x 8 -> 128 x 16 x 16
nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(True),
# 128 x 16 x 16 -> 64 x 32 x 32
nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(True),
# 64 x 32 x 32 -> 32 x 64 x 64
nn.ConvTranspose2d(64, 32, 4, 2, 1, bias=False),
nn.BatchNorm2d(32),
nn.ReLU(True),
# 32 x 64 x 64 -> 16 x 128 x 128
nn.ConvTranspose2d(32, 16, 4, 2, 1, bias=False),
nn.BatchNorm2d(16),
nn.ReLU(True),
# 16 x 128 x 128 -> 3 x 256 x 256
nn.ConvTranspose2d(16, 3, 4, 2, 1, bias=False),
)
# Frame interpolator
self.frame_interpolator = nn.Sequential(
nn.Linear(latent_dim + 1, latent_dim),
nn.LeakyReLU(0.2),
nn.Linear(latent_dim, latent_dim),
nn.LeakyReLU(0.2)
)
def forward(self, input_ids, attention_mask, num_frames=1):
batch_size = input_ids.shape[0]
# Encode text usando il T5 completo
text_outputs = self.text_encoder.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True
)
# Get text features
text_features = text_outputs.last_hidden_state.mean(dim=1)
# Project to latent space
latent_vector = self.text_projection(text_features)
# Generate multiple frames if needed
all_frames = []
for frame_idx in range(max(num_frames.max().item(), 1)):
frame_info = torch.ones((batch_size, 1), device=latent_vector.device) * frame_idx / max(num_frames.max().item(), 1)
# Combine latent vector with frame info
frame_latent = self.frame_interpolator(
torch.cat([latent_vector, frame_info], dim=1)
)
# Generate frame
frame_latent_reshaped = frame_latent.unsqueeze(2).unsqueeze(3)
frame = self.generator(frame_latent_reshaped)
frame = torch.tanh(frame)
all_frames.append(frame)
# Stack all frames
sprites = torch.stack(all_frames, dim=1)
return sprites
def initialize_model():
print("Inizializzazione del modello...")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SpriteGenerator()
try:
# Scarica il modello da Hugging Face Hub
model_path = hf_hub_download(
repo_id="Lod34/Animator2D-v2",
filename="pytorch_model.bin",
repo_type="model"
)
# Carica il modello
state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict)
model = model.to(device)
model.eval()
print(f"Modello caricato con successo su {device}!")
return model, device
except Exception as e:
print(f"Errore nel caricamento del modello: {str(e)}")
raise
def generate_sprite(prompt, num_frames=8):
try:
# Usa il modello e il device globali
global model, device, tokenizer
# Tokenizza il testo
tokens = tokenizer(prompt, return_tensors="pt", padding=True)
tokens = {k: v.to(device) for k, v in tokens.items()}
# Genera l'immagine
with torch.no_grad():
frames = model(
input_ids=tokens["input_ids"],
attention_mask=tokens["attention_mask"],
num_frames=torch.tensor([num_frames], device=device)
)
# Converte il tensore in immagine
frames = (frames * 0.5 + 0.5).clamp(0, 1)
frames = frames.cpu().numpy()
# Ritorna il primo frame come esempio
frame = frames[0, 0] # Prende il primo frame del batch
frame = (frame * 255).astype('uint8').transpose(1, 2, 0)
return Image.fromarray(frame)
except Exception as e:
print(f"Errore nella generazione: {str(e)}")
raise
# Inizializzazione globale
print("Caricamento del modello e configurazione dell'interfaccia...")
try:
# Inizializzazione del modello e del tokenizer
model, device = initialize_model()
tokenizer = AutoTokenizer.from_pretrained("t5-base")
# Configurazione dell'interfaccia Gradio
interface = gr.Interface(
fn=generate_sprite,
inputs=[
gr.Textbox(
label="Descrivi lo sprite che vuoi generare",
placeholder="Esempio: un personaggio pixel art che cammina"
),
gr.Slider(
minimum=1,
maximum=16,
value=8,
step=1,
label="Numero di frame",
info="Più frame = animazione più fluida ma generazione più lenta"
)
],
outputs=gr.Image(label="Sprite generato"),
title="🎮 Animator2D-v2 Sprite Generator",
description="""
## Generatore di Sprite Animati
Questo strumento genera sprite pixel art da descrizioni testuali.
### Come usare:
1. Inserisci una descrizione dello sprite che vuoi generare
2. Regola il numero di frame desiderati
3. Clicca su Submit e attendi la generazione
### Tips:
- Sii specifico nella descrizione
- Prova diversi numeri di frame per risultati diversi
- Le descrizioni in inglese potrebbero funzionare meglio
""",
article="""
### Note:
- La generazione può richiedere alcuni secondi
- Vengono mostrati solo i primi frame dell'animazione
- Per risultati migliori, usa descrizioni dettagliate
Creato da [Lod34](https://huggingface.co/Lod34)
"""
)
# Avvio dell'interfaccia
interface.launch()
except Exception as e:
print(f"Errore nell'inizializzazione dell'applicazione: {str(e)}")
raise |