ganesh3 commited on
Commit
e2cbf8c
·
verified ·
1 Parent(s): 1391300

Update app/generate_ground_truth.py

Browse files
Files changed (1) hide show
  1. app/generate_ground_truth.py +21 -54
app/generate_ground_truth.py CHANGED
@@ -1,15 +1,13 @@
1
  import pandas as pd
2
  import json
3
  from tqdm import tqdm
4
- import ollama
5
- from elasticsearch import Elasticsearch
6
- import sqlite3
7
  import logging
8
  import os
9
- import re
10
  import sys
 
 
11
 
12
- # Configure logging for stdout only
13
  logging.basicConfig(
14
  level=logging.INFO,
15
  format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
@@ -18,27 +16,11 @@ logging.basicConfig(
18
  logger = logging.getLogger(__name__)
19
 
20
  def extract_model_name(index_name):
21
- # Extract the model name from the index name
22
  match = re.search(r'video_[^_]+_(.+)$', index_name)
23
  if match:
24
  return match.group(1)
25
  return None
26
 
27
- def get_transcript_from_elasticsearch(es, index_name, video_id):
28
- try:
29
- result = es.search(index=index_name, body={
30
- "query": {
31
- "match": {
32
- "video_id": video_id
33
- }
34
- }
35
- })
36
- if result['hits']['hits']:
37
- return result['hits']['hits'][0]['_source']['content']
38
- except Exception as e:
39
- logger.error(f"Error retrieving transcript from Elasticsearch: {str(e)}")
40
- return None
41
-
42
  def get_transcript_from_sqlite(db_path, video_id):
43
  try:
44
  conn = sqlite3.connect(db_path)
@@ -73,13 +55,12 @@ def generate_questions(transcript, max_retries=3):
73
  retries = 0
74
 
75
  while len(all_questions) < 10 and retries < max_retries:
76
- prompt = prompt_template.format(transcript=transcript)
77
  try:
78
- response = ollama.chat(
79
- model='phi3.5',
80
- messages=[{"role": "user", "content": prompt}]
81
- )
82
- questions = json.loads(response['message']['content'])['questions']
83
  all_questions.update(questions)
84
  except Exception as e:
85
  logger.error(f"Error generating questions: {str(e)}")
@@ -91,19 +72,11 @@ def generate_questions(transcript, max_retries=3):
91
  return {"questions": list(all_questions)[:10]}
92
 
93
  def generate_ground_truth(db_handler, data_processor, video_id):
94
- es = Elasticsearch([f'http://{os.getenv("ELASTICSEARCH_HOST", "localhost")}:{os.getenv("ELASTICSEARCH_PORT", "9200")}'])
95
-
96
  # Get existing questions for this video to avoid duplicates
97
  existing_questions = set(q[1] for q in db_handler.get_ground_truth_by_video(video_id))
98
 
99
- transcript = None
100
- index_name = db_handler.get_elasticsearch_index_by_youtube_id(video_id)
101
-
102
- if index_name:
103
- transcript = get_transcript_from_elasticsearch(es, index_name, video_id)
104
-
105
- if not transcript:
106
- transcript = db_handler.get_transcript_content(video_id)
107
 
108
  if not transcript:
109
  logger.error(f"Failed to retrieve transcript for video {video_id}")
@@ -141,10 +114,18 @@ def generate_ground_truth(db_handler, data_processor, video_id):
141
  logger.info(f"Ground truth data saved to {csv_path}")
142
  return df
143
 
 
 
 
 
 
 
 
 
 
 
144
  def get_ground_truth_display_data(db_handler, video_id=None, channel_name=None):
145
  """Get ground truth data from both database and CSV file"""
146
- import pandas as pd
147
-
148
  # Try to get data from database first
149
  if video_id:
150
  data = db_handler.get_ground_truth_by_video(video_id)
@@ -203,18 +184,4 @@ def generate_ground_truth_for_all_videos(db_handler, data_processor):
203
  return df
204
  else:
205
  logger.error("Failed to generate questions for any video.")
206
- return None
207
-
208
- def get_evaluation_display_data(video_id=None):
209
- """Get evaluation data from both database and CSV file"""
210
- import pandas as pd
211
-
212
- # Try to get data from CSV
213
- try:
214
- csv_df = pd.read_csv('data/evaluation_results.csv')
215
- if video_id:
216
- csv_df = csv_df[csv_df['video_id'] == video_id]
217
- except FileNotFoundError:
218
- csv_df = pd.DataFrame()
219
-
220
- return csv_df
 
1
  import pandas as pd
2
  import json
3
  from tqdm import tqdm
 
 
 
4
  import logging
5
  import os
 
6
  import sys
7
+ import re
8
+ import sqlite3
9
 
10
+ # Configure logging
11
  logging.basicConfig(
12
  level=logging.INFO,
13
  format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
 
16
  logger = logging.getLogger(__name__)
17
 
18
  def extract_model_name(index_name):
 
19
  match = re.search(r'video_[^_]+_(.+)$', index_name)
20
  if match:
21
  return match.group(1)
22
  return None
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  def get_transcript_from_sqlite(db_path, video_id):
25
  try:
26
  conn = sqlite3.connect(db_path)
 
55
  retries = 0
56
 
57
  while len(all_questions) < 10 and retries < max_retries:
 
58
  try:
59
+ model = pipeline("text-generation", model="google/flan-t5-base", device=-1)
60
+ response = model(prompt_template.format(transcript=transcript),
61
+ max_length=1024,
62
+ num_return_sequences=1)[0]['generated_text']
63
+ questions = json.loads(response)['questions']
64
  all_questions.update(questions)
65
  except Exception as e:
66
  logger.error(f"Error generating questions: {str(e)}")
 
72
  return {"questions": list(all_questions)[:10]}
73
 
74
  def generate_ground_truth(db_handler, data_processor, video_id):
 
 
75
  # Get existing questions for this video to avoid duplicates
76
  existing_questions = set(q[1] for q in db_handler.get_ground_truth_by_video(video_id))
77
 
78
+ # Get transcript from SQLite
79
+ transcript = get_transcript_from_sqlite(db_handler.db_path, video_id)
 
 
 
 
 
 
80
 
81
  if not transcript:
82
  logger.error(f"Failed to retrieve transcript for video {video_id}")
 
114
  logger.info(f"Ground truth data saved to {csv_path}")
115
  return df
116
 
117
+ def get_evaluation_display_data(video_id=None):
118
+ """Get evaluation data from CSV file"""
119
+ try:
120
+ csv_df = pd.read_csv('data/evaluation_results.csv')
121
+ if video_id:
122
+ csv_df = csv_df[csv_df['video_id'] == video_id]
123
+ return csv_df
124
+ except FileNotFoundError:
125
+ return pd.DataFrame()
126
+
127
  def get_ground_truth_display_data(db_handler, video_id=None, channel_name=None):
128
  """Get ground truth data from both database and CSV file"""
 
 
129
  # Try to get data from database first
130
  if video_id:
131
  data = db_handler.get_ground_truth_by_video(video_id)
 
184
  return df
185
  else:
186
  logger.error("Failed to generate questions for any video.")
187
+ return None