|
import gradio as gr |
|
from google import genai |
|
from google.genai import types |
|
from PIL import Image |
|
from io import BytesIO |
|
import base64 |
|
import os |
|
import json |
|
import random |
|
|
|
|
|
try: |
|
api_key = os.environ['GEMINI_API_KEY'] |
|
except KeyError: |
|
raise ValueError("Please set the GEMINI_API_KEY environment variable.") |
|
client = genai.Client(api_key=api_key) |
|
|
|
def generate_ideas(tag): |
|
""" |
|
Generate a diverse set of ideas related to the tag using the LLM. |
|
|
|
Args: |
|
tag (str): The tag to base the ideas on. |
|
|
|
Returns: |
|
list: A list of ideas as strings. |
|
""" |
|
prompt = f""" |
|
Generate a list of 5 diverse and creative ideas related to {tag} that can be used for a TikTok video. |
|
Each idea should be a short sentence describing a specific scene or concept. |
|
Return the response as a JSON object with a single key 'ideas' containing a list of 5 ideas. |
|
Example: {{"ideas": ["A neon-lit gaming setup with RGB lights flashing", "A futuristic robot assembling a gadget"]}} |
|
""" |
|
response = client.models.generate_content( |
|
model='gemini-2.5-flash-preview-04-17', |
|
contents=[prompt], |
|
config=types.GenerateContentConfig(temperature=1.2) |
|
) |
|
try: |
|
response_json = json.loads(response.text.strip()) |
|
ideas = response_json['ideas'] |
|
return ideas |
|
except (json.JSONDecodeError, KeyError): |
|
return [ |
|
f"A vibrant {tag} scene at sunset", |
|
f"A close-up of {tag} with neon lights", |
|
f"A futuristic take on {tag} with holograms", |
|
f"A cozy {tag} moment with warm lighting", |
|
f"An action-packed {tag} scene with dynamic colors" |
|
] |
|
|
|
def generate_item(tag, ideas): |
|
""" |
|
Generate a single feed item using one of the ideas. |
|
|
|
Args: |
|
tag (str): The tag to base the content on. |
|
ideas (list): List of ideas to choose from. |
|
|
|
Returns: |
|
dict: A dictionary with 'text' (str) and 'image_base64' (str). |
|
""" |
|
selected_idea = random.choice(ideas) |
|
prompt = f""" |
|
Based on the idea "{selected_idea}", create content for a TikTok video about {tag}. |
|
Return a JSON object with two keys: |
|
- 'caption': A short, viral TikTok-style caption with hashtags. |
|
- 'image_prompt': A detailed image prompt for generating a high-quality visual scene. |
|
The image prompt should describe the scene vividly, specify a perspective and style, and ensure no text or letters are included. |
|
Example: {{"caption": "Neon vibes only! 🌌 #tech", "image_prompt": "A close-up view of a neon-lit gaming setup with RGB lights flashing, in a futuristic style, no text or letters"}} |
|
""" |
|
response = client.models.generate_content( |
|
model='gemini-2.5-flash-preview-04-17', |
|
contents=[prompt], |
|
config=types.GenerateContentConfig(temperature=1.2) |
|
) |
|
try: |
|
response_json = json.loads(response.text.strip()) |
|
text = response_json['caption'] |
|
image_prompt = response_json['image_prompt'] |
|
except (json.JSONDecodeError, KeyError): |
|
text = f"Obsessed with {tag}! 🔥 #{tag}" |
|
image_prompt = f"A vivid scene of {selected_idea}, in a vibrant pop art style, no text or letters" |
|
|
|
image_response = client.models.generate_images( |
|
model='imagen-3.0-generate-002', |
|
prompt=image_prompt, |
|
config=types.GenerateImagesConfig( |
|
number_of_images=1, |
|
aspect_ratio="9:16", |
|
person_generation="DONT_ALLOW" |
|
) |
|
) |
|
|
|
if image_response.generated_images and len(image_response.generated_images) > 0: |
|
generated_image = image_response.generated_images[0] |
|
image = Image.open(BytesIO(generated_image.image.image_bytes)) |
|
else: |
|
image = Image.new('RGB', (360, 640), color='gray') |
|
|
|
buffered = BytesIO() |
|
image.save(buffered, format="PNG") |
|
img_str = base64.b64encode(buffered.getvalue()).decode() |
|
|
|
return {'text': text, 'image_base64': img_str, 'ideas': ideas} |
|
|
|
def start_feed(tag, current_index, feed_items, is_loading): |
|
""" |
|
Start or update the feed based on the tag. |
|
|
|
Args: |
|
tag (str): The tag to generate content for. |
|
current_index (int): The current item index. |
|
feed_items (list): The current list of feed items. |
|
is_loading (bool): Whether the feed is currently loading. |
|
|
|
Returns: |
|
tuple: (current_tag, current_index, feed_items, html_content, is_loading) |
|
""" |
|
if not tag.strip(): |
|
tag = "trending" |
|
|
|
is_loading = True |
|
yield tag, current_index, feed_items, generate_html([], False, 0, tag, is_loading), is_loading |
|
|
|
ideas = generate_ideas(tag) |
|
item = generate_item(tag, ideas) |
|
feed_items = [item] |
|
current_index = 0 |
|
|
|
is_loading = False |
|
return tag, current_index, feed_items, generate_html(feed_items, False, current_index, tag, is_loading), is_loading |
|
|
|
def load_next(tag, current_index, feed_items, is_loading): |
|
""" |
|
Load the next item in the feed. |
|
|
|
Args: |
|
tag (str): The tag to generate content for. |
|
current_index (int): The current item index. |
|
feed_items (list): The current list of feed items. |
|
is_loading (bool): Whether the feed is currently loading. |
|
|
|
Returns: |
|
tuple: (current_tag, current_index, feed_items, html_content, is_loading) |
|
""" |
|
is_loading = True |
|
yield tag, current_index, feed_items, generate_html(feed_items, False, current_index, tag, is_loading), is_loading |
|
|
|
if current_index + 1 < len(feed_items): |
|
current_index += 1 |
|
else: |
|
ideas = feed_items[-1]['ideas'] if feed_items else generate_ideas(tag) |
|
new_item = generate_item(tag, ideas) |
|
feed_items.append(new_item) |
|
current_index = len(feed_items) - 1 |
|
|
|
is_loading = False |
|
return tag, current_index, feed_items, generate_html(feed_items, False, current_index, tag, is_loading), is_loading |
|
|
|
def load_previous(tag, current_index, feed_items, is_loading): |
|
""" |
|
Load the previous item in the feed. |
|
|
|
Args: |
|
tag (str): The tag to generate content for. |
|
current_index (int): The current item index. |
|
feed_items (list): The current list of feed items. |
|
is_loading (bool): Whether the feed is currently loading. |
|
|
|
Returns: |
|
tuple: (current_tag, current_index, feed_items, html_content, is_loading) |
|
""" |
|
if current_index > 0: |
|
current_index -= 1 |
|
return tag, current_index, feed_items, generate_html(feed_items, False, current_index, tag, is_loading), is_loading |
|
|
|
def generate_html(feed_items, scroll_to_latest=False, current_index=0, tag="", is_loading=False): |
|
""" |
|
Generate an HTML string to display the current feed item with click navigation. |
|
|
|
Args: |
|
feed_items (list): List of dictionaries containing 'text' and 'image_base64'. |
|
scroll_to_latest (bool): Whether to auto-scroll to the latest item (not used here). |
|
current_index (int): The index of the item to display. |
|
tag (str): The current tag for loading messages. |
|
is_loading (bool): Whether the feed is currently loading. |
|
|
|
Returns: |
|
str: HTML string representing the feed. |
|
""" |
|
|
|
loading_messages = [ |
|
f"Cooking up a {tag} masterpiece... 🍳", |
|
f"Snapping a vibrant {tag} moment... 📸", |
|
f"Creating a {tag} vibe that pops... ✨", |
|
f"Getting that perfect {tag} shot... 🎥", |
|
f"Bringing {tag} to life... 🌟" |
|
] |
|
|
|
if is_loading: |
|
return f""" |
|
<div id="feed-container" style=" |
|
display: flex; |
|
flex-direction: column; |
|
align-items: center; |
|
justify-content: center; |
|
max-width: 360px; |
|
margin: 0 auto; |
|
background-color: #000; |
|
height: 640px; |
|
border: 1px solid #333; |
|
border-radius: 10px; |
|
color: white; |
|
font-family: Arial, sans-serif; |
|
position: relative; |
|
"> |
|
<div id="loading-message" style=" |
|
font-size: 18px; |
|
font-weight: bold; |
|
text-align: center; |
|
margin-bottom: 20px; |
|
text-shadow: 1px 1px 2px rgba(0,0,0,0.5); |
|
"> |
|
{loading_messages[0]} |
|
</div> |
|
<div style=" |
|
width: 80%; |
|
height: 10px; |
|
background-color: #333; |
|
border-radius: 5px; |
|
overflow: hidden; |
|
"> |
|
<div style=" |
|
width: 0%; |
|
height: 100%; |
|
background: linear-gradient(to right, #ff2d55, #ff5e78); |
|
animation: loading 2s infinite; |
|
"></div> |
|
</div> |
|
<style> |
|
@keyframes loading {{ |
|
0% {{ width: 0%; }} |
|
50% {{ width: 100%; }} |
|
100% {{ width: 0%; }} |
|
}} |
|
</style> |
|
<script> |
|
const messages = {json.dumps(loading_messages)}; |
|
let currentMessageIndex = 0; |
|
const messageElement = document.getElementById('loading-message'); |
|
function rotateMessages() {{ |
|
currentMessageIndex = (currentMessageIndex + 1) % messages.length; |
|
messageElement.textContent = messages[currentMessageIndex]; |
|
}} |
|
setInterval(rotateMessages, 2000); |
|
</script> |
|
</div> |
|
""" |
|
|
|
if not feed_items or current_index >= len(feed_items): |
|
return """ |
|
<div style=" |
|
display: flex; |
|
flex-direction: column; |
|
align-items: center; |
|
justify-content: center; |
|
max-width: 360px; |
|
margin: 0 auto; |
|
background-color: #000; |
|
height: 640px; |
|
border: 1px solid #333; |
|
border-radius: 10px; |
|
color: white; |
|
font-family: Arial, sans-serif; |
|
"> |
|
<p>Select a tag to start your feed!</p> |
|
</div> |
|
""" |
|
|
|
item = feed_items[current_index] |
|
html_str = """ |
|
<div id="feed-container" style=" |
|
display: flex; |
|
flex-direction: column; |
|
align-items: center; |
|
max-width: 360px; |
|
margin: 0 auto; |
|
background-color: #000; |
|
height: 640px; |
|
border: 1px solid #333; |
|
border-radius: 10px; |
|
position: relative; |
|
"> |
|
<div class="feed-item" style=" |
|
width: 100%; |
|
height: 100%; |
|
position: relative; |
|
display: flex; |
|
flex-direction: column; |
|
justify-content: flex-end; |
|
overflow: hidden; |
|
cursor: pointer; |
|
" onclick="handleClick(event)"> |
|
<img id="feed-image" src="data:image/png;base64,{image_base64}" style=" |
|
width: 100%; |
|
height: 100%; |
|
object-fit: cover; |
|
position: absolute; |
|
top: 0; |
|
left: 0; |
|
z-index: 1; |
|
"> |
|
<div style=" |
|
position: relative; |
|
z-index: 2; |
|
background: linear-gradient(to top, rgba(0,0,0,0.7), transparent); |
|
padding: 20px; |
|
color: white; |
|
font-family: Arial, sans-serif; |
|
font-size: 18px; |
|
font-weight: bold; |
|
text-shadow: 1px 1px 2px rgba(0,0,0,0.5); |
|
"> |
|
{text} |
|
</div> |
|
</div> |
|
</div> |
|
<script> |
|
function handleClick(event) {{ |
|
const image = document.getElementById('feed-image'); |
|
const rect = image.getBoundingClientRect(); |
|
const clickX = event.clientX - rect.left; |
|
const width = rect.width; |
|
if (clickX > width * 0.75) {{ |
|
document.getElementById('previous-button').click(); |
|
}} else {{ |
|
document.getElementById('next-button').click(); |
|
}} |
|
}} |
|
</script> |
|
<button id="next-button" style="display: none;" onclick="document.getElementById('next-button').click()"></button> |
|
<button id="previous-button" style="display: none;" onclick="document.getElementById('previous-button').click()"></button> |
|
""".format(image_base64=item['image_base64'], text=item['text']) |
|
|
|
return html_str |
|
|
|
|
|
with gr.Blocks( |
|
css=""" |
|
body { background-color: #000; color: #fff; font-family: Arial, sans-serif; } |
|
.gradio-container { max-width: 400px; margin: 0 auto; padding: 10px; } |
|
input, select { border-radius: 5px; background-color: #222; color: #fff; border: 1px solid #444; } |
|
.gr-form { background-color: #111; padding: 15px; border-radius: 10px; } |
|
""", |
|
title="TikTok-Style Infinite Feed" |
|
) as demo: |
|
|
|
current_tag = gr.State(value="") |
|
current_index = gr.State(value=0) |
|
feed_items = gr.State(value=[]) |
|
is_loading = gr.State(value=False) |
|
|
|
|
|
with gr.Column(elem_classes="gr-form"): |
|
gr.Markdown("### Create Your TikTok Feed") |
|
with gr.Row(): |
|
suggested_tags = gr.Dropdown( |
|
choices=["food", "travel", "fashion", "tech"], |
|
label="Pick a Tag", |
|
value="food" |
|
) |
|
tag_input = gr.Textbox( |
|
label="Or Enter a Custom Tag", |
|
value="food", |
|
placeholder="e.g., sushi, adventure", |
|
submit_btn=False |
|
) |
|
|
|
|
|
feed_html = gr.HTML() |
|
|
|
|
|
def set_tag(selected_tag): |
|
"""Update the tag input when a suggested tag is selected and start the feed.""" |
|
return selected_tag |
|
|
|
|
|
suggested_tags.change( |
|
fn=set_tag, |
|
inputs=suggested_tags, |
|
outputs=tag_input |
|
).then( |
|
fn=start_feed, |
|
inputs=[tag_input, current_index, feed_items, is_loading], |
|
outputs=[current_tag, current_index, feed_items, feed_html, is_loading] |
|
) |
|
|
|
|
|
tag_input.submit( |
|
fn=start_feed, |
|
inputs=[tag_input, current_index, feed_items, is_loading], |
|
outputs=[current_tag, current_index, feed_items, feed_html, is_loading] |
|
) |
|
|
|
|
|
next_button = gr.Button("Next", elem_id="next-button", visible=False) |
|
previous_button = gr.Button("Previous", elem_id="previous-button", visible=False) |
|
|
|
|
|
next_button.click( |
|
fn=load_next, |
|
inputs=[current_tag, current_index, feed_items, is_loading], |
|
outputs=[current_tag, current_index, feed_items, feed_html, is_loading] |
|
) |
|
|
|
|
|
previous_button.click( |
|
fn=load_previous, |
|
inputs=[current_tag, current_index, feed_items, is_loading], |
|
outputs=[current_tag, current_index, feed_items, feed_html, is_loading] |
|
) |
|
|
|
|
|
demo.launch(share=True) |