SimpleLearn_2 / src /quiz_processing.py
MrSimple01's picture
Update src/quiz_processing.py
42638f9 verified
import os
import re
import json
import time
import tempfile
from typing import Dict, Any, List, Optional
from transformers import AutoTokenizer
from sentence_transformers import SentenceTransformer
from huggingface_hub import login
from src.prompts import SUMMARY_PROMPT_TEMPLATE, QUIZ_PROMPT_TEMPLATE
GEMINI_MODEL = "gemini-2.0-flash"
DEFAULT_TEMPERATURE = 0.7
TOKENIZER_MODEL = "answerdotai/ModernBERT-base"
SENTENCE_TRANSFORMER_MODEL = "all-MiniLM-L6-v2"
hf_token = os.environ.get('HF_TOKEN', None)
login(token=hf_token)
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_MODEL)
sentence_model = SentenceTransformer(SENTENCE_TRANSFORMER_MODEL)
def clean_text(text):
text = re.sub(r'\[speaker_\d+\]', '', text)
text = re.sub(r'\s+', ' ', text).strip()
return text
def split_text_by_tokens(text, max_tokens=12000):
text = clean_text(text)
tokens = tokenizer.encode(text)
if len(tokens) <= max_tokens:
return [text]
split_point = len(tokens) // 2
sentences = re.split(r'(?<=[.!?])\s+', text)
first_half = []
second_half = []
current_tokens = 0
for sentence in sentences:
sentence_tokens = len(tokenizer.encode(sentence))
if current_tokens + sentence_tokens <= split_point:
first_half.append(sentence)
current_tokens += sentence_tokens
else:
second_half.append(sentence)
return [" ".join(first_half), " ".join(second_half)]
def generate_with_gemini(text, api_key, language, content_type="summary"):
from langchain_google_genai import ChatGoogleGenerativeAI
os.environ["GOOGLE_API_KEY"] = api_key
llm = ChatGoogleGenerativeAI(
model=GEMINI_MODEL,
temperature=DEFAULT_TEMPERATURE,
max_retries=3
)
if content_type == "summary":
base_prompt = SUMMARY_PROMPT_TEMPLATE.format(text=text)
else:
base_prompt = QUIZ_PROMPT_TEMPLATE.format(text=text)
language_instruction = f"\nIMPORTANT: Generate ALL content in {language} language."
prompt = base_prompt + language_instruction
try:
messages = [
{"role": "system", "content": "You are a helpful AI assistant that creates high-quality text summaries and quizzes."},
{"role": "user", "content": prompt}
]
response = llm.invoke(messages)
try:
content = response.content
json_match = re.search(r'```json\s*([\s\S]*?)\s*```', content)
if json_match:
json_str = json_match.group(1)
else:
json_match = re.search(r'(\{[\s\S]*\})', content)
if json_match:
json_str = json_match.group(1)
else:
json_str = content
# Parse the JSON
function_call = json.loads(json_str)
return function_call
except json.JSONDecodeError:
raise Exception("Could not parse JSON from LLM response")
except Exception as e:
raise Exception(f"Error calling API: {str(e)}")
def format_summary_for_display(results, language="English"):
output = []
if language == "Uzbek":
title_header = "SARLAVHA"
overview_header = "UMUMIY KO'RINISH"
key_points_header = "ASOSIY NUQTALAR"
key_entities_header = "ASOSIY SHAXSLAR VA TUSHUNCHALAR"
conclusion_header = "XULOSA"
elif language == "Russian":
title_header = "ЗАГОЛОВОК"
overview_header = "ОБЗОР"
key_points_header = "КЛЮЧЕВЫЕ МОМЕНТЫ"
key_entities_header = "КЛЮЧЕВЫЕ ОБЪЕКТЫ"
conclusion_header = "ЗАКЛЮЧЕНИЕ"
else:
title_header = "TITLE"
overview_header = "OVERVIEW"
key_points_header = "KEY POINTS"
key_entities_header = "KEY ENTITIES"
conclusion_header = "CONCLUSION"
if "summary" not in results:
if "segments" in results:
segments = results.get("segments", [])
for i, segment in enumerate(segments):
topic = segment.get("topic_name", f"Section {i+1}")
segment_num = i + 1
output.append(f"\n\n{'='*40}")
output.append(f"SEGMENT {segment_num}: {topic}")
output.append(f"{'='*40}\n")
if "key_concepts" in segment:
output.append("KEY CONCEPTS:")
for concept in segment["key_concepts"]:
output.append(f"• {concept}")
if "summary" in segment:
output.append("\nSUMMARY:")
output.append(segment["summary"])
return "\n".join(output)
else:
return "Error: Could not parse summary results. Invalid format received."
summary = results["summary"]
if "title" in summary:
output.append(f"\n\n{'='*40}")
output.append(f"{title_header}: {summary['title']}")
output.append(f"{'='*40}\n")
# Overview
if "overview" in summary:
output.append(f"{overview_header}:")
output.append(f"{summary['overview']}\n")
# Key Points
if "key_points" in summary and summary["key_points"]:
output.append(f"{key_points_header}:")
for theme_group in summary["key_points"]:
if "theme" in theme_group:
output.append(f"\n{theme_group['theme']}:")
if "points" in theme_group:
for point in theme_group["points"]:
output.append(f"• {point}")
# Key Entities
if "key_entities" in summary and summary["key_entities"]:
output.append(f"\n{key_entities_header}:")
for entity in summary["key_entities"]:
if "name" in entity and "description" in entity:
output.append(f"• **{entity['name']}**: {entity['description']}")
# Conclusion
if "conclusion" in summary:
output.append(f"\n{conclusion_header}:")
output.append(summary["conclusion"])
return "\n".join(output)
def format_quiz_for_display(results, language="English"):
output = []
if language == "Uzbek":
quiz_questions_header = "TEST SAVOLLARI"
elif language == "Russian":
quiz_questions_header = "ТЕСТОВЫЕ ВОПРОСЫ"
else:
quiz_questions_header = "QUIZ QUESTIONS"
output.append(f"{'='*40}")
output.append(f"{quiz_questions_header}")
output.append(f"{'='*40}\n")
quiz_questions = results.get("quiz_questions", [])
for i, q in enumerate(quiz_questions):
output.append(f"\n{i+1}. {q['question']}")
for j, option in enumerate(q['options']):
letter = chr(97 + j).upper()
correct_marker = " ✓" if option["correct"] else ""
output.append(f" {letter}. {option['text']}{correct_marker}")
return "\n".join(output)
def analyze_document(text, gemini_api_key, language, content_type="summary"):
try:
start_time = time.time()
text_parts = split_text_by_tokens(text)
input_tokens = 0
output_tokens = 0
if content_type == "summary":
all_results = {}
for part in text_parts:
actual_prompt = SUMMARY_PROMPT_TEMPLATE.format(text=part)
prompt_tokens = len(tokenizer.encode(actual_prompt))
input_tokens += prompt_tokens
analysis = generate_with_gemini(part, gemini_api_key, language, "summary")
if not all_results and "summary" in analysis:
all_results = analysis
elif "summary" in analysis:
if "key_points" in analysis["summary"] and "key_points" in all_results["summary"]:
all_results["summary"]["key_points"].extend(analysis["summary"]["key_points"])
if "key_entities" in analysis["summary"] and "key_entities" in all_results["summary"]:
all_results["summary"]["key_entities"].extend(analysis["summary"]["key_entities"])
formatted_output = format_summary_for_display(all_results, language)
else:
all_results = {"quiz_questions": []}
for part in text_parts:
actual_prompt = QUIZ_PROMPT_TEMPLATE.format(text=part)
prompt_tokens = len(tokenizer.encode(actual_prompt))
input_tokens += prompt_tokens
analysis = generate_with_gemini(part, gemini_api_key, language, "quiz")
if "quiz_questions" in analysis:
remaining_slots = 10 - len(all_results["quiz_questions"])
if remaining_slots > 0:
questions_to_add = analysis["quiz_questions"][:remaining_slots]
all_results["quiz_questions"].extend(questions_to_add)
formatted_output = format_quiz_for_display(all_results, language)
end_time = time.time()
total_time = end_time - start_time
output_tokens = len(tokenizer.encode(formatted_output))
token_info = f"Input tokens: {input_tokens}\nOutput tokens: {output_tokens}\nTotal tokens: {input_tokens + output_tokens}\n"
formatted_text = f"Total Processing time: {total_time:.2f}s\n{token_info}\n" + formatted_output
json_path = tempfile.mktemp(suffix='.json')
with open(json_path, 'w', encoding='utf-8') as json_file:
json.dump(all_results, json_file, indent=2)
txt_path = tempfile.mktemp(suffix='.txt')
with open(txt_path, 'w', encoding='utf-8') as txt_file:
txt_file.write(formatted_text)
return formatted_text, json_path, txt_path
except Exception as e:
error_message = f"Error processing document: {str(e)}"
return error_message, None, None