ai / jarvis.py
hadadrjt's picture
fixup! ai: Better handling of load balancing.
a4cdda6
raw
history blame
16 kB
#
# SPDX-FileCopyrightText: Hadad <[email protected]>
# SPDX-License-Identifier: Apache-2.0
#
import asyncio
import codecs
import docx
import gradio as gr
import httpx
import json
import os
import pandas as pd
import pdfplumber
import pytesseract
import random
import requests
import threading
import uuid
import zipfile
import io
from PIL import Image
from pathlib import Path
from pptx import Presentation
from openpyxl import load_workbook
os.system("apt-get update -q -y && apt-get install -q -y tesseract-ocr tesseract-ocr-eng tesseract-ocr-ind libleptonica-dev libtesseract-dev")
INTERNAL_AI_GET_SERVER = os.getenv("INTERNAL_AI_GET_SERVER")
INTERNAL_TRAINING_DATA = os.getenv("INTERNAL_TRAINING_DATA")
SYSTEM_PROMPT_MAPPING = json.loads(os.getenv("SYSTEM_PROMPT_MAPPING", "{}"))
SYSTEM_PROMPT_DEFAULT = os.getenv("DEFAULT_SYSTEM")
LINUX_SERVER_HOSTS = [h for h in json.loads(os.getenv("LINUX_SERVER_HOST", "[]")) if h]
LINUX_SERVER_HOSTS_MARKED = set()
LINUX_SERVER_HOSTS_ATTEMPTS = {}
LINUX_SERVER_PROVIDER_KEYS = [k for k in json.loads(os.getenv("LINUX_SERVER_PROVIDER_KEY", "[]")) if k]
LINUX_SERVER_PROVIDER_KEYS_MARKED = set()
LINUX_SERVER_PROVIDER_KEYS_ATTEMPTS = {}
LINUX_SERVER_ERRORS = set(map(int, os.getenv("LINUX_SERVER_ERROR", "").split(",")))
AI_TYPES = {f"AI_TYPE_{i}": os.getenv(f"AI_TYPE_{i}") for i in range(1, 8)}
RESPONSES = {f"RESPONSE_{i}": os.getenv(f"RESPONSE_{i}") for i in range(1, 11)}
MODEL_MAPPING = json.loads(os.getenv("MODEL_MAPPING", "{}"))
MODEL_CONFIG = json.loads(os.getenv("MODEL_CONFIG", "{}"))
MODEL_CHOICES = list(MODEL_MAPPING.values())
DEFAULT_CONFIG = json.loads(os.getenv("DEFAULT_CONFIG", "{}"))
DEFAULT_MODEL_KEY = list(MODEL_MAPPING.keys())[0] if MODEL_MAPPING else None
META_TAGS = os.getenv("META_TAGS")
ALLOWED_EXTENSIONS = json.loads(os.getenv("ALLOWED_EXTENSIONS", "[]"))
class SessionWithID(requests.Session):
def __init__(sess):
super().__init__()
sess.session_id = str(uuid.uuid4())
def create_session():
return SessionWithID()
def ensure_stop_event(sess):
if not hasattr(sess, "stop_event"):
sess.stop_event = asyncio.Event()
def get_available_items(items, marked):
a = [i for i in items if i not in marked]
random.shuffle(a)
return a
def marked_item(item, marked, attempts):
marked.add(item)
attempts[item] = attempts.get(item, 0) + 1
if attempts[item] >= 3:
def remove():
marked.discard(item)
attempts.pop(item, None)
threading.Timer(300, remove).start()
def get_model_key(display):
return next((k for k, v in MODEL_MAPPING.items() if v == display), DEFAULT_MODEL_KEY)
def extract_pdf_content(fp):
content = ""
try:
with pdfplumber.open(fp) as pdf:
for page in pdf.pages:
text = page.extract_text() or ""
content += text + "\n"
if page.images:
img_obj = page.to_image(resolution=300)
for img in page.images:
bbox = (img["x0"], img["top"], img["x1"], img["bottom"])
cropped = img_obj.original.crop(bbox)
ocr_text = pytesseract.image_to_string(cropped)
if ocr_text.strip():
content += ocr_text + "\n"
tables = page.extract_tables()
for table in tables:
for row in table:
cells = [str(cell) for cell in row if cell is not None]
if cells:
content += "\t".join(cells) + "\n"
except Exception as e:
content += f"{fp}: {e}"
return content.strip()
def extract_docx_content(fp):
content = ""
try:
doc = docx.Document(fp)
for para in doc.paragraphs:
content += para.text + "\n"
for table in doc.tables:
for row in table.rows:
cells = [cell.text for cell in row.cells]
content += "\t".join(cells) + "\n"
with zipfile.ZipFile(fp) as z:
for file in z.namelist():
if file.startswith("word/media/"):
data = z.read(file)
try:
img = Image.open(io.BytesIO(data))
ocr_text = pytesseract.image_to_string(img)
if ocr_text.strip():
content += ocr_text + "\n"
except:
pass
except Exception as e:
content += f"{fp}: {e}"
return content.strip()
def extract_excel_content(fp):
content = ""
try:
sheets = pd.read_excel(fp, sheet_name=None)
for name, df in sheets.items():
content += f"Sheet: {name}\n"
content += df.to_csv(index=False) + "\n"
wb = load_workbook(fp, data_only=True)
if wb._images:
for image in wb._images:
img = image.ref
if isinstance(img, bytes):
try:
pil_img = Image.open(io.BytesIO(img))
ocr_text = pytesseract.image_to_string(pil_img)
if ocr_text.strip():
content += ocr_text + "\n"
except:
pass
except Exception as e:
content += f"{fp}: {e}"
return content.strip()
def extract_pptx_content(fp):
content = ""
try:
prs = Presentation(fp)
for slide in prs.slides:
for shape in slide.shapes:
if hasattr(shape, "text") and shape.text:
content += shape.text + "\n"
if shape.shape_type == 13 and hasattr(shape, "image") and shape.image:
try:
img = Image.open(io.BytesIO(shape.image.blob))
ocr_text = pytesseract.image_to_string(img)
if ocr_text.strip():
content += ocr_text + "\n"
except:
pass
for shape in slide.shapes:
if shape.has_table:
table = shape.table
for row in table.rows:
cells = [cell.text for cell in row.cells]
content += "\t".join(cells) + "\n"
except Exception as e:
content += f"{fp}: {e}"
return content.strip()
def extract_file_content(fp):
ext = Path(fp).suffix.lower()
if ext == ".pdf":
return extract_pdf_content(fp)
elif ext in [".doc", ".docx"]:
return extract_docx_content(fp)
elif ext in [".xlsx", ".xls"]:
return extract_excel_content(fp)
elif ext in [".ppt", ".pptx"]:
return extract_pptx_content(fp)
else:
try:
return Path(fp).read_text(encoding="utf-8").strip()
except Exception as e:
return f"{fp}: {e}"
async def fetch_response_stream_async(host, key, model, msgs, cfg, sid, stop_event):
for t in [0.5, 1]:
try:
async with httpx.AsyncClient(timeout=t) as client:
async with client.stream("POST", host, json={**{"model": model, "messages": msgs, "session_id": sid, "stream": True}, **cfg}, headers={"Authorization": f"Bearer {key}"}) as response:
if response.status_code in LINUX_SERVER_ERRORS:
marked_item(key, LINUX_SERVER_PROVIDER_KEYS_MARKED, LINUX_SERVER_PROVIDER_KEYS_ATTEMPTS)
return
async for line in response.aiter_lines():
if stop_event.is_set():
return
if not line:
continue
if line.startswith("data: "):
data = line[6:]
if data.strip() == RESPONSES["RESPONSE_10"]:
return
try:
j = json.loads(data)
if isinstance(j, dict) and j.get("choices"):
for ch in j["choices"]:
delta = ch.get("delta", {})
if "reasoning" in delta and delta["reasoning"] is not None and delta["reasoning"] != "":
decoded_reasoning = delta["reasoning"].encode('utf-8').decode('unicode_escape')
yield ("reasoning", decoded_reasoning)
if "content" in delta and delta["content"] is not None and delta["content"] != "":
yield ("content", delta["content"])
except:
continue
except:
continue
marked_item(key, LINUX_SERVER_PROVIDER_KEYS_MARKED, LINUX_SERVER_PROVIDER_KEYS_ATTEMPTS)
return
async def chat_with_model_async(history, user_input, model_display, sess, custom_prompt):
ensure_stop_event(sess)
sess.stop_event.clear()
if not get_available_items(LINUX_SERVER_PROVIDER_KEYS, LINUX_SERVER_PROVIDER_KEYS_MARKED) or not get_available_items(LINUX_SERVER_HOSTS, LINUX_SERVER_HOSTS_ATTEMPTS):
yield ("content", RESPONSES["RESPONSE_3"])
return
if not hasattr(sess, "session_id") or not sess.session_id:
sess.session_id = str(uuid.uuid4())
sess.stop_event = asyncio.Event()
if not hasattr(sess, "active_candidate"):
sess.active_candidate = None
model_key = get_model_key(model_display)
cfg = MODEL_CONFIG.get(model_key, DEFAULT_CONFIG)
msgs = [{"role": "user", "content": u} for u, _ in history] + [{"role": "assistant", "content": a} for _, a in history if a]
prompt = INTERNAL_TRAINING_DATA if model_key == DEFAULT_MODEL_KEY and INTERNAL_TRAINING_DATA else (custom_prompt or SYSTEM_PROMPT_MAPPING.get(model_key, SYSTEM_PROMPT_DEFAULT))
msgs.insert(0, {"role": "system", "content": prompt})
msgs.append({"role": "user", "content": user_input})
if sess.active_candidate:
async for chunk in fetch_response_stream_async(sess.active_candidate[0], sess.active_candidate[1], model_key, msgs, cfg, sess.session_id, sess.stop_event):
if sess.stop_event.is_set():
return
yield chunk
return
keys = get_available_items(LINUX_SERVER_PROVIDER_KEYS, LINUX_SERVER_PROVIDER_KEYS_MARKED)
hosts = get_available_items(LINUX_SERVER_HOSTS, LINUX_SERVER_HOSTS_ATTEMPTS)
random.shuffle(keys)
random.shuffle(hosts)
for k in keys:
for h in hosts:
stream_gen = fetch_response_stream_async(h, k, model_key, msgs, cfg, sess.session_id, sess.stop_event)
full_text = ""
got_any = False
async for chunk in stream_gen:
if sess.stop_event.is_set():
return
if not got_any:
got_any = True
sess.active_candidate = (h, k)
full_text += chunk[1]
yield chunk
if got_any and full_text:
return
yield ("content", RESPONSES["RESPONSE_2"])
async def respond_async(multi, history, model_display, sess, custom_prompt):
ensure_stop_event(sess)
sess.stop_event.clear()
msg_input = {"text": multi.get("text", "").strip(), "files": multi.get("files", [])}
if not msg_input["text"] and not msg_input["files"]:
yield history, gr.update(value="", interactive=True, submit_btn=True, stop_btn=False), sess
return
inp = ""
for f in msg_input["files"]:
fp = f.get("data", f.get("name", "")) if isinstance(f, dict) else f
inp += f"{Path(fp).name}\n\n{extract_file_content(fp)}\n\n"
if msg_input["text"]:
inp += msg_input["text"]
history.append([inp, RESPONSES["RESPONSE_8"]])
yield history, gr.update(interactive=False, submit_btn=False, stop_btn=True), sess
queue = asyncio.Queue()
async def background():
display_text = ""
content_started = False
async for typ, chunk in chat_with_model_async(history, inp, model_display, sess, custom_prompt):
if sess.stop_event.is_set():
break
if typ == "reasoning":
if content_started:
continue
display_text += chunk
await queue.put(("set", display_text))
elif typ == "content":
if not content_started:
content_started = True
display_text = chunk
await queue.put(("replace", display_text))
else:
display_text += chunk
await queue.put(("append", display_text))
await queue.put(None)
return display_text
bg_task = asyncio.create_task(background())
stop_task = asyncio.create_task(sess.stop_event.wait())
try:
while True:
done, _ = await asyncio.wait({stop_task, asyncio.create_task(queue.get())}, return_when=asyncio.FIRST_COMPLETED)
if stop_task in done:
bg_task.cancel()
history[-1][1] = RESPONSES["RESPONSE_1"]
yield history, gr.update(value="", interactive=True, submit_btn=True, stop_btn=False), sess
sess.stop_event.clear()
return
for d in done:
result = d.result()
if result is None:
raise StopAsyncIteration
action, text = result
history[-1][1] = text
yield history, gr.update(interactive=False, submit_btn=False, stop_btn=True), sess
except StopAsyncIteration:
pass
finally:
stop_task.cancel()
full_response = await bg_task
yield history, gr.update(value="", interactive=True, submit_btn=True, stop_btn=False), sess
def change_model(new):
visible = new != MODEL_CHOICES[0]
default = SYSTEM_PROMPT_MAPPING.get(get_model_key(new), SYSTEM_PROMPT_DEFAULT)
return [], create_session(), new, default, gr.update(value=default, visible=visible)
def stop_response(history, sess):
ensure_stop_event(sess)
sess.stop_event.set()
if history:
history[-1][1] = RESPONSES["RESPONSE_1"]
new_sess = create_session()
return history, None, new_sess
with gr.Blocks(fill_height=True, fill_width=True, title=AI_TYPES["AI_TYPE_4"], head=META_TAGS) as jarvis:
user_history = gr.State([])
user_session = gr.State(create_session())
selected_model = gr.State(MODEL_CHOICES[0] if MODEL_CHOICES else "")
custom_prompt_state = gr.State("")
chatbot = gr.Chatbot(label=AI_TYPES["AI_TYPE_1"], show_copy_button=True, scale=1, elem_id=AI_TYPES["AI_TYPE_2"])
msg = gr.MultimodalTextbox(show_label=False, placeholder=RESPONSES["RESPONSE_5"], interactive=True, file_count="single", file_types=ALLOWED_EXTENSIONS)
with gr.Accordion(AI_TYPES["AI_TYPE_6"], open=False):
model_dropdown = gr.Dropdown(show_label=False, choices=MODEL_CHOICES, value=MODEL_CHOICES[0])
system_prompt = gr.Textbox(label=AI_TYPES["AI_TYPE_7"], lines=2, interactive=True, visible=False)
model_dropdown.change(fn=change_model, inputs=[model_dropdown], outputs=[user_history, user_session, selected_model, custom_prompt_state, system_prompt])
system_prompt.change(fn=lambda x: x, inputs=[system_prompt], outputs=[custom_prompt_state])
msg.submit(fn=respond_async, inputs=[msg, user_history, selected_model, user_session, custom_prompt_state], outputs=[chatbot, msg, user_session], api_name=INTERNAL_AI_GET_SERVER)
msg.stop(fn=stop_response, inputs=[user_history, user_session], outputs=[chatbot, msg, user_session])
jarvis.queue(default_concurrency_limit=2).launch(max_file_size="1mb")