Sacha-1 / app.py
Nac31's picture
Update app
d3dcfb8
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
import os
from dotenv import load_dotenv
from huggingface_hub import login
from transformers import BitsAndBytesConfig
import logging
# Configuration du logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
load_dotenv()
# Login to Hugging Face
hf_token = os.getenv('HF_TOKEN')
login(hf_token)
# Configuration du modèle
model_path = "mistralai/Mistral-Large-Instruct-2411"
# Détermination automatique du dtype optimal
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
logger.info(f"Using dtype: {dtype}")
# Configuration de la quantification 4-bits
logger.info("Configuring 4-bit quantization")
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=dtype, # Utilisation du dtype optimal
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True
)
# Initialisation du modèle
logger.info(f"Loading tokenizer from {model_path}")
tokenizer = AutoTokenizer.from_pretrained(model_path)
logger.info("Tokenizer loaded successfully")
logger.info(f"Loading model from {model_path} with 4-bit quantization")
model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map="auto",
quantization_config=quantization_config
)
logger.info("Model loaded successfully")
logger.info("Creating inference pipeline")
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
logger.info("Inference pipeline created successfully")
def generate_response(message, temperature=0.7, max_new_tokens=256):
try:
logger.info(f"Generating response for message: {message[:50]}...")
parameters = {
"temperature": temperature,
"max_new_tokens": max_new_tokens,
# "do_sample": True,
# "top_k": 50,
# "top_p": 0.9,
# "pad_token_id": tokenizer.pad_token_id,
# "eos_token_id": tokenizer.eos_token_id,
# "batch_size": 1
}
logger.info(f"Parameters: {parameters}")
response = pipe(message, **parameters)
logger.info("Response generated successfully")
return response[0]['generated_text']
except Exception as e:
logger.error(f"Error during generation: {str(e)}")
return f"Une erreur s'est produite : {str(e)}"
# Interface Gradio
demo = gr.Interface(
fn=generate_response,
inputs=[
gr.Textbox(label="Votre message", placeholder="Entrez votre message ici..."),
gr.Slider(minimum=0.1, maximum=1.0, value=0.7, label="Température"),
gr.Slider(minimum=10, maximum=3000, value=256, step=10, label="Nombre de tokens")
],
outputs=gr.Textbox(label="Réponse"),
title="Chat avec Sacha-Mistral",
description="Un assistant conversationnel en français basé sur le modèle Sacha-Mistral"
)
if __name__ == "__main__":
logger.info("Starting Gradio interface")
demo.launch(share=True)