File size: 2,435 Bytes
0245be8 c095e79 0245be8 c095e79 0245be8 c095e79 0245be8 c095e79 0245be8 c095e79 0245be8 c095e79 0245be8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
from fastapi import UploadFile
from functools import partial
from hashlib import sha256
from uuid import UUID
import aiofiles
import json
import re
from config import (
logger
)
_snake_1 = partial(re.compile(r'(.)((?<![^A-Za-z])[A-Z][a-z]+)').sub, r'\1_\2')
_snake_2 = partial(re.compile(r'([a-z0-9])([A-Z])').sub, r'\1_\2')
# ---------------------------------------
# Convert to snake casing (for DB models)
# ---------------------------------------
def snake_case(string: str) -> str:
return _snake_2(_snake_1(string)).casefold()
# ------------------------------
# Check if string is UUID format
# ------------------------------
def is_uuid(uuid: str) -> bool:
uuid = str(uuid) if isinstance(uuid, UUID) else uuid
return re.match(r"^[0-9a-f]{8}-?[0-9a-f]{4}-?4[0-9a-f]{3}-?[89ab][0-9a-f]{3}-?[0-9a-f]{12}$", uuid)
# ---------------------------
# Writes a file to disk async
# ---------------------------
async def save_file(file: UploadFile, file_path: str):
async with aiofiles.open(file_path, 'wb') as f:
await f.write(await file.read())
# ---------------------------
# Get SHA256 hash of contents
# ---------------------------
def get_sha256(contents: bytes):
return sha256(contents).hexdigest()
# -----------------------
# Get SHA256 hash of file
# -----------------------
def get_file_hash(
file_path: str,
):
with open(file_path, 'rb') as f:
file_hash = sha256(f.read()).hexdigest()
return file_hash
# -------------------
# Clean up LLM output
# -------------------
def sanitize_output(
str_output: str
):
# Let's sanitize the JSON
res = str_output.replace("\n", '')
# If the first character is "?", remove it. Ran into this issue for some reason.
if res[0] == '?':
res = res[1:]
# check if response is valid json
try:
json.loads(res)
except json.JSONDecodeError:
raise ValueError(f'LLM response is not valid JSON: {res}')
if 'message' not in res or 'tags' not in res or 'is_escalate' not in res:
raise ValueError(f'LLM response is missing required fields: {res}')
logger.debug(f'Output: {res}')
return res
# ------------------
# Clean up LLM input
# ------------------
def sanitize_input(
str_input: str
):
# Escape single quotes that cause output JSON issues
str_input = str_input.replace("'", "")
logger.debug(f'Input: {str_input}')
return str_input
|