ab / app2.py
theapps's picture
Upload 36 files
ca165c7 verified
from fastapi import FastAPI, Request, Form
from fastapi.templating import Jinja2Templates
import gpt_2_simple as gpt2
from datetime import datetime
import csv
app = FastAPI()
templates = Jinja2Templates(directory="templates")
# Download the GPT-2 model if not already downloaded
gpt2.download_gpt2(model_name="124M")
# Load the GPT-2 model
sess = gpt2.start_tf_sess()
gpt2.load_gpt2(sess, model_name="124M")
async def generate_conversation(prompt):
try:
conversation = gpt2.generate(sess, prefix=prompt, length=300, temperature=0.7, return_as_list=True)[0]
return conversation
except Exception as e:
return f"Error: {str(e)}"
def save_to_csv(prompt, conversation):
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
filename = f"info.csv"
with open(filename, mode='w', newline='', encoding='utf-8') as csv_file:
csv_writer = csv.writer(csv_file)
csv_writer.writerow(['Prompt', 'Generated Conversation'])
csv_writer.writerow([prompt, conversation])
return filename
@app.get("/")
def read_form(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.post("/")
async def generate_and_display(request: Request, prompt: str = Form(...)):
conversation = await generate_conversation(prompt)
csv_filename = save_to_csv(prompt, conversation)
return templates.TemplateResponse("index.html", {"request": request, "prompt": prompt, "conversation": conversation, "csv_filename": csv_filename})
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="127.0.0.1", port=8000)