rui3000 commited on
Commit
1f004f6
·
verified ·
1 Parent(s): 23fc124

Create db.py

Browse files
Files changed (1) hide show
  1. db.py +89 -0
db.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import time
4
+ import spaces
5
+ import sqlite3
6
+ import os
7
+ import json
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer
9
+
10
+ # --- Database Setup ---
11
+ # Use the persistent storage location in Hugging Face Spaces
12
+ DB_PATH = os.path.join("/data", "model_tests.db")
13
+
14
+ def init_db():
15
+ """Initialize the SQLite database with required tables if they don't exist"""
16
+ with sqlite3.connect(DB_PATH) as conn:
17
+ cursor = conn.cursor()
18
+
19
+ # Create table for model inputs and outputs
20
+ cursor.execute('''
21
+ CREATE TABLE IF NOT EXISTS model_tests (
22
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
23
+ analysis_mode TEXT NOT NULL,
24
+ system_prompt TEXT NOT NULL,
25
+ input_content TEXT,
26
+ model_response TEXT NOT NULL,
27
+ generation_time REAL NOT NULL,
28
+ tokens_generated INTEGER NOT NULL,
29
+ temperature REAL NOT NULL,
30
+ top_p REAL NOT NULL,
31
+ max_length INTEGER NOT NULL,
32
+ timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
33
+ )
34
+ ''')
35
+
36
+ conn.commit()
37
+ print("Test database initialized successfully")
38
+
39
+ # Initialize the database when the app starts
40
+ init_db()
41
+
42
+ # Function to save test results
43
+ def save_test_result(analysis_mode, system_prompt, input_content, model_response,
44
+ generation_time, tokens_generated, temperature, top_p, max_length):
45
+ with sqlite3.connect(DB_PATH) as conn:
46
+ cursor = conn.cursor()
47
+
48
+ cursor.execute('''
49
+ INSERT INTO model_tests
50
+ (analysis_mode, system_prompt, input_content, model_response,
51
+ generation_time, tokens_generated, temperature, top_p, max_length)
52
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
53
+ ''', (analysis_mode, system_prompt, input_content, model_response,
54
+ generation_time, tokens_generated, temperature, top_p, max_length))
55
+
56
+ test_id = cursor.lastrowid
57
+ conn.commit()
58
+
59
+ return test_id
60
+
61
+ # Function to retrieve test history
62
+ def get_test_history(limit=10):
63
+ with sqlite3.connect(DB_PATH) as conn:
64
+ cursor = conn.cursor()
65
+ cursor.execute('''
66
+ SELECT id, analysis_mode, timestamp, generation_time, tokens_generated
67
+ FROM model_tests
68
+ ORDER BY id DESC LIMIT ?
69
+ ''', (limit,))
70
+ history = cursor.fetchall()
71
+
72
+ return history
73
+
74
+ # Function to get test details by ID
75
+ def get_test_details(test_id):
76
+ with sqlite3.connect(DB_PATH) as conn:
77
+ cursor = conn.cursor()
78
+ cursor.execute('''
79
+ SELECT * FROM model_tests WHERE id = ?
80
+ ''', (test_id,))
81
+ test = cursor.fetchone()
82
+
83
+ if test:
84
+ # Convert to dictionary
85
+ columns = [col[0] for col in cursor.description]
86
+ test_dict = {columns[i]: test[i] for i in range(len(columns))}
87
+ return test_dict
88
+ else:
89
+ return None