llm_host / utils.py
Bahodir Nematjonov
updated model
7c59172
raw
history blame
1.2 kB
import asyncio
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
# Load latest available LLaMA model (Change this if LLaMA 3 becomes available)
MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf"
# Detect device (Use GPU if available)
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto"
).to(device)
# Text generation pipeline for efficient inference
generator = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0 if torch.cuda.is_available() else -1)
async def generate_stream(query: str):
"""Stream responses using LLaMA."""
input_ids = tokenizer(query, return_tensors="pt").input_ids.to(device)
# Generate text
output = generator(query, max_length=512, do_sample=True, temperature=0.7)
response_text = output[0]["generated_text"]
# Simulate streaming
for word in response_text.split():
yield word + " "
await asyncio.sleep(0.05)
yield "\n"