ai / model.py
arya-ai-model's picture
updated model.py
948bd8f
raw
history blame
1.45 kB
import os
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
MODEL_NAME = "bigcode/starcoderbase-1b"
HF_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
device = "cpu"
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
token=HF_TOKEN,
torch_dtype=torch.float32,
trust_remote_code=True
).to(device)
def generate_code(prompt: str, max_tokens: int = 256):
# πŸ›  **Improve Speed & Support Multi-language**
formatted_prompt = f"{prompt}\n### Code:\n" # Hint that code follows
inputs = tokenizer(
formatted_prompt,
return_tensors="pt",
truncation=True,
max_length=512 # ⏩ Reduce max length to speed up processing
).to(device)
output = model.generate(
**inputs,
max_new_tokens=max_tokens,
pad_token_id=tokenizer.pad_token_id,
do_sample=True,
top_p=0.90, # 🎯 Prioritize better predictions
temperature=0.6 # πŸ”₯ More deterministic output
)
generated_code = tokenizer.decode(output[0], skip_special_tokens=True)
# Remove the input prompt from the output
if generated_code.startswith(formatted_prompt):
generated_code = generated_code[len(formatted_prompt):]
return generated_code.strip()