Update main.py
Browse files
main.py
CHANGED
@@ -23,7 +23,7 @@ import base64
|
|
23 |
|
24 |
from dotenv import load_dotenv
|
25 |
|
26 |
-
# Load environment variables
|
27 |
load_dotenv()
|
28 |
|
29 |
# Configure logging
|
@@ -145,7 +145,7 @@ class ImageResponseCustom:
|
|
145 |
self.url = url
|
146 |
self.alt = alt
|
147 |
|
148 |
-
# Blackbox AI Integration
|
149 |
class Blackbox:
|
150 |
url = "https://www.blackbox.ai"
|
151 |
api_endpoint = "https://www.blackbox.ai/api/chat" # Placeholder endpoint
|
@@ -421,7 +421,7 @@ class Blackbox:
|
|
421 |
if attempt == retry_attempts - 1:
|
422 |
raise HTTPException(status_code=500, detail=str(e))
|
423 |
|
424 |
-
# FastAPI app
|
425 |
app = FastAPI()
|
426 |
|
427 |
# Add the cleanup task when the app starts
|
@@ -453,8 +453,308 @@ async def security_middleware(request: Request, call_next):
|
|
453 |
response = await call_next(request)
|
454 |
return response
|
455 |
|
456 |
-
#
|
457 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
458 |
|
459 |
# Custom exception handler to match OpenAI's error format
|
460 |
@app.exception_handler(HTTPException)
|
|
|
23 |
|
24 |
from dotenv import load_dotenv
|
25 |
|
26 |
+
# Load environment variables from .env file
|
27 |
load_dotenv()
|
28 |
|
29 |
# Configure logging
|
|
|
145 |
self.url = url
|
146 |
self.alt = alt
|
147 |
|
148 |
+
# Placeholder for Blackbox AI Integration
|
149 |
class Blackbox:
|
150 |
url = "https://www.blackbox.ai"
|
151 |
api_endpoint = "https://www.blackbox.ai/api/chat" # Placeholder endpoint
|
|
|
421 |
if attempt == retry_attempts - 1:
|
422 |
raise HTTPException(status_code=500, detail=str(e))
|
423 |
|
424 |
+
# Initialize FastAPI app
|
425 |
app = FastAPI()
|
426 |
|
427 |
# Add the cleanup task when the app starts
|
|
|
453 |
response = await call_next(request)
|
454 |
return response
|
455 |
|
456 |
+
# Pydantic Models
|
457 |
+
|
458 |
+
class TextContent(BaseModel):
|
459 |
+
type: str = "text"
|
460 |
+
text: str
|
461 |
+
|
462 |
+
@validator('type')
|
463 |
+
def type_must_be_text(cls, v):
|
464 |
+
if v != "text":
|
465 |
+
raise ValueError("Type must be 'text'")
|
466 |
+
return v
|
467 |
+
|
468 |
+
class ImageContent(BaseModel):
|
469 |
+
type: str = "image_url"
|
470 |
+
image_url: Dict[str, str]
|
471 |
+
|
472 |
+
@validator('type')
|
473 |
+
def type_must_be_image_url(cls, v):
|
474 |
+
if v != "image_url":
|
475 |
+
raise ValueError("Type must be 'image_url'")
|
476 |
+
return v
|
477 |
+
|
478 |
+
ContentItem = Union[TextContent, ImageContent]
|
479 |
+
|
480 |
+
class Message(BaseModel):
|
481 |
+
role: str
|
482 |
+
content: Union[str, List[ContentItem]]
|
483 |
+
|
484 |
+
@validator('role')
|
485 |
+
def role_must_be_valid(cls, v):
|
486 |
+
if v not in {"system", "user", "assistant"}:
|
487 |
+
raise ValueError("Role must be 'system', 'user', or 'assistant'")
|
488 |
+
return v
|
489 |
+
|
490 |
+
class ChatRequest(BaseModel):
|
491 |
+
model: str
|
492 |
+
messages: List[Message]
|
493 |
+
temperature: Optional[float] = 1.0
|
494 |
+
top_p: Optional[float] = 1.0
|
495 |
+
n: Optional[int] = 1
|
496 |
+
stream: Optional[bool] = False
|
497 |
+
stop: Optional[Union[str, List[str]]] = None
|
498 |
+
max_tokens: Optional[int] = None
|
499 |
+
presence_penalty: Optional[float] = 0.0
|
500 |
+
frequency_penalty: Optional[float] = 0.0
|
501 |
+
logit_bias: Optional[Dict[str, float]] = None
|
502 |
+
user: Optional[str] = None
|
503 |
+
webSearchMode: Optional[bool] = False # Custom parameter
|
504 |
+
|
505 |
+
class TokenizerRequest(BaseModel):
|
506 |
+
text: str
|
507 |
+
|
508 |
+
# Utility Functions
|
509 |
+
|
510 |
+
def calculate_estimated_cost(prompt_tokens: int, completion_tokens: int) -> float:
|
511 |
+
"""
|
512 |
+
Calculate the estimated cost based on the number of tokens.
|
513 |
+
Replace the pricing below with your actual pricing model.
|
514 |
+
"""
|
515 |
+
# Example pricing: $0.00000268 per token
|
516 |
+
cost_per_token = 0.00000268
|
517 |
+
return round((prompt_tokens + completion_tokens) * cost_per_token, 8)
|
518 |
+
|
519 |
+
def count_tokens(text: str) -> int:
|
520 |
+
"""
|
521 |
+
Counts the number of tokens in a given text using tiktoken.
|
522 |
+
"""
|
523 |
+
try:
|
524 |
+
import tiktoken
|
525 |
+
encoding = tiktoken.get_encoding("cl100k_base")
|
526 |
+
return len(encoding.encode(text))
|
527 |
+
except ImportError:
|
528 |
+
# Fallback if tiktoken is not installed
|
529 |
+
return len(text.split())
|
530 |
+
|
531 |
+
def create_response(content: str, model: str, finish_reason: Optional[str] = None) -> Dict[str, Any]:
|
532 |
+
return {
|
533 |
+
"id": f"chatcmpl-{uuid.uuid4()}",
|
534 |
+
"object": "chat.completion",
|
535 |
+
"created": int(datetime.now().timestamp()),
|
536 |
+
"model": model,
|
537 |
+
"choices": [
|
538 |
+
{
|
539 |
+
"index": 0,
|
540 |
+
"message": {
|
541 |
+
"role": "assistant",
|
542 |
+
"content": content
|
543 |
+
},
|
544 |
+
"finish_reason": finish_reason
|
545 |
+
}
|
546 |
+
],
|
547 |
+
"usage": None, # To be filled in non-streaming responses
|
548 |
+
}
|
549 |
+
|
550 |
+
def extract_all_images_from_content(content: Union[str, List[ContentItem]]) -> List[Tuple[str, str]]:
|
551 |
+
"""
|
552 |
+
Extracts all images from the content.
|
553 |
+
Returns a list of tuples containing (alt_text, image_data_uri).
|
554 |
+
"""
|
555 |
+
images = []
|
556 |
+
if isinstance(content, list):
|
557 |
+
for item in content:
|
558 |
+
if isinstance(item, ImageContent):
|
559 |
+
alt_text = item.image_url.get('alt', '') # Optional alt text
|
560 |
+
image_data_uri = item.image_url.get('url', '')
|
561 |
+
if image_data_uri:
|
562 |
+
images.append((alt_text, image_data_uri))
|
563 |
+
return images
|
564 |
+
|
565 |
+
# Image Analysis Function (Placeholder)
|
566 |
+
async def analyze_image(image_data_uri: str) -> str:
|
567 |
+
"""
|
568 |
+
Placeholder function to analyze the image.
|
569 |
+
Replace this with actual image analysis logic or API calls.
|
570 |
+
"""
|
571 |
+
try:
|
572 |
+
# Extract base64 data
|
573 |
+
image_data = image_data_uri.split(",")[1]
|
574 |
+
# Decode the image
|
575 |
+
image_bytes = base64.b64decode(image_data)
|
576 |
+
|
577 |
+
# Here, integrate with an image analysis API or implement your own logic
|
578 |
+
# For demonstration, we'll simulate analysis with a dummy response.
|
579 |
+
await asyncio.sleep(1) # Simulate processing delay
|
580 |
+
return "Image analysis result: The image depicts a beautiful sunset over the mountains."
|
581 |
+
except Exception as e:
|
582 |
+
logger.error(f"Failed to analyze image: {e}")
|
583 |
+
raise HTTPException(status_code=400, detail="Failed to process the provided image.")
|
584 |
+
|
585 |
+
# Endpoint: POST /v1/chat/completions
|
586 |
+
@app.post("/v1/chat/completions", dependencies=[Depends(rate_limiter_per_ip)])
|
587 |
+
async def chat_completions(request: ChatRequest, req: Request, api_key: str = Depends(get_api_key)):
|
588 |
+
client_ip = req.client.host
|
589 |
+
# Redact user messages only for logging purposes
|
590 |
+
redacted_messages = [{"role": msg.role, "content": "[redacted]"} for msg in request.messages]
|
591 |
+
|
592 |
+
logger.info(f"Received chat completions request from API key: {api_key} | IP: {client_ip} | Model: {request.model} | Messages: {redacted_messages}")
|
593 |
+
|
594 |
+
try:
|
595 |
+
# Validate that the requested model is available
|
596 |
+
if request.model not in Blackbox.models and request.model not in Blackbox.model_aliases:
|
597 |
+
logger.warning(f"Attempt to use unavailable model: {request.model} from IP: {client_ip}")
|
598 |
+
raise HTTPException(status_code=400, detail="Requested model is not available.")
|
599 |
+
|
600 |
+
# Initialize response content
|
601 |
+
assistant_content = ""
|
602 |
+
|
603 |
+
# Iterate through messages to find and process images
|
604 |
+
for msg in request.messages:
|
605 |
+
if msg.role == "user":
|
606 |
+
# Extract all images from the message content
|
607 |
+
images = extract_all_images_from_content(msg.content)
|
608 |
+
for alt_text, image_data_uri in images:
|
609 |
+
# Analyze the image
|
610 |
+
analysis_result = await analyze_image(image_data_uri)
|
611 |
+
assistant_content += analysis_result + "\n"
|
612 |
+
|
613 |
+
# Example response content
|
614 |
+
assistant_content += "Based on the image you provided, here are the insights..."
|
615 |
+
|
616 |
+
# Calculate token usage (simple approximation)
|
617 |
+
prompt_tokens = sum(count_tokens(" ".join([item.text if isinstance(item, TextContent) else item.image_url['url'] for item in msg.content]) ) for msg in request.messages)
|
618 |
+
completion_tokens = count_tokens(assistant_content)
|
619 |
+
total_tokens = prompt_tokens + completion_tokens
|
620 |
+
estimated_cost = calculate_estimated_cost(prompt_tokens, completion_tokens)
|
621 |
+
|
622 |
+
logger.info(f"Completed response generation for API key: {api_key} | IP: {client_ip}")
|
623 |
+
|
624 |
+
if request.stream:
|
625 |
+
async def generate():
|
626 |
+
try:
|
627 |
+
for msg in request.messages:
|
628 |
+
if msg.role == "user":
|
629 |
+
images = extract_all_images_from_content(msg.content)
|
630 |
+
for alt_text, image_data_uri in images:
|
631 |
+
analysis_result = await analyze_image(image_data_uri)
|
632 |
+
response_chunk = {
|
633 |
+
"id": f"chatcmpl-{uuid.uuid4()}",
|
634 |
+
"object": "chat.completion.chunk",
|
635 |
+
"created": int(datetime.now().timestamp()),
|
636 |
+
"model": request.model,
|
637 |
+
"choices": [
|
638 |
+
{
|
639 |
+
"index": 0,
|
640 |
+
"delta": {"content": analysis_result + "\n", "role": "assistant"},
|
641 |
+
"finish_reason": None,
|
642 |
+
}
|
643 |
+
],
|
644 |
+
"usage": None,
|
645 |
+
}
|
646 |
+
yield f"data: {json.dumps(response_chunk)}\n\n"
|
647 |
+
|
648 |
+
# Final message
|
649 |
+
final_response = {
|
650 |
+
"id": f"chatcmpl-{uuid.uuid4()}",
|
651 |
+
"object": "chat.completion",
|
652 |
+
"created": int(datetime.now().timestamp()),
|
653 |
+
"model": request.model,
|
654 |
+
"choices": [
|
655 |
+
{
|
656 |
+
"message": {
|
657 |
+
"role": "assistant",
|
658 |
+
"content": assistant_content.strip()
|
659 |
+
},
|
660 |
+
"finish_reason": "stop",
|
661 |
+
"index": 0
|
662 |
+
}
|
663 |
+
],
|
664 |
+
"usage": {
|
665 |
+
"prompt_tokens": prompt_tokens,
|
666 |
+
"completion_tokens": completion_tokens,
|
667 |
+
"total_tokens": total_tokens,
|
668 |
+
"estimated_cost": estimated_cost
|
669 |
+
},
|
670 |
+
}
|
671 |
+
yield f"data: {json.dumps(final_response)}\n\n"
|
672 |
+
yield "data: [DONE]\n\n"
|
673 |
+
except HTTPException as he:
|
674 |
+
error_response = {"error": he.detail}
|
675 |
+
yield f"data: {json.dumps(error_response)}\n\n"
|
676 |
+
except Exception as e:
|
677 |
+
logger.exception(f"Error during streaming response generation from IP: {client_ip}.")
|
678 |
+
error_response = {"error": str(e)}
|
679 |
+
yield f"data: {json.dumps(error_response)}\n\n"
|
680 |
+
|
681 |
+
return StreamingResponse(generate(), media_type="text/event-stream")
|
682 |
+
else:
|
683 |
+
return {
|
684 |
+
"id": f"chatcmpl-{uuid.uuid4()}",
|
685 |
+
"object": "chat.completion",
|
686 |
+
"created": int(datetime.now().timestamp()),
|
687 |
+
"model": request.model,
|
688 |
+
"choices": [
|
689 |
+
{
|
690 |
+
"message": {
|
691 |
+
"role": "assistant",
|
692 |
+
"content": assistant_content.strip()
|
693 |
+
},
|
694 |
+
"finish_reason": "stop",
|
695 |
+
"index": 0
|
696 |
+
}
|
697 |
+
],
|
698 |
+
"usage": {
|
699 |
+
"prompt_tokens": prompt_tokens,
|
700 |
+
"completion_tokens": completion_tokens,
|
701 |
+
"total_tokens": total_tokens,
|
702 |
+
"estimated_cost": estimated_cost
|
703 |
+
},
|
704 |
+
}
|
705 |
+
except ModelNotWorkingException as e:
|
706 |
+
logger.warning(f"Model not working: {e} | IP: {client_ip}")
|
707 |
+
raise HTTPException(status_code=503, detail=str(e))
|
708 |
+
except HTTPException as he:
|
709 |
+
logger.warning(f"HTTPException: {he.detail} | IP: {client_ip}")
|
710 |
+
raise he
|
711 |
+
except Exception as e:
|
712 |
+
logger.exception(f"An unexpected error occurred while processing the chat completions request from IP: {client_ip}.")
|
713 |
+
raise HTTPException(status_code=500, detail=str(e))
|
714 |
+
|
715 |
+
# Endpoint: POST /v1/tokenizer
|
716 |
+
@app.post("/v1/tokenizer", dependencies=[Depends(rate_limiter_per_ip)])
|
717 |
+
async def tokenizer(request: TokenizerRequest, req: Request, api_key: str = Depends(get_api_key)):
|
718 |
+
client_ip = req.client.host
|
719 |
+
text = request.text
|
720 |
+
token_count = count_tokens(text)
|
721 |
+
logger.info(f"Tokenizer requested from IP: {client_ip} | Text length: {len(text)}")
|
722 |
+
return {"text": text, "tokens": token_count}
|
723 |
+
|
724 |
+
# Endpoint: GET /v1/models
|
725 |
+
@app.get("/v1/models", dependencies=[Depends(rate_limiter_per_ip)])
|
726 |
+
async def get_models(req: Request, api_key: str = Depends(get_api_key)):
|
727 |
+
client_ip = req.client.host
|
728 |
+
logger.info(f"Fetching available models from IP: {client_ip}")
|
729 |
+
return {"data": [{"id": model, "object": "model"} for model in Blackbox.models]}
|
730 |
+
|
731 |
+
# Endpoint: GET /v1/models/{model}/status
|
732 |
+
@app.get("/v1/models/{model}/status", dependencies=[Depends(rate_limiter_per_ip)])
|
733 |
+
async def model_status(model: str, req: Request, api_key: str = Depends(get_api_key)):
|
734 |
+
client_ip = req.client.host
|
735 |
+
logger.info(f"Model status requested for '{model}' from IP: {client_ip}")
|
736 |
+
if model in Blackbox.models:
|
737 |
+
return {"model": model, "status": "available"}
|
738 |
+
elif model in Blackbox.model_aliases and Blackbox.model_aliases[model] in Blackbox.models:
|
739 |
+
actual_model = Blackbox.model_aliases[model]
|
740 |
+
return {"model": actual_model, "status": "available via alias"}
|
741 |
+
else:
|
742 |
+
logger.warning(f"Model not found: {model} from IP: {client_ip}")
|
743 |
+
raise HTTPException(status_code=404, detail="Model not found")
|
744 |
+
|
745 |
+
# Endpoint: GET /v1/health
|
746 |
+
@app.get("/v1/health", dependencies=[Depends(rate_limiter_per_ip)])
|
747 |
+
async def health_check(req: Request, api_key: str = Depends(get_api_key)):
|
748 |
+
client_ip = req.client.host
|
749 |
+
logger.info(f"Health check requested from IP: {client_ip}")
|
750 |
+
return {"status": "ok"}
|
751 |
+
|
752 |
+
# Endpoint: GET /v1/chat/completions (GET method)
|
753 |
+
@app.get("/v1/chat/completions")
|
754 |
+
async def chat_completions_get(req: Request, api_key: str = Depends(get_api_key)):
|
755 |
+
client_ip = req.client.host
|
756 |
+
logger.info(f"GET request made to /v1/chat/completions from IP: {client_ip}, redirecting to 'about:blank'")
|
757 |
+
return RedirectResponse(url='about:blank')
|
758 |
|
759 |
# Custom exception handler to match OpenAI's error format
|
760 |
@app.exception_handler(HTTPException)
|