from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer import threading import gradio as gr model_name = "programordie2/trumpgpt" model = AutoModelForCausalLM.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) def stream_generate(prompt): if not prompt: return input_ids = tokenizer(prompt, return_tensors="pt").input_ids streamer = TextIteratorStreamer(tokenizer, skip_prompt=True) generation_kwargs = dict(input_ids=input_ids, max_new_tokens=50, streamer=streamer) thread = threading.Thread(target=model.generate, kwargs=generation_kwargs) thread.start() generated_text = prompt for text in streamer: generated_text += text yield generated_text # Preset prompts example_prompts = [ "The Fake News Media", "Sleepy Joe Biden", "MAKE AMERICA", ] # Update input when a prompt is selected def set_prompt(prompt): return gr.update(elem_id="prompt-box", value=prompt) # Interface with custom layout with gr.Blocks(css=""" body { font-family: 'Segoe UI', sans-serif; } .gradio-container { max-width: 700px; margin: auto; padding: 2em; } textarea { font-size: 1rem !important; } #output-box { white-space: pre-wrap; border-radius: 12px; padding: 1em; box-shadow: 0 2px 10px rgba(0,0,0,0.1); } """) as demo: gr.Markdown("## ✨ TrumpGPT Playground") gr.Markdown("TrumpGPT is a LLM based on GPT-2, trained on Donald Trump's tweets.") gr.Markdown("Please note this is a next word predictor, not a chatbot.") with gr.Column(): prompt_box = gr.Textbox(label="Prompt", lines=1, placeholder="Type the start of a sentence", elem_id="prompt-box") gr.Markdown("No inspiration? Try one of these:") for prompt in example_prompts: btn = gr.Button(prompt, elem_id=f"prompt-{prompt}") btn.click(set_prompt, btn, prompt_box, show_progress="hidden") gr.Markdown("---") generate_btn = gr.Button("Generate", variant="primary", elem_id="generate-btn") output_box = gr.Textbox(label="Generated Text", lines=8, interactive=False, elem_id="output-box") generate_btn.click(stream_generate, prompt_box, output_box) demo.launch()