Dolphin / app.py
Abhi0028's picture
Update app.py
178f86d verified
raw
history blame contribute delete
1.59 kB
import gradio as gr
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("cognitivecomputations/Dolphin3.0-Mistral-24B")
model = AutoModelForCausalLM.from_pretrained("cognitivecomputations/Dolphin3.0-Mistral-24B", torch_dtype=torch.float16).cuda()
# FastAPI app
app = FastAPI()
# Request Body
class InputText(BaseModel):
prompt: str
max_length: int = 100
@app.post("/generate")
async def generate_text(input_data: InputText):
inputs = tokenizer(input_data.prompt, return_tensors="pt").to("cuda")
output = model.generate(**inputs, max_length=input_data.max_length)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
return {"response": generated_text}
# Run the server using: uvicorn app:app --host 0.0.0.0 --port 8000
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
)
if __name__ == "__main__":
demo.launch()