wikihop-server / db /wiki_db_sqlite.py
stillerman's picture
stillerman HF Staff
sqlite backend
e91ced9
import sqlite3
class WikiDBSqlite:
def __init__(self, db_path):
"""Initialize the database with path to SQLite database"""
self.db_path = db_path
self.conn = sqlite3.connect(db_path)
self.conn.row_factory = sqlite3.Row
self.cursor = self.conn.cursor()
self._article_count = self._get_article_count()
print(f"Connected to SQLite database with {self._article_count} articles")
def __del__(self):
"""Close database connection when object is destroyed"""
if hasattr(self, 'conn') and self.conn:
self.conn.close()
def _get_article_count(self):
"""Get the number of articles in the database"""
self.cursor.execute("SELECT COUNT(*) FROM articles")
return self.cursor.fetchone()[0]
def get_article_count(self):
"""Return the number of articles in the database"""
return self._article_count
def get_all_article_titles(self):
"""Return a list of all article titles"""
self.cursor.execute("SELECT title FROM articles")
return [row[0] for row in self.cursor.fetchall()]
def get_article(self, title):
"""Get article data by title"""
self.cursor.execute(
"SELECT title, text FROM articles WHERE title = ?",
(title,)
)
article = self.cursor.fetchone()
if not article:
return {}
# Get links for this article
self.cursor.execute(
"SELECT target_title FROM links WHERE source_title = ?",
(title,)
)
links = [row[0] for row in self.cursor.fetchall()]
return {
'title': article['title'],
'text': article['text'],
'links': links
}
def article_exists(self, title):
"""Check if an article exists in the database"""
self.cursor.execute(
"SELECT 1 FROM articles WHERE title = ? LIMIT 1",
(title,)
)
return bool(self.cursor.fetchone())
def get_article_text(self, title):
"""Get the text of an article"""
self.cursor.execute(
"SELECT text FROM articles WHERE title = ?",
(title,)
)
result = self.cursor.fetchone()
return result['text'] if result else ''
def get_article_links(self, title):
"""Get the links of an article"""
self.cursor.execute(
"SELECT target_title FROM links WHERE source_title = ?",
(title,)
)
return [row[0] for row in self.cursor.fetchall()]