Update .gitignore to include .env, add model download script, enhance save_data endpoint for batch processing, and modify model_type formatting in index.html
Browse files- .gitignore +1 -0
- Dockerfile +2 -0
- download_models.py +10 -0
- main.py +89 -76
- static/index.html +2 -2
.gitignore
CHANGED
@@ -1,2 +1,3 @@
|
|
1 |
.venv
|
2 |
__pycache__/
|
|
|
|
1 |
.venv
|
2 |
__pycache__/
|
3 |
+
.env
|
Dockerfile
CHANGED
@@ -24,4 +24,6 @@ WORKDIR $HOME/app
|
|
24 |
# Copy the current directory contents into the container at $HOME/app setting the owner to the user
|
25 |
COPY --chown=user . $HOME/app
|
26 |
|
|
|
|
|
27 |
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
|
|
|
24 |
# Copy the current directory contents into the container at $HOME/app setting the owner to the user
|
25 |
COPY --chown=user . $HOME/app
|
26 |
|
27 |
+
RUN python3 $HOME/app/download_models.py
|
28 |
+
|
29 |
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
|
download_models.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoModel, AutoTokenizer
|
2 |
+
|
3 |
+
model_names = [
|
4 |
+
"WhereIsAI/UAE-Large-V1",
|
5 |
+
"BAAI/bge-large-en-v1.5",
|
6 |
+
]
|
7 |
+
|
8 |
+
for name in model_names:
|
9 |
+
AutoModel.from_pretrained(name)
|
10 |
+
AutoTokenizer.from_pretrained(name)
|
main.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from fastapi import FastAPI, HTTPException, Depends, status
|
2 |
from fastapi.responses import FileResponse
|
3 |
from fastapi.staticfiles import StaticFiles
|
4 |
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
@@ -19,6 +19,8 @@ import os
|
|
19 |
import logging
|
20 |
from functools import lru_cache
|
21 |
from diskcache import Cache
|
|
|
|
|
22 |
|
23 |
# Configure logging
|
24 |
logging.basicConfig(level=logging.INFO)
|
@@ -36,8 +38,8 @@ app = FastAPI()
|
|
36 |
cache = Cache('./cache')
|
37 |
|
38 |
# JWT Configuration
|
39 |
-
SECRET_KEY = os.environ.get("
|
40 |
-
REFRESH_SECRET_KEY = os.environ.get("
|
41 |
ALGORITHM = "HS256"
|
42 |
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
43 |
REFRESH_TOKEN_EXPIRE_DAYS = 7
|
@@ -322,88 +324,99 @@ async def search(
|
|
322 |
detail=f"Search failed: {str(e)}"
|
323 |
)
|
324 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
325 |
@app.post("/save")
|
326 |
async def save_data(
|
327 |
save_input: SaveBatchInput,
|
328 |
username: str = Depends(verify_access_token)
|
329 |
):
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
"
|
346 |
-
|
347 |
-
"reaction": [],
|
348 |
-
"timestamp": [],
|
349 |
-
"confidence_score": []
|
350 |
-
}
|
351 |
-
|
352 |
-
# Add each item to the data dict
|
353 |
-
for item in save_input.items:
|
354 |
-
data["user_type"].append(item.user_type)
|
355 |
-
data["username"].append(item.username)
|
356 |
-
data["query"].append(item.query)
|
357 |
-
data["retrieved_text"].append(item.retrieved_text)
|
358 |
-
data["model_type"].append(item.model_type)
|
359 |
-
data["reaction"].append(item.reaction)
|
360 |
-
data["timestamp"].append(datetime.now(timezone.utc).isoformat().replace('+00:00', 'Z'))
|
361 |
-
data["confidence_score"].append(item.confidence_score)
|
362 |
-
|
363 |
-
try:
|
364 |
-
# Load existing dataset and merge
|
365 |
-
dataset = load_dataset(
|
366 |
-
"HumbleBeeAI/al-ghazali-rag-retrieval-evaluation",
|
367 |
-
split="train"
|
368 |
-
)
|
369 |
-
existing_data = dataset.to_dict()
|
370 |
-
|
371 |
-
# Add new data
|
372 |
-
for key in data:
|
373 |
-
if key not in existing_data:
|
374 |
-
existing_data[key] = [
|
375 |
-
"" if key in ["timestamp"] else
|
376 |
-
0.0 if key in ["confidence_score"] else None
|
377 |
-
] * len(next(iter(existing_data.values())))
|
378 |
-
existing_data[key].extend(data[key])
|
379 |
-
|
380 |
-
except Exception as e:
|
381 |
-
logging.warning(f"Could not load existing dataset, creating new one: {str(e)}")
|
382 |
-
existing_data = data
|
383 |
-
|
384 |
-
# Create and push dataset
|
385 |
-
updated_dataset = Dataset.from_dict(existing_data)
|
386 |
-
updated_dataset.push_to_hub(
|
387 |
-
"HumbleBeeAI/al-ghazali-rag-retrieval-evaluation"
|
388 |
-
)
|
389 |
-
|
390 |
-
return {"message": "Data saved successfully"}
|
391 |
-
|
392 |
-
except Exception as e:
|
393 |
-
logging.error(f"Save error: {str(e)}")
|
394 |
-
raise HTTPException(
|
395 |
-
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
396 |
-
detail=f"Failed to save data: {str(e)}"
|
397 |
-
)
|
398 |
|
399 |
# Make sure to keep the static files mounting
|
400 |
app.mount("/home", StaticFiles(directory="static", html=True), name="home")
|
401 |
|
402 |
-
# Startup event to create cache directory if it doesn't exist
|
403 |
-
@app.on_event("startup")
|
404 |
-
async def startup_event():
|
405 |
-
os.makedirs("./cache", exist_ok=True)
|
406 |
-
|
407 |
if __name__ == "__main__":
|
408 |
import uvicorn
|
409 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
1 |
+
from fastapi import FastAPI, HTTPException, Depends, status, BackgroundTasks
|
2 |
from fastapi.responses import FileResponse
|
3 |
from fastapi.staticfiles import StaticFiles
|
4 |
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
|
|
19 |
import logging
|
20 |
from functools import lru_cache
|
21 |
from diskcache import Cache
|
22 |
+
import json
|
23 |
+
import asyncio
|
24 |
|
25 |
# Configure logging
|
26 |
logging.basicConfig(level=logging.INFO)
|
|
|
38 |
cache = Cache('./cache')
|
39 |
|
40 |
# JWT Configuration
|
41 |
+
SECRET_KEY = os.environ.get("PRIME_AUTH", "c0369f977b69e717dc16f6fc574039eb2b1ebde38014d2be")
|
42 |
+
REFRESH_SECRET_KEY = os.environ.get("PROLONGED_AUTH", "916018771b29084378c9362c0cd9e631fd4927b8aea07f91")
|
43 |
ALGORITHM = "HS256"
|
44 |
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
45 |
REFRESH_TOKEN_EXPIRE_DAYS = 7
|
|
|
324 |
detail=f"Search failed: {str(e)}"
|
325 |
)
|
326 |
|
327 |
+
# new constants
|
328 |
+
QUEUE_FILE = "./save_queue.jsonl"
|
329 |
+
PUSH_INTERVAL_S = 300 # seconds
|
330 |
+
QUEUE_THRESHOLD = 1000
|
331 |
+
MAX_PUSH_INTERVAL_S = 47 * 3600 # 44 hours
|
332 |
+
|
333 |
+
# background task to batch-push queued records
|
334 |
+
async def _hf_sync_loop():
|
335 |
+
# authenticate once for private repo access
|
336 |
+
hf_token = os.environ.get("HF_TOKEN")
|
337 |
+
if not hf_token:
|
338 |
+
logging.error("HF_TOKEN not set for Hugging Face authentication")
|
339 |
+
return
|
340 |
+
login(token=hf_token)
|
341 |
+
|
342 |
+
last_push_time = datetime.now(timezone.utc).timestamp()
|
343 |
+
|
344 |
+
while True:
|
345 |
+
await asyncio.sleep(PUSH_INTERVAL_S)
|
346 |
+
try:
|
347 |
+
# Count lines in queue file
|
348 |
+
if not os.path.exists(QUEUE_FILE):
|
349 |
+
continue
|
350 |
+
with open(QUEUE_FILE, "r") as f:
|
351 |
+
lines = f.read().splitlines()
|
352 |
+
queue_len = len(lines)
|
353 |
+
now = datetime.now(timezone.utc).timestamp()
|
354 |
+
time_since_last_push = now - last_push_time
|
355 |
+
|
356 |
+
# print(f"Queue length: {queue_len}, Time since last push: {time_since_last_push}")
|
357 |
+
# Only push if threshold met or max interval
|
358 |
+
if queue_len >= QUEUE_THRESHOLD or time_since_last_push >= MAX_PUSH_INTERVAL_S:
|
359 |
+
if not lines:
|
360 |
+
last_push_time = now
|
361 |
+
continue
|
362 |
+
new_records = [json.loads(l) for l in lines]
|
363 |
+
# load remote dataset with auth
|
364 |
+
dataset = load_dataset(
|
365 |
+
"HumbleBeeAI/al-ghazali-rag-retrieval-evaluation",
|
366 |
+
split="train"
|
367 |
+
)
|
368 |
+
data = dataset.to_dict()
|
369 |
+
# append new records
|
370 |
+
for rec in new_records:
|
371 |
+
for k, v in rec.items():
|
372 |
+
data.setdefault(k, []).append(v)
|
373 |
+
updated = Dataset.from_dict(data)
|
374 |
+
# push with token
|
375 |
+
updated.push_to_hub(
|
376 |
+
"HumbleBeeAI/al-ghazali-rag-retrieval-evaluation",
|
377 |
+
token=hf_token
|
378 |
+
)
|
379 |
+
# clear queue
|
380 |
+
open(QUEUE_FILE, "w").close()
|
381 |
+
last_push_time = now
|
382 |
+
except Exception as e:
|
383 |
+
logging.error(f"Background sync failed: {e}")
|
384 |
+
|
385 |
+
# replace existing startup_event
|
386 |
+
@app.on_event("startup")
|
387 |
+
async def startup_event():
|
388 |
+
os.makedirs("./cache", exist_ok=True)
|
389 |
+
Path(QUEUE_FILE).touch(exist_ok=True)
|
390 |
+
# start background sync loop
|
391 |
+
asyncio.create_task(_hf_sync_loop())
|
392 |
+
|
393 |
+
# replace existing /save endpoint
|
394 |
@app.post("/save")
|
395 |
async def save_data(
|
396 |
save_input: SaveBatchInput,
|
397 |
username: str = Depends(verify_access_token)
|
398 |
):
|
399 |
+
records = []
|
400 |
+
for item in save_input.items:
|
401 |
+
records.append({
|
402 |
+
"user_type": item.user_type,
|
403 |
+
"username": item.username,
|
404 |
+
"query": item.query,
|
405 |
+
"retrieved_text": item.retrieved_text,
|
406 |
+
"model_type": item.model_type,
|
407 |
+
"reaction": item.reaction,
|
408 |
+
"timestamp": datetime.now(timezone.utc).isoformat().replace('+00:00','Z'),
|
409 |
+
"confidence_score": item.confidence_score
|
410 |
+
})
|
411 |
+
# append to local queue
|
412 |
+
with open(QUEUE_FILE, "a") as f:
|
413 |
+
for r in records:
|
414 |
+
f.write(json.dumps(r) + "\n")
|
415 |
+
return {"message": "Your data is queued for batch upload."}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
416 |
|
417 |
# Make sure to keep the static files mounting
|
418 |
app.mount("/home", StaticFiles(directory="static", html=True), name="home")
|
419 |
|
|
|
|
|
|
|
|
|
|
|
420 |
if __name__ == "__main__":
|
421 |
import uvicorn
|
422 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
static/index.html
CHANGED
@@ -219,7 +219,7 @@ class LoginResponse {
|
|
219 |
"username": "user1332",
|
220 |
"query": "What is the seventh test of lovers of God?",
|
221 |
"retrieved_text": "The seventh test is that lovers of God will love those who obey Him and hate the infidels and the disobedient...",
|
222 |
-
"model_type": "
|
223 |
"reaction": "positive",
|
224 |
"confidence_score": 0.95
|
225 |
},
|
@@ -228,7 +228,7 @@ class LoginResponse {
|
|
228 |
"username": "user1332",
|
229 |
"query": "What is the seventh test of lovers of God?",
|
230 |
"retrieved_text": "The seventh test is that lovers of God will love those who obey Him and hate the infidels and the disobedient...",
|
231 |
-
"model_type": "
|
232 |
"reaction": "positive",
|
233 |
"confidence_score": 0.92
|
234 |
}
|
|
|
219 |
"username": "user1332",
|
220 |
"query": "What is the seventh test of lovers of God?",
|
221 |
"retrieved_text": "The seventh test is that lovers of God will love those who obey Him and hate the infidels and the disobedient...",
|
222 |
+
"model_type": "WhereIsAI_UAE_Large_V1",
|
223 |
"reaction": "positive",
|
224 |
"confidence_score": 0.95
|
225 |
},
|
|
|
228 |
"username": "user1332",
|
229 |
"query": "What is the seventh test of lovers of God?",
|
230 |
"retrieved_text": "The seventh test is that lovers of God will love those who obey Him and hate the infidels and the disobedient...",
|
231 |
+
"model_type": "BAAI_bge_large_en_v1.5",
|
232 |
"reaction": "positive",
|
233 |
"confidence_score": 0.92
|
234 |
}
|