Spaces:
Sleeping
Sleeping
import os | |
import re | |
import json | |
import time | |
import gradio as gr | |
import tempfile | |
from typing import Dict, Any, List, Optional | |
from transformers import AutoTokenizer | |
from sentence_transformers import SentenceTransformer | |
from pydantic import BaseModel, Field | |
from anthropic import Anthropic | |
from huggingface_hub import login | |
from src.prompts import SYSTEM_PROMPT, ANALYSIS_PROMPT_TEMPLATE_CLAUDE, ANALYSIS_PROMPT_TEMPLATE_GEMINI | |
CLAUDE_MODEL = "claude-3-5-sonnet-20241022" | |
OPENAI_MODEL = "gpt-4o" | |
GEMINI_MODEL = "gemini-2.0-flash" | |
DEFAULT_TEMPERATURE = 0.7 | |
TOKENIZER_MODEL = "answerdotai/ModernBERT-base" | |
SENTENCE_TRANSFORMER_MODEL = "all-MiniLM-L6-v2" | |
hf_token = os.environ.get('HF_TOKEN', None) | |
login(token=hf_token) | |
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_MODEL) | |
sentence_model = SentenceTransformer(SENTENCE_TRANSFORMER_MODEL) | |
class CourseInfo(BaseModel): | |
course_name: str = Field(description="Name of the course") | |
section_name: str = Field(description="Name of the course section") | |
lesson_name: str = Field(description="Name of the lesson") | |
class QuizOption(BaseModel): | |
text: str = Field(description="The text of the answer option") | |
correct: bool = Field(description="Whether this option is correct") | |
class QuizQuestion(BaseModel): | |
question: str = Field(description="The text of the quiz question") | |
options: List[QuizOption] = Field(description="List of answer options") | |
class Segment(BaseModel): | |
segment_number: int = Field(description="The segment number") | |
topic_name: str = Field(description="Unique and specific topic name that clearly differentiates it from other segments") | |
key_concepts: List[str] = Field(description="3-5 key concepts discussed in the segment") | |
summary: str = Field(description="Brief summary of the segment (3-5 sentences)") | |
quiz_questions: List[QuizQuestion] = Field(description="5 quiz questions based on the segment content") | |
class TextSegmentAnalysis(BaseModel): | |
course_info: CourseInfo = Field(description="Information about the course") | |
segments: List[Segment] = Field(description="List of text segments with analysis") | |
def clean_text(text): | |
text = re.sub(r'\[speaker_\d+\]', '', text) | |
text = re.sub(r'\s+', ' ', text).strip() | |
return text | |
def split_text_by_tokens(text, max_tokens=12000): | |
text = clean_text(text) | |
tokens = tokenizer.encode(text) | |
if len(tokens) <= max_tokens: | |
return [text] | |
split_point = len(tokens) // 2 | |
sentences = re.split(r'(?<=[.!?])\s+', text) | |
first_half = [] | |
second_half = [] | |
current_tokens = 0 | |
for sentence in sentences: | |
sentence_tokens = len(tokenizer.encode(sentence)) | |
if current_tokens + sentence_tokens <= split_point: | |
first_half.append(sentence) | |
current_tokens += sentence_tokens | |
else: | |
second_half.append(sentence) | |
return [" ".join(first_half), " ".join(second_half)] | |
def generate_with_claude(text, api_key, course_name="", section_name="", lesson_name=""): | |
client = Anthropic(api_key=api_key) | |
segment_analysis_schema = TextSegmentAnalysis.model_json_schema() | |
tools = [ | |
{ | |
"name": "build_segment_analysis", | |
"description": "Build the text segment analysis with quiz questions", | |
"input_schema": segment_analysis_schema | |
} | |
] | |
prompt = ANALYSIS_PROMPT_TEMPLATE_CLAUDE.format( | |
course_name=course_name, | |
section_name=section_name, | |
lesson_name=lesson_name, | |
text=text | |
) | |
try: | |
response = client.messages.create( | |
model=CLAUDE_MODEL, | |
max_tokens=8192, | |
temperature=DEFAULT_TEMPERATURE, | |
system=SYSTEM_PROMPT, | |
messages=[ | |
{ | |
"role": "user", | |
"content": prompt | |
} | |
], | |
tools=tools, | |
tool_choice={"type": "tool", "name": "build_segment_analysis"} | |
) | |
# Extract the tool call content | |
if response.content and len(response.content) > 0 and hasattr(response.content[0], 'input'): | |
function_call = response.content[0].input | |
return function_call | |
else: | |
raise Exception("No valid tool call found in the response") | |
except Exception as e: | |
raise Exception(f"Error calling Anthropic API: {str(e)}") | |
def get_active_api_key(gemini_key, claude_key, language): | |
if language == "Uzbek" and claude_key: | |
return claude_key, "claude" | |
else: | |
return gemini_key, "gemini" | |
def segment_and_analyze_text(text: str, gemini_api_key: str, claude_api_key: str, language: str, | |
course_name="", section_name="", lesson_name="") -> Dict[str, Any]: | |
active_key, api_type = get_active_api_key(gemini_api_key, claude_api_key, language) | |
if api_type == "claude": | |
return generate_with_claude(text, active_key, course_name, section_name, lesson_name) | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
os.environ["GOOGLE_API_KEY"] = active_key | |
llm = ChatGoogleGenerativeAI( | |
model=GEMINI_MODEL, | |
temperature=DEFAULT_TEMPERATURE, | |
max_retries=3 | |
) | |
base_prompt = ANALYSIS_PROMPT_TEMPLATE_GEMINI.format( | |
course_name=course_name, | |
section_name=section_name, | |
lesson_name=lesson_name, | |
text=text | |
) | |
language_instruction = f"\nIMPORTANT: Generate ALL content (including topic names, key concepts, summaries, and quiz questions) in {language} language." | |
prompt = base_prompt + language_instruction | |
try: | |
messages = [ | |
{"role": "system", "content": SYSTEM_PROMPT}, | |
{"role": "user", "content": prompt} | |
] | |
response = llm.invoke(messages) | |
try: | |
content = response.content | |
json_match = re.search(r'```json\s*([\s\S]*?)\s*```', content) | |
if json_match: | |
json_str = json_match.group(1) | |
else: | |
json_match = re.search(r'(\{[\s\S]*\})', content) | |
if json_match: | |
json_str = json_match.group(1) | |
else: | |
json_str = content | |
# Parse the JSON | |
function_call = json.loads(json_str) | |
return function_call | |
except json.JSONDecodeError: | |
raise Exception("Could not parse JSON from LLM response") | |
except Exception as e: | |
raise Exception(f"Error calling API: {str(e)}") | |
def format_quiz_for_display(results, language="English"): | |
output = [] | |
if language == "Uzbek": | |
course_header = "KURS" | |
section_header = "BO'LIM" | |
lesson_header = "DARS" | |
segment_header = "QISM" | |
key_concepts_header = "ASOSIY TUSHUNCHALAR" | |
summary_header = "QISQACHA MAZMUN" | |
quiz_questions_header = "TEST SAVOLLARI" | |
elif language == "Russian": | |
course_header = "КУРС" | |
section_header = "РАЗДЕЛ" | |
lesson_header = "УРОК" | |
segment_header = "СЕГМЕНТ" | |
key_concepts_header = "КЛЮЧЕВЫЕ ПОНЯТИЯ" | |
summary_header = "КРАТКОЕ СОДЕРЖАНИЕ" | |
quiz_questions_header = "ТЕСТОВЫЕ ВОПРОСЫ" | |
else: | |
course_header = "COURSE" | |
section_header = "SECTION" | |
lesson_header = "LESSON" | |
segment_header = "SEGMENT" | |
key_concepts_header = "KEY CONCEPTS" | |
summary_header = "SUMMARY" | |
quiz_questions_header = "QUIZ QUESTIONS" | |
if "course_info" in results: | |
course_info = results["course_info"] | |
output.append(f"{'='*40}") | |
output.append(f"{course_header}: {course_info.get('course_name', 'N/A')}") | |
output.append(f"{section_header}: {course_info.get('section_name', 'N/A')}") | |
output.append(f"{lesson_header}: {course_info.get('lesson_name', 'N/A')}") | |
output.append(f"{'='*40}\n") | |
segments = results.get("segments", []) | |
for i, segment in enumerate(segments): | |
topic = segment["topic_name"] | |
segment_num = i + 1 | |
output.append(f"\n\n{'='*40}") | |
output.append(f"{segment_header} {segment_num}: {topic}") | |
output.append(f"{'='*40}\n") | |
output.append(f"{key_concepts_header}:") | |
for concept in segment["key_concepts"]: | |
output.append(f"• {concept}") | |
output.append(f"\n{summary_header}:") | |
output.append(segment["summary"]) | |
output.append(f"\n{quiz_questions_header}:") | |
for i, q in enumerate(segment["quiz_questions"]): | |
output.append(f"\n{i+1}. {q['question']}") | |
for j, option in enumerate(q['options']): | |
letter = chr(97 + j).upper() | |
correct_marker = " ✓" if option["correct"] else "" | |
output.append(f" {letter}. {option['text']}{correct_marker}") | |
return "\n".join(output) | |
def analyze_document(text, gemini_api_key, claude_api_key, course_name, section_name, lesson_name, language): | |
try: | |
start_time = time.time() | |
text_parts = split_text_by_tokens(text) | |
input_tokens = 0 | |
output_tokens = 0 | |
all_results = { | |
"course_info": { | |
"course_name": course_name, | |
"section_name": section_name, | |
"lesson_name": lesson_name | |
}, | |
"segments": [] | |
} | |
segment_counter = 1 | |
# Process each part of the text | |
for part in text_parts: | |
if language == "Uzbek" and claude_api_key: | |
# from prompts import ANALYSIS_PROMPT_TEMPLATE_CLAUDE | |
prompt_template = ANALYSIS_PROMPT_TEMPLATE_CLAUDE | |
else: | |
# from prompts import ANALYSIS_PROMPT_TEMPLATE_GEMINI | |
prompt_template = ANALYSIS_PROMPT_TEMPLATE_GEMINI | |
# Format the prompt with actual values | |
actual_prompt = prompt_template.format( | |
course_name=course_name, | |
section_name=section_name, | |
lesson_name=lesson_name, | |
text=part | |
) | |
prompt_tokens = len(tokenizer.encode(actual_prompt)) | |
input_tokens += prompt_tokens | |
analysis = segment_and_analyze_text( | |
text, | |
gemini_api_key, | |
claude_api_key, | |
language, | |
course_name=course_name, | |
section_name=section_name, | |
lesson_name=lesson_name | |
) | |
if "segments" in analysis: | |
for segment in analysis["segments"]: | |
segment["segment_number"] = segment_counter | |
all_results["segments"].append(segment) | |
segment_counter += 1 | |
end_time = time.time() | |
total_time = end_time - start_time | |
print(f"Total quiz processing time: {total_time}s") | |
formatted_output = format_quiz_for_display(all_results, language) | |
output_tokens = len(tokenizer.encode(formatted_output)) | |
token_info = f"Input tokens: {input_tokens}\nOutput tokens: {output_tokens}\nTotal tokens: {input_tokens + output_tokens}\n" | |
formatted_text = format_quiz_for_display(all_results, language) | |
formatted_text = f"Total quiz Processing time: {total_time:.2f}s\n{token_info}\n" + formatted_text | |
output_tokens = len(tokenizer.encode(formatted_output)) | |
json_path = tempfile.mktemp(suffix='.json') | |
with open(json_path, 'w', encoding='utf-8') as json_file: | |
json.dump(all_results, json_file, indent=2) | |
txt_path = tempfile.mktemp(suffix='.txt') | |
with open(txt_path, 'w', encoding='utf-8') as txt_file: | |
txt_file.write(formatted_text) | |
return formatted_text, json_path, txt_path | |
except Exception as e: | |
error_message = f"Error processing document: {str(e)}" | |
return error_message, None, None |