|
import os |
|
import openai |
|
from longcepo.main import run_longcepo |
|
|
|
|
|
|
|
|
|
SAMBANOVA_API_KEY = os.environ.get("SAMBANOVA_API_KEY") |
|
|
|
if SAMBANOVA_API_KEY: |
|
|
|
SAMBANOVA_API_KEY = SAMBANOVA_API_KEY.strip() |
|
|
|
if not SAMBANOVA_API_KEY: |
|
raise ValueError("Sambanova API key not found or is empty. Please set the SAMBANOVA_API_KEY environment variable or Hugging Face Secret.") |
|
|
|
client = openai.OpenAI( |
|
api_key=SAMBANOVA_API_KEY, |
|
base_url="https://api.sambanova.ai/v1", |
|
) |
|
|
|
|
|
|
|
SAMBANOVA_MODEL = "Meta-Llama-3.1-8B-Instruct" |
|
|
|
def process_with_longcepo(system_prompt: str, initial_query: str): |
|
"""Processes a query using the modified LongCePO plugin with Sambanova backend.""" |
|
print(f"Processing query with LongCePO using model: {SAMBANOVA_MODEL}") |
|
try: |
|
|
|
answer, total_tokens = run_longcepo( |
|
system_prompt=system_prompt, |
|
initial_query=initial_query, |
|
client=client, |
|
model=SAMBANOVA_MODEL |
|
) |
|
print(f"LongCePO finished. Total tokens used: {total_tokens}") |
|
return answer |
|
except Exception as e: |
|
print(f"Error during LongCePO processing: {e}") |
|
|
|
import traceback |
|
traceback.print_exc() |
|
return f"An error occurred: {e}" |
|
|
|
|
|
if __name__ == "__main__": |
|
test_system_prompt = "You are a helpful assistant designed to answer questions based on the provided context." |
|
|
|
dummy_context = """ |
|
Paris is the capital and most populous city of France. It is known for its art, fashion, gastronomy and culture. |
|
Its 19th-century cityscape is crisscrossed by wide boulevards and the River Seine. |
|
Beyond such landmarks as the Eiffel Tower and the 12th-century, Gothic Notre-Dame cathedral, the city is known for its cafe culture and designer boutiques along the Rue du Faubourg Saint-Honoré. |
|
The Louvre Museum houses Da Vinci's Mona Lisa. The Musée d'Orsay has Impressionist and Post-Impressionist masterpieces. |
|
France is a country in Western Europe. It borders Belgium, Luxembourg, Germany, Switzerland, Monaco, Italy, Andorra, and Spain. |
|
The official language is French. |
|
""" |
|
test_query = "Based on the provided text, what are the main attractions in Paris and what countries does France border?" |
|
|
|
test_initial_query = f"{dummy_context}<CONTEXT_END>{test_query}" |
|
|
|
print("Running test query...") |
|
result = process_with_longcepo(test_system_prompt, test_initial_query) |
|
print(f"\nTest Result:\n{result}") |
|
|
|
|