Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -4,10 +4,17 @@ from transformers import BertTokenizer, BertModel
|
|
4 |
import torchvision.transforms as transforms
|
5 |
from PIL import Image
|
6 |
import numpy as np
|
7 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
# Imposta il dispositivo
|
10 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
11 |
|
12 |
# Trasformazioni per le immagini
|
13 |
transform = transforms.Compose([
|
@@ -15,7 +22,7 @@ transform = transforms.Compose([
|
|
15 |
transforms.ToTensor(),
|
16 |
])
|
17 |
|
18 |
-
# Definizione del modello Animator2D
|
19 |
class Animator2DModel(torch.nn.Module):
|
20 |
def __init__(self):
|
21 |
super(Animator2DModel, self).__init__()
|
@@ -41,40 +48,49 @@ class Animator2DModel(torch.nn.Module):
|
|
41 |
return generated_frames.view(-1, num_frames, 3, 64, 64)
|
42 |
|
43 |
# Funzione per generare i frame
|
44 |
-
def generate_animation(description, base_frame_image, num_frames
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
|
|
51 |
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
|
59 |
-
|
60 |
-
|
61 |
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
|
75 |
-
|
|
|
|
|
|
|
|
|
76 |
|
77 |
# Interfaccia Gradio
|
|
|
78 |
with gr.Blocks(title="Animator2D-v1.0.0") as demo:
|
79 |
gr.Markdown("# Animator2D-v1.0.0\nInserisci una descrizione e un'immagine di base per generare un'animazione!")
|
80 |
|
@@ -88,4 +104,11 @@ with gr.Blocks(title="Animator2D-v1.0.0") as demo:
|
|
88 |
with gr.Column():
|
89 |
output_gallery = gr.Gallery(label="Frame generati", show_label=True)
|
90 |
|
91 |
-
submit_button.click
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
import torchvision.transforms as transforms
|
5 |
from PIL import Image
|
6 |
import numpy as np
|
7 |
+
import logging
|
8 |
+
|
9 |
+
# Configura il logging
|
10 |
+
logging.basicConfig(level=logging.INFO)
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
logger.info("Inizio inizializzazione dell'app")
|
14 |
|
15 |
# Imposta il dispositivo
|
16 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
17 |
+
logger.info(f"Dispositivo selezionato: {device}")
|
18 |
|
19 |
# Trasformazioni per le immagini
|
20 |
transform = transforms.Compose([
|
|
|
22 |
transforms.ToTensor(),
|
23 |
])
|
24 |
|
25 |
+
# Definizione del modello Animator2D
|
26 |
class Animator2DModel(torch.nn.Module):
|
27 |
def __init__(self):
|
28 |
super(Animator2DModel, self).__init__()
|
|
|
48 |
return generated_frames.view(-1, num_frames, 3, 64, 64)
|
49 |
|
50 |
# Funzione per generare i frame
|
51 |
+
def generate_animation(description, base_frame_image, num_frames):
|
52 |
+
logger.info("Inizio generazione animazione")
|
53 |
+
try:
|
54 |
+
# Carica il modello da Hugging Face
|
55 |
+
model = Animator2DModel().to(device)
|
56 |
+
logger.info("Modello inizializzato, caricamento pesi...")
|
57 |
+
model.load_state_dict(torch.hub.load_state_dict_from_url(
|
58 |
+
"https://huggingface.co/Lod34/Animator2D-v1.0.0/resolve/main/animator2d_v1_0_0.pth",
|
59 |
+
map_location=device))
|
60 |
+
model.eval()
|
61 |
+
logger.info("Modello caricato con successo")
|
62 |
|
63 |
+
# Prepara il testo
|
64 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
65 |
+
inputs = tokenizer(description, return_tensors='pt', padding='max_length',
|
66 |
+
truncation=True, max_length=512)
|
67 |
+
input_ids = inputs['input_ids'].to(device)
|
68 |
+
attention_mask = inputs['attention_mask'].to(device)
|
69 |
|
70 |
+
# Prepara l'immagine di base
|
71 |
+
base_frame = transform(base_frame_image).unsqueeze(0).to(device)
|
72 |
|
73 |
+
# Genera i frame
|
74 |
+
with torch.no_grad():
|
75 |
+
generated_frames = model(input_ids, attention_mask, base_frame, num_frames)
|
76 |
+
|
77 |
+
# Converte i frame generati in immagini PIL
|
78 |
+
generated_frames = generated_frames.squeeze(0).cpu().numpy()
|
79 |
+
output_frames = []
|
80 |
+
for i in range(num_frames):
|
81 |
+
frame = generated_frames[i].transpose(1, 2, 0) # Da (C, H, W) a (H, W, C)
|
82 |
+
frame = np.clip(frame, 0, 1) # Normalizza tra 0 e 1
|
83 |
+
frame = (frame * 255).astype(np.uint8) # Converte in formato immagine
|
84 |
+
output_frames.append(Image.fromarray(frame))
|
85 |
|
86 |
+
logger.info("Animazione generata con successo")
|
87 |
+
return output_frames
|
88 |
+
except Exception as e:
|
89 |
+
logger.error(f"Errore durante la generazione: {str(e)}")
|
90 |
+
raise
|
91 |
|
92 |
# Interfaccia Gradio
|
93 |
+
logger.info("Inizio configurazione interfaccia Gradio")
|
94 |
with gr.Blocks(title="Animator2D-v1.0.0") as demo:
|
95 |
gr.Markdown("# Animator2D-v1.0.0\nInserisci una descrizione e un'immagine di base per generare un'animazione!")
|
96 |
|
|
|
104 |
with gr.Column():
|
105 |
output_gallery = gr.Gallery(label="Frame generati", show_label=True)
|
106 |
|
107 |
+
submit_button.click(
|
108 |
+
fn=generate_animation,
|
109 |
+
inputs=[description_input, base_frame_input, num_frames_input],
|
110 |
+
outputs=output_gallery
|
111 |
+
)
|
112 |
+
|
113 |
+
logger.info("Interfaccia Gradio configurata, avvio...")
|
114 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|