Dhan98 commited on
Commit
685b08e
Β·
verified Β·
1 Parent(s): a79815f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +122 -57
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import streamlit as st
2
  from transformers import BlipProcessor, BlipForConditionalGeneration
3
  from diffusers import DiffusionPipeline
@@ -7,69 +8,133 @@ import numpy as np
7
  from PIL import Image
8
  import tempfile
9
  import os
10
- import base64
 
 
 
 
 
 
11
 
12
  @st.cache_resource
13
  def load_models():
14
- pipeline = DiffusionPipeline.from_pretrained(
15
- "cerspense/zeroscope_v2_576w",
16
- torch_dtype=torch.float16
17
- )
18
- if torch.cuda.is_available():
19
- pipeline.to("cuda")
20
-
21
- blip = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
22
- blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
23
-
24
- return pipeline, blip, blip_processor
 
 
 
 
 
 
 
 
 
25
 
26
- def generate_video(pipeline, description):
27
- video_frames = pipeline(
28
- description,
29
- num_inference_steps=30, # Reduced from 50
30
- num_frames=16 # Reduced from 24
31
- ).frames
32
-
33
- temp_dir = tempfile.mkdtemp()
34
- temp_path = os.path.join(temp_dir, "output.mp4")
35
-
36
- height, width = video_frames[0].shape[:2]
37
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
38
- video_writer = cv2.VideoWriter(temp_path, fourcc, 8, (width, height))
39
-
40
- for frame in video_frames:
41
- video_writer.write(frame)
42
- video_writer.release()
43
-
44
- return temp_path
45
 
46
- def get_binary_file_downloader_html(bin_file, file_label='File'):
47
- with open(bin_file, 'rb') as f:
48
- data = f.read()
49
- bin_str = base64.b64encode(data).decode()
50
- href = f'<a href="data:application/octet-stream;base64,{bin_str}" download="{file_label}">Download {file_label}</a>'
51
- return href
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  def main():
54
- st.title("Video Generator")
55
-
56
- pipeline, blip, blip_processor = load_models()
57
-
58
- image_file = st.file_uploader("Upload Image", type=['png', 'jpg', 'jpeg'])
59
- if image_file:
60
- image = Image.open(image_file)
61
- enhanced_image = enhance_image(image)
62
-
63
- st.image(enhanced_image, caption="Enhanced Image")
64
-
65
- description = get_description(enhanced_image, blip, blip_processor)
66
- st.write("Image Description:", description)
67
-
68
- if st.button("Generate Video"):
69
- with st.spinner("Generating video..."):
70
- video_path = generate_video(pipeline, description)
71
- st.video(video_path)
72
- st.markdown(get_binary_file_downloader_html(video_path, 'video.mp4'), unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  if __name__ == "__main__":
75
- main()
 
1
+ # app.py
2
  import streamlit as st
3
  from transformers import BlipProcessor, BlipForConditionalGeneration
4
  from diffusers import DiffusionPipeline
 
8
  from PIL import Image
9
  import tempfile
10
  import os
11
+
12
+ # Configure page
13
+ st.set_page_config(
14
+ page_title="Video Generator",
15
+ page_icon="πŸŽ₯",
16
+ layout="wide"
17
+ )
18
 
19
  @st.cache_resource
20
  def load_models():
21
+ # Load text-to-video model
22
+ pipeline = DiffusionPipeline.from_pretrained(
23
+ "cerspense/zeroscope_v2_576w",
24
+ torch_dtype=torch.float16
25
+ )
26
+ if torch.cuda.is_available():
27
+ pipeline.to("cuda")
28
+ else:
29
+ pipeline.to("cpu")
30
+
31
+ # Load image captioning model
32
+ blip = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
33
+ blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
34
+
35
+ if torch.cuda.is_available():
36
+ blip.to("cuda")
37
+ else:
38
+ blip.to("cpu")
39
+
40
+ return pipeline, blip, blip_processor
41
 
42
+ def enhance_image(image):
43
+ # Convert PIL Image to numpy array
44
+ img_array = np.array(image)
45
+
46
+ # Basic enhancement: Increase contrast and brightness
47
+ enhanced = cv2.convertScaleAbs(img_array, alpha=1.2, beta=10)
48
+
49
+ return Image.fromarray(enhanced)
 
 
 
 
 
 
 
 
 
 
 
50
 
51
+ def get_description(image, blip, blip_processor):
52
+ # Process image for BLIP
53
+ inputs = blip_processor(image, return_tensors="pt")
54
+
55
+ if torch.cuda.is_available():
56
+ inputs = {k: v.to("cuda") for k, v in inputs.items()}
57
+
58
+ # Generate caption
59
+ with torch.no_grad():
60
+ generated_ids = blip.generate(pixel_values=inputs["pixel_values"], max_length=50)
61
+ description = blip_processor.decode(generated_ids[0], skip_special_tokens=True)
62
+
63
+ return description
64
+
65
+ def generate_video(pipeline, description):
66
+ # Generate video frames
67
+ video_frames = pipeline(
68
+ description,
69
+ num_inference_steps=30,
70
+ num_frames=16
71
+ ).frames
72
+
73
+ # Create temporary directory and file path
74
+ temp_dir = tempfile.mkdtemp()
75
+ temp_path = os.path.join(temp_dir, "output.mp4")
76
+
77
+ # Convert frames to video
78
+ height, width = video_frames[0].shape[:2]
79
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
80
+ video_writer = cv2.VideoWriter(temp_path, fourcc, 8, (width, height))
81
+
82
+ for frame in video_frames:
83
+ video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
84
+
85
+ video_writer.release()
86
+
87
+ return temp_path
88
 
89
  def main():
90
+ st.title("πŸŽ₯ AI Video Generator")
91
+ st.write("Upload an image to generate a video based on its content!")
92
+
93
+ try:
94
+ # Load models
95
+ pipeline, blip, blip_processor = load_models()
96
+
97
+ # File uploader
98
+ image_file = st.file_uploader("Upload Image", type=['png', 'jpg', 'jpeg'])
99
+
100
+ if image_file:
101
+ # Display original and enhanced image
102
+ col1, col2 = st.columns(2)
103
+
104
+ with col1:
105
+ image = Image.open(image_file)
106
+ st.image(image, caption="Original Image")
107
+
108
+ with col2:
109
+ enhanced_image = enhance_image(image)
110
+ st.image(enhanced_image, caption="Enhanced Image")
111
+
112
+ # Get and display description
113
+ description = get_description(enhanced_image, blip, blip_processor)
114
+ st.write("πŸ“ Generated Description:", description)
115
+
116
+ # Allow user to edit description
117
+ modified_description = st.text_area("Edit description if needed:", description)
118
+
119
+ # Generate video button
120
+ if st.button("🎬 Generate Video"):
121
+ with st.spinner("Generating video... This may take a few minutes."):
122
+ video_path = generate_video(pipeline, modified_description)
123
+ st.success("Video generated successfully!")
124
+ st.video(video_path)
125
+
126
+ # Add download button
127
+ with open(video_path, 'rb') as f:
128
+ st.download_button(
129
+ label="Download Video",
130
+ data=f,
131
+ file_name="generated_video.mp4",
132
+ mime="video/mp4"
133
+ )
134
+
135
+ except Exception as e:
136
+ st.error(f"An error occurred: {str(e)}")
137
+ st.error("Please try again or contact support if the error persists.")
138
 
139
  if __name__ == "__main__":
140
+ main()