m7n's picture
Update app.py
82f4377 verified
raw
history blame
2.39 kB
import os
from pathlib import Path
import json
import base64
from datetime import datetime
import gradio as gr
from fastapi import FastAPI, Request, Response
from fastapi.staticfiles import StaticFiles
import uvicorn
import httpx
from fastapi.middleware.base import BaseHTTPMiddleware
import spaces
from spaces.zero.client import _get_token
# Set environment variables
os.environ["GRADIO_SSR_MODE"] = "True"
os.environ["GRADIO_SERVER_PORT"] = "7860"
os.environ["GRADIO_SERVER_NAME"] = "0.0.0.0"
os.environ["GRADIO_NODE_SERVER_NAME"] = "127.0.0.1" # host only
os.environ["GRADIO_ROOT_PATH"] = "/"
# Create FastAPI app
app = FastAPI()
# Optional middleware to fix malformed SSR asset paths.
class SSRPathRewriteMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
path = request.url.path
# If the path starts with the node port number attached (e.g., "/7861")
if path.startswith("/7861"):
# Replace "/7861" with "/_app" (or adjust to the correct base path)
fixed = path.replace("/7861", "/_app", 1)
request.scope["path"] = fixed
return await call_next(request)
app.add_middleware(SSRPathRewriteMiddleware)
# Create and configure static directory
static_dir = Path("./static")
static_dir.mkdir(parents=True, exist_ok=True)
app.mount("/static", StaticFiles(directory="static"), name="static")
os.environ["GRADIO_ALLOWED_PATHS"] = str(static_dir.resolve())
@spaces.GPU(duration=4*60)
def process_text(text):
return text.upper()
def process_and_save(request: gr.Request, text):
token = _get_token(request)
payload = token.split('.')[1]
payload = f"{payload}{'=' * ((4 - len(payload) % 4) % 4)}"
payload = json.loads(base64.urlsafe_b64decode(payload).decode())
print(f"Token payload: {payload}")
processed_text = process_text(text)
return processed_text
process_and_save.zerogpu = True
with gr.Blocks() as demo:
text_input = gr.Textbox(label="Enter some text")
submit_btn = gr.Button("Process and Download")
output = gr.Textbox(label="Output")
submit_btn.click(fn=process_and_save, inputs=[text_input], outputs=output)
# Mount the Gradio app with SSR mode, using node_port 7861.
app = gr.mount_gradio_app(app, demo, path="/", ssr_mode=True, node_port=7861)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)