Update app.py
Browse files
app.py
CHANGED
@@ -629,6 +629,10 @@ def update_context_display(provider, model_name):
|
|
629 |
|
630 |
def is_vision_model(provider, model_name):
|
631 |
"""Check if a model supports vision/images"""
|
|
|
|
|
|
|
|
|
632 |
if provider in VISION_MODELS:
|
633 |
if model_name in VISION_MODELS[provider]:
|
634 |
return True
|
@@ -1132,6 +1136,7 @@ def extract_ai_response(result, provider):
|
|
1132 |
# ==========================================================
|
1133 |
|
1134 |
def openrouter_streaming_handler(response, history, message):
|
|
|
1135 |
try:
|
1136 |
updated_history = history + [{"role": "user", "content": message}]
|
1137 |
assistant_response = ""
|
@@ -1163,66 +1168,54 @@ def openrouter_streaming_handler(response, history, message):
|
|
1163 |
# Add error message to the current response
|
1164 |
yield updated_history + [{"role": "assistant", "content": f"Error during streaming: {str(e)}"}]
|
1165 |
|
1166 |
-
def openai_streaming_handler(response,
|
|
|
1167 |
try:
|
1168 |
-
|
1169 |
-
|
1170 |
-
|
1171 |
-
|
1172 |
-
full_response = ""
|
1173 |
for chunk in response:
|
1174 |
if hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content is not None:
|
1175 |
content = chunk.choices[0].delta.content
|
1176 |
-
|
1177 |
-
|
1178 |
-
yield chatbot
|
1179 |
-
|
1180 |
except Exception as e:
|
1181 |
logger.error(f"Error in OpenAI streaming handler: {str(e)}")
|
1182 |
# Add error message to the current response
|
1183 |
-
|
1184 |
-
yield chatbot
|
1185 |
|
1186 |
-
def groq_streaming_handler(response,
|
|
|
1187 |
try:
|
1188 |
-
|
1189 |
-
|
1190 |
-
|
1191 |
-
|
1192 |
-
full_response = ""
|
1193 |
for chunk in response:
|
1194 |
if hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content is not None:
|
1195 |
content = chunk.choices[0].delta.content
|
1196 |
-
|
1197 |
-
|
1198 |
-
yield chatbot
|
1199 |
-
|
1200 |
except Exception as e:
|
1201 |
logger.error(f"Error in Groq streaming handler: {str(e)}")
|
1202 |
# Add error message to the current response
|
1203 |
-
|
1204 |
-
yield chatbot
|
1205 |
|
1206 |
-
def together_streaming_handler(response,
|
|
|
1207 |
try:
|
1208 |
-
|
1209 |
-
|
1210 |
-
|
1211 |
-
|
1212 |
-
full_response = ""
|
1213 |
for chunk in response:
|
1214 |
if hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content is not None:
|
1215 |
content = chunk.choices[0].delta.content
|
1216 |
-
|
1217 |
-
|
1218 |
-
yield chatbot
|
1219 |
-
|
1220 |
except Exception as e:
|
1221 |
logger.error(f"Error in Together streaming handler: {str(e)}")
|
1222 |
# Add error message to the current response
|
1223 |
-
|
1224 |
-
|
1225 |
-
|
1226 |
# ==========================================================
|
1227 |
# MAIN FUNCTION TO ASK AI
|
1228 |
# ==========================================================
|
@@ -1236,11 +1229,8 @@ def ask_ai(message, history, provider, model_choice, temperature, max_tokens, to
|
|
1236 |
if not message.strip() and not images and not documents:
|
1237 |
return history
|
1238 |
|
1239 |
-
#
|
1240 |
-
|
1241 |
-
|
1242 |
-
# Create messages from chat history
|
1243 |
-
messages = format_to_message_dict(chat_history)
|
1244 |
|
1245 |
# Add system message if provided
|
1246 |
if system_message and system_message.strip():
|
@@ -1252,7 +1242,7 @@ def ask_ai(message, history, provider, model_choice, temperature, max_tokens, to
|
|
1252 |
# Prepare message with images and documents if any
|
1253 |
content = prepare_message_with_media(message, images, documents)
|
1254 |
|
1255 |
-
# Add current message
|
1256 |
messages.append({"role": "user", "content": content})
|
1257 |
|
1258 |
# Common parameters for all providers
|
@@ -1272,8 +1262,11 @@ def ask_ai(message, history, provider, model_choice, temperature, max_tokens, to
|
|
1272 |
model_id, _ = get_model_info(provider, model_choice)
|
1273 |
if not model_id:
|
1274 |
error_message = f"Error: Model '{model_choice}' not found in OpenRouter"
|
1275 |
-
|
1276 |
-
return
|
|
|
|
|
|
|
1277 |
|
1278 |
# Build OpenRouter payload
|
1279 |
payload = {
|
@@ -1319,13 +1312,35 @@ def ask_ai(message, history, provider, model_choice, temperature, max_tokens, to
|
|
1319 |
|
1320 |
# Handle streaming response
|
1321 |
if stream_output and response.status_code == 200:
|
1322 |
-
# Add
|
1323 |
-
|
1324 |
|
1325 |
# Set up generator for streaming updates
|
1326 |
def streaming_generator():
|
1327 |
-
|
1328 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1329 |
|
1330 |
return streaming_generator()
|
1331 |
|
@@ -1337,9 +1352,11 @@ def ask_ai(message, history, provider, model_choice, temperature, max_tokens, to
|
|
1337 |
# Extract AI response
|
1338 |
ai_response = extract_ai_response(result, provider)
|
1339 |
|
1340 |
-
# Add response to history
|
1341 |
-
|
1342 |
-
|
|
|
|
|
1343 |
|
1344 |
# Handle error response
|
1345 |
else:
|
@@ -1351,16 +1368,20 @@ def ask_ai(message, history, provider, model_choice, temperature, max_tokens, to
|
|
1351 |
error_message += f"\n\nResponse: {response.text}"
|
1352 |
|
1353 |
logger.error(error_message)
|
1354 |
-
|
1355 |
-
|
|
|
|
|
1356 |
|
1357 |
elif provider == "OpenAI":
|
1358 |
# Get model ID from registry
|
1359 |
model_id, _ = get_model_info(provider, model_choice)
|
1360 |
if not model_id:
|
1361 |
error_message = f"Error: Model '{model_choice}' not found in OpenAI"
|
1362 |
-
|
1363 |
-
|
|
|
|
|
1364 |
|
1365 |
# Build OpenAI payload
|
1366 |
payload = {
|
@@ -1381,34 +1402,44 @@ def ask_ai(message, history, provider, model_choice, temperature, max_tokens, to
|
|
1381 |
|
1382 |
# Handle streaming response
|
1383 |
if stream_output:
|
1384 |
-
# Add
|
1385 |
-
|
1386 |
|
1387 |
# Set up generator for streaming updates
|
1388 |
def streaming_generator():
|
1389 |
-
|
1390 |
-
|
|
|
|
|
|
|
|
|
1391 |
|
1392 |
return streaming_generator()
|
1393 |
|
1394 |
# Handle normal response
|
1395 |
else:
|
1396 |
ai_response = extract_ai_response(response, provider)
|
1397 |
-
|
1398 |
-
|
|
|
|
|
1399 |
except Exception as e:
|
1400 |
error_message = f"OpenAI API Error: {str(e)}"
|
1401 |
logger.error(error_message)
|
1402 |
-
|
1403 |
-
|
|
|
|
|
1404 |
|
1405 |
elif provider == "HuggingFace":
|
1406 |
# Get model ID from registry
|
1407 |
model_id, _ = get_model_info(provider, model_choice)
|
1408 |
if not model_id:
|
1409 |
error_message = f"Error: Model '{model_choice}' not found in HuggingFace"
|
1410 |
-
|
1411 |
-
|
|
|
|
|
1412 |
|
1413 |
# Build HuggingFace payload
|
1414 |
payload = {
|
@@ -1426,21 +1457,27 @@ def ask_ai(message, history, provider, model_choice, temperature, max_tokens, to
|
|
1426 |
|
1427 |
# Extract response
|
1428 |
ai_response = extract_ai_response(response, provider)
|
1429 |
-
|
1430 |
-
|
|
|
|
|
1431 |
except Exception as e:
|
1432 |
error_message = f"HuggingFace API Error: {str(e)}"
|
1433 |
logger.error(error_message)
|
1434 |
-
|
1435 |
-
|
|
|
|
|
1436 |
|
1437 |
elif provider == "Groq":
|
1438 |
# Get model ID from registry
|
1439 |
model_id, _ = get_model_info(provider, model_choice)
|
1440 |
if not model_id:
|
1441 |
error_message = f"Error: Model '{model_choice}' not found in Groq"
|
1442 |
-
|
1443 |
-
|
|
|
|
|
1444 |
|
1445 |
# Build Groq payload
|
1446 |
payload = {
|
@@ -1460,34 +1497,44 @@ def ask_ai(message, history, provider, model_choice, temperature, max_tokens, to
|
|
1460 |
|
1461 |
# Handle streaming response
|
1462 |
if stream_output:
|
1463 |
-
# Add
|
1464 |
-
|
1465 |
|
1466 |
# Set up generator for streaming updates
|
1467 |
def streaming_generator():
|
1468 |
-
|
1469 |
-
|
|
|
|
|
|
|
|
|
1470 |
|
1471 |
return streaming_generator()
|
1472 |
|
1473 |
# Handle normal response
|
1474 |
else:
|
1475 |
ai_response = extract_ai_response(response, provider)
|
1476 |
-
|
1477 |
-
|
|
|
|
|
1478 |
except Exception as e:
|
1479 |
error_message = f"Groq API Error: {str(e)}"
|
1480 |
logger.error(error_message)
|
1481 |
-
|
1482 |
-
|
|
|
|
|
1483 |
|
1484 |
elif provider == "Cohere":
|
1485 |
# Get model ID from registry
|
1486 |
model_id, _ = get_model_info(provider, model_choice)
|
1487 |
if not model_id:
|
1488 |
error_message = f"Error: Model '{model_choice}' not found in Cohere"
|
1489 |
-
|
1490 |
-
|
|
|
|
|
1491 |
|
1492 |
# Build Cohere payload (doesn't support streaming the same way)
|
1493 |
payload = {
|
@@ -1505,21 +1552,27 @@ def ask_ai(message, history, provider, model_choice, temperature, max_tokens, to
|
|
1505 |
|
1506 |
# Extract response
|
1507 |
ai_response = extract_ai_response(response, provider)
|
1508 |
-
|
1509 |
-
|
|
|
|
|
1510 |
except Exception as e:
|
1511 |
error_message = f"Cohere API Error: {str(e)}"
|
1512 |
logger.error(error_message)
|
1513 |
-
|
1514 |
-
|
|
|
|
|
1515 |
|
1516 |
elif provider == "Together":
|
1517 |
# Get model ID from registry
|
1518 |
model_id, _ = get_model_info(provider, model_choice)
|
1519 |
if not model_id:
|
1520 |
error_message = f"Error: Model '{model_choice}' not found in Together"
|
1521 |
-
|
1522 |
-
|
|
|
|
|
1523 |
|
1524 |
# Build Together payload
|
1525 |
payload = {
|
@@ -1538,34 +1591,44 @@ def ask_ai(message, history, provider, model_choice, temperature, max_tokens, to
|
|
1538 |
|
1539 |
# Handle streaming response
|
1540 |
if stream_output:
|
1541 |
-
# Add
|
1542 |
-
|
1543 |
|
1544 |
# Set up generator for streaming updates
|
1545 |
def streaming_generator():
|
1546 |
-
|
1547 |
-
|
|
|
|
|
|
|
|
|
1548 |
|
1549 |
return streaming_generator()
|
1550 |
|
1551 |
# Handle normal response
|
1552 |
else:
|
1553 |
ai_response = extract_ai_response(response, provider)
|
1554 |
-
|
1555 |
-
|
|
|
|
|
1556 |
except Exception as e:
|
1557 |
error_message = f"Together API Error: {str(e)}"
|
1558 |
logger.error(error_message)
|
1559 |
-
|
1560 |
-
|
|
|
|
|
1561 |
|
1562 |
elif provider == "OVH":
|
1563 |
# Get model ID from registry
|
1564 |
model_id, _ = get_model_info(provider, model_choice)
|
1565 |
if not model_id:
|
1566 |
error_message = f"Error: Model '{model_choice}' not found in OVH"
|
1567 |
-
|
1568 |
-
|
|
|
|
|
1569 |
|
1570 |
# Build OVH payload
|
1571 |
payload = {
|
@@ -1583,21 +1646,27 @@ def ask_ai(message, history, provider, model_choice, temperature, max_tokens, to
|
|
1583 |
|
1584 |
# Extract response
|
1585 |
ai_response = extract_ai_response(response, provider)
|
1586 |
-
|
1587 |
-
|
|
|
|
|
1588 |
except Exception as e:
|
1589 |
error_message = f"OVH API Error: {str(e)}"
|
1590 |
logger.error(error_message)
|
1591 |
-
|
1592 |
-
|
|
|
|
|
1593 |
|
1594 |
elif provider == "Cerebras":
|
1595 |
# Get model ID from registry
|
1596 |
model_id, _ = get_model_info(provider, model_choice)
|
1597 |
if not model_id:
|
1598 |
error_message = f"Error: Model '{model_choice}' not found in Cerebras"
|
1599 |
-
|
1600 |
-
|
|
|
|
|
1601 |
|
1602 |
# Build Cerebras payload
|
1603 |
payload = {
|
@@ -1615,21 +1684,27 @@ def ask_ai(message, history, provider, model_choice, temperature, max_tokens, to
|
|
1615 |
|
1616 |
# Extract response
|
1617 |
ai_response = extract_ai_response(response, provider)
|
1618 |
-
|
1619 |
-
|
|
|
|
|
1620 |
except Exception as e:
|
1621 |
error_message = f"Cerebras API Error: {str(e)}"
|
1622 |
logger.error(error_message)
|
1623 |
-
|
1624 |
-
|
|
|
|
|
1625 |
|
1626 |
elif provider == "GoogleAI":
|
1627 |
# Get model ID from registry
|
1628 |
model_id, _ = get_model_info(provider, model_choice)
|
1629 |
if not model_id:
|
1630 |
error_message = f"Error: Model '{model_choice}' not found in GoogleAI"
|
1631 |
-
|
1632 |
-
|
|
|
|
|
1633 |
|
1634 |
# Build GoogleAI payload
|
1635 |
payload = {
|
@@ -1648,24 +1723,32 @@ def ask_ai(message, history, provider, model_choice, temperature, max_tokens, to
|
|
1648 |
|
1649 |
# Extract response
|
1650 |
ai_response = extract_ai_response(response, provider)
|
1651 |
-
|
1652 |
-
|
|
|
|
|
1653 |
except Exception as e:
|
1654 |
error_message = f"GoogleAI API Error: {str(e)}"
|
1655 |
logger.error(error_message)
|
1656 |
-
|
1657 |
-
|
|
|
|
|
1658 |
|
1659 |
else:
|
1660 |
error_message = f"Error: Unsupported provider '{provider}'"
|
1661 |
-
|
1662 |
-
|
|
|
|
|
1663 |
|
1664 |
except Exception as e:
|
1665 |
error_message = f"Error: {str(e)}"
|
1666 |
logger.error(f"Exception during API call: {error_message}")
|
1667 |
-
|
1668 |
-
|
|
|
|
|
1669 |
|
1670 |
def clear_chat():
|
1671 |
"""Reset all inputs"""
|
@@ -2160,14 +2243,20 @@ def create_app():
|
|
2160 |
|
2161 |
def update_vision_indicator(provider, model_choice):
|
2162 |
"""Update the vision capability indicator"""
|
|
|
|
|
|
|
2163 |
return is_vision_model(provider, model_choice)
|
2164 |
|
2165 |
def update_image_upload_visibility(provider, model_choice):
|
2166 |
"""Show/hide image upload based on model vision capabilities"""
|
|
|
|
|
|
|
2167 |
is_vision = is_vision_model(provider, model_choice)
|
2168 |
return gr.update(visible=is_vision)
|
2169 |
|
2170 |
-
# Search model function
|
2171 |
def search_openrouter_models(search_term):
|
2172 |
"""Filter OpenRouter models based on search term"""
|
2173 |
all_models = [model[0] for model in OPENROUTER_ALL_MODELS]
|
@@ -2588,9 +2677,11 @@ def create_app():
|
|
2588 |
|
2589 |
# Check if model is selected
|
2590 |
if not model_choice:
|
2591 |
-
|
2592 |
-
|
2593 |
-
|
|
|
|
|
2594 |
|
2595 |
# Select the appropriate API key based on the provider
|
2596 |
api_key_override = None
|
|
|
629 |
|
630 |
def is_vision_model(provider, model_name):
|
631 |
"""Check if a model supports vision/images"""
|
632 |
+
# Safety check for None model name
|
633 |
+
if model_name is None:
|
634 |
+
return False
|
635 |
+
|
636 |
if provider in VISION_MODELS:
|
637 |
if model_name in VISION_MODELS[provider]:
|
638 |
return True
|
|
|
1136 |
# ==========================================================
|
1137 |
|
1138 |
def openrouter_streaming_handler(response, history, message):
|
1139 |
+
"""Handle streaming responses from OpenRouter"""
|
1140 |
try:
|
1141 |
updated_history = history + [{"role": "user", "content": message}]
|
1142 |
assistant_response = ""
|
|
|
1168 |
# Add error message to the current response
|
1169 |
yield updated_history + [{"role": "assistant", "content": f"Error during streaming: {str(e)}"}]
|
1170 |
|
1171 |
+
def openai_streaming_handler(response, history, message):
|
1172 |
+
"""Handle streaming responses from OpenAI"""
|
1173 |
try:
|
1174 |
+
updated_history = history + [{"role": "user", "content": message}]
|
1175 |
+
assistant_response = ""
|
1176 |
+
|
|
|
|
|
1177 |
for chunk in response:
|
1178 |
if hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content is not None:
|
1179 |
content = chunk.choices[0].delta.content
|
1180 |
+
assistant_response += content
|
1181 |
+
yield updated_history + [{"role": "assistant", "content": assistant_response}]
|
|
|
|
|
1182 |
except Exception as e:
|
1183 |
logger.error(f"Error in OpenAI streaming handler: {str(e)}")
|
1184 |
# Add error message to the current response
|
1185 |
+
yield updated_history + [{"role": "assistant", "content": f"Error during streaming: {str(e)}"}]
|
|
|
1186 |
|
1187 |
+
def groq_streaming_handler(response, history, message):
|
1188 |
+
"""Handle streaming responses from Groq"""
|
1189 |
try:
|
1190 |
+
updated_history = history + [{"role": "user", "content": message}]
|
1191 |
+
assistant_response = ""
|
1192 |
+
|
|
|
|
|
1193 |
for chunk in response:
|
1194 |
if hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content is not None:
|
1195 |
content = chunk.choices[0].delta.content
|
1196 |
+
assistant_response += content
|
1197 |
+
yield updated_history + [{"role": "assistant", "content": assistant_response}]
|
|
|
|
|
1198 |
except Exception as e:
|
1199 |
logger.error(f"Error in Groq streaming handler: {str(e)}")
|
1200 |
# Add error message to the current response
|
1201 |
+
yield updated_history + [{"role": "assistant", "content": f"Error during streaming: {str(e)}"}]
|
|
|
1202 |
|
1203 |
+
def together_streaming_handler(response, history, message):
|
1204 |
+
"""Handle streaming responses from Together"""
|
1205 |
try:
|
1206 |
+
updated_history = history + [{"role": "user", "content": message}]
|
1207 |
+
assistant_response = ""
|
1208 |
+
|
|
|
|
|
1209 |
for chunk in response:
|
1210 |
if hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content is not None:
|
1211 |
content = chunk.choices[0].delta.content
|
1212 |
+
assistant_response += content
|
1213 |
+
yield updated_history + [{"role": "assistant", "content": assistant_response}]
|
|
|
|
|
1214 |
except Exception as e:
|
1215 |
logger.error(f"Error in Together streaming handler: {str(e)}")
|
1216 |
# Add error message to the current response
|
1217 |
+
yield updated_history + [{"role": "assistant", "content": f"Error during streaming: {str(e)}"}]
|
1218 |
+
|
|
|
1219 |
# ==========================================================
|
1220 |
# MAIN FUNCTION TO ASK AI
|
1221 |
# ==========================================================
|
|
|
1229 |
if not message.strip() and not images and not documents:
|
1230 |
return history
|
1231 |
|
1232 |
+
# Create messages from chat history for API requests
|
1233 |
+
messages = format_to_message_dict(history)
|
|
|
|
|
|
|
1234 |
|
1235 |
# Add system message if provided
|
1236 |
if system_message and system_message.strip():
|
|
|
1242 |
# Prepare message with images and documents if any
|
1243 |
content = prepare_message_with_media(message, images, documents)
|
1244 |
|
1245 |
+
# Add current message to API messages
|
1246 |
messages.append({"role": "user", "content": content})
|
1247 |
|
1248 |
# Common parameters for all providers
|
|
|
1262 |
model_id, _ = get_model_info(provider, model_choice)
|
1263 |
if not model_id:
|
1264 |
error_message = f"Error: Model '{model_choice}' not found in OpenRouter"
|
1265 |
+
# Use proper message format
|
1266 |
+
return history + [
|
1267 |
+
{"role": "user", "content": message},
|
1268 |
+
{"role": "assistant", "content": error_message}
|
1269 |
+
]
|
1270 |
|
1271 |
# Build OpenRouter payload
|
1272 |
payload = {
|
|
|
1312 |
|
1313 |
# Handle streaming response
|
1314 |
if stream_output and response.status_code == 200:
|
1315 |
+
# Add message to history
|
1316 |
+
updated_history = history + [{"role": "user", "content": message}]
|
1317 |
|
1318 |
# Set up generator for streaming updates
|
1319 |
def streaming_generator():
|
1320 |
+
assistant_response = ""
|
1321 |
+
for line in response.iter_lines():
|
1322 |
+
if not line:
|
1323 |
+
continue
|
1324 |
+
|
1325 |
+
line = line.decode('utf-8')
|
1326 |
+
if not line.startswith('data: '):
|
1327 |
+
continue
|
1328 |
+
|
1329 |
+
data = line[6:]
|
1330 |
+
if data.strip() == '[DONE]':
|
1331 |
+
break
|
1332 |
+
|
1333 |
+
try:
|
1334 |
+
chunk = json.loads(data)
|
1335 |
+
if "choices" in chunk and len(chunk["choices"]) > 0:
|
1336 |
+
delta = chunk["choices"][0].get("delta", {})
|
1337 |
+
if "content" in delta and delta["content"]:
|
1338 |
+
# Update the current response
|
1339 |
+
assistant_response += delta["content"]
|
1340 |
+
# Yield updated history with current assistant response
|
1341 |
+
yield updated_history + [{"role": "assistant", "content": assistant_response}]
|
1342 |
+
except json.JSONDecodeError:
|
1343 |
+
logger.error(f"Failed to parse JSON from chunk: {data}")
|
1344 |
|
1345 |
return streaming_generator()
|
1346 |
|
|
|
1352 |
# Extract AI response
|
1353 |
ai_response = extract_ai_response(result, provider)
|
1354 |
|
1355 |
+
# Add response to history with proper format
|
1356 |
+
return history + [
|
1357 |
+
{"role": "user", "content": message},
|
1358 |
+
{"role": "assistant", "content": ai_response}
|
1359 |
+
]
|
1360 |
|
1361 |
# Handle error response
|
1362 |
else:
|
|
|
1368 |
error_message += f"\n\nResponse: {response.text}"
|
1369 |
|
1370 |
logger.error(error_message)
|
1371 |
+
return history + [
|
1372 |
+
{"role": "user", "content": message},
|
1373 |
+
{"role": "assistant", "content": error_message}
|
1374 |
+
]
|
1375 |
|
1376 |
elif provider == "OpenAI":
|
1377 |
# Get model ID from registry
|
1378 |
model_id, _ = get_model_info(provider, model_choice)
|
1379 |
if not model_id:
|
1380 |
error_message = f"Error: Model '{model_choice}' not found in OpenAI"
|
1381 |
+
return history + [
|
1382 |
+
{"role": "user", "content": message},
|
1383 |
+
{"role": "assistant", "content": error_message}
|
1384 |
+
]
|
1385 |
|
1386 |
# Build OpenAI payload
|
1387 |
payload = {
|
|
|
1402 |
|
1403 |
# Handle streaming response
|
1404 |
if stream_output:
|
1405 |
+
# Add message to history
|
1406 |
+
updated_history = history + [{"role": "user", "content": message}]
|
1407 |
|
1408 |
# Set up generator for streaming updates
|
1409 |
def streaming_generator():
|
1410 |
+
assistant_response = ""
|
1411 |
+
for chunk in response:
|
1412 |
+
if hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content is not None:
|
1413 |
+
content = chunk.choices[0].delta.content
|
1414 |
+
assistant_response += content
|
1415 |
+
yield updated_history + [{"role": "assistant", "content": assistant_response}]
|
1416 |
|
1417 |
return streaming_generator()
|
1418 |
|
1419 |
# Handle normal response
|
1420 |
else:
|
1421 |
ai_response = extract_ai_response(response, provider)
|
1422 |
+
return history + [
|
1423 |
+
{"role": "user", "content": message},
|
1424 |
+
{"role": "assistant", "content": ai_response}
|
1425 |
+
]
|
1426 |
except Exception as e:
|
1427 |
error_message = f"OpenAI API Error: {str(e)}"
|
1428 |
logger.error(error_message)
|
1429 |
+
return history + [
|
1430 |
+
{"role": "user", "content": message},
|
1431 |
+
{"role": "assistant", "content": error_message}
|
1432 |
+
]
|
1433 |
|
1434 |
elif provider == "HuggingFace":
|
1435 |
# Get model ID from registry
|
1436 |
model_id, _ = get_model_info(provider, model_choice)
|
1437 |
if not model_id:
|
1438 |
error_message = f"Error: Model '{model_choice}' not found in HuggingFace"
|
1439 |
+
return history + [
|
1440 |
+
{"role": "user", "content": message},
|
1441 |
+
{"role": "assistant", "content": error_message}
|
1442 |
+
]
|
1443 |
|
1444 |
# Build HuggingFace payload
|
1445 |
payload = {
|
|
|
1457 |
|
1458 |
# Extract response
|
1459 |
ai_response = extract_ai_response(response, provider)
|
1460 |
+
return history + [
|
1461 |
+
{"role": "user", "content": message},
|
1462 |
+
{"role": "assistant", "content": ai_response}
|
1463 |
+
]
|
1464 |
except Exception as e:
|
1465 |
error_message = f"HuggingFace API Error: {str(e)}"
|
1466 |
logger.error(error_message)
|
1467 |
+
return history + [
|
1468 |
+
{"role": "user", "content": message},
|
1469 |
+
{"role": "assistant", "content": error_message}
|
1470 |
+
]
|
1471 |
|
1472 |
elif provider == "Groq":
|
1473 |
# Get model ID from registry
|
1474 |
model_id, _ = get_model_info(provider, model_choice)
|
1475 |
if not model_id:
|
1476 |
error_message = f"Error: Model '{model_choice}' not found in Groq"
|
1477 |
+
return history + [
|
1478 |
+
{"role": "user", "content": message},
|
1479 |
+
{"role": "assistant", "content": error_message}
|
1480 |
+
]
|
1481 |
|
1482 |
# Build Groq payload
|
1483 |
payload = {
|
|
|
1497 |
|
1498 |
# Handle streaming response
|
1499 |
if stream_output:
|
1500 |
+
# Add message to history
|
1501 |
+
updated_history = history + [{"role": "user", "content": message}]
|
1502 |
|
1503 |
# Set up generator for streaming updates
|
1504 |
def streaming_generator():
|
1505 |
+
assistant_response = ""
|
1506 |
+
for chunk in response:
|
1507 |
+
if hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content is not None:
|
1508 |
+
content = chunk.choices[0].delta.content
|
1509 |
+
assistant_response += content
|
1510 |
+
yield updated_history + [{"role": "assistant", "content": assistant_response}]
|
1511 |
|
1512 |
return streaming_generator()
|
1513 |
|
1514 |
# Handle normal response
|
1515 |
else:
|
1516 |
ai_response = extract_ai_response(response, provider)
|
1517 |
+
return history + [
|
1518 |
+
{"role": "user", "content": message},
|
1519 |
+
{"role": "assistant", "content": ai_response}
|
1520 |
+
]
|
1521 |
except Exception as e:
|
1522 |
error_message = f"Groq API Error: {str(e)}"
|
1523 |
logger.error(error_message)
|
1524 |
+
return history + [
|
1525 |
+
{"role": "user", "content": message},
|
1526 |
+
{"role": "assistant", "content": error_message}
|
1527 |
+
]
|
1528 |
|
1529 |
elif provider == "Cohere":
|
1530 |
# Get model ID from registry
|
1531 |
model_id, _ = get_model_info(provider, model_choice)
|
1532 |
if not model_id:
|
1533 |
error_message = f"Error: Model '{model_choice}' not found in Cohere"
|
1534 |
+
return history + [
|
1535 |
+
{"role": "user", "content": message},
|
1536 |
+
{"role": "assistant", "content": error_message}
|
1537 |
+
]
|
1538 |
|
1539 |
# Build Cohere payload (doesn't support streaming the same way)
|
1540 |
payload = {
|
|
|
1552 |
|
1553 |
# Extract response
|
1554 |
ai_response = extract_ai_response(response, provider)
|
1555 |
+
return history + [
|
1556 |
+
{"role": "user", "content": message},
|
1557 |
+
{"role": "assistant", "content": ai_response}
|
1558 |
+
]
|
1559 |
except Exception as e:
|
1560 |
error_message = f"Cohere API Error: {str(e)}"
|
1561 |
logger.error(error_message)
|
1562 |
+
return history + [
|
1563 |
+
{"role": "user", "content": message},
|
1564 |
+
{"role": "assistant", "content": error_message}
|
1565 |
+
]
|
1566 |
|
1567 |
elif provider == "Together":
|
1568 |
# Get model ID from registry
|
1569 |
model_id, _ = get_model_info(provider, model_choice)
|
1570 |
if not model_id:
|
1571 |
error_message = f"Error: Model '{model_choice}' not found in Together"
|
1572 |
+
return history + [
|
1573 |
+
{"role": "user", "content": message},
|
1574 |
+
{"role": "assistant", "content": error_message}
|
1575 |
+
]
|
1576 |
|
1577 |
# Build Together payload
|
1578 |
payload = {
|
|
|
1591 |
|
1592 |
# Handle streaming response
|
1593 |
if stream_output:
|
1594 |
+
# Add message to history
|
1595 |
+
updated_history = history + [{"role": "user", "content": message}]
|
1596 |
|
1597 |
# Set up generator for streaming updates
|
1598 |
def streaming_generator():
|
1599 |
+
assistant_response = ""
|
1600 |
+
for chunk in response:
|
1601 |
+
if hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content is not None:
|
1602 |
+
content = chunk.choices[0].delta.content
|
1603 |
+
assistant_response += content
|
1604 |
+
yield updated_history + [{"role": "assistant", "content": assistant_response}]
|
1605 |
|
1606 |
return streaming_generator()
|
1607 |
|
1608 |
# Handle normal response
|
1609 |
else:
|
1610 |
ai_response = extract_ai_response(response, provider)
|
1611 |
+
return history + [
|
1612 |
+
{"role": "user", "content": message},
|
1613 |
+
{"role": "assistant", "content": ai_response}
|
1614 |
+
]
|
1615 |
except Exception as e:
|
1616 |
error_message = f"Together API Error: {str(e)}"
|
1617 |
logger.error(error_message)
|
1618 |
+
return history + [
|
1619 |
+
{"role": "user", "content": message},
|
1620 |
+
{"role": "assistant", "content": error_message}
|
1621 |
+
]
|
1622 |
|
1623 |
elif provider == "OVH":
|
1624 |
# Get model ID from registry
|
1625 |
model_id, _ = get_model_info(provider, model_choice)
|
1626 |
if not model_id:
|
1627 |
error_message = f"Error: Model '{model_choice}' not found in OVH"
|
1628 |
+
return history + [
|
1629 |
+
{"role": "user", "content": message},
|
1630 |
+
{"role": "assistant", "content": error_message}
|
1631 |
+
]
|
1632 |
|
1633 |
# Build OVH payload
|
1634 |
payload = {
|
|
|
1646 |
|
1647 |
# Extract response
|
1648 |
ai_response = extract_ai_response(response, provider)
|
1649 |
+
return history + [
|
1650 |
+
{"role": "user", "content": message},
|
1651 |
+
{"role": "assistant", "content": ai_response}
|
1652 |
+
]
|
1653 |
except Exception as e:
|
1654 |
error_message = f"OVH API Error: {str(e)}"
|
1655 |
logger.error(error_message)
|
1656 |
+
return history + [
|
1657 |
+
{"role": "user", "content": message},
|
1658 |
+
{"role": "assistant", "content": error_message}
|
1659 |
+
]
|
1660 |
|
1661 |
elif provider == "Cerebras":
|
1662 |
# Get model ID from registry
|
1663 |
model_id, _ = get_model_info(provider, model_choice)
|
1664 |
if not model_id:
|
1665 |
error_message = f"Error: Model '{model_choice}' not found in Cerebras"
|
1666 |
+
return history + [
|
1667 |
+
{"role": "user", "content": message},
|
1668 |
+
{"role": "assistant", "content": error_message}
|
1669 |
+
]
|
1670 |
|
1671 |
# Build Cerebras payload
|
1672 |
payload = {
|
|
|
1684 |
|
1685 |
# Extract response
|
1686 |
ai_response = extract_ai_response(response, provider)
|
1687 |
+
return history + [
|
1688 |
+
{"role": "user", "content": message},
|
1689 |
+
{"role": "assistant", "content": ai_response}
|
1690 |
+
]
|
1691 |
except Exception as e:
|
1692 |
error_message = f"Cerebras API Error: {str(e)}"
|
1693 |
logger.error(error_message)
|
1694 |
+
return history + [
|
1695 |
+
{"role": "user", "content": message},
|
1696 |
+
{"role": "assistant", "content": error_message}
|
1697 |
+
]
|
1698 |
|
1699 |
elif provider == "GoogleAI":
|
1700 |
# Get model ID from registry
|
1701 |
model_id, _ = get_model_info(provider, model_choice)
|
1702 |
if not model_id:
|
1703 |
error_message = f"Error: Model '{model_choice}' not found in GoogleAI"
|
1704 |
+
return history + [
|
1705 |
+
{"role": "user", "content": message},
|
1706 |
+
{"role": "assistant", "content": error_message}
|
1707 |
+
]
|
1708 |
|
1709 |
# Build GoogleAI payload
|
1710 |
payload = {
|
|
|
1723 |
|
1724 |
# Extract response
|
1725 |
ai_response = extract_ai_response(response, provider)
|
1726 |
+
return history + [
|
1727 |
+
{"role": "user", "content": message},
|
1728 |
+
{"role": "assistant", "content": ai_response}
|
1729 |
+
]
|
1730 |
except Exception as e:
|
1731 |
error_message = f"GoogleAI API Error: {str(e)}"
|
1732 |
logger.error(error_message)
|
1733 |
+
return history + [
|
1734 |
+
{"role": "user", "content": message},
|
1735 |
+
{"role": "assistant", "content": error_message}
|
1736 |
+
]
|
1737 |
|
1738 |
else:
|
1739 |
error_message = f"Error: Unsupported provider '{provider}'"
|
1740 |
+
return history + [
|
1741 |
+
{"role": "user", "content": message},
|
1742 |
+
{"role": "assistant", "content": error_message}
|
1743 |
+
]
|
1744 |
|
1745 |
except Exception as e:
|
1746 |
error_message = f"Error: {str(e)}"
|
1747 |
logger.error(f"Exception during API call: {error_message}")
|
1748 |
+
return history + [
|
1749 |
+
{"role": "user", "content": message},
|
1750 |
+
{"role": "assistant", "content": error_message}
|
1751 |
+
]
|
1752 |
|
1753 |
def clear_chat():
|
1754 |
"""Reset all inputs"""
|
|
|
2243 |
|
2244 |
def update_vision_indicator(provider, model_choice):
|
2245 |
"""Update the vision capability indicator"""
|
2246 |
+
# Safety check for None model
|
2247 |
+
if model_choice is None:
|
2248 |
+
return False
|
2249 |
return is_vision_model(provider, model_choice)
|
2250 |
|
2251 |
def update_image_upload_visibility(provider, model_choice):
|
2252 |
"""Show/hide image upload based on model vision capabilities"""
|
2253 |
+
# Safety check for None model
|
2254 |
+
if model_choice is None:
|
2255 |
+
return gr.update(visible=False)
|
2256 |
is_vision = is_vision_model(provider, model_choice)
|
2257 |
return gr.update(visible=is_vision)
|
2258 |
|
2259 |
+
# Search model function
|
2260 |
def search_openrouter_models(search_term):
|
2261 |
"""Filter OpenRouter models based on search term"""
|
2262 |
all_models = [model[0] for model in OPENROUTER_ALL_MODELS]
|
|
|
2677 |
|
2678 |
# Check if model is selected
|
2679 |
if not model_choice:
|
2680 |
+
error_message = f"Error: No model selected for provider {provider}"
|
2681 |
+
return history + [
|
2682 |
+
{"role": "user", "content": message},
|
2683 |
+
{"role": "assistant", "content": error_message}
|
2684 |
+
]
|
2685 |
|
2686 |
# Select the appropriate API key based on the provider
|
2687 |
api_key_override = None
|