Lod34 commited on
Commit
2268f5b
·
verified ·
1 Parent(s): 0db91c9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -31
app.py CHANGED
@@ -4,10 +4,17 @@ from transformers import BertTokenizer, BertModel
4
  import torchvision.transforms as transforms
5
  from PIL import Image
6
  import numpy as np
7
- import os
 
 
 
 
 
 
8
 
9
  # Imposta il dispositivo
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
11
 
12
  # Trasformazioni per le immagini
13
  transform = transforms.Compose([
@@ -15,7 +22,7 @@ transform = transforms.Compose([
15
  transforms.ToTensor(),
16
  ])
17
 
18
- # Definizione del modello Animator2D (uguale al training)
19
  class Animator2DModel(torch.nn.Module):
20
  def __init__(self):
21
  super(Animator2DModel, self).__init__()
@@ -41,40 +48,49 @@ class Animator2DModel(torch.nn.Module):
41
  return generated_frames.view(-1, num_frames, 3, 64, 64)
42
 
43
  # Funzione per generare i frame
44
- def generate_animation(description, base_frame_image, num_frames=3):
45
- # Carica il modello da Hugging Face
46
- model = Animator2DModel().to(device)
47
- model.load_state_dict(torch.hub.load_state_dict_from_url(
48
- "https://huggingface.co/Lod34/Animator2D-v1.0.0/resolve/main/animator2d_v1_0_0.pth",
49
- map_location=device))
50
- model.eval()
 
 
 
 
51
 
52
- # Prepara il testo
53
- tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
54
- inputs = tokenizer(description, return_tensors='pt', padding='max_length',
55
- truncation=True, max_length=512)
56
- input_ids = inputs['input_ids'].to(device)
57
- attention_mask = inputs['attention_mask'].to(device)
58
 
59
- # Prepara l'immagine di base
60
- base_frame = transform(base_frame_image).unsqueeze(0).to(device)
61
 
62
- # Genera i frame
63
- with torch.no_grad():
64
- generated_frames = model(input_ids, attention_mask, base_frame, num_frames)
65
-
66
- # Converte i frame generati in immagini PIL
67
- generated_frames = generated_frames.squeeze(0).cpu().numpy()
68
- output_frames = []
69
- for i in range(num_frames):
70
- frame = generated_frames[i].transpose(1, 2, 0) # Da (C, H, W) a (H, W, C)
71
- frame = np.clip(frame, 0, 1) # Normalizza tra 0 e 1
72
- frame = (frame * 255).astype(np.uint8) # Converte in formato immagine
73
- output_frames.append(Image.fromarray(frame))
74
 
75
- return output_frames
 
 
 
 
76
 
77
  # Interfaccia Gradio
 
78
  with gr.Blocks(title="Animator2D-v1.0.0") as demo:
79
  gr.Markdown("# Animator2D-v1.0.0\nInserisci una descrizione e un'immagine di base per generare un'animazione!")
80
 
@@ -88,4 +104,11 @@ with gr.Blocks(title="Animator2D-v1.0.0") as demo:
88
  with gr.Column():
89
  output_gallery = gr.Gallery(label="Frame generati", show_label=True)
90
 
91
- submit_button.click
 
 
 
 
 
 
 
 
4
  import torchvision.transforms as transforms
5
  from PIL import Image
6
  import numpy as np
7
+ import logging
8
+
9
+ # Configura il logging
10
+ logging.basicConfig(level=logging.INFO)
11
+ logger = logging.getLogger(__name__)
12
+
13
+ logger.info("Inizio inizializzazione dell'app")
14
 
15
  # Imposta il dispositivo
16
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+ logger.info(f"Dispositivo selezionato: {device}")
18
 
19
  # Trasformazioni per le immagini
20
  transform = transforms.Compose([
 
22
  transforms.ToTensor(),
23
  ])
24
 
25
+ # Definizione del modello Animator2D
26
  class Animator2DModel(torch.nn.Module):
27
  def __init__(self):
28
  super(Animator2DModel, self).__init__()
 
48
  return generated_frames.view(-1, num_frames, 3, 64, 64)
49
 
50
  # Funzione per generare i frame
51
+ def generate_animation(description, base_frame_image, num_frames):
52
+ logger.info("Inizio generazione animazione")
53
+ try:
54
+ # Carica il modello da Hugging Face
55
+ model = Animator2DModel().to(device)
56
+ logger.info("Modello inizializzato, caricamento pesi...")
57
+ model.load_state_dict(torch.hub.load_state_dict_from_url(
58
+ "https://huggingface.co/Lod34/Animator2D-v1.0.0/resolve/main/animator2d_v1_0_0.pth",
59
+ map_location=device))
60
+ model.eval()
61
+ logger.info("Modello caricato con successo")
62
 
63
+ # Prepara il testo
64
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
65
+ inputs = tokenizer(description, return_tensors='pt', padding='max_length',
66
+ truncation=True, max_length=512)
67
+ input_ids = inputs['input_ids'].to(device)
68
+ attention_mask = inputs['attention_mask'].to(device)
69
 
70
+ # Prepara l'immagine di base
71
+ base_frame = transform(base_frame_image).unsqueeze(0).to(device)
72
 
73
+ # Genera i frame
74
+ with torch.no_grad():
75
+ generated_frames = model(input_ids, attention_mask, base_frame, num_frames)
76
+
77
+ # Converte i frame generati in immagini PIL
78
+ generated_frames = generated_frames.squeeze(0).cpu().numpy()
79
+ output_frames = []
80
+ for i in range(num_frames):
81
+ frame = generated_frames[i].transpose(1, 2, 0) # Da (C, H, W) a (H, W, C)
82
+ frame = np.clip(frame, 0, 1) # Normalizza tra 0 e 1
83
+ frame = (frame * 255).astype(np.uint8) # Converte in formato immagine
84
+ output_frames.append(Image.fromarray(frame))
85
 
86
+ logger.info("Animazione generata con successo")
87
+ return output_frames
88
+ except Exception as e:
89
+ logger.error(f"Errore durante la generazione: {str(e)}")
90
+ raise
91
 
92
  # Interfaccia Gradio
93
+ logger.info("Inizio configurazione interfaccia Gradio")
94
  with gr.Blocks(title="Animator2D-v1.0.0") as demo:
95
  gr.Markdown("# Animator2D-v1.0.0\nInserisci una descrizione e un'immagine di base per generare un'animazione!")
96
 
 
104
  with gr.Column():
105
  output_gallery = gr.Gallery(label="Frame generati", show_label=True)
106
 
107
+ submit_button.click(
108
+ fn=generate_animation,
109
+ inputs=[description_input, base_frame_input, num_frames_input],
110
+ outputs=output_gallery
111
+ )
112
+
113
+ logger.info("Interfaccia Gradio configurata, avvio...")
114
+ demo.launch(server_name="0.0.0.0", server_port=7860)