Spaces:
Running
Running
import gradio as gr | |
import pandas as pd | |
import json | |
import os | |
import re | |
from PyPDF2 import PdfReader | |
from collections import defaultdict | |
from typing import Dict, List, Optional, Tuple, Union | |
import html | |
from pathlib import Path | |
import fitz # PyMuPDF for better PDF text extraction | |
import pytesseract | |
from PIL import Image | |
import io | |
import secrets | |
import string | |
from huggingface_hub import HfApi, HfFolder | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import time | |
# ========== CONFIGURATION ========== | |
PROFILES_DIR = "student_profiles" | |
ALLOWED_FILE_TYPES = [".pdf", ".png", ".jpg", ".jpeg"] | |
MAX_FILE_SIZE_MB = 5 | |
MIN_AGE = 5 | |
MAX_AGE = 120 | |
SESSION_TOKEN_LENGTH = 32 | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
# Model configuration | |
MODEL_CHOICES = { | |
"TinyLlama (Fastest)": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", | |
"Phi-2 (Balanced)": "microsoft/phi-2", | |
"DeepSeek-V3 (Most Powerful)": "deepseek-ai/deepseek-llm-7b" | |
} | |
DEFAULT_MODEL = "TinyLlama (Fastest)" | |
# Initialize Hugging Face API | |
if HF_TOKEN: | |
hf_api = HfApi(token=HF_TOKEN) | |
HfFolder.save_token(HF_TOKEN) | |
# ========== OPTIMIZED MODEL LOADING ========== | |
class ModelLoader: | |
def __init__(self): | |
self.model = None | |
self.tokenizer = None | |
self.loaded = False | |
self.loading = False | |
self.error = None | |
self.current_model = None | |
def load_model(self, model_name, progress=gr.Progress()): | |
"""Lazy load the model with progress feedback""" | |
if self.loaded and self.current_model == model_name: | |
return self.model, self.tokenizer | |
self.loading = True | |
self.error = None | |
try: | |
progress(0, desc=f"Loading {model_name}...") | |
# Clear previous model if any | |
if self.model: | |
del self.model | |
del self.tokenizer | |
torch.cuda.empty_cache() | |
# Load tokenizer first | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
MODEL_CHOICES[model_name], | |
trust_remote_code=True | |
) | |
progress(0.3, desc="Loaded tokenizer...") | |
# Load model with appropriate settings | |
self.model = AutoModelForCausalLM.from_pretrained( | |
MODEL_CHOICES[model_name], | |
trust_remote_code=True, | |
torch_dtype=torch.float16, | |
device_map="auto" if torch.cuda.is_available() else None, | |
low_cpu_mem_usage=True | |
) | |
progress(0.9, desc="Finalizing...") | |
self.loaded = True | |
self.current_model = model_name | |
return self.model, self.tokenizer | |
except Exception as e: | |
self.error = str(e) | |
print(f"Error loading model: {self.error}") | |
return None, None | |
finally: | |
self.loading = False | |
# Initialize model loader | |
model_loader = ModelLoader() | |
# ========== UTILITY FUNCTIONS ========== | |
def generate_session_token() -> str: | |
"""Generate a random session token for user identification.""" | |
alphabet = string.ascii_letters + string.digits | |
return ''.join(secrets.choice(alphabet) for _ in range(SESSION_TOKEN_LENGTH)) | |
def sanitize_input(text: str) -> str: | |
"""Sanitize user input to prevent XSS and injection attacks.""" | |
return html.escape(text.strip()) | |
def validate_name(name: str) -> str: | |
"""Validate name input.""" | |
name = name.strip() | |
if not name: | |
raise gr.Error("Name cannot be empty") | |
if len(name) > 100: | |
raise gr.Error("Name is too long (max 100 characters)") | |
if any(c.isdigit() for c in name): | |
raise gr.Error("Name cannot contain numbers") | |
return name | |
def validate_age(age: Union[int, float, str]) -> int: | |
"""Validate and convert age input.""" | |
try: | |
age_int = int(age) | |
if not MIN_AGE <= age_int <= MAX_AGE: | |
raise gr.Error(f"Age must be between {MIN_AGE} and {MAX_AGE}") | |
return age_int | |
except (ValueError, TypeError): | |
raise gr.Error("Please enter a valid age number") | |
def validate_file(file_obj) -> None: | |
"""Validate uploaded file.""" | |
if not file_obj: | |
raise gr.Error("No file uploaded") | |
file_ext = os.path.splitext(file_obj.name)[1].lower() | |
if file_ext not in ALLOWED_FILE_TYPES: | |
raise gr.Error(f"Invalid file type. Allowed: {', '.join(ALLOWED_FILE_TYPES)}") | |
file_size = os.path.getsize(file_obj.name) / (1024 * 1024) # MB | |
if file_size > MAX_FILE_SIZE_MB: | |
raise gr.Error(f"File too large. Max size: {MAX_FILE_SIZE_MB}MB") | |
# ========== TEXT EXTRACTION FUNCTIONS ========== | |
def extract_text_from_file(file_path: str, file_ext: str) -> str: | |
"""Enhanced text extraction with better error handling and fallbacks.""" | |
text = "" | |
try: | |
if file_ext == '.pdf': | |
# First try PyMuPDF for better text extraction | |
try: | |
doc = fitz.open(file_path) | |
for page in doc: | |
text += page.get_text("text") + '\n' | |
if not text.strip(): | |
raise ValueError("PyMuPDF returned empty text") | |
except Exception as e: | |
print(f"PyMuPDF failed, trying OCR fallback: {str(e)}") | |
text = extract_text_from_pdf_with_ocr(file_path) | |
elif file_ext in ['.png', '.jpg', '.jpeg']: | |
text = extract_text_with_ocr(file_path) | |
# Clean up the extracted text | |
text = clean_extracted_text(text) | |
if not text.strip(): | |
raise ValueError("No text could be extracted from the file") | |
return text | |
except Exception as e: | |
raise gr.Error(f"Text extraction error: {str(e)}") | |
def extract_text_from_pdf_with_ocr(file_path: str) -> str: | |
"""Fallback PDF text extraction using OCR.""" | |
text = "" | |
try: | |
doc = fitz.open(file_path) | |
for page in doc: | |
pix = page.get_pixmap() | |
img = Image.open(io.BytesIO(pix.tobytes())) | |
text += pytesseract.image_to_string(img) + '\n' | |
except Exception as e: | |
raise ValueError(f"PDF OCR failed: {str(e)}") | |
return text | |
def extract_text_with_ocr(file_path: str) -> str: | |
"""Extract text from image files using OCR with preprocessing.""" | |
try: | |
image = Image.open(file_path) | |
# Preprocess image for better OCR results | |
image = image.convert('L') # Convert to grayscale | |
image = image.point(lambda x: 0 if x < 128 else 255, '1') # Thresholding | |
# Custom Tesseract configuration | |
custom_config = r'--oem 3 --psm 6' | |
text = pytesseract.image_to_string(image, config=custom_config) | |
return text | |
except Exception as e: | |
raise ValueError(f"OCR processing failed: {str(e)}") | |
def clean_extracted_text(text: str) -> str: | |
"""Clean and normalize the extracted text.""" | |
# Remove multiple spaces and newlines | |
text = re.sub(r'\s+', ' ', text).strip() | |
# Fix common OCR errors | |
replacements = { | |
'|': 'I', | |
'‘': "'", | |
'’': "'", | |
'“': '"', | |
'”': '"', | |
'fi': 'fi', | |
'fl': 'fl' | |
} | |
for wrong, right in replacements.items(): | |
text = text.replace(wrong, right) | |
return text | |
def remove_sensitive_info(text: str) -> str: | |
"""Remove potentially sensitive information from transcript text.""" | |
# Remove social security numbers | |
text = re.sub(r'\b\d{3}-\d{2}-\d{4}\b', '[REDACTED]', text) | |
# Remove student IDs (assuming 6-9 digit numbers) | |
text = re.sub(r'\b\d{6,9}\b', '[ID]', text) | |
# Remove email addresses | |
text = re.sub(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', '[EMAIL]', text) | |
return text | |
# ========== TRANSCRIPT PARSING ========== | |
class TranscriptParser: | |
def __init__(self): | |
self.student_data = {} | |
self.requirements = {} | |
self.current_courses = [] | |
self.course_history = [] | |
def parse_transcript(self, text: str) -> Dict: | |
"""Main method to parse transcript text""" | |
self._extract_student_info(text) | |
self._extract_requirements(text) | |
self._extract_course_history(text) | |
self._extract_current_courses(text) | |
return { | |
"student_info": self.student_data, | |
"requirements": self.requirements, | |
"current_courses": self.current_courses, | |
"course_history": self.course_history, | |
"completion_status": self._calculate_completion() | |
} | |
def _extract_student_info(self, text: str): | |
"""Extract student personal information""" | |
header_match = re.search( | |
r"(\d{7}) - ([\w\s,]+)\s*\|\s*Cohort \w+\s*\|\s*Un-weighted GPA ([\d.]+)\s*\|\s*Comm Serv Hours (\d+)", | |
text | |
) | |
if header_match: | |
self.student_data = { | |
"id": header_match.group(1), | |
"name": header_match.group(2).strip(), | |
"unweighted_gpa": float(header_match.group(3)), | |
"community_service_hours": int(header_match.group(4)) | |
} | |
# Extract additional info | |
grade_match = re.search( | |
r"Current Grade: (\d+)\s*\|\s*YOG (\d{4})\s*\|\s*Weighted GPA ([\d.]+)\s*\|\s*Comm Serv Date \d{2}/\d{2}/\d{4}\s*\|\s*Total Credits Earned ([\d.]+)", | |
text | |
) | |
if grade_match: | |
self.student_data.update({ | |
"current_grade": grade_match.group(1), | |
"graduation_year": grade_match.group(2), | |
"weighted_gpa": float(grade_match.group(3)), | |
"total_credits": float(grade_match.group(4)) | |
}) | |
def _extract_requirements(self, text: str): | |
"""Parse the graduation requirements section""" | |
req_table = re.findall( | |
r"\|([A-Z]-[\w\s]+)\s*\|([^\|]+)\|([\d.]+)\s*\|([\d.]+)\s*\|([\d.]+)\s*\|([^\|]+)\|", | |
text | |
) | |
for row in req_table: | |
req_name = row[0].strip() | |
self.requirements[req_name] = { | |
"required": float(row[2]), | |
"completed": float(row[4]), | |
"status": f"{row[5].strip()}%" | |
} | |
def _extract_course_history(self, text: str): | |
"""Parse the detailed course history""" | |
course_lines = re.findall( | |
r"\|([A-Z]-[\w\s&\(\)]+)\s*\|(\d{4}-\d{4})\s*\|(\d{2})\s*\|([A-Z0-9]+)\s*\|([^\|]+)\|([^\|]+)\|([^\|]+)\|([A-Z])\s*\|([YRXW]?)\s*\|([^\|]+)\|", | |
text | |
) | |
for course in course_lines: | |
self.course_history.append({ | |
"requirement_category": course[0].strip(), | |
"school_year": course[1], | |
"grade_level": course[2], | |
"course_code": course[3], | |
"description": course[4].strip(), | |
"term": course[5].strip(), | |
"district_number": course[6].strip(), | |
"grade": course[7], | |
"inclusion_status": course[8], | |
"credits": course[9].strip() | |
}) | |
def _extract_current_courses(self, text: str): | |
"""Identify courses currently in progress""" | |
in_progress = [c for c in self.course_history if "inProgress" in c["credits"]] | |
self.current_courses = [ | |
{ | |
"course": c["description"], | |
"category": c["requirement_category"], | |
"term": c["term"], | |
"credits": c["credits"] | |
} | |
for c in in_progress | |
] | |
def _calculate_completion(self) -> Dict: | |
"""Calculate overall completion status""" | |
total_required = sum(req["required"] for req in self.requirements.values()) | |
total_completed = sum(req["completed"] for req in self.requirements.values()) | |
return { | |
"total_required": total_required, | |
"total_completed": total_completed, | |
"percent_complete": round((total_completed / total_required) * 100, 1), | |
"remaining_credits": total_required - total_completed | |
} | |
def to_json(self) -> str: | |
"""Export parsed data as JSON""" | |
return json.dumps({ | |
"student_info": self.student_data, | |
"requirements": self.requirements, | |
"current_courses": self.current_courses, | |
"course_history": self.course_history, | |
"completion_status": self._calculate_completion() | |
}, indent=2) | |
def parse_transcript_with_ai(text: str, progress=gr.Progress()) -> Dict: | |
"""Use AI model to parse transcript text with progress feedback""" | |
try: | |
# First try structured parsing | |
progress(0.1, desc="Parsing transcript structure...") | |
parser = TranscriptParser() | |
parsed_data = parser.parse_transcript(text) | |
progress(0.9, desc="Formatting results...") | |
# Convert to expected format | |
formatted_data = { | |
"grade_level": parsed_data["student_info"].get("current_grade", "Unknown"), | |
"gpa": { | |
"weighted": parsed_data["student_info"].get("weighted_gpa", "N/A"), | |
"unweighted": parsed_data["student_info"].get("unweighted_gpa", "N/A") | |
}, | |
"courses": [] | |
} | |
# Add courses | |
for course in parsed_data["course_history"]: | |
formatted_data["courses"].append({ | |
"code": course["course_code"], | |
"name": course["description"], | |
"grade": course["grade"], | |
"credits": course["credits"], | |
"year": course["school_year"], | |
"grade_level": course["grade_level"] | |
}) | |
progress(1.0) | |
return validate_parsed_data(formatted_data) | |
except Exception as e: | |
print(f"Structured parsing failed, falling back to AI: {str(e)}") | |
# Fall back to AI parsing if structured parsing fails | |
return parse_transcript_with_ai_fallback(text, progress) | |
def parse_transcript_with_ai_fallback(text: str, progress=gr.Progress()) -> Dict: | |
"""Fallback AI parsing method when structured parsing fails""" | |
# Ensure model is loaded | |
if not model_loader.loaded: | |
model_loader.load_model(model_loader.current_model or DEFAULT_MODEL, progress) | |
if not model_loader.model or not model_loader.tokenizer: | |
raise gr.Error("AI model failed to load. Please try again or select a different model.") | |
# Pre-process the text | |
text = remove_sensitive_info(text[:15000]) # Limit input size | |
prompt = f""" | |
Analyze this academic transcript and extract structured information: | |
- Current grade level | |
- Weighted GPA (if available) | |
- Unweighted GPA (if available) | |
- List of all courses with: | |
* Course code | |
* Course name | |
* Grade received | |
* Credits earned | |
* Year/semester taken | |
* Grade level when taken | |
Return the data in JSON format. | |
Transcript Text: | |
{text} | |
""" | |
try: | |
progress(0.1, desc="Processing transcript with AI...") | |
# Tokenize and generate response | |
inputs = model_loader.tokenizer(prompt, return_tensors="pt").to(model_loader.model.device) | |
progress(0.4) | |
outputs = model_loader.model.generate( | |
**inputs, | |
max_new_tokens=1500, | |
temperature=0.1, | |
do_sample=True | |
) | |
progress(0.8) | |
# Decode the response | |
response = model_loader.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Extract JSON from response | |
json_str = response.split('```json')[1].split('```')[0].strip() if '```json' in response else response | |
# Parse and validate | |
parsed_data = json.loads(json_str) | |
progress(1.0) | |
return validate_parsed_data(parsed_data) | |
except torch.cuda.OutOfMemoryError: | |
raise gr.Error("The model ran out of memory. Try with a smaller transcript or use a smaller model.") | |
except Exception as e: | |
raise gr.Error(f"Error processing transcript: {str(e)}") | |
def validate_parsed_data(data: Dict) -> Dict: | |
"""Validate and clean the parsed data structure.""" | |
if not isinstance(data, dict): | |
raise ValueError("Invalid data format") | |
# Set default structure if missing | |
if 'grade_level' not in data: | |
data['grade_level'] = 'Unknown' | |
if 'gpa' not in data: | |
data['gpa'] = {'weighted': 'N/A', 'unweighted': 'N/A'} | |
if 'courses' not in data: | |
data['courses'] = [] | |
# Clean course data | |
for course in data['courses']: | |
if 'grade' in course: | |
course['grade'] = course['grade'].upper().strip() | |
# Ensure numeric credits are strings | |
if 'credits' in course and isinstance(course['credits'], (int, float)): | |
course['credits'] = str(course['credits']) | |
return data | |
def format_transcript_output(data: Dict) -> str: | |
"""Format the parsed data into human-readable text.""" | |
output = [] | |
output.append(f"Student Transcript Summary\n{'='*40}") | |
output.append(f"Current Grade Level: {data.get('grade_level', 'Unknown')}") | |
if 'gpa' in data: | |
output.append(f"\nGPA:") | |
output.append(f"- Weighted: {data['gpa'].get('weighted', 'N/A')}") | |
output.append(f"- Unweighted: {data['gpa'].get('unweighted', 'N/A')}") | |
if 'courses' in data: | |
output.append("\nCourse History:\n" + '='*40) | |
# Group courses by grade level | |
courses_by_grade = defaultdict(list) | |
for course in data['courses']: | |
grade_level = course.get('grade_level', 'Unknown') | |
courses_by_grade[grade_level].append(course) | |
# Sort grades numerically | |
for grade in sorted(courses_by_grade.keys(), key=lambda x: int(x) if x.isdigit() else x): | |
output.append(f"\nGrade {grade}:\n{'-'*30}") | |
for course in courses_by_grade[grade]: | |
course_str = f"- {course.get('code', '')} {course.get('name', 'Unnamed course')}" | |
if 'grade' in course: | |
course_str += f" (Grade: {course['grade']})" | |
if 'credits' in course: | |
course_str += f" | Credits: {course['credits']}" | |
if 'year' in course: | |
course_str += f" | Year: {course['year']}" | |
output.append(course_str) | |
return '\n'.join(output) | |
def parse_transcript(file_obj, progress=gr.Progress()) -> Tuple[str, Optional[Dict]]: | |
"""Main function to parse transcript files.""" | |
try: | |
if not file_obj: | |
raise ValueError("Please upload a file first") | |
validate_file(file_obj) | |
file_ext = os.path.splitext(file_obj.name)[1].lower() | |
# Extract text from file | |
text = extract_text_from_file(file_obj.name, file_ext) | |
# Use hybrid parsing approach | |
parsed_data = parse_transcript_with_ai(text, progress) | |
# Format output text | |
output_text = format_transcript_output(parsed_data) | |
# Prepare the data structure for saving | |
transcript_data = { | |
"grade_level": parsed_data.get('grade_level', 'Unknown'), | |
"gpa": parsed_data.get('gpa', {}), | |
"courses": defaultdict(list) | |
} | |
# Organize courses by grade level | |
for course in parsed_data.get('courses', []): | |
grade_level = course.get('grade_level', 'Unknown') | |
transcript_data["courses"][grade_level].append(course) | |
return output_text, transcript_data | |
except Exception as e: | |
return f"Error processing transcript: {str(e)}", None | |
# ========== LEARNING STYLE QUIZ ========== | |
class LearningStyleQuiz: | |
def __init__(self): | |
self.questions = [ | |
"When learning something new, I prefer to:", | |
"I remember information best when I:", | |
"When giving directions, I:", | |
"When I have to concentrate, I'm most distracted by:", | |
"I prefer to get new information in:", | |
"When I'm trying to recall something, I:", | |
"When I'm angry, I tend to:", | |
"I tend to:", | |
"When I meet someone new, I remember:", | |
"When I'm relaxing, I prefer to:" | |
] | |
self.options = [ | |
["See diagrams and charts", "Listen to an explanation", "Try it out myself"], | |
["See pictures or diagrams", "Hear someone explain it", "Do something with it"], | |
["Draw a map", "Give verbal instructions", "Show them how to get there"], | |
["Untidiness or movement", "Noises", "Other people moving around"], | |
["Written form", "Spoken form", "Demonstration form"], | |
["See a mental picture", "Repeat it to myself", "Feel it or move my hands"], | |
["Visualize the incident", "Shout and yell", "Stomp around and slam doors"], | |
["Talk to myself", "Use my hands when talking", "Move around a lot"], | |
["Their face", "Their name", "Something we did together"], | |
["Watch TV or read", "Listen to music or talk", "Do something active"] | |
] | |
self.learning_styles = { | |
"Visual": "You learn best through seeing. Use visual aids like diagrams, charts, and color-coding.", | |
"Auditory": "You learn best through listening. Record lectures, discuss concepts, and use rhymes or songs.", | |
"Kinesthetic": "You learn best through movement and touch. Use hands-on activities and take frequent breaks." | |
} | |
def get_quiz_questions(self) -> List[Dict]: | |
"""Return formatted questions for the quiz interface""" | |
return [ | |
{"question": q, "options": opts} | |
for q, opts in zip(self.questions, self.options) | |
] | |
def calculate_learning_style(self, answers: List[int]) -> Dict: | |
"""Calculate the learning style based on user answers""" | |
if len(answers) != len(self.questions): | |
raise ValueError("Invalid number of answers") | |
style_counts = {"Visual": 0, "Auditory": 0, "Kinesthetic": 0} | |
style_map = {0: "Visual", 1: "Auditory", 2: "Kinesthetic"} | |
for answer in answers: | |
if answer not in [0, 1, 2]: | |
raise ValueError("Invalid answer value") | |
style = style_map[answer] | |
style_counts[style] += 1 | |
primary_style = max(style_counts, key=style_counts.get) | |
secondary_styles = [ | |
style for style, count in style_counts.items() | |
if style != primary_style and count > 0 | |
] | |
return { | |
"primary": primary_style, | |
"secondary": secondary_styles, | |
"description": self.learning_styles[primary_style], | |
"scores": style_counts | |
} | |
# Initialize quiz instance | |
learning_style_quiz = LearningStyleQuiz() | |
# ========== PROFILE MANAGEMENT ========== | |
class ProfileManager: | |
def __init__(self): | |
self.profiles_dir = Path(PROFILES_DIR) | |
self.profiles_dir.mkdir(exist_ok=True) | |
def create_profile( | |
self, | |
name: str, | |
age: int, | |
grade_level: str, | |
learning_style: Dict, | |
transcript_data: Optional[Dict] = None | |
) -> str: | |
"""Create a new student profile with all collected data""" | |
try: | |
name = validate_name(name) | |
age = validate_age(age) | |
profile_id = f"{name.lower().replace(' ', '_')}_{age}" | |
profile_path = self.profiles_dir / f"{profile_id}.json" | |
if profile_path.exists(): | |
raise ValueError("Profile already exists") | |
profile_data = { | |
"id": profile_id, | |
"name": name, | |
"age": age, | |
"grade_level": grade_level, | |
"learning_style": learning_style, | |
"transcript": transcript_data or {}, | |
"created_at": time.strftime("%Y-%m-%d %H:%M:%S"), | |
"updated_at": time.strftime("%Y-%m-%d %H:%M:%S") | |
} | |
with open(profile_path, 'w') as f: | |
json.dump(profile_data, f, indent=2) | |
return profile_id | |
except Exception as e: | |
raise gr.Error(f"Error creating profile: {str(e)}") | |
def get_profile(self, profile_id: str) -> Dict: | |
"""Retrieve a student profile by ID""" | |
try: | |
profile_path = self.profiles_dir / f"{profile_id}.json" | |
if not profile_path.exists(): | |
raise ValueError("Profile not found") | |
with open(profile_path, 'r') as f: | |
return json.load(f) | |
except Exception as e: | |
raise gr.Error(f"Error loading profile: {str(e)}") | |
def update_profile(self, profile_id: str, updates: Dict) -> Dict: | |
"""Update an existing profile with new data""" | |
try: | |
profile = self.get_profile(profile_id) | |
profile.update(updates) | |
profile["updated_at"] = time.strftime("%Y-%m-%d %H:%M:%S") | |
profile_path = self.profiles_dir / f"{profile_id}.json" | |
with open(profile_path, 'w') as f: | |
json.dump(profile, f, indent=2) | |
return profile | |
except Exception as e: | |
raise gr.Error(f"Error updating profile: {str(e)}") | |
def list_profiles(self) -> List[Dict]: | |
"""List all available student profiles""" | |
try: | |
profiles = [] | |
for file in self.profiles_dir.glob("*.json"): | |
with open(file, 'r') as f: | |
profile = json.load(f) | |
profiles.append({ | |
"id": profile["id"], | |
"name": profile["name"], | |
"age": profile["age"], | |
"grade_level": profile["grade_level"], | |
"created_at": profile["created_at"] | |
}) | |
return sorted(profiles, key=lambda x: x["name"]) | |
except Exception as e: | |
raise gr.Error(f"Error listing profiles: {str(e)}") | |
# Initialize profile manager | |
profile_manager = ProfileManager() | |
# ========== AI TEACHING ASSISTANT ========== | |
class TeachingAssistant: | |
def __init__(self): | |
self.model_loader = model_loader | |
def generate_study_plan(self, profile_data: Dict, progress=gr.Progress()) -> str: | |
"""Generate a personalized study plan based on student profile""" | |
try: | |
# Ensure model is loaded | |
if not self.model_loader.loaded: | |
self.model_loader.load_model(DEFAULT_MODEL, progress) | |
learning_style = profile_data.get("learning_style", {}) | |
transcript = profile_data.get("transcript", {}) | |
# Prepare prompt | |
prompt = f""" | |
Create a personalized study plan for {profile_data['name']}, a {profile_data['age']}-year-old student in grade {profile_data['grade_level']}. | |
Learning Style: | |
- Primary: {learning_style.get('primary', 'Unknown')} | |
- Description: {learning_style.get('description', 'No learning style information')} | |
Academic History: | |
- Current GPA: {transcript.get('gpa', {}).get('weighted', 'N/A')} (weighted) | |
- Courses Completed: {len(transcript.get('courses', []))} | |
Focus on study techniques that match the student's learning style and provide specific recommendations based on their academic history. | |
Include: | |
1. Daily study routine suggestions | |
2. Subject-specific strategies | |
3. Recommended resources | |
4. Time management tips | |
5. Any areas that need improvement | |
Format the response with clear headings and bullet points. | |
""" | |
progress(0.2, desc="Generating study plan...") | |
# Generate response | |
inputs = self.model_loader.tokenizer(prompt, return_tensors="pt").to(self.model_loader.model.device) | |
outputs = self.model_loader.model.generate( | |
**inputs, | |
max_new_tokens=1000, | |
temperature=0.7, | |
do_sample=True | |
) | |
progress(0.8, desc="Formatting response...") | |
response = self.model_loader.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return self._format_response(response) | |
except Exception as e: | |
raise gr.Error(f"Error generating study plan: {str(e)}") | |
def answer_question(self, question: str, context: str = "", progress=gr.Progress()) -> str: | |
"""Answer student questions with optional context""" | |
try: | |
if not question.strip(): | |
return "Please ask a question." | |
# Ensure model is loaded | |
if not self.model_loader.loaded: | |
self.model_loader.load_model(DEFAULT_MODEL, progress) | |
prompt = f""" | |
Answer the following student question in a helpful, educational manner. | |
{f"Context: {context}" if context else ""} | |
Question: {question} | |
Provide a clear, concise answer with examples if helpful. Break down complex concepts. | |
If the question is unclear, ask for clarification. | |
""" | |
progress(0.3, desc="Processing question...") | |
# Generate response | |
inputs = self.model_loader.tokenizer(prompt, return_tensors="pt").to(self.model_loader.model.device) | |
outputs = self.model_loader.model.generate( | |
**inputs, | |
max_new_tokens=500, | |
temperature=0.5, | |
do_sample=True | |
) | |
progress(0.8, desc="Formatting answer...") | |
response = self.model_loader.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return self._format_response(response) | |
except Exception as e: | |
raise gr.Error(f"Error answering question: {str(e)}") | |
def _format_response(self, text: str) -> str: | |
"""Format the AI response for better readability""" | |
# Clean up common artifacts | |
text = text.replace("<|endoftext|>", "").strip() | |
# Add markdown formatting if not present | |
if "#" not in text and "**" not in text: | |
# Split into paragraphs and add headings | |
sections = text.split("\n\n") | |
formatted = [] | |
for section in sections: | |
if section.strip().endswith(":"): | |
formatted.append(f"**{section}**") | |
else: | |
formatted.append(section) | |
text = "\n\n".join(formatted) | |
return text | |
# Initialize teaching assistant | |
teaching_assistant = TeachingAssistant() | |
# ========== GRADIO INTERFACE ========== | |
def create_interface(): | |
with gr.Blocks(title="Student Profile Assistant", theme="soft") as app: | |
session_token = gr.State(generate_session_token()) | |
# Tab navigation | |
with gr.Tabs(): | |
with gr.Tab("Profile Creation"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("## Student Information") | |
name_input = gr.Textbox(label="Full Name", placeholder="Enter student's full name") | |
age_input = gr.Number(label="Age", minimum=MIN_AGE, maximum=MAX_AGE, step=1) | |
grade_level = gr.Dropdown( | |
label="Grade Level", | |
choices=["9", "10", "11", "12", "Other"], | |
value="9" | |
) | |
gr.Markdown("## Transcript Upload") | |
file_upload = gr.File(label="Upload Transcript", file_types=ALLOWED_FILE_TYPES) | |
parse_btn = gr.Button("Parse Transcript") | |
transcript_output = gr.Textbox(label="Transcript Summary", interactive=False, lines=10) | |
with gr.Column(scale=1): | |
gr.Markdown("## Learning Style Quiz") | |
quiz_components = [] | |
for i, question in enumerate(learning_style_quiz.questions): | |
quiz_components.append( | |
gr.Radio( | |
label=question, | |
choices=learning_style_quiz.options[i], | |
type="index" | |
) | |
) | |
quiz_submit = gr.Button("Submit Quiz") | |
learning_style_output = gr.JSON(label="Learning Style Results") | |
gr.Markdown("## Complete Profile") | |
create_profile_btn = gr.Button("Create Profile") | |
profile_status = gr.Textbox(label="Profile Status", interactive=False) | |
with gr.Tab("Study Tools"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("## Study Plan Generator") | |
profile_selector = gr.Dropdown( | |
label="Select Profile", | |
choices=[p["id"] for p in profile_manager.list_profiles()], | |
interactive=True | |
) | |
refresh_profiles = gr.Button("Refresh Profiles") | |
study_plan_btn = gr.Button("Generate Study Plan") | |
study_plan_output = gr.Markdown(label="Personalized Study Plan") | |
with gr.Column(scale=1): | |
gr.Markdown("## Ask the Teaching Assistant") | |
question_input = gr.Textbox(label="Your Question", lines=3) | |
context_input = gr.Textbox(label="Additional Context (optional)", lines=2) | |
ask_btn = gr.Button("Ask Question") | |
answer_output = gr.Markdown(label="Answer") | |
with gr.Tab("Profile Management"): | |
gr.Markdown("## Existing Profiles") | |
profile_table = gr.Dataframe( | |
headers=["Name", "Age", "Grade Level", "Created At"], | |
datatype=["str", "number", "str", "str"], | |
interactive=False | |
) | |
refresh_table = gr.Button("Refresh Profiles") | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("## Profile Details") | |
selected_profile = gr.Dropdown( | |
label="Select Profile", | |
choices=[p["id"] for p in profile_manager.list_profiles()], | |
interactive=True | |
) | |
view_profile_btn = gr.Button("View Profile") | |
profile_display = gr.JSON(label="Profile Data") | |
with gr.Column(): | |
gr.Markdown("## Update Profile") | |
update_grade = gr.Dropdown( | |
label="Update Grade Level", | |
choices=["9", "10", "11", "12", "Other"], | |
interactive=True | |
) | |
update_transcript = gr.File(label="Update Transcript", file_types=ALLOWED_FILE_TYPES) | |
update_btn = gr.Button("Update Profile") | |
update_status = gr.Textbox(label="Update Status", interactive=False) | |
# ========== EVENT HANDLERS ========== | |
# Transcript parsing | |
parse_btn.click( | |
parse_transcript, | |
inputs=[file_upload], | |
outputs=[transcript_output, gr.State()], | |
show_progress=True | |
) | |
# Learning style quiz | |
quiz_submit.click( | |
learning_style_quiz.calculate_learning_style, | |
inputs=quiz_components, | |
outputs=learning_style_output | |
) | |
# Profile creation | |
create_profile_btn.click( | |
profile_manager.create_profile, | |
inputs=[ | |
name_input, | |
age_input, | |
grade_level, | |
learning_style_output, | |
gr.State() | |
], | |
outputs=profile_status | |
).then( | |
lambda: [p["id"] for p in profile_manager.list_profiles()], | |
outputs=profile_selector | |
).then( | |
lambda: [p["id"] for p in profile_manager.list_profiles()], | |
outputs=selected_profile | |
).then( | |
lambda: profile_manager.list_profiles(), | |
outputs=profile_table | |
) | |
# Study tools | |
refresh_profiles.click( | |
lambda: [p["id"] for p in profile_manager.list_profiles()], | |
outputs=profile_selector | |
) | |
study_plan_btn.click( | |
lambda profile_id: profile_manager.get_profile(profile_id), | |
inputs=profile_selector, | |
outputs=gr.State() | |
).then( | |
teaching_assistant.generate_study_plan, | |
inputs=gr.State(), | |
outputs=study_plan_output, | |
show_progress=True | |
) | |
# Teaching assistant | |
ask_btn.click( | |
teaching_assistant.answer_question, | |
inputs=[question_input, context_input], | |
outputs=answer_output, | |
show_progress=True | |
) | |
# Profile management | |
refresh_table.click( | |
lambda: profile_manager.list_profiles(), | |
outputs=profile_table | |
).then( | |
lambda: [p["id"] for p in profile_manager.list_profiles()], | |
outputs=selected_profile | |
) | |
view_profile_btn.click( | |
profile_manager.get_profile, | |
inputs=selected_profile, | |
outputs=profile_display | |
) | |
update_btn.click( | |
lambda profile_id, grade, file_obj: ( | |
profile_manager.update_profile( | |
profile_id, | |
{"grade_level": grade} | |
) if not file_obj else None, | |
parse_transcript(file_obj) if file_obj else (None, None) | |
), | |
inputs=[selected_profile, update_grade, update_transcript], | |
outputs=[profile_display, gr.State()] | |
).then( | |
lambda: "Profile updated successfully!", | |
outputs=update_status | |
) | |
# Initialization | |
app.load( | |
lambda: profile_manager.list_profiles(), | |
outputs=profile_table | |
).then( | |
lambda: [p["id"] for p in profile_manager.list_profiles()], | |
outputs=profile_selector | |
).then( | |
lambda: [p["id"] for p in profile_manager.list_profiles()], | |
outputs=selected_profile | |
) | |
return app | |
# Create the interface | |
app = create_interface() | |
# For Hugging Face Spaces deployment | |
if __name__ == "__main__": | |
app.launch() | |