|
from fastapi import ( |
|
FastAPI, |
|
File, |
|
Depends, |
|
HTTPException, |
|
UploadFile |
|
) |
|
from fastapi.openapi.utils import get_openapi |
|
from fastapi.staticfiles import StaticFiles |
|
from sqlmodel import Session, select |
|
|
|
from typing import ( |
|
List, |
|
Optional, |
|
Union, |
|
Any |
|
) |
|
from datetime import datetime |
|
import requests |
|
import aiohttp |
|
import time |
|
import json |
|
import os |
|
|
|
|
|
|
|
|
|
from llm import ( |
|
chat_query |
|
) |
|
|
|
|
|
|
|
|
|
from models import ( |
|
|
|
|
|
|
|
Organization, |
|
OrganizationCreate, |
|
OrganizationRead, |
|
OrganizationUpdate, |
|
User, |
|
UserCreate, |
|
UserRead, |
|
UserReadList, |
|
UserUpdate, |
|
DocumentRead, |
|
DocumentReadList, |
|
ProjectCreate, |
|
ProjectRead, |
|
ProjectReadList, |
|
ChatSessionResponse, |
|
ChatSessionCreatePost, |
|
WebhookCreate, |
|
|
|
|
|
|
|
get_engine, |
|
get_session |
|
|
|
) |
|
from helpers import ( |
|
|
|
|
|
|
|
get_org_by_uuid_or_namespace, |
|
get_project_by_uuid, |
|
get_user_by_uuid_or_identifier, |
|
get_users, |
|
get_documents_by_project_and_org, |
|
get_document_by_uuid, |
|
create_org_by_org_or_uuid, |
|
create_project_by_org |
|
) |
|
from util import ( |
|
save_file, |
|
get_sha256, |
|
is_uuid, |
|
logger |
|
) |
|
|
|
|
|
|
|
from config import ( |
|
APP_NAME, |
|
APP_VERSION, |
|
APP_DESCRIPTION, |
|
ENTITY_STATUS, |
|
CHANNEL_TYPE, |
|
LLM_MODELS, |
|
LLM_DISTANCE_THRESHOLD, |
|
LLM_DEFAULT_DISTANCE_STRATEGY, |
|
LLM_MAX_OUTPUT_TOKENS, |
|
LLM_MIN_NODE_LIMIT, |
|
FILE_UPLOAD_PATH, |
|
RASA_WEBHOOK_URL |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI() |
|
|
|
app.mount("/static", StaticFiles(directory="static"), name="static") |
|
|
|
|
|
|
|
|
|
@app.get("/health", include_in_schema=False) |
|
def health_check(): |
|
return {'status': 'ok'} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/org", response_model=List[OrganizationRead]) |
|
def read_organizations(): |
|
''' |
|
## Get all active organizations |
|
|
|
Returns: |
|
List[OrganizationRead]: List of organizations |
|
|
|
''' |
|
with Session(get_engine()) as session: |
|
orgs = session.exec(select(Organization).where(Organization.status == ENTITY_STATUS.ACTIVE.value)).all() |
|
return orgs |
|
|
|
|
|
|
|
|
|
|
|
@app.post("/org", response_model=Union[OrganizationRead, Any]) |
|
def create_organization( |
|
*, |
|
session: Session = Depends(get_session), |
|
organization: Optional[OrganizationCreate] = None, |
|
namespace: Optional[str] = None, |
|
display_name: Optional[str] = None |
|
): |
|
''' |
|
|
|
### Creates a new organization |
|
### <u>Args:</u> |
|
- **namespace**: Unique namespace for the organization (ex. openai) |
|
- **name**: Name of the organization (ex. OpenAI) |
|
- **bot_url**: URL of the bot (ex. https://t.me/your_bot) |
|
|
|
### <u>Returns:</u> |
|
- OrganizationRead |
|
--- |
|
<details><summary>π π» Code examples:</summary> |
|
### π₯οΈ Curl |
|
```bash |
|
curl -X POST "http://localhost:8888/org" -H "accept: application/json" -H "Content-Type: application/json" -d '{\"namespace\":\"openai\",\"name\":\"OpenAI\",\"bot_url\":\"https://t.me/your_bot\"}' |
|
``` |
|
<br/> |
|
### π Python |
|
```python |
|
import requests |
|
response = requests.post("http://localhost:8888/org", json={"namespace":"openai","name":"OpenAI","bot_url":"https://t.me/your_bot"}) |
|
print(response.json()) |
|
``` |
|
</details> |
|
''' |
|
|
|
return create_org_by_org_or_uuid( |
|
organization=organization, |
|
namespace=namespace, |
|
display_name=display_name, session=session |
|
) |
|
|
|
|
|
|
|
|
|
|
|
@app.get("/org/{organization_id}", response_model=Union[OrganizationRead, Any]) |
|
def read_organization( |
|
*, |
|
session: Session = Depends(get_session), |
|
organization_id: str |
|
): |
|
|
|
organization = get_org_by_uuid_or_namespace(organization_id, session=session) |
|
|
|
return organization |
|
|
|
|
|
|
|
|
|
|
|
@app.put("/org/{organization_id}", response_model=Union[OrganizationRead, Any]) |
|
def update_organization( |
|
*, |
|
session: Session = Depends(get_session), |
|
organization_id: str, |
|
organization: OrganizationUpdate |
|
): |
|
|
|
org = get_org_by_uuid_or_namespace(organization_id, session=session) |
|
|
|
org.update(organization.dict(exclude_unset=True)) |
|
return org |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/project", response_model=List[ProjectReadList]) |
|
def read_projects( |
|
*, |
|
session: Session = Depends(get_session), |
|
organization_id: str |
|
): |
|
|
|
organization = get_org_by_uuid_or_namespace(organization_id, session=session) |
|
|
|
if not organization.projects: |
|
raise HTTPException(status_code=404, detail='No projects found for organization') |
|
|
|
return organization.projects |
|
|
|
|
|
|
|
|
|
|
|
@app.post("/project", response_model=Union[ProjectRead, Any]) |
|
def create_project( |
|
*, |
|
session: Session = Depends(get_session), |
|
organization_id: str, |
|
project: ProjectCreate |
|
): |
|
return create_project_by_org( |
|
organization_id=organization_id, |
|
project=project, |
|
session=session |
|
) |
|
|
|
|
|
|
|
|
|
|
|
@app.get("/project/{project_id}", response_model=Union[ProjectRead, Any]) |
|
def read_project( |
|
*, |
|
session: Session = Depends(get_session), |
|
organization_id: str, |
|
project_id: str |
|
): |
|
|
|
return get_project_by_uuid(uuid=project_id, organization_id=organization_id, session=session) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/document", response_model=Union[DocumentReadList, Any]) |
|
async def upload_document( |
|
*, |
|
session: Session = Depends(get_session), |
|
organization_id: str, |
|
project_id: str, |
|
url: Optional[str] = None, |
|
file: Optional[UploadFile] = File(...), |
|
overwrite: Optional[bool] = True |
|
): |
|
organization = get_org_by_uuid_or_namespace(organization_id, session=session) |
|
project = get_project_by_uuid(uuid=project_id, organization_id=organization_id, session=session) |
|
file_root_path = os.path.join(FILE_UPLOAD_PATH, str(organization.uuid), str(project.uuid)) |
|
|
|
file_version = 1 |
|
|
|
|
|
|
|
|
|
if url and file: |
|
raise HTTPException(status_code=400, detail='You can only upload a file OR provide a URL, not both') |
|
|
|
|
|
|
|
|
|
if url: |
|
file_name = url.split('/')[-1] |
|
file_upload_path = os.path.join(file_root_path, file_name) |
|
file_exists = os.path.isfile(file_upload_path) |
|
|
|
if file_exists: |
|
file_name = f'{file_name}_{int(time.time())}' |
|
file_upload_path = os.path.join(file_root_path, file_name) |
|
|
|
async with aiohttp.ClientSession() as session: |
|
async with session.get(url) as resp: |
|
if resp.status != 200: |
|
raise HTTPException(status_code=400, detail=f'Could not download file from {url}') |
|
|
|
with open(file_upload_path, 'wb') as f: |
|
while True: |
|
chunk = await resp.content.read(1024) |
|
if not chunk: |
|
break |
|
f.write(chunk) |
|
|
|
file_contents = open(file_upload_path, 'rb').read() |
|
file_hash = get_sha256(contents=file_contents) |
|
|
|
|
|
|
|
|
|
else: |
|
file_name = file.filename |
|
file_upload_path = os.path.join(file_root_path, file_name) |
|
file_exists = os.path.isfile(file_upload_path) |
|
|
|
if file_exists: |
|
file_name = f'{file_name}_{int(time.time())}' |
|
file_upload_path = os.path.join(file_root_path, file_name) |
|
|
|
file_contents = await file.read() |
|
file_hash = get_sha256(contents=file_contents) |
|
await save_file(file, file_upload_path) |
|
|
|
document_obj = create_document_by_file_path( |
|
organization=organization, |
|
project=project, |
|
file_path=file_upload_path, |
|
file_hash=file_hash, |
|
file_version=file_version, |
|
url=url, |
|
overwrite=overwrite, |
|
session=session |
|
) |
|
return document_obj |
|
|
|
|
|
|
|
|
|
|
|
@app.get("/document", response_model=List[DocumentReadList]) |
|
def read_documents( |
|
*, |
|
session: Session = Depends(get_session), |
|
organization_id: str, |
|
project_id: str |
|
): |
|
return get_documents_by_project_and_org(project_id=project_id, organization_id=organization_id, session=session) |
|
|
|
|
|
|
|
|
|
@app.get("/document/{document_id}", response_model=DocumentRead) |
|
def read_document( |
|
*, |
|
session: Session = Depends(get_session), |
|
organization_id: str, |
|
project_id: str, |
|
document_id: str |
|
): |
|
return get_document_by_uuid(uuid=document_id, project_id=project_id, organization_id=organization_id, session=session) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/user", response_model=List[UserReadList]) |
|
def read_users( |
|
*, |
|
session: Session = Depends(get_session), |
|
): |
|
return get_users(session=session) |
|
|
|
|
|
|
|
|
|
|
|
@app.post("/user", response_model=UserRead) |
|
def create_user( |
|
*, |
|
session: Session = Depends(get_session), |
|
user: UserCreate |
|
): |
|
|
|
return create_user( |
|
user=user, |
|
session=session |
|
) |
|
|
|
|
|
|
|
|
|
|
|
@app.get("/user/{user_uuid}", response_model=UserRead) |
|
def read_user( |
|
*, |
|
session: Session = Depends(get_session), |
|
user_id: str |
|
): |
|
|
|
return get_user_by_uuid_or_identifier(id=user_id, session=session) |
|
|
|
|
|
|
|
|
|
|
|
@app.put("/user/{user_uuid}", response_model=UserRead) |
|
def update_user(*, user_uuid: str, user: UserUpdate): |
|
|
|
|
|
user = User.get(uuid=user_uuid) |
|
|
|
|
|
if user: |
|
user.update(**user.dict()) |
|
return user |
|
|
|
|
|
else: |
|
raise HTTPException(status_code=404, detail=f'User {user_uuid} not found!') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def process_webhook_telegram(webhook_data: dict): |
|
""" |
|
Telegram example response: |
|
{ |
|
"update_id": 248146407, |
|
"message": { |
|
"message_id": 299, |
|
"from": { |
|
"id": 123456789, |
|
"is_bot": false, |
|
"first_name": "Elon", |
|
"username": "elonmusk", |
|
"language_code": "en" |
|
}, |
|
"chat": { |
|
"id": 123456789, |
|
"first_name": "Elon", |
|
"username": "elonmusk", |
|
"type": "private" |
|
}, |
|
"date": 1683115867, |
|
"text": "Tell me about the company?" |
|
} |
|
} |
|
""" |
|
message = webhook_data.get('message', None) |
|
chat = message.get('chat', None) |
|
message_from = message.get('from', None) |
|
return { |
|
'update_id': webhook_data.get('update_id', None), |
|
'message_id': message.get('message_id', None), |
|
'user_id': message_from.get('id', None), |
|
'username': message_from.get('username', None), |
|
'user_language': message_from.get('language_code', None), |
|
'user_firstname': chat.get('first_name', None), |
|
'user_message': message.get('text', None), |
|
'message_ts': datetime.fromtimestamp(message.get('date', None)) if message.get('date', None) else None, |
|
'message_type': chat.get('type', None) |
|
} |
|
|
|
|
|
@app.post("/webhooks/{channel}/webhook") |
|
def get_webhook( |
|
*, |
|
session: Session = Depends(get_session), |
|
channel: str, |
|
webhook: WebhookCreate |
|
): |
|
webhook_data = webhook.dict() |
|
|
|
|
|
|
|
|
|
if channel == 'telegram': |
|
rasa_webhook_url = f'{RASA_WEBHOOK_URL}/webhooks/{channel}/webhook' |
|
data = process_webhook_telegram(webhook_data) |
|
channel = CHANNEL_TYPE.TELEGRAM.value |
|
user_data = { |
|
'identifier': data['user_id'], |
|
'identifier_type': channel, |
|
'first_name': data['user_firstname'], |
|
'language': data['user_language'] |
|
} |
|
session_metadata = { |
|
'update_id': data['update_id'], |
|
'username': data['username'], |
|
'message_id': data['user_message'], |
|
'msg_ts': data['message_ts'], |
|
'msg_type': data['message_type'], |
|
} |
|
user_message = data['user_message'] |
|
else: |
|
|
|
raise HTTPException(status_code=404, detail=f'Channel {channel} not a valid webhook channel!') |
|
|
|
chat_session = chat_query( |
|
user_message, |
|
session=session, |
|
channel=channel, |
|
identifier=user_data['identifier'], |
|
user_data=user_data, |
|
meta=session_metadata |
|
) |
|
|
|
meta = chat_session.meta |
|
|
|
|
|
|
|
|
|
webhook_data['message']['meta'] = { |
|
'response': chat_session.response if chat_session.response else None, |
|
'tags': meta['tags'] if 'tags' in meta else None, |
|
'is_escalate': meta['is_escalate'] if 'is_escalate' in meta else False, |
|
'session_id': meta['session_id'] if 'session_id' in meta else None |
|
|
|
} |
|
|
|
|
|
|
|
|
|
res = requests.post(rasa_webhook_url, data=json.dumps(webhook_data)) |
|
logger.debug(f'[π€ RasaGPT API webhook]\nPosting data: {json.dumps(webhook_data)}\n\n[π€ RasaGPT API webhook]\nRasa webhook response: {res.text}') |
|
|
|
return {'status': 'ok'} |
|
|
|
|
|
|
|
|
|
|
|
_schema = get_openapi( |
|
title=APP_NAME, |
|
description=APP_DESCRIPTION, |
|
version=APP_VERSION, |
|
routes=app.routes, |
|
) |
|
_schema['info']['x-logo'] = { |
|
'url': '/static/img/rasagpt-logo-1.png' |
|
} |
|
app.openapi_schema = _schema |