|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
|
import torch |
|
import os |
|
from dotenv import load_dotenv |
|
from huggingface_hub import login |
|
from transformers import BitsAndBytesConfig |
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
load_dotenv() |
|
|
|
|
|
hf_token = os.getenv('HF_TOKEN') |
|
login(hf_token) |
|
|
|
|
|
model_path = "mistralai/Mistral-Large-Instruct-2411" |
|
|
|
|
|
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16 |
|
logger.info(f"Using dtype: {dtype}") |
|
|
|
|
|
logger.info("Configuring 4-bit quantization") |
|
quantization_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_compute_dtype=dtype, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_use_double_quant=True |
|
) |
|
|
|
|
|
logger.info(f"Loading tokenizer from {model_path}") |
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
logger.info("Tokenizer loaded successfully") |
|
|
|
logger.info(f"Loading model from {model_path} with 4-bit quantization") |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_path, |
|
device_map="auto", |
|
quantization_config=quantization_config |
|
) |
|
logger.info("Model loaded successfully") |
|
|
|
logger.info("Creating inference pipeline") |
|
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer) |
|
logger.info("Inference pipeline created successfully") |
|
|
|
def generate_response(message, temperature=0.7, max_new_tokens=256): |
|
try: |
|
logger.info(f"Generating response for message: {message[:50]}...") |
|
parameters = { |
|
"temperature": temperature, |
|
"max_new_tokens": max_new_tokens, |
|
|
|
|
|
|
|
|
|
|
|
|
|
} |
|
logger.info(f"Parameters: {parameters}") |
|
|
|
response = pipe(message, **parameters) |
|
logger.info("Response generated successfully") |
|
return response[0]['generated_text'] |
|
except Exception as e: |
|
logger.error(f"Error during generation: {str(e)}") |
|
return f"Une erreur s'est produite : {str(e)}" |
|
|
|
|
|
demo = gr.Interface( |
|
fn=generate_response, |
|
inputs=[ |
|
gr.Textbox(label="Votre message", placeholder="Entrez votre message ici..."), |
|
gr.Slider(minimum=0.1, maximum=1.0, value=0.7, label="Température"), |
|
gr.Slider(minimum=10, maximum=3000, value=256, step=10, label="Nombre de tokens") |
|
], |
|
outputs=gr.Textbox(label="Réponse"), |
|
title="Chat avec Sacha-Mistral", |
|
description="Un assistant conversationnel en français basé sur le modèle Sacha-Mistral" |
|
) |
|
|
|
if __name__ == "__main__": |
|
logger.info("Starting Gradio interface") |
|
demo.launch(share=True) |