Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -5,9 +5,11 @@ from typing import List, Dict, Any, Optional
|
|
5 |
from pydantic import BaseModel
|
6 |
import asyncio
|
7 |
import httpx
|
8 |
-
|
9 |
-
from config import cookies, headers
|
10 |
from prompts import ChiplingPrompts
|
|
|
|
|
11 |
|
12 |
app = FastAPI()
|
13 |
|
@@ -26,6 +28,8 @@ class ChatRequest(BaseModel):
|
|
26 |
messages: List[Dict[Any, Any]]
|
27 |
model: Optional[str] = "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8"
|
28 |
|
|
|
|
|
29 |
async def generate(json_data: Dict[str, Any]):
|
30 |
max_retries = 5
|
31 |
for attempt in range(max_retries):
|
@@ -60,6 +64,84 @@ async def generate(json_data: Dict[str, Any]):
|
|
60 |
|
61 |
yield "data: [Max retries reached]\n\n"
|
62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
@app.get("/")
|
64 |
async def index():
|
65 |
return {"status": "ok"}
|
@@ -93,7 +175,8 @@ async def chat(request: ChatRequest):
|
|
93 |
'stream': True,
|
94 |
}
|
95 |
|
96 |
-
|
|
|
97 |
|
98 |
|
99 |
@app.post("/generate-modules")
|
@@ -135,8 +218,8 @@ async def generate_modules(request: Request):
|
|
135 |
'messages': current_messages,
|
136 |
'stream': True,
|
137 |
}
|
138 |
-
|
139 |
-
return StreamingResponse(
|
140 |
|
141 |
|
142 |
@app.post("/generate-topics")
|
@@ -179,4 +262,5 @@ async def generate_topics(request: Request):
|
|
179 |
'stream': True,
|
180 |
}
|
181 |
|
182 |
-
|
|
|
|
5 |
from pydantic import BaseModel
|
6 |
import asyncio
|
7 |
import httpx
|
8 |
+
import random
|
9 |
+
from config import cookies, headers, groqapi
|
10 |
from prompts import ChiplingPrompts
|
11 |
+
from groq import Groq
|
12 |
+
import json
|
13 |
|
14 |
app = FastAPI()
|
15 |
|
|
|
28 |
messages: List[Dict[Any, Any]]
|
29 |
model: Optional[str] = "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8"
|
30 |
|
31 |
+
client = Groq(api_key=groqapi)
|
32 |
+
|
33 |
async def generate(json_data: Dict[str, Any]):
|
34 |
max_retries = 5
|
35 |
for attempt in range(max_retries):
|
|
|
64 |
|
65 |
yield "data: [Max retries reached]\n\n"
|
66 |
|
67 |
+
def convert_to_groq_schema(messages: List[Dict[str, Any]]) -> List[Dict[str, str]]:
|
68 |
+
converted = []
|
69 |
+
for message in messages:
|
70 |
+
role = message.get("role", "user")
|
71 |
+
content = message.get("content")
|
72 |
+
|
73 |
+
if isinstance(content, list):
|
74 |
+
flattened = []
|
75 |
+
for item in content:
|
76 |
+
if isinstance(item, dict) and item.get("type") == "text":
|
77 |
+
flattened.append(item.get("text", ""))
|
78 |
+
content = "\n".join(flattened)
|
79 |
+
elif not isinstance(content, str):
|
80 |
+
content = str(content)
|
81 |
+
|
82 |
+
converted.append({"role": role, "content": content})
|
83 |
+
return converted
|
84 |
+
|
85 |
+
|
86 |
+
async def groqgenerate(json_data: Dict[str, Any]):
|
87 |
+
try:
|
88 |
+
messages = convert_to_groq_schema(json_data["messages"])
|
89 |
+
chunk_id = "groq-" + "".join(random.choices("0123456789abcdef", k=32))
|
90 |
+
created = int(asyncio.get_event_loop().time())
|
91 |
+
|
92 |
+
# Create streaming response
|
93 |
+
stream = client.chat.completions.create(
|
94 |
+
messages=messages,
|
95 |
+
model="meta-llama/llama-4-scout-17b-16e-instruct",
|
96 |
+
temperature=json_data.get("temperature", 0.7),
|
97 |
+
max_completion_tokens=json_data.get("max_tokens", 1024),
|
98 |
+
top_p=json_data.get("top_p", 1),
|
99 |
+
stop=json_data.get("stop", None),
|
100 |
+
stream=True,
|
101 |
+
)
|
102 |
+
|
103 |
+
total_tokens = 0
|
104 |
+
|
105 |
+
# Use normal for-loop since stream is not async
|
106 |
+
for chunk in stream:
|
107 |
+
content = chunk.choices[0].delta.content
|
108 |
+
if content:
|
109 |
+
response = {
|
110 |
+
"id": chunk_id,
|
111 |
+
"object": "chat.completion.chunk",
|
112 |
+
"created": created,
|
113 |
+
"model": json_data.get("model", "llama-3.3-70b-versatile"),
|
114 |
+
"choices": [{
|
115 |
+
"index": 0,
|
116 |
+
"text": content,
|
117 |
+
"logprobs": None,
|
118 |
+
"finish_reason": None
|
119 |
+
}],
|
120 |
+
"usage": None
|
121 |
+
}
|
122 |
+
yield f"data: {json.dumps(response)}\n\n"
|
123 |
+
total_tokens += 1
|
124 |
+
|
125 |
+
final = {
|
126 |
+
"id": chunk_id,
|
127 |
+
"object": "chat.completion.chunk",
|
128 |
+
"created": created,
|
129 |
+
"model": json_data.get("model", "llama-3.3-70b-versatile"),
|
130 |
+
"choices": [],
|
131 |
+
"usage": {
|
132 |
+
"prompt_tokens": len(messages),
|
133 |
+
"completion_tokens": total_tokens,
|
134 |
+
"total_tokens": len(messages) + total_tokens,
|
135 |
+
}
|
136 |
+
}
|
137 |
+
yield f"data: {json.dumps(final)}\n\n"
|
138 |
+
yield "data: [DONE]\n\n"
|
139 |
+
|
140 |
+
except Exception as e:
|
141 |
+
return generate(json_data)
|
142 |
+
|
143 |
+
|
144 |
+
|
145 |
@app.get("/")
|
146 |
async def index():
|
147 |
return {"status": "ok"}
|
|
|
175 |
'stream': True,
|
176 |
}
|
177 |
|
178 |
+
selected_generator = random.choice([groqgenerate, generate])
|
179 |
+
return StreamingResponse(selected_generator(json_data), media_type='text/event-stream')
|
180 |
|
181 |
|
182 |
@app.post("/generate-modules")
|
|
|
218 |
'messages': current_messages,
|
219 |
'stream': True,
|
220 |
}
|
221 |
+
selected_generator = random.choice([groqgenerate])
|
222 |
+
return StreamingResponse(selected_generator(json_data), media_type='text/event-stream')
|
223 |
|
224 |
|
225 |
@app.post("/generate-topics")
|
|
|
262 |
'stream': True,
|
263 |
}
|
264 |
|
265 |
+
selected_generator = random.choice([groqgenerate, generate])
|
266 |
+
return StreamingResponse(selected_generator(json_data), media_type='text/event-stream')
|