Spaces:
Sleeping
Sleeping
import datetime | |
import os | |
from dotenv import load_dotenv | |
import asyncio | |
from fastapi import FastAPI, Body, File, UploadFile, HTTPException | |
from fastapi.responses import StreamingResponse | |
from typing import List, AsyncIterable, Annotated, Optional | |
from enum import Enum | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
from langchain_openai import ChatOpenAI | |
from langchain import hub | |
from langchain_chroma import Chroma | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.runnables import RunnablePassthrough | |
from langchain_nomic.embeddings import NomicEmbeddings | |
from langchain_text_splitters import RecursiveCharacterTextSplitter | |
from langchain.callbacks import AsyncIteratorCallbackHandler | |
from langchain_core.documents import Document | |
from in_memory import load_all_documents | |
from langchain_nomic.embeddings import Embeddings, NomicEmbeddings | |
from loader import load_web_content, load_youtube_content | |
from get_pattern import generate_pattern | |
from get_agents import process_agents | |
# ################################### FastAPI setup ############################################ | |
app = FastAPI() | |
origins = ["*"] | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=origins, | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# ################################### Helper functions ############################################ | |
# async def load_all_documents(files: List[UploadFile]) -> List[Document]: | |
# documents = [] | |
# for file in files: | |
# docs = await load_document(file) | |
# documents.extend(docs) | |
# return documents | |
# ################################### LLM, RAG and Streaming ############################################ | |
load_dotenv() | |
GROQ_API_KEY = os.environ.get("GROQ_API_KEY") | |
GROQ_API_BASE = os.environ.get("GROQ_API_BASE") | |
OPENAI_MODEL_NAME = os.environ.get("OPENAI_MODEL_NAME") | |
embedding_model = NomicEmbeddings(model="nomic-embed-text-v1.5") | |
def split_documents(documents: List[Document], chunk_size=1000, chunk_overlap=200) -> List[Document]: | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=chunk_size, chunk_overlap=chunk_overlap | |
) | |
print("Splitting documents into chunks...") | |
return text_splitter.split_documents(documents) | |
def generate_embeddings(documents: List[Document]) -> NomicEmbeddings: | |
embedding_model = NomicEmbeddings(model="nomic-embed-text-v1.5") | |
embeddings = [embedding_model.embed( | |
[document.page_content], task_type='search_document') for document in documents] | |
return embedding_model | |
def store_embeddings(documents: List[Document], embeddings: NomicEmbeddings): | |
vectorstore = Chroma.from_documents( | |
documents=documents, embedding=embeddings, persist_directory="./chroma_db") | |
return vectorstore | |
def load_embeddings(embeddings: NomicEmbeddings) -> Chroma: | |
embeddings = Chroma(persist_directory="./chroma_db", | |
embedding_function=embeddings) | |
return embeddings | |
# ################################### Updated generate_chunks Function ############################################ | |
async def generate_chunks(query: str) -> AsyncIterable[str]: | |
callback = AsyncIteratorCallbackHandler() | |
llm = ChatOpenAI( | |
openai_api_base=GROQ_API_BASE, | |
api_key=GROQ_API_KEY, | |
temperature=0.0, | |
model_name=OPENAI_MODEL_NAME, # "mixtral-8x7b-32768", | |
streaming=True, # ! important | |
verbose=True, | |
callbacks=[callback] | |
) | |
# Load vector store (this should be pre-populated with documents and embeddings) | |
# Ensure to modify this to load your actual vector store | |
vectorstore = load_embeddings(embeddings=embedding_model) | |
# Retrieve relevant documents for the query | |
retriever = vectorstore.as_retriever() | |
# relevant_docs = retriever(query) | |
# Combine the retrieved documents into a single string | |
def format_docs(docs): | |
return "\n\n".join(doc.page_content for doc in docs) | |
# Define the RAG chain | |
prompt = hub.pull("rlm/rag-prompt") | |
rag_chain = ( | |
{"context": retriever | format_docs, "question": RunnablePassthrough()} | |
| prompt | |
| llm | |
| StrOutputParser() | |
) | |
# Generate the response | |
task = asyncio.create_task( | |
rag_chain.ainvoke(query) | |
) | |
index = 0 | |
try: | |
async for token in callback.aiter(): | |
print(index, ": ", token, ": ", datetime.datetime.now().time()) | |
index = index + 1 | |
yield token | |
except Exception as e: | |
print(f"Caught exception: {e}") | |
finally: | |
callback.done.set() | |
await task | |
# ################################### Models ######################################## | |
class QuestionType(str, Enum): | |
PATTERN = "PATTERN" | |
AGENTS = "AGENTS" | |
RAG = "RAG" | |
class Input(BaseModel): | |
question: str | |
type: QuestionType | |
pattern: Optional[str] | |
chat_history: List[str] | |
class Metadata(BaseModel): | |
conversation_id: str | |
class Config(BaseModel): | |
metadata: Metadata | |
class RequestBody(BaseModel): | |
input: Input | |
config: Config | |
# ################################### Routes ############################################ | |
def read_root(): | |
return {"Hello": "World from Marigen"} | |
async def chat(query: RequestBody = Body(...)): | |
print(query.input.question) | |
print(query.input.type) | |
if query.input.type == QuestionType.PATTERN: | |
print(query.input.pattern) | |
pattern = query.input.pattern | |
gen = generate_pattern(pattern=pattern, query=query.input.question) | |
return StreamingResponse(gen, media_type="text/event-stream") | |
elif query.input.type == QuestionType.AGENTS: | |
gen = process_agents(query.input.question) | |
return StreamingResponse(gen, media_type="text/event-stream") | |
elif query.input.type == QuestionType.RAG: | |
gen = generate_chunks(query.input.question) | |
return StreamingResponse(gen, media_type="text/event-stream") | |
raise HTTPException(status_code=400, detail="No accurate response for your given query") | |
async def create_upload_files( | |
files: Annotated[List[UploadFile], File(description="Multiple files as UploadFile")], | |
): | |
try: | |
# Load documents from files | |
documents = await load_all_documents(files) | |
print(f"Loaded {len(documents)} documents") | |
print(f"----------> {documents} documents <-----------") | |
chunks = [] | |
# Split documents into chunks | |
for docs in documents: | |
print(docs) | |
chunk = split_documents(docs[0]) | |
chunks.extend(chunk) | |
print(f"Split into {len(chunks)} chunks") | |
# Generate embeddings for chunks | |
# embeddings_model = generate_embeddings(chunks) | |
# print(f"Generated {len(embeddings)} embeddings") | |
# # Store embeddings in vector store | |
vectorstore = store_embeddings(chunks, embedding_model) | |
print("Embeddings stored in vector store") | |
return {"filenames": [file.filename for file in files], 'chunks': chunks, "message": "Files processed and embeddings generated."} | |
except Exception as e: | |
print(f"Error loading documents: {e}") | |
return {"message": f"Error loading documents: {e}"} | |
# New routes for YouTube and website content loading | |
async def load_youtube(youtube_url: str): | |
try: | |
documents = load_youtube_content(youtube_url) | |
chunks = split_documents(documents) | |
store_embeddings(chunks, embedding_model) | |
return {"message": f"YouTube video loaded and processed successfully.", "documents": documents} | |
except Exception as e: | |
print(f"Error loading YouTube video: {e}") | |
return {"message": f"Error loading YouTube video: {e}"} | |
async def load_website(website_url: str): | |
try: | |
documents = load_web_content(website_url) | |
chunks = split_documents(documents) | |
store_embeddings(chunks, embedding_model) | |
return {"message": f"Website loaded and processed successfully.", "documents": documents} | |
except Exception as e: | |
print(f"Error loading website: {e}") | |
return {"message": f"Error loading website: {e}"} | |
async def query_vector_store(query: str): | |
# Load the vector store (ensure you maintain a reference to it, possibly store in memory or a persistent store) | |
# Modify this with actual loading mechanism | |
vectorstore = load_embeddings(embeddings=embedding_model) | |
# Perform a query to retrieve relevant documents | |
relevant_docs = vectorstore.query(query) | |
return {"query": query, "results": [doc.page_content for doc in relevant_docs]} | |