|
import uvicorn |
|
from fastapi import FastAPI, Depends |
|
from starlette.responses import RedirectResponse |
|
from starlette.middleware.sessions import SessionMiddleware |
|
from authlib.integrations.starlette_client import OAuth, OAuthError |
|
from fastapi import Request |
|
import os |
|
from starlette.config import Config |
|
import gradio as gr |
|
|
|
app = FastAPI() |
|
|
|
|
|
GOOGLE_CLIENT_ID = os.environ.get("GOOGLE_CLIENT_ID") |
|
GOOGLE_CLIENT_SECRET = os.environ.get("GOOGLE_CLIENT_SECRET") |
|
SECRET_KEY = os.environ.get("SECRET_KEY") |
|
|
|
|
|
config_data = {'GOOGLE_CLIENT_ID': GOOGLE_CLIENT_ID, 'GOOGLE_CLIENT_SECRET': GOOGLE_CLIENT_SECRET} |
|
starlette_config = Config(environ=config_data) |
|
oauth = OAuth(starlette_config) |
|
oauth.register( |
|
name='google', |
|
server_metadata_url='https://accounts.google.com/.well-known/openid-configuration', |
|
client_kwargs={'scope': 'openid email profile'}, |
|
) |
|
|
|
app.add_middleware(SessionMiddleware, secret_key=SECRET_KEY) |
|
|
|
|
|
def get_user(request: Request): |
|
user = request.session.get('user') |
|
print ("User", user) |
|
if user and user['email'].endswith("@zalando.de"): |
|
return user['name'] |
|
return None |
|
|
|
@app.get('/') |
|
def public(request: Request, user = Depends(get_user)): |
|
root_url = gr.route_utils.get_root_url(request, "/", None) |
|
if user: |
|
return RedirectResponse(url=f'{root_url}/gradio/') |
|
else: |
|
return RedirectResponse(url=f'{root_url}/main/') |
|
|
|
@app.route('/logout') |
|
async def logout(request: Request): |
|
request.session.pop('user', None) |
|
return RedirectResponse(url='/') |
|
|
|
@app.route('/login') |
|
async def login(request: Request): |
|
root_url = gr.route_utils.get_root_url(request, "/login", None) |
|
redirect_uri = f"{root_url}/auth" |
|
print("Redirecting to", redirect_uri) |
|
return await oauth.google.authorize_redirect(request, redirect_uri) |
|
|
|
@app.route('/auth') |
|
async def auth(request: Request): |
|
try: |
|
access_token = await oauth.google.authorize_access_token(request) |
|
except OAuthError: |
|
print("Error getting access token", str(OAuthError)) |
|
return RedirectResponse(url='/') |
|
request.session['user'] = dict(access_token)["userinfo"] |
|
print("Redirecting to /gradio") |
|
return RedirectResponse(url='/gradio') |
|
|
|
with gr.Blocks() as login_demo: |
|
btn = gr.Button("Login") |
|
_js_redirect = """ |
|
() => { |
|
url = '/login' + window.location.search; |
|
window.open(url, '_blank'); |
|
} |
|
""" |
|
btn.click(None, js=_js_redirect) |
|
|
|
app = gr.mount_gradio_app(app, login_demo, path="/main") |
|
|
|
def greet(request: gr.Request): |
|
return f"Welcome to Gradio, {request.username}" |
|
|
|
with gr.Blocks() as main_demo: |
|
m = gr.Markdown("Welcome to Gradio!") |
|
gr.Button("Logout", link="/logout") |
|
main_demo.load(greet, None, m) |
|
|
|
app = gr.mount_gradio_app(app, main_demo, path="/gradio", auth_dependency=get_user) |
|
|
|
|
|
if __name__ == '__main__': |
|
uvicorn.run(app) |
|
|