File size: 4,790 Bytes
9b90da9
279374e
9b90da9
 
 
 
 
 
 
e730056
9b90da9
 
 
 
4a8bcb2
 
 
 
9b90da9
 
4a8bcb2
d6f0f11
9b90da9
4a8bcb2
9b90da9
 
4a8bcb2
 
9b90da9
 
 
4a8bcb2
 
 
 
 
 
 
 
 
9b90da9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import gradio as gr
from google import genai
from google.genai import types
from PIL import Image
from io import BytesIO
import base64
import os

# Initialize the Google Generative AI client with the API key from environment variables
client = genai.Client(api_key=os.environ['GEMINI_API_KEY'])

def generate_item(tag):
    # Generate text using Gemini LLM
    prompt = f"Generate a short, engaging post about {tag} in the style of a TikTok caption."
    text_response = client.models.generate_content(
        model='gemini-2.5-flash-preview-04-17',
        contents=[prompt]
    )
    text = text_response.text.strip()

    # Generate an image based on the text or tag
    image_response = client.models.generate_images(
        model='imagen-3.0-generate-002',
        prompt=text,  # Using the generated text as the prompt
        config=types.GenerateImagesConfig(
            number_of_images=1,
            aspect_ratio="9:16",
            person_generation="DONT_ALLOW"
        )
    )

    # Check if images were generated
    if image_response.generated_images and len(image_response.generated_images) > 0:
        generated_image = image_response.generated_images[0]
        image = Image.open(BytesIO(generated_image.image.image_bytes))
    else:
        # Fallback to a placeholder image if no images are generated
        image = Image.new('RGB', (300, 533), color='gray')  # Size matches 9:16 aspect ratio

    # Convert the image to base64
    buffered = BytesIO()
    image.save(buffered, format="PNG")
    img_str = base64.b64encode(buffered.getvalue()).decode()

    return {'text': text, 'image_base64': img_str}

def start_feed(tag):
    """
    Start a new feed with the given tag by generating one initial item.
    
    Args:
        tag (str): The tag to generate content for.
    
    Returns:
        tuple: (current_tag, feed_items, html_content)
    """
    item = generate_item(tag)
    feed_items = [item]
    html_content = generate_html(feed_items)
    return tag, feed_items, html_content

def load_more(current_tag, feed_items):
    """
    Append a new item to the existing feed using the current tag.
    
    Args:
        current_tag (str): The tag currently being used for the feed.
        feed_items (list): The current list of feed items.
    
    Returns:
        tuple: (current_tag, updated_feed_items, updated_html_content)
    """
    new_item = generate_item(current_tag)
    feed_items.append(new_item)
    html_content = generate_html(feed_items)
    return current_tag, feed_items, html_content

def generate_html(feed_items):
    """
    Generate an HTML string to display the feed items.
    
    Args:
        feed_items (list): List of dictionaries containing 'text' and 'image_base64'.
    
    Returns:
        str: HTML string representing the feed.
    """
    html_str = '<div style="max-height: 600px; overflow-y: auto; border: 1px solid #ccc; padding: 10px;">'
    for item in feed_items:
        html_str += f"""
        <div style="margin-bottom: 20px; border-bottom: 1px solid #eee; padding-bottom: 20px;">
            <p style="font-size: 16px; margin-bottom: 10px;">{item['text']}</p>
            <img src="data:image/png;base64,{item['image_base64']}" style="width: 100%; max-width: 300px; height: auto;">
        </div>
        """
    html_str += '</div>'
    return html_str

# Define the Gradio interface
with gr.Blocks(title="TikTok-Style Infinite Feed") as demo:
    # Header
    gr.Markdown("# TikTok-Style Infinite Feed Generator")
    gr.Markdown("Enter a tag or select a suggested one to generate a scrollable feed of AI-generated content!")

    # Input components
    with gr.Row():
        suggested_tags = gr.Dropdown(
            choices=["technology", "nature", "art", "food"],
            label="Suggested Tags",
            value="nature"  # Default value
        )
        tag_input = gr.Textbox(label="Enter a Custom Tag", value="nature")

    # Buttons
    with gr.Row():
        start_button = gr.Button("Start Feed")
        load_more_button = gr.Button("Load More")

    # Output display
    feed_html = gr.HTML(label="Your Feed")

    # State variables to maintain feed and tag
    current_tag = gr.State(value="")
    feed_items = gr.State(value=[])

    # Event handlers
    def set_tag(selected_tag):
        """Update the tag input when a suggested tag is selected."""
        return selected_tag

    suggested_tags.change(fn=set_tag, inputs=suggested_tags, outputs=tag_input)
    start_button.click(
        fn=start_feed,
        inputs=tag_input,
        outputs=[current_tag, feed_items, feed_html]
    )
    load_more_button.click(
        fn=load_more,
        inputs=[current_tag, feed_items],
        outputs=[current_tag, feed_items, feed_html]
    )

# Launch the app
demo.launch()