|
from dotenv import load_dotenv |
|
from openai import OpenAI |
|
import os |
|
|
|
MODEL_ALIAS = {'llama3_8b': 'Meta-Llama-3.1-8B-Instruct', |
|
'llama3_70b': 'Meta-Llama-3.1-70B-Instruct', |
|
'llama3_3_70b': 'Meta-Llama-3.3-70B-Instruct', |
|
'llama3_405b': 'Meta-Llama-3.1-405B-Instruct', |
|
'llama3_1b': "Meta-Llama-3.2-1B-Instruct", |
|
'llama3_3b': "Meta-Llama-3.2-3B-Instruct"} |
|
|
|
load_dotenv() |
|
|
|
client = OpenAI( |
|
base_url="https://api.sambanova.ai/v1", |
|
api_key=os.environ.get("SAMBA_API_KEY"), |
|
) |
|
|
|
|
|
def call_llama(system_prompt, prompt, model="Meta-Llama-3.1-8B-Instruct", **kwargs): |
|
""" |
|
kwargs: |
|
temperature = 0.1, |
|
top_p = 0.1 |
|
max_tokens = 50 |
|
""" |
|
try: |
|
completion = client.chat.completions.create( |
|
model=model, |
|
messages=[ |
|
{"role": "system", "content": system_prompt}, |
|
{"role": "user", "content": prompt} |
|
], |
|
stream=True, |
|
**kwargs, |
|
) |
|
response = "" |
|
for chunk in completion: |
|
response += chunk.choices[0].delta.content or "" |
|
return response |
|
except Exception as e: |
|
print('API Error = {}'.format(e)) |
|
return "" |
|
|
|
def call_llama_chat(messages, model="Meta-Llama-3.1-8B-Instruct", **kwargs): |
|
""" |
|
kwargs: |
|
temperature = 0.1, |
|
top_p = 0.1 |
|
""" |
|
try: |
|
completion = client.chat.completions.create( |
|
model=model, |
|
messages=messages, |
|
stream=True, |
|
**kwargs, |
|
) |
|
response = "" |
|
for chunk in completion: |
|
response += chunk.choices[0].delta.content or "" |
|
return response |
|
except Exception as e: |
|
print('API Error = {}'.format(e)) |
|
return "" |
|
|