chatbot-space / app.py
mynuddin's picture
Update app.py
0404a20 verified
raw
history blame contribute delete
2.08 kB
import os
from fastapi import FastAPI
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torch
from pydantic import BaseModel
# Set writable cache directory inside the container
os.environ['SENTENCE_TRANSFORMERS_HOME'] = '/app/hf_home'
os.environ['TRANSFORMERS_CACHE'] = '/app/hf_home'
# Ensure the directory exists
os.makedirs(os.environ['TRANSFORMERS_CACHE'], exist_ok=True)
# Define base model and adapter model
base_model_name = "facebook/opt-2.7b"
adapter_name = "mynuddin/chatbot" # Adapter model path or name
# Load base model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.float16)
# Load PEFT adapter
model = PeftModel.from_pretrained(base_model, adapter_name)
model = model.to("cuda" if torch.cuda.is_available() else "cpu") # Use GPU if available
model.eval()
app = FastAPI()
# Define Pydantic model for input
class PromptInput(BaseModel):
prompt: str
@app.post("/generate")
def generate_text(input: PromptInput):
prompt = input.prompt # Access prompt from the request body
# Format the prompt with specific style for your fine-tuned model
input_text = f"### Prompt: {prompt}\n### Completion:"
inputs = tokenizer(input_text, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
# Generate the output
with torch.no_grad():
output = model.generate(**inputs, max_length=128, do_sample=False, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id)
# Decode the output and remove special tokens
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
# Extract the query part from the generated output
if "### Completion:" in generated_text:
query_output = generated_text.split("### Completion:")[1].strip()
else:
query_output = generated_text.replace(input_text, "").strip() # Fallback if the structure is not as expected
return {"generated_query": query_output}