Maouu commited on
Commit
99fbd2b
·
verified ·
1 Parent(s): 8a378d0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -6
app.py CHANGED
@@ -5,9 +5,11 @@ from typing import List, Dict, Any, Optional
5
  from pydantic import BaseModel
6
  import asyncio
7
  import httpx
8
-
9
- from config import cookies, headers
10
  from prompts import ChiplingPrompts
 
 
11
 
12
  app = FastAPI()
13
 
@@ -26,6 +28,8 @@ class ChatRequest(BaseModel):
26
  messages: List[Dict[Any, Any]]
27
  model: Optional[str] = "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8"
28
 
 
 
29
  async def generate(json_data: Dict[str, Any]):
30
  max_retries = 5
31
  for attempt in range(max_retries):
@@ -60,6 +64,84 @@ async def generate(json_data: Dict[str, Any]):
60
 
61
  yield "data: [Max retries reached]\n\n"
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  @app.get("/")
64
  async def index():
65
  return {"status": "ok"}
@@ -93,7 +175,8 @@ async def chat(request: ChatRequest):
93
  'stream': True,
94
  }
95
 
96
- return StreamingResponse(generate(json_data), media_type='text/event-stream')
 
97
 
98
 
99
  @app.post("/generate-modules")
@@ -135,8 +218,8 @@ async def generate_modules(request: Request):
135
  'messages': current_messages,
136
  'stream': True,
137
  }
138
-
139
- return StreamingResponse(generate(json_data), media_type='text/event-stream')
140
 
141
 
142
  @app.post("/generate-topics")
@@ -179,4 +262,5 @@ async def generate_topics(request: Request):
179
  'stream': True,
180
  }
181
 
182
- return StreamingResponse(generate(json_data), media_type='text/event-stream')
 
 
5
  from pydantic import BaseModel
6
  import asyncio
7
  import httpx
8
+ import random
9
+ from config import cookies, headers, groqapi
10
  from prompts import ChiplingPrompts
11
+ from groq import Groq
12
+ import json
13
 
14
  app = FastAPI()
15
 
 
28
  messages: List[Dict[Any, Any]]
29
  model: Optional[str] = "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8"
30
 
31
+ client = Groq(api_key=groqapi)
32
+
33
  async def generate(json_data: Dict[str, Any]):
34
  max_retries = 5
35
  for attempt in range(max_retries):
 
64
 
65
  yield "data: [Max retries reached]\n\n"
66
 
67
+ def convert_to_groq_schema(messages: List[Dict[str, Any]]) -> List[Dict[str, str]]:
68
+ converted = []
69
+ for message in messages:
70
+ role = message.get("role", "user")
71
+ content = message.get("content")
72
+
73
+ if isinstance(content, list):
74
+ flattened = []
75
+ for item in content:
76
+ if isinstance(item, dict) and item.get("type") == "text":
77
+ flattened.append(item.get("text", ""))
78
+ content = "\n".join(flattened)
79
+ elif not isinstance(content, str):
80
+ content = str(content)
81
+
82
+ converted.append({"role": role, "content": content})
83
+ return converted
84
+
85
+
86
+ async def groqgenerate(json_data: Dict[str, Any]):
87
+ try:
88
+ messages = convert_to_groq_schema(json_data["messages"])
89
+ chunk_id = "groq-" + "".join(random.choices("0123456789abcdef", k=32))
90
+ created = int(asyncio.get_event_loop().time())
91
+
92
+ # Create streaming response
93
+ stream = client.chat.completions.create(
94
+ messages=messages,
95
+ model="meta-llama/llama-4-scout-17b-16e-instruct",
96
+ temperature=json_data.get("temperature", 0.7),
97
+ max_completion_tokens=json_data.get("max_tokens", 1024),
98
+ top_p=json_data.get("top_p", 1),
99
+ stop=json_data.get("stop", None),
100
+ stream=True,
101
+ )
102
+
103
+ total_tokens = 0
104
+
105
+ # Use normal for-loop since stream is not async
106
+ for chunk in stream:
107
+ content = chunk.choices[0].delta.content
108
+ if content:
109
+ response = {
110
+ "id": chunk_id,
111
+ "object": "chat.completion.chunk",
112
+ "created": created,
113
+ "model": json_data.get("model", "llama-3.3-70b-versatile"),
114
+ "choices": [{
115
+ "index": 0,
116
+ "text": content,
117
+ "logprobs": None,
118
+ "finish_reason": None
119
+ }],
120
+ "usage": None
121
+ }
122
+ yield f"data: {json.dumps(response)}\n\n"
123
+ total_tokens += 1
124
+
125
+ final = {
126
+ "id": chunk_id,
127
+ "object": "chat.completion.chunk",
128
+ "created": created,
129
+ "model": json_data.get("model", "llama-3.3-70b-versatile"),
130
+ "choices": [],
131
+ "usage": {
132
+ "prompt_tokens": len(messages),
133
+ "completion_tokens": total_tokens,
134
+ "total_tokens": len(messages) + total_tokens,
135
+ }
136
+ }
137
+ yield f"data: {json.dumps(final)}\n\n"
138
+ yield "data: [DONE]\n\n"
139
+
140
+ except Exception as e:
141
+ return generate(json_data)
142
+
143
+
144
+
145
  @app.get("/")
146
  async def index():
147
  return {"status": "ok"}
 
175
  'stream': True,
176
  }
177
 
178
+ selected_generator = random.choice([groqgenerate, generate])
179
+ return StreamingResponse(selected_generator(json_data), media_type='text/event-stream')
180
 
181
 
182
  @app.post("/generate-modules")
 
218
  'messages': current_messages,
219
  'stream': True,
220
  }
221
+ selected_generator = random.choice([groqgenerate])
222
+ return StreamingResponse(selected_generator(json_data), media_type='text/event-stream')
223
 
224
 
225
  @app.post("/generate-topics")
 
262
  'stream': True,
263
  }
264
 
265
+ selected_generator = random.choice([groqgenerate, generate])
266
+ return StreamingResponse(selected_generator(json_data), media_type='text/event-stream')