Al-Atlas-LLM / app.py
nouamanetazi's picture
nouamanetazi HF Staff
Update app.py
ece7108 verified
raw
history blame
14.1 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import os
import spaces
import torch
from datasets import load_dataset
from huggingface_hub import CommitScheduler
from pathlib import Path
import uuid
import json
import time
from datetime import datetime
import logging
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("app.log"),
logging.StreamHandler()
]
)
logger = logging.getLogger("darija-llm")
device = "cuda:0" if torch.cuda.is_available() else "cpu"
logger.info(f'Using device: {device}')
# token
token = os.environ['TOKEN']
# Load the pretrained model and tokenizer
MODEL_NAME = "atlasia/Al-Atlas-0.5B" # "atlasia/Al-Atlas-LLM-mid-training" # "BounharAbdelaziz/Al-Atlas-LLM-0.5B" #"atlasia/Al-Atlas-LLM"
logger.info(f"Loading model: {MODEL_NAME}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME,token=token) # , token=token
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME,token=token).to(device)
logger.info("Model loaded successfully")
# Fix tokenizer padding
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token # Set pad token
logger.info("Set pad_token to eos_token")
# Predefined examples
examples = [
["الذكاء الاصطناعي هو فرع من علوم الكمبيوتر اللي كيركز"
, 256, 0.7, 0.9, 150, 4, 1.5],
["المستقبل ديال الذكاء الصناعي فالمغرب"
, 256, 0.7, 0.9, 150, 4, 1.5],
[" المطبخ المغربي"
, 256, 0.7, 0.9, 150, 4, 1.5],
["الماكلة المغربية كتعتبر من أحسن الماكلات فالعالم"
, 256, 0.7, 0.9, 150, 4, 1.5],
]
# Define the file where to save the data
submit_file = Path("user_submit/") / f"data_{uuid.uuid4()}.json"
feedback_file = submit_file
# Create directory if it doesn't exist
submit_file.parent.mkdir(exist_ok=True, parents=True)
logger.info(f"Created feedback file: {feedback_file}")
scheduler = CommitScheduler(
repo_id="atlasia/atlaset_inference_ds",
repo_type="dataset",
folder_path=submit_file.parent,
path_in_repo="data",
every=5,
token=token
)
logger.info(f"Initialized CommitScheduler for repo: atlasia/atlaset_inference_ds")
# Track usage statistics
usage_stats = {
"total_generations": 0,
"total_tokens_generated": 0,
"start_time": time.time()
}
@spaces.GPU
def generate_text(prompt, max_length=256, temperature=0.7, top_p=0.9, top_k=150, num_beams=8, repetition_penalty=1.5, progress=gr.Progress()):
if not prompt.strip():
logger.warning("Empty prompt submitted")
return "", "الرجاء إدخال نص للتوليد (Please enter text to generate)"
logger.info(f"Generating text for prompt: '{prompt[:50]}...' (length: {len(prompt)})")
logger.info(f"Parameters: max_length={max_length}, temp={temperature}, top_p={top_p}, top_k={top_k}, beams={num_beams}, rep_penalty={repetition_penalty}")
start_time = time.time()
# Start progress
progress(0, desc="تجهيز النموذج (Preparing model)")
# Tokenize input
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
progress(0.1, desc="تحليل النص (Tokenizing)")
# Generate text with optimized parameters for speed
progress(0.2, desc="توليد النص (Generating text)")
output = model.generate(
**inputs,
max_length=max_length,
temperature=temperature,
top_p=top_p,
do_sample=True,
repetition_penalty=repetition_penalty,
num_beams=1 if num_beams > 4 else num_beams, # Reduce beam search or use greedy decoding
top_k=top_k,
early_stopping=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
use_cache=True, # Ensure cache is used
)
# Decode output
progress(0.9, desc="معالجة النتائج (Processing results)")
result = tokenizer.decode(output[0], skip_special_tokens=True)
# Update stats
generation_time = time.time() - start_time
token_count = len(output[0])
with scheduler.lock:
usage_stats["total_generations"] += 1
usage_stats["total_tokens_generated"] += token_count
logger.info(f"Generated {token_count} tokens in {generation_time:.2f}s")
logger.info(f"Result: '{result[:50]}...' (length: {len(result)})")
# Save feedback with additional metadata
save_feedback(
prompt,
result,
{
"max_length": max_length,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"num_beams": num_beams,
"repetition_penalty": repetition_penalty,
"generation_time": generation_time,
"token_count": token_count,
"timestamp": datetime.now().isoformat()
}
)
progress(1.0, desc="اكتمل (Complete)")
return result, f"تم توليد {token_count} رمز في {generation_time:.2f} ثانية (Generated {token_count} tokens in {generation_time:.2f} seconds)"
def save_feedback(input, output, params) -> None:
"""
Append input/outputs and parameters to a JSON Lines file using a thread lock
to avoid concurrent writes from different users.
"""
logger.info(f"Saving feedback to {feedback_file}")
with scheduler.lock:
try:
with feedback_file.open("a") as f:
f.write(json.dumps({
"input": input,
"output": output,
"params": params
}))
f.write("\n")
logger.info("Feedback saved successfully")
except Exception as e:
logger.error(f"Error saving feedback: {str(e)}")
def get_stats():
"""Return current usage statistics"""
with scheduler.lock:
uptime = time.time() - usage_stats["start_time"]
hours = uptime / 3600
stats = {
"Total generations": usage_stats["total_generations"],
"Total tokens generated": usage_stats["total_tokens_generated"],
"Uptime": f"{int(hours)}h {int((hours % 1) * 60)}m",
"Generations per hour": f"{usage_stats['total_generations'] / hours:.1f}" if hours > 0 else "N/A",
"Last updated": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
}
logger.info(f"Stats requested: {stats}")
return stats
def reset_params():
"""Reset parameters to default values"""
logger.info("Parameters reset to defaults")
return 128, 0.7, 0.9, 50, 1, 1.2 # Updated defaults for faster generation
def thumbs_up_callback(input_text, output_text):
"""Record positive feedback"""
logger.info("Received positive feedback")
feedback_path = Path("user_submit") / "positive_feedback.jsonl"
feedback_path.parent.mkdir(exist_ok=True, parents=True)
with scheduler.lock:
try:
with feedback_path.open("a") as f:
feedback_data = {
"input": input_text,
"output": output_text,
"rating": "positive",
"timestamp": datetime.now().isoformat()
}
f.write(json.dumps(feedback_data))
f.write("\n")
logger.info(f"Positive feedback saved to {feedback_path}")
except Exception as e:
logger.error(f"Error saving positive feedback: {str(e)}")
return "شكرا على التقييم الإيجابي!"
def thumbs_down_callback(input_text, output_text, feedback=""):
"""Record negative feedback"""
logger.info(f"Received negative feedback: '{feedback}'")
feedback_path = Path("user_submit") / "negative_feedback.jsonl"
feedback_path.parent.mkdir(exist_ok=True, parents=True)
with scheduler.lock:
try:
with feedback_path.open("a") as f:
feedback_data = {
"input": input_text,
"output": output_text,
"rating": "negative",
"feedback": feedback,
"timestamp": datetime.now().isoformat()
}
f.write(json.dumps(feedback_data))
f.write("\n")
logger.info(f"Negative feedback saved to {feedback_path}")
except Exception as e:
logger.error(f"Error saving negative feedback: {str(e)}")
return "شكرا على ملاحظاتك!"
if __name__ == "__main__":
logger.info("Starting Moroccan Darija LLM application")
# Create the Gradio interface
with gr.Blocks(css="footer {visibility: hidden}") as app:
gr.Markdown("""
# 🇲🇦 نموذج اللغة المغربية الدارجة (Moroccan Darija LLM)
أدخل نصًا بالدارجة المغربية واحصل على نص تم إنشاؤه بواسطة نموذج اللغة الخاص بنا المدرب على الدارجة المغربية.
Enter a prompt and get AI-generated text using our pretrained LLM on Moroccan Darija.
""")
with gr.Row():
with gr.Column(scale=6):
prompt_input = gr.Textbox(
label="الدخل (Prompt): دخل النص بالدارجة",
placeholder="اكتب هنا...",
lines=4, rtl=True
)
with gr.Row():
submit_btn = gr.Button("توليد النص (Generate)", variant="primary")
clear_btn = gr.Button("مسح (Clear)")
reset_btn = gr.Button("إعادة ضبط المعلمات (Reset Parameters)")
with gr.Accordion("معلمات التوليد (Generation Parameters)", open=False):
with gr.Row():
with gr.Column():
max_length = gr.Slider(8, 4096, value=128, label="Max Length (الطول الأقصى)") # Reduced default
temperature = gr.Slider(0.0, 2, value=0.7, label="Temperature (درجة الحرارة)")
top_p = gr.Slider(0.0, 1.0, value=0.9, label="Top-p (أعلى احتمال)")
with gr.Column():
top_k = gr.Slider(1, 10000, value=50, label="Top-k (أعلى ك)") # Reduced default
num_beams = gr.Slider(1, 20, value=1, label="Number of Beams (عدد الأشعة)") # Reduced default
repetition_penalty = gr.Slider(0.0, 100.0, value=1.2, label="Repetition Penalty (عقوبة التكرار)") # Reduced default
with gr.Column(scale=6):
output_text = gr.Textbox(label="النص المولد (Generated Text)", lines=10, rtl=True)
generation_info = gr.Markdown("")
with gr.Row():
thumbs_up = gr.Button("👍 جيد")
thumbs_down = gr.Button("👎 سيء")
with gr.Accordion("تعليق (Feedback)", open=False, visible=False) as feedback_accordion:
feedback_text = gr.Textbox(label="لماذا لم يعجبك الناتج؟ (Why didn't you like the output?)", lines=2, rtl=True)
submit_feedback = gr.Button("إرسال التعليق (Submit Feedback)")
feedback_result = gr.Markdown("")
with gr.Accordion("إحصائيات الاستخدام (Usage Statistics)", open=False):
stats_md = gr.JSON(get_stats, every=10)
refresh_stats = gr.Button("تحديث (Refresh)")
# Examples section with caching
gr.Examples(
examples=examples,
inputs=[prompt_input, max_length, temperature, top_p, top_k, num_beams, repetition_penalty],
outputs=[output_text, generation_info],
fn=generate_text,
cache_examples=True
)
# Button actions
submit_btn.click(
generate_text,
inputs=[prompt_input, max_length, temperature, top_p, top_k, num_beams, repetition_penalty],
outputs=[output_text, generation_info]
)
clear_btn.click(
lambda: ("", ""),
inputs=None,
outputs=[prompt_input, output_text]
)
reset_btn.click(
reset_params,
inputs=None,
outputs=[max_length, temperature, top_p, top_k, num_beams, repetition_penalty]
)
# Feedback system
thumbs_up.click(
thumbs_up_callback,
inputs=[prompt_input, output_text],
outputs=[feedback_result]
)
thumbs_down.click(
lambda: (gr.Accordion.update(visible=True, open=True), ""),
inputs=None,
outputs=[feedback_accordion, feedback_result]
)
submit_feedback.click(
thumbs_down_callback,
inputs=[prompt_input, output_text, feedback_text],
outputs=[feedback_result]
)
# Stats refresh
refresh_stats.click(
get_stats,
inputs=None,
outputs=[stats_md]
)
# Keyboard shortcuts
prompt_input.submit(
generate_text,
inputs=[prompt_input, max_length, temperature, top_p, top_k, num_beams, repetition_penalty],
outputs=[output_text, generation_info]
)
logger.info("Launching Gradio interface")
app.launch()
logger.info("Gradio interface closed")