File size: 1,204 Bytes
8b883c8
7c59172
 
e4f5d4a
7c59172
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e4f5d4a
8b883c8
7c59172
 
 
 
 
 
 
 
 
 
 
 
 
e4f5d4a
7c59172
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import asyncio
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

# Load latest available LLaMA model (Change this if LLaMA 3 becomes available)
MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf"

# Detect device (Use GPU if available)
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    device_map="auto"
).to(device)

# Text generation pipeline for efficient inference
generator = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0 if torch.cuda.is_available() else -1)

async def generate_stream(query: str):
    """Stream responses using LLaMA."""
    
    input_ids = tokenizer(query, return_tensors="pt").input_ids.to(device)

    # Generate text
    output = generator(query, max_length=512, do_sample=True, temperature=0.7)

    response_text = output[0]["generated_text"]
    
    # Simulate streaming
    for word in response_text.split():
        yield word + " "
        await asyncio.sleep(0.05)

    yield "\n"