File size: 2,392 Bytes
d1ed09d
 
fa9a2fa
 
 
 
d1ed09d
190832b
d1ed09d
 
190832b
82f4377
d1ed09d
fa9a2fa
 
 
82f4377
1a55407
 
 
82f4377
 
1a55407
755230a
d1ed09d
24e24f3
82f4377
 
 
 
 
 
 
 
 
 
 
 
 
755230a
 
 
 
 
d1ed09d
1a55407
fa9a2fa
 
 
 
 
 
 
 
1a55407
 
190832b
d1ed09d
755230a
d1ed09d
755230a
 
fa9a2fa
1a55407
190832b
d1ed09d
82f4377
89bc1ae
1ef8927
d1ed09d
e178c81
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
from pathlib import Path
import json
import base64
from datetime import datetime

import gradio as gr
from fastapi import FastAPI, Request, Response
from fastapi.staticfiles import StaticFiles
import uvicorn
import httpx
from fastapi.middleware.base import BaseHTTPMiddleware

import spaces
from spaces.zero.client import _get_token

# Set environment variables
os.environ["GRADIO_SSR_MODE"] = "True"
os.environ["GRADIO_SERVER_PORT"] = "7860"
os.environ["GRADIO_SERVER_NAME"] = "0.0.0.0"
os.environ["GRADIO_NODE_SERVER_NAME"] = "127.0.0.1"  # host only
os.environ["GRADIO_ROOT_PATH"] = "/"

# Create FastAPI app
app = FastAPI()

# Optional middleware to fix malformed SSR asset paths.
class SSRPathRewriteMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request: Request, call_next):
        path = request.url.path
        # If the path starts with the node port number attached (e.g., "/7861")
        if path.startswith("/7861"):
            # Replace "/7861" with "/_app" (or adjust to the correct base path)
            fixed = path.replace("/7861", "/_app", 1)
            request.scope["path"] = fixed
        return await call_next(request)

app.add_middleware(SSRPathRewriteMiddleware)

# Create and configure static directory
static_dir = Path("./static")
static_dir.mkdir(parents=True, exist_ok=True)
app.mount("/static", StaticFiles(directory="static"), name="static")
os.environ["GRADIO_ALLOWED_PATHS"] = str(static_dir.resolve())

@spaces.GPU(duration=4*60)
def process_text(text):
    return text.upper()

def process_and_save(request: gr.Request, text):
    token = _get_token(request)
    payload = token.split('.')[1]
    payload = f"{payload}{'=' * ((4 - len(payload) % 4) % 4)}"
    payload = json.loads(base64.urlsafe_b64decode(payload).decode())
    print(f"Token payload: {payload}")
    processed_text = process_text(text)
    return processed_text

process_and_save.zerogpu = True

with gr.Blocks() as demo:
    text_input = gr.Textbox(label="Enter some text")
    submit_btn = gr.Button("Process and Download")
    output = gr.Textbox(label="Output")
    submit_btn.click(fn=process_and_save, inputs=[text_input], outputs=output)

# Mount the Gradio app with SSR mode, using node_port 7861.
app = gr.mount_gradio_app(app, demo, path="/", ssr_mode=True, node_port=7861)

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)