Spaces:
Running
Running
from fastapi import FastAPI, Request | |
from fastapi.responses import StreamingResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from typing import List, Dict, Any, Optional | |
from pydantic import BaseModel | |
import asyncio | |
import httpx | |
import random | |
from config import cookies, headers, groqapi | |
from prompts import ChiplingPrompts | |
from groq import Groq | |
import json | |
from fastapi.responses import HTMLResponse | |
from fastapi.templating import Jinja2Templates | |
from pathlib import Path | |
from collections import Counter, defaultdict | |
from utils.logger import log_request | |
app = FastAPI() | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
templates = Jinja2Templates(directory="templates") | |
LOG_FILE = Path("logs.json") | |
async def dashboard(request: Request, endpoint: str = None): | |
try: | |
with open("logs.json") as f: | |
logs = json.load(f) | |
except FileNotFoundError: | |
logs = [] | |
# Filter logs | |
if endpoint: | |
logs = [log for log in logs if log["endpoint"] == endpoint] | |
# Summary stats | |
total_requests = len(logs) | |
endpoint_counts = Counter(log["endpoint"] for log in logs) | |
query_counts = Counter(log["query"] for log in logs) | |
# Requests per date | |
date_counts = defaultdict(int) | |
for log in logs: | |
date = log["timestamp"].split("T")[0] | |
date_counts[date] += 1 | |
# Sort logs by timestamp (desc) | |
logs_sorted = sorted(logs, key=lambda x: x["timestamp"], reverse=True) | |
return templates.TemplateResponse("dashboard.html", { | |
"request": request, | |
"logs": logs_sorted[:100], # show top 100 | |
"total_requests": total_requests, | |
"endpoint_counts": dict(endpoint_counts), | |
"query_counts": query_counts.most_common(5), | |
"date_counts": dict(date_counts), | |
"filter_endpoint": endpoint or "", | |
}) | |
# Define request model | |
class ChatRequest(BaseModel): | |
message: str | |
messages: List[Dict[Any, Any]] | |
model: Optional[str] = "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8" | |
client = Groq(api_key=groqapi) | |
async def generate(json_data: Dict[str, Any]): | |
max_retries = 5 | |
for attempt in range(max_retries): | |
async with httpx.AsyncClient(timeout=None) as client: | |
try: | |
request_ctx = client.stream( | |
"POST", | |
"https://api.together.ai/inference", | |
cookies=cookies, | |
headers=headers, | |
json=json_data | |
) | |
async with request_ctx as response: | |
if response.status_code == 200: | |
async for line in response.aiter_lines(): | |
if line: | |
yield f"{line}\n" | |
return | |
elif response.status_code == 429: | |
if attempt < max_retries - 1: | |
await asyncio.sleep(0.5) | |
continue | |
yield "data: [Rate limited, max retries]\n\n" | |
return | |
else: | |
yield f"data: [Unexpected status code: {response.status_code}]\n\n" | |
return | |
except Exception as e: | |
yield f"data: [Connection error: {str(e)}]\n\n" | |
return | |
yield "data: [Max retries reached]\n\n" | |
def convert_to_groq_schema(messages: List[Dict[str, Any]]) -> List[Dict[str, str]]: | |
converted = [] | |
for message in messages: | |
role = message.get("role", "user") | |
content = message.get("content") | |
if isinstance(content, list): | |
flattened = [] | |
for item in content: | |
if isinstance(item, dict) and item.get("type") == "text": | |
flattened.append(item.get("text", "")) | |
content = "\n".join(flattened) | |
elif not isinstance(content, str): | |
content = str(content) | |
converted.append({"role": role, "content": content}) | |
return converted | |
async def groqgenerate(json_data: Dict[str, Any]): | |
try: | |
messages = convert_to_groq_schema(json_data["messages"]) | |
chunk_id = "groq-" + "".join(random.choices("0123456789abcdef", k=32)) | |
created = int(asyncio.get_event_loop().time()) | |
# Create streaming response | |
stream = client.chat.completions.create( | |
messages=messages, | |
model="meta-llama/llama-4-scout-17b-16e-instruct", | |
temperature=json_data.get("temperature", 0.7), | |
max_completion_tokens=json_data.get("max_tokens", 1024), | |
top_p=json_data.get("top_p", 1), | |
stop=json_data.get("stop", None), | |
stream=True, | |
) | |
total_tokens = 0 | |
# Use normal for-loop since stream is not async | |
for chunk in stream: | |
content = chunk.choices[0].delta.content | |
if content: | |
response = { | |
"id": chunk_id, | |
"object": "chat.completion.chunk", | |
"created": created, | |
"model": json_data.get("model", "llama-3.3-70b-versatile"), | |
"choices": [{ | |
"index": 0, | |
"text": content, | |
"logprobs": None, | |
"finish_reason": None | |
}], | |
"usage": None | |
} | |
yield f"data: {json.dumps(response)}\n\n" | |
total_tokens += 1 | |
final = { | |
"id": chunk_id, | |
"object": "chat.completion.chunk", | |
"created": created, | |
"model": json_data.get("model", "llama-3.3-70b-versatile"), | |
"choices": [], | |
"usage": { | |
"prompt_tokens": len(messages), | |
"completion_tokens": total_tokens, | |
"total_tokens": len(messages) + total_tokens, | |
} | |
} | |
yield f"data: {json.dumps(final)}\n\n" | |
yield "data: [DONE]\n\n" | |
except Exception as e: | |
generate(json_data) | |
async def index(): | |
return {"status": "ok"} | |
async def chat(request: ChatRequest): | |
current_messages = request.messages.copy() | |
# Handle both single text or list content | |
if request.messages and isinstance(request.messages[-1].get('content'), list): | |
current_messages = request.messages | |
else: | |
current_messages.append({ | |
'content': [{ | |
'type': 'text', | |
'text': request.message | |
}], | |
'role': 'user' | |
}) | |
json_data = { | |
'model': request.model, | |
'max_tokens': None, | |
'temperature': 0.7, | |
'top_p': 0.7, | |
'top_k': 50, | |
'repetition_penalty': 1, | |
'stream_tokens': True, | |
'stop': ['<|eot_id|>', '<|eom_id|>'], | |
'messages': current_messages, | |
'stream': True, | |
} | |
selected_generator = random.choice([groqgenerate, generate]) | |
log_request("/chat", selected_generator.__name__) | |
return StreamingResponse(selected_generator(json_data), media_type='text/event-stream') | |
async def generate_modules(request: Request): | |
data = await request.json() | |
search_query = data.get("searchQuery") | |
log_request("/generate-modules", search_query) | |
if not search_query: | |
return {"error": "searchQuery is required"} | |
system_prompt = ChiplingPrompts.generateModules(search_query) | |
current_messages = [ | |
{ | |
'role': 'system', | |
'content': [{ | |
'type': 'text', | |
'text': system_prompt | |
}] | |
}, | |
{ | |
'role': 'user', | |
'content': [{ | |
'type': 'text', | |
'text': search_query | |
}] | |
} | |
] | |
json_data = { | |
'model': "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8", | |
'max_tokens': None, | |
'temperature': 0.7, | |
'top_p': 0.7, | |
'top_k': 50, | |
'repetition_penalty': 1, | |
'stream_tokens': True, | |
'stop': ['<|eot_id|>', '<|eom_id|>'], | |
'messages': current_messages, | |
'stream': True, | |
} | |
selected_generator = random.choice([groqgenerate]) | |
return StreamingResponse(selected_generator(json_data), media_type='text/event-stream') | |
async def generate_topics(request: Request): | |
data = await request.json() | |
search_query = data.get("searchQuery") | |
if not search_query: | |
return {"error": "searchQuery is required"} | |
log_request("/generate-topics", search_query) | |
system_prompt = ChiplingPrompts.generateTopics(search_query) | |
current_messages = [ | |
{ | |
'role': 'system', | |
'content': [{ | |
'type': 'text', | |
'text': system_prompt | |
}] | |
}, | |
{ | |
'role': 'user', | |
'content': [{ | |
'type': 'text', | |
'text': search_query | |
}] | |
} | |
] | |
json_data = { | |
'model': "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8", | |
'max_tokens': None, | |
'temperature': 0.7, | |
'top_p': 0.7, | |
'top_k': 50, | |
'repetition_penalty': 1, | |
'stream_tokens': True, | |
'stop': ['<|eot_id|>', '<|eom_id|>'], | |
'messages': current_messages, | |
'stream': True, | |
} | |
selected_generator = random.choice([groqgenerate, generate]) | |
return StreamingResponse(selected_generator(json_data), media_type='text/event-stream') |