File size: 3,337 Bytes
2799123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ca593a
2799123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch
import os
from fastapi import FastAPI, Request
from pydantic import BaseModel
from huggingface_hub import hf_hub_download
from model import GPT, GPTConfig
from fastapi.templating import Jinja2Templates
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse
from pathlib import Path
import tempfile
from transformers import AutoTokenizer
import uvicorn

# Get the absolute path to the templates directory
TEMPLATES_DIR = os.path.join(os.path.dirname(__file__), "templates")

MODEL_ID = "sagargurujula/smollm2-text-generator"

# Initialize FastAPI
app = FastAPI(title="SMOLLM2 Text Generator")

# Templates with absolute path
templates = Jinja2Templates(directory=TEMPLATES_DIR)

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Set device
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Use system's temporary directory
cache_dir = Path(tempfile.gettempdir()) / "model_cache"
os.environ['TRANSFORMERS_CACHE'] = str(cache_dir)
os.environ['HF_HOME'] = str(cache_dir)

# Load model from Hugging Face Hub
def load_model():
    try:
        # Download the model file from HF Hub with authentication
        model_path = hf_hub_download(
            repo_id=MODEL_ID,
            filename="best_model.pth",
            cache_dir=cache_dir,
            token=os.environ.get('HF_TOKEN')  # Get token from environment variable
        )
        
        # Initialize our custom GPT model
        model = GPT(GPTConfig())
        
        # Load the state dict
        checkpoint = torch.load(model_path, map_location=device, weights_only=True)
        model.load_state_dict(checkpoint['model_state_dict'])
        
        model.to(device)
        model.eval()
        # Load tokenizer
        tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path ="HuggingFaceTB/cosmo2-tokenizer", cache_dir=cache_dir)
        tokenizer.pad_token = tokenizer.eos_token
        return model, tokenizer
    
    except Exception as e:
        print(f"Error loading model: {e}")
        raise

# Load the model
model, tokenizer = load_model()

# Define the request body
class TextInput(BaseModel):
    text: str

@app.post("/generate/")
async def generate_text(input: TextInput):
    # Prepare input tensor
    input_ids = tokenizer(input.text, return_tensors='pt').input_ids.to(device)
    
    # Generate multiple tokens
    generated_tokens = []
    num_tokens_to_generate = 50  # Generate 20 new tokens
    
    with torch.no_grad():
        generated_tokens = model.generate(input_ids, max_length=50, eos_token_id = tokenizer.eos_token_id)
        generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
    
    # Return both input and generated text
    return {
        "input_text": input.text,
        "generated_text": generated_text
    }

# Modify the root route to serve the template
@app.get("/", response_class=HTMLResponse)
async def home(request: Request):
    return templates.TemplateResponse(
        "index.html", 
        {"request": request, "title": "GPT Text Generator"}
    )

if __name__ == "__main__":
    uvicorn.run(app, host="127.0.0.1", port=8080)

# To run the app, use the command: uvicorn app:app --reload