Lod34 commited on
Commit
76ce627
·
verified ·
1 Parent(s): 104351f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -0
app.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import BertTokenizer, BertModel
4
+ import torchvision.transforms as transforms
5
+ from PIL import Image
6
+ import numpy as np
7
+ import os
8
+
9
+ # Imposta il dispositivo
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+
12
+ # Trasformazioni per le immagini
13
+ transform = transforms.Compose([
14
+ transforms.Resize((64, 64)),
15
+ transforms.ToTensor(),
16
+ ])
17
+
18
+ # Definizione del modello Animator2D (uguale al training)
19
+ class Animator2DModel(torch.nn.Module):
20
+ def __init__(self):
21
+ super(Animator2DModel, self).__init__()
22
+ self.text_encoder = BertModel.from_pretrained('bert-base-uncased')
23
+ self.image_encoder = torch.nn.Sequential(
24
+ torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
25
+ torch.nn.ReLU(),
26
+ torch.nn.MaxPool2d(2),
27
+ torch.nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
28
+ torch.nn.ReLU(),
29
+ torch.nn.MaxPool2d(2)
30
+ )
31
+ self.decoder = torch.nn.LSTM(input_size=768 + 128, hidden_size=256, num_layers=2, batch_first=True)
32
+ self.frame_generator = torch.nn.Linear(256, 64 * 64 * 3)
33
+
34
+ def forward(self, input_ids, attention_mask, base_frame, num_frames):
35
+ text_features = self.text_encoder(input_ids, attention_mask=attention_mask).pooler_output
36
+ image_features = self.image_encoder(base_frame).flatten(start_dim=1)
37
+ combined_features = torch.cat((text_features, image_features), dim=1)
38
+ combined_features = combined_features.unsqueeze(1).repeat(1, num_frames, 1)
39
+ output, _ = self.decoder(combined_features)
40
+ generated_frames = self.frame_generator(output)
41
+ return generated_frames.view(-1, num_frames, 3, 64, 64)
42
+
43
+ # Funzione per generare i frame
44
+ def generate_animation(description, base_frame_image, num_frames=3):
45
+ # Carica il modello da Hugging Face
46
+ model = Animator2DModel().to(device)
47
+ model.load_state_dict(torch.hub.load_state_dict_from_url(
48
+ "https://huggingface.co/Lod34/Animator2D-v1.0.0/resolve/main/animator2d_v1_0_0.pth",
49
+ map_location=device))
50
+ model.eval()
51
+
52
+ # Prepara il testo
53
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
54
+ inputs = tokenizer(description, return_tensors='pt', padding='max_length',
55
+ truncation=True, max_length=512)
56
+ input_ids = inputs['input_ids'].to(device)
57
+ attention_mask = inputs['attention_mask'].to(device)
58
+
59
+ # Prepara l'immagine di base
60
+ base_frame = transform(base_frame_image).unsqueeze(0).to(device)
61
+
62
+ # Genera i frame
63
+ with torch.no_grad():
64
+ generated_frames = model(input_ids, attention_mask, base_frame, num_frames)
65
+
66
+ # Converte i frame generati in immagini PIL
67
+ generated_frames = generated_frames.squeeze(0).cpu().numpy()
68
+ output_frames = []
69
+ for i in range(num_frames):
70
+ frame = generated_frames[i].transpose(1, 2, 0) # Da (C, H, W) a (H, W, C)
71
+ frame = np.clip(frame, 0, 1) # Normalizza tra 0 e 1
72
+ frame = (frame * 255).astype(np.uint8) # Converte in formato immagine
73
+ output_frames.append(Image.fromarray(frame))
74
+
75
+ return output_frames
76
+
77
+ # Interfaccia Gradio
78
+ with gr.Blocks(title="Animator2D-v1.0.0") as demo:
79
+ gr.Markdown("# Animator2D-v1.0.0\nInserisci una descrizione e un'immagine di base per generare un'animazione!")
80
+
81
+ with gr.Row():
82
+ with gr.Column():
83
+ description_input = gr.Textbox(label="Descrizione dell'animazione", placeholder="Es: 'A character jumping'")
84
+ base_frame_input = gr.Image(label="Immagine di base", type="pil")
85
+ num_frames_input = gr.Slider(1, 5, value=3, step=1, label="Numero di frame")
86
+ submit_button = gr.Button("Genera Animazione")
87
+
88
+ with gr.Column():
89
+ output_gallery = gr.Gallery(label="Frame generati", show_label=True)
90
+
91
+ submit_button.click