|
import gradio as gr |
|
import google.generativeai as genai |
|
from google.genai import types |
|
from PIL import Image |
|
from io import BytesIO |
|
import base64 |
|
import os |
|
|
|
|
|
api_key = os.getenv("GEMINI_API_KEY") |
|
genai.configure(api_key=api_key) |
|
|
|
|
|
gemini_model = genai.GenerativeModel('gemini-2.5-flash-preview-04-17') |
|
|
|
def generate_item(tag): |
|
""" |
|
Generate a single feed item consisting of text from Gemini LLM and an image from Imagen. |
|
|
|
Args: |
|
tag (str): The tag to base the content on. |
|
|
|
Returns: |
|
dict: A dictionary with 'text' (str) and 'image_base64' (str). |
|
""" |
|
|
|
prompt = f"Generate a short, engaging post about {tag} in the style of a TikTok caption." |
|
text_response = gemini_model.generate_content(prompt) |
|
text = text_response.text.strip() |
|
|
|
|
|
image_response = genai.generate_images( |
|
model='imagen-3.0-generate-002', |
|
prompt=text, |
|
config=types.GenerateImagesConfig( |
|
number_of_images=1, |
|
aspect_ratio="9:16", |
|
person_generation="DONT_ALLOW" |
|
) |
|
) |
|
generated_image = image_response.generated_images[0] |
|
image = Image.open(BytesIO(generated_image.image.image_bytes)) |
|
|
|
|
|
buffered = BytesIO() |
|
image.save(buffered, format="PNG") |
|
img_str = base64.b64encode(buffered.getvalue()).decode() |
|
|
|
return {'text': text, 'image_base64': img_str} |
|
|
|
def start_feed(tag): |
|
""" |
|
Start a new feed with the given tag by generating one initial item. |
|
|
|
Args: |
|
tag (str): The tag to generate content for. |
|
|
|
Returns: |
|
tuple: (current_tag, feed_items, html_content) |
|
""" |
|
item = generate_item(tag) |
|
feed_items = [item] |
|
html_content = generate_html(feed_items) |
|
return tag, feed_items, html_content |
|
|
|
def load_more(current_tag, feed_items): |
|
""" |
|
Append a new item to the existing feed using the current tag. |
|
|
|
Args: |
|
current_tag (str): The tag currently being used for the feed. |
|
feed_items (list): The current list of feed items. |
|
|
|
Returns: |
|
tuple: (current_tag, updated_feed_items, updated_html_content) |
|
""" |
|
new_item = generate_item(current_tag) |
|
feed_items.append(new_item) |
|
html_content = generate_html(feed_items) |
|
return current_tag, feed_items, html_content |
|
|
|
def generate_html(feed_items): |
|
""" |
|
Generate an HTML string to display the feed items. |
|
|
|
Args: |
|
feed_items (list): List of dictionaries containing 'text' and 'image_base64'. |
|
|
|
Returns: |
|
str: HTML string representing the feed. |
|
""" |
|
html_str = '<div style="max-height: 600px; overflow-y: auto; border: 1px solid #ccc; padding: 10px;">' |
|
for item in feed_items: |
|
html_str += f""" |
|
<div style="margin-bottom: 20px; border-bottom: 1px solid #eee; padding-bottom: 20px;"> |
|
<p style="font-size: 16px; margin-bottom: 10px;">{item['text']}</p> |
|
<img src="data:image/png;base64,{item['image_base64']}" style="width: 100%; max-width: 300px; height: auto;"> |
|
</div> |
|
""" |
|
html_str += '</div>' |
|
return html_str |
|
|
|
|
|
with gr.Blocks(title="TikTok-Style Infinite Feed") as demo: |
|
|
|
gr.Markdown("# TikTok-Style Infinite Feed Generator") |
|
gr.Markdown("Enter a tag or select a suggested one to generate a scrollable feed of AI-generated content!") |
|
|
|
|
|
with gr.Row(): |
|
suggested_tags = gr.Dropdown( |
|
choices=["technology", "nature", "art", "food"], |
|
label="Suggested Tags", |
|
value="nature" |
|
) |
|
tag_input = gr.Textbox(label="Enter a Custom Tag", value="nature") |
|
|
|
|
|
with gr.Row(): |
|
start_button = gr.Button("Start Feed") |
|
load_more_button = gr.Button("Load More") |
|
|
|
|
|
feed_html = gr.HTML(label="Your Feed") |
|
|
|
|
|
current_tag = gr.State(value="") |
|
feed_items = gr.State(value=[]) |
|
|
|
|
|
def set_tag(selected_tag): |
|
"""Update the tag input when a suggested tag is selected.""" |
|
return selected_tag |
|
|
|
suggested_tags.change(fn=set_tag, inputs=suggested_tags, outputs=tag_input) |
|
start_button.click( |
|
fn=start_feed, |
|
inputs=tag_input, |
|
outputs=[current_tag, feed_items, feed_html] |
|
) |
|
load_more_button.click( |
|
fn=load_more, |
|
inputs=[current_tag, feed_items], |
|
outputs=[current_tag, feed_items, feed_html] |
|
) |
|
|
|
|
|
demo.launch() |