Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from peft import PeftModel | |
import torch | |
app = FastAPI() | |
# Define paths | |
base_model_path = "NousResearch/Hermes-3-Llama-3.2-3B" | |
adapter_path = "zach9111/llama_startup_adapter" | |
# Check if GPU is available | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load base model with `device_map="auto"` to handle GPUs automatically | |
base_model = AutoModelForCausalLM.from_pretrained( | |
base_model_path, torch_dtype=torch.float16, device_map="auto" | |
) | |
# Load adapter and ensure it is on the correct device | |
model = PeftModel.from_pretrained(base_model, adapter_path).to(device) | |
# Load tokenizer | |
tokenizer = AutoTokenizer.from_pretrained(base_model_path) | |
class GenerateRequest(BaseModel): | |
prompt: str | |
# **Use `model.generate()` instead of `pipeline()`** | |
def generate_text_from_model(prompt: str): | |
try: | |
input_ids = tokenizer(f"<s>[INST] {prompt} [/INST]", return_tensors="pt").input_ids.to(device) | |
output_ids = model.generate(input_ids, max_length=512) | |
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
return generated_text | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
# Root endpoint for testing | |
async def root(): | |
return {"message": "Model is running! Use /generate/ for text generation."} | |
# Text generation endpoint | |
async def generate_text(request: GenerateRequest): | |
response = generate_text_from_model(request.prompt) | |
return {"response": response} | |