Dhan98 commited on
Commit
1d25574
·
verified ·
1 Parent(s): 762ce93

new changes

Browse files
Files changed (1) hide show
  1. app.py +39 -21
app.py CHANGED
@@ -1,21 +1,26 @@
1
  import streamlit as st
2
- from transformers import AutoProcessor, AutoModelForCausalLM, BlipProcessor, BlipForConditionalGeneration
 
3
  import torch
4
  import cv2
5
  import numpy as np
6
  from PIL import Image
 
 
7
 
8
  @st.cache_resource
9
  def load_models():
10
- # Text-to-video model
11
- video_model = AutoModelForCausalLM.from_pretrained("damo-vilab/text-to-video-ms-1.7b", trust_remote_code=True)
12
- video_processor = AutoProcessor.from_pretrained("damo-vilab/text-to-video-ms-1.7b")
 
 
 
13
 
14
- # Image captioning
15
  blip = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
16
  blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
17
 
18
- return video_model, video_processor, blip, blip_processor
19
 
20
  def enhance_image(image):
21
  img = np.array(image)
@@ -32,36 +37,49 @@ def get_description(image, blip_model, blip_processor):
32
  output = blip_model.generate(**inputs, max_length=50)
33
  return blip_processor.decode(output[0], skip_special_tokens=True)
34
 
35
- def generate_video(model, processor, description):
36
- inputs = processor(text=description, return_tensors="pt")
37
- with torch.no_grad():
38
- video_frames = model.generate(
39
- **inputs,
40
- num_frames=16,
41
- num_inference_steps=50,
42
- guidance_scale=7.5
43
- )
44
- return video_frames
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  def main():
47
  st.title("Video Generator")
48
 
49
- models = load_models()
50
- video_model, video_processor, blip, blip_processor = models
51
 
52
  image_file = st.file_uploader("Upload Image", type=['png', 'jpg', 'jpeg'])
53
  if image_file:
54
  image = Image.open(image_file)
55
  enhanced_image = enhance_image(image)
 
56
  st.image(enhanced_image, caption="Enhanced Image")
57
 
58
  description = get_description(enhanced_image, blip, blip_processor)
59
  st.write("Image Description:", description)
60
 
61
  if st.button("Generate Video"):
62
- with st.spinner("Generating..."):
63
- video = generate_video(video_model, video_processor, description)
64
- st.video(video)
65
 
66
  if __name__ == "__main__":
67
  main()
 
1
  import streamlit as st
2
+ from transformers import BlipProcessor, BlipForConditionalGeneration
3
+ from diffusers import DiffusionPipeline
4
  import torch
5
  import cv2
6
  import numpy as np
7
  from PIL import Image
8
+ import tempfile
9
+ import os
10
 
11
  @st.cache_resource
12
  def load_models():
13
+ pipeline = DiffusionPipeline.from_pretrained(
14
+ "cerspense/zeroscope_v2_576w",
15
+ torch_dtype=torch.float16
16
+ )
17
+ if torch.cuda.is_available():
18
+ pipeline.to("cuda")
19
 
 
20
  blip = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
21
  blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
22
 
23
+ return pipeline, blip, blip_processor
24
 
25
  def enhance_image(image):
26
  img = np.array(image)
 
37
  output = blip_model.generate(**inputs, max_length=50)
38
  return blip_processor.decode(output[0], skip_special_tokens=True)
39
 
40
+ def save_video_frames(frames, fps=8):
41
+ temp_dir = tempfile.mkdtemp()
42
+ temp_path = os.path.join(temp_dir, "output.mp4")
43
+
44
+ height, width = frames[0].shape[:2]
45
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
46
+ video_writer = cv2.VideoWriter(temp_path, fourcc, fps, (width, height))
47
+
48
+ for frame in frames:
49
+ video_writer.write(frame)
50
+ video_writer.release()
51
+
52
+ return temp_path
53
+
54
+ def generate_video(pipeline, description):
55
+ video_frames = pipeline(
56
+ description,
57
+ num_inference_steps=50,
58
+ num_frames=24
59
+ ).frames
60
+
61
+ video_path = save_video_frames(video_frames)
62
+ return video_path
63
 
64
  def main():
65
  st.title("Video Generator")
66
 
67
+ pipeline, blip, blip_processor = load_models()
 
68
 
69
  image_file = st.file_uploader("Upload Image", type=['png', 'jpg', 'jpeg'])
70
  if image_file:
71
  image = Image.open(image_file)
72
  enhanced_image = enhance_image(image)
73
+
74
  st.image(enhanced_image, caption="Enhanced Image")
75
 
76
  description = get_description(enhanced_image, blip, blip_processor)
77
  st.write("Image Description:", description)
78
 
79
  if st.button("Generate Video"):
80
+ with st.spinner("Generating video..."):
81
+ video_path = generate_video(pipeline, description)
82
+ st.video(video_path)
83
 
84
  if __name__ == "__main__":
85
  main()