|
import gradio as gr |
|
from google import genai |
|
from google.genai import types |
|
from PIL import Image |
|
from io import BytesIO |
|
import base64 |
|
import os |
|
import json |
|
import random |
|
|
|
|
|
try: |
|
api_key = os.environ['GEMINI_API_KEY'] |
|
except KeyError: |
|
raise ValueError("Please set the GEMINI_API_KEY environment variable.") |
|
client = genai.Client(api_key=api_key) |
|
|
|
def generate_item(tag, item_index): |
|
""" |
|
Generate a single feed item with diverse text and image. |
|
|
|
Args: |
|
tag (str): The tag to base the content on. |
|
item_index (int): Index of the item to ensure diversity. |
|
|
|
Returns: |
|
dict: A dictionary with 'text' (str) and 'image_base64' (str). |
|
""" |
|
|
|
styles = [ |
|
"futuristic neon lighting", |
|
"soft pastel tones with a dreamy vibe", |
|
"vibrant and colorful pop art style", |
|
"minimalist black and white aesthetic", |
|
"retro 80s synthwave look", |
|
"golden hour sunlight with warm tones" |
|
] |
|
perspectives = [ |
|
"a close-up view", |
|
"a wide-angle shot", |
|
"an aerial perspective", |
|
"a side profile", |
|
"a dynamic angled shot" |
|
] |
|
style = random.choice(styles) |
|
perspective = random.choice(perspectives) |
|
|
|
|
|
prompt = f""" |
|
Generate a short, engaging TikTok-style caption about {tag}. |
|
Return the response as a JSON object with a single key 'caption' containing the caption text. |
|
Example: {{"caption": "Craving this yummy treat! 😍 #foodie"}} |
|
Do not include additional commentary or options. |
|
Use creative and varied language to ensure uniqueness. |
|
""" |
|
text_response = client.models.generate_content( |
|
model='gemini-2.5-flash-preview-04-17', |
|
contents=[prompt], |
|
generation_config={"temperature": 1.2} |
|
) |
|
|
|
try: |
|
response_json = json.loads(text_response.text.strip()) |
|
text = response_json['caption'] |
|
except (json.JSONDecodeError, KeyError): |
|
text = f"Obsessed with {tag}! 🔥 #{tag}" |
|
|
|
|
|
image_prompt = f""" |
|
A high-quality visual scene representing {tag}, designed for a TikTok video. |
|
The image should be {perspective} with a {style}. |
|
Ensure the image is colorful, engaging, and has no text or letters. |
|
""" |
|
image_response = client.models.generate_images( |
|
model='imagen-3.0-generate-002', |
|
prompt=image_prompt, |
|
config=types.GenerateImagesConfig( |
|
number_of_images=1, |
|
aspect_ratio="9:16", |
|
person_generation="DONT_ALLOW" |
|
) |
|
) |
|
|
|
|
|
if image_response.generated_images and len(image_response.generated_images) > 0: |
|
generated_image = image_response.generated_images[0] |
|
image = Image.open(BytesIO(generated_image.image.image_bytes)) |
|
else: |
|
|
|
image = Image.new('RGB', (360, 640), color='gray') |
|
|
|
|
|
buffered = BytesIO() |
|
image.save(buffered, format="PNG") |
|
img_str = base64.b64encode(buffered.getvalue()).decode() |
|
|
|
return {'text': text, 'image_base64': img_str} |
|
|
|
def start_feed(tag): |
|
""" |
|
Start a new feed with the given tag by generating one initial item. |
|
|
|
Args: |
|
tag (str): The tag to generate content for. |
|
|
|
Returns: |
|
tuple: (current_tag, feed_items, html_content) |
|
""" |
|
if not tag.strip(): |
|
tag = "trending" |
|
item = generate_item(tag, 0) |
|
feed_items = [item] |
|
html_content = generate_html(feed_items) |
|
return tag, feed_items, html_content |
|
|
|
def load_more(current_tag, feed_items): |
|
""" |
|
Append a new item to the existing feed and scroll to the latest item. |
|
|
|
Args: |
|
current_tag (str): The tag currently being used for the feed. |
|
feed_items (list): The current list of feed items. |
|
|
|
Returns: |
|
tuple: (current_tag, updated_feed_items, updated_html_content) |
|
""" |
|
new_item = generate_item(current_tag, len(feed_items)) |
|
feed_items.append(new_item) |
|
html_content = generate_html(feed_items, scroll_to_latest=True) |
|
return current_tag, feed_items, html_content |
|
|
|
def generate_html(feed_items, scroll_to_latest=False): |
|
""" |
|
Generate an HTML string to display the feed items in a TikTok-like carousel. |
|
|
|
Args: |
|
feed_items (list): List of dictionaries containing 'text' and 'image_base64'. |
|
scroll_to_latest (bool): Whether to auto-scroll to the latest item. |
|
|
|
Returns: |
|
str: HTML string representing the feed. |
|
""" |
|
html_str = """ |
|
<div id="feed-container" style=" |
|
display: flex; |
|
flex-direction: column; |
|
align-items: center; |
|
max-width: 360px; |
|
margin: 0 auto; |
|
background-color: #000; |
|
height: 640px; |
|
overflow-y: scroll; |
|
scroll-snap-type: y mandatory; |
|
scrollbar-width: none; |
|
-ms-overflow-style: none; |
|
border: 1px solid #333; |
|
border-radius: 10px; |
|
"> |
|
""" |
|
|
|
html_str += """ |
|
<style> |
|
#feed-container::-webkit-scrollbar { |
|
display: none; |
|
} |
|
.feed-item { |
|
scroll-snap-align: start; |
|
} |
|
</style> |
|
""" |
|
for idx, item in enumerate(feed_items): |
|
html_str += f""" |
|
<div class="feed-item" id="item-{idx}" style=" |
|
width: 100%; |
|
height: 640px; |
|
position: relative; |
|
display: flex; |
|
flex-direction: column; |
|
justify-content: flex-end; |
|
overflow: hidden; |
|
"> |
|
<img src="data:image/png;base64,{item['image_base64']}" style=" |
|
width: 100%; |
|
height: 100%; |
|
object-fit: cover; |
|
position: absolute; |
|
top: 0; |
|
left: 0; |
|
z-index: 1; |
|
"> |
|
<div style=" |
|
position: relative; |
|
z-index: 2; |
|
background: linear-gradient(to top, rgba(0,0,0,0.7), transparent); |
|
padding: 20px; |
|
color: white; |
|
font-family: Arial, sans-serif; |
|
font-size: 18px; |
|
font-weight: bold; |
|
text-shadow: 1px 1px 2px rgba(0,0,0,0.5); |
|
"> |
|
{item['text']} |
|
</div> |
|
</div> |
|
""" |
|
html_str += "</div>" |
|
|
|
|
|
if scroll_to_latest and feed_items: |
|
html_str += f""" |
|
<script> |
|
document.getElementById('item-{len(feed_items) - 1}').scrollIntoView({{ behavior: 'smooth' }}); |
|
</script> |
|
""" |
|
|
|
return html_str |
|
|
|
|
|
with gr.Blocks( |
|
css=""" |
|
body { background-color: #000; color: #fff; font-family: Arial, sans-serif; } |
|
.gradio-container { max-width: 400px; margin: 0 auto; padding: 10px; } |
|
input, select, button { border-radius: 5px; background-color: #222; color: #fff; border: 1px solid #444; } |
|
button { background-color: #ff2d55; border: none; } |
|
button:hover { background-color: #e0264b; } |
|
.gr-button { width: 100%; margin-top: 10px; } |
|
.gr-form { background-color: #111; padding: 15px; border-radius: 10px; } |
|
""", |
|
title="TikTok-Style Infinite Feed" |
|
) as demo: |
|
|
|
with gr.Column(elem_classes="gr-form"): |
|
gr.Markdown("### Create Your TikTok Feed") |
|
with gr.Row(): |
|
suggested_tags = gr.Dropdown( |
|
choices=["food", "travel", "fashion", "tech"], |
|
label="Pick a Tag", |
|
value="food" |
|
) |
|
tag_input = gr.Textbox( |
|
label="Or Enter a Custom Tag", |
|
value="food", |
|
placeholder="e.g., sushi, adventure" |
|
) |
|
with gr.Row(): |
|
start_button = gr.Button("Start Feed") |
|
load_more_button = gr.Button("Load More") |
|
|
|
|
|
feed_html = gr.HTML() |
|
|
|
|
|
current_tag = gr.State(value="") |
|
feed_items = gr.State(value=[]) |
|
|
|
|
|
def set_tag(selected_tag): |
|
"""Update the tag input when a suggested tag is selected.""" |
|
return selected_tag |
|
|
|
suggested_tags.change(fn=set_tag, inputs=suggested_tags, outputs=tag_input) |
|
start_button.click( |
|
fn=start_feed, |
|
inputs=tag_input, |
|
outputs=[current_tag, feed_items, feed_html] |
|
) |
|
load_more_button.click( |
|
fn=load_more, |
|
inputs=[current_tag, feed_items], |
|
outputs=[current_tag, feed_items, feed_html] |
|
) |
|
|
|
|
|
demo.launch() |