m7n's picture
Update app.py
190832b verified
raw
history blame
2.2 kB
import os
from pathlib import Path
import json
import base64
from datetime import datetime
import gradio as gr
from fastapi import FastAPI, Request, Response
from fastapi.staticfiles import StaticFiles
import uvicorn
import httpx
import spaces
from spaces.zero.client import _get_token
# Set environment variables for SSR
os.environ["GRADIO_SSR_MODE"] = "True"
os.environ["GRADIO_SERVER_PORT"] = "7860"
os.environ["GRADIO_SERVER_NAME"] = "0.0.0.0"
os.environ["GRADIO_NODE_SERVER_NAME"] = "0.0.0.0"
# Create FastAPI app
app = FastAPI()
# Create and configure static directory
static_dir = Path("./static")
static_dir.mkdir(parents=True, exist_ok=True)
app.mount("/static", StaticFiles(directory="static"), name="static")
os.environ["GRADIO_ALLOWED_PATHS"] = str(static_dir.resolve())
@spaces.GPU(duration=4*60)
def process_text(text):
return text.upper()
def process_and_save(request: gr.Request, text):
token = _get_token(request)
payload = token.split('.')[1]
payload = f"{payload}{'=' * ((4 - len(payload) % 4) % 4)}"
payload = json.loads(base64.urlsafe_b64decode(payload).decode())
print(f"Token payload: {payload}")
processed_text = process_text(text)
return processed_text
process_and_save.zerogpu = True
with gr.Blocks() as demo:
text_input = gr.Textbox(label="Enter some text")
submit_btn = gr.Button("Process and Download")
output = gr.Textbox(label="Output")
submit_btn.click(fn=process_and_save, inputs=[text_input], outputs=output)
# Mount the Gradio app with SSR mode, using node_port 7861
app = gr.mount_gradio_app(app, demo, path="/", ssr_mode=True, node_port=7861)
# Manual reverse proxy for SSR assets
@app.get("/_app/{full_path:path}")
async def proxy_to_node(full_path: str, request: Request):
node_server = "http://127.0.0.1:7861"
url = f"{node_server}/_app/{full_path}"
async with httpx.AsyncClient() as client:
node_response = await client.get(url)
return Response(
content=node_response.content,
status_code=node_response.status_code,
media_type=node_response.headers.get("content-type"),
)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)