Spaces:
Running
Running
# ========== Standard Library ========== | |
import os | |
import tempfile | |
import zipfile | |
from typing import List, Optional, Tuple, Union | |
import collections | |
# ========== Third-Party Libraries ========== | |
import gradio as gr | |
from groq import Groq | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.document_loaders import DirectoryLoader, UnstructuredFileLoader | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.prompts import PromptTemplate | |
from langchain_core.runnables import RunnablePassthrough | |
from langchain_core.vectorstores import InMemoryVectorStore | |
from langchain_groq import ChatGroq | |
from langchain_huggingface import HuggingFaceEmbeddings | |
# ========== Configs ========== | |
TITLE = """<h1 align="center">π¨οΈπ¦ Llama 4 Docx Chatter</h1>""" | |
AVATAR_IMAGES = ( | |
None, | |
"./logo.png", | |
) | |
# Acceptable file extensions | |
TEXT_EXTENSIONS = [".docx", ".zip"] | |
# ========== Models & Clients ========== | |
GROQ_API_KEY = os.getenv("GROQ_API_KEY") | |
client = Groq(api_key=GROQ_API_KEY) | |
llm = ChatGroq(model="meta-llama/llama-4-scout-17b-16e-instruct", api_key=GROQ_API_KEY) | |
embed_model = HuggingFaceEmbeddings(model_name="mixedbread-ai/mxbai-embed-large-v1") | |
# ========== Core Components ========== | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=1000, | |
chunk_overlap=100, | |
separators=["\n\n", "\n"], | |
) | |
rag_template = """You are an expert assistant tasked with answering questions based on the provided documents. | |
Use only the given context to generate your answer. | |
If the answer cannot be found in the context, clearly state that you do not know. | |
Be detailed and precise in your response, but avoid mentioning or referencing the context itself. | |
Context: | |
{context} | |
Question: | |
{question} | |
Answer:""" | |
rag_prompt = PromptTemplate.from_template(rag_template) | |
# ========== App State ========== | |
class AppState: | |
vectorstore: Optional[InMemoryVectorStore] = None | |
rag_chain = None | |
state = AppState() | |
# ========== Utility Functions ========== | |
def load_documents_from_files(files: List[str]) -> List: | |
"""Load documents from uploaded files directly without moving.""" | |
all_documents = [] | |
# Temporary directory if ZIP needs extraction | |
with tempfile.TemporaryDirectory() as temp_dir: | |
for file_path in files: | |
ext = os.path.splitext(file_path)[1].lower() | |
if ext == ".zip": | |
# Extract ZIP inside temp_dir | |
with zipfile.ZipFile(file_path, "r") as zip_ref: | |
zip_ref.extractall(temp_dir) | |
# Load all docx from extracted zip | |
loader = DirectoryLoader( | |
path=temp_dir, | |
glob="**/*.docx", | |
use_multithreading=True, | |
) | |
docs = loader.load() | |
all_documents.extend(docs) | |
elif ext == ".docx": | |
# Load single docx directly | |
loader = UnstructuredFileLoader(file_path) | |
docs = loader.load() | |
all_documents.extend(docs) | |
return all_documents | |
def get_last_user_message(chatbot: List[Union[gr.ChatMessage, dict]]) -> Optional[str]: | |
"""Get last user prompt.""" | |
for message in reversed(chatbot): | |
content = ( | |
message.get("content") if isinstance(message, dict) else message.content | |
) | |
if ( | |
message.get("role") if isinstance(message, dict) else message.role | |
) == "user": | |
return content | |
return None | |
# ========== Main Logic ========== | |
def upload_files( | |
files: Optional[List[str]], chatbot: List[Union[gr.ChatMessage, dict]] | |
): | |
"""Handle file upload - .docx or .zip containing docx.""" | |
if not files: | |
return chatbot | |
file_summaries = [] # <-- Collect formatted file/folder info | |
documents = [] | |
with tempfile.TemporaryDirectory() as temp_dir: | |
for file_path in files: | |
filename = os.path.basename(file_path) | |
ext = os.path.splitext(file_path)[1].lower() | |
if ext == ".zip": | |
file_summaries.append(f"π¦ **{filename}** (ZIP file) contains:") | |
try: | |
with zipfile.ZipFile(file_path, "r") as zip_ref: | |
zip_ref.extractall(temp_dir) | |
zip_contents = zip_ref.namelist() | |
# Group files by folder | |
folder_map = collections.defaultdict(list) | |
for item in zip_contents: | |
if item.endswith("/"): | |
continue # skip folder entries themselves | |
folder = os.path.dirname(item) | |
file_name = os.path.basename(item) | |
folder_map[folder].append(file_name) | |
# Format nicely | |
for folder, files_in_folder in folder_map.items(): | |
if folder: | |
file_summaries.append(f"π {folder}/") | |
else: | |
file_summaries.append(f"π (root)") | |
for f in files_in_folder: | |
file_summaries.append(f" - {f}") | |
# Load docx files extracted from ZIP | |
loader = DirectoryLoader( | |
path=temp_dir, | |
glob="**/*.docx", | |
use_multithreading=True, | |
) | |
docs = loader.load() | |
documents.extend(docs) | |
except zipfile.BadZipFile: | |
chatbot.append( | |
gr.ChatMessage( | |
role="assistant", | |
content=f"β Failed to open ZIP file: {filename}", | |
) | |
) | |
elif ext == ".docx": | |
file_summaries.append(f"π **{filename}**") | |
loader = UnstructuredFileLoader(file_path) | |
docs = loader.load() | |
documents.extend(docs) | |
else: | |
file_summaries.append(f"β Unsupported file type: {filename}") | |
if not documents: | |
chatbot.append( | |
gr.ChatMessage( | |
role="assistant", content="No valid .docx files found in upload." | |
) | |
) | |
return chatbot | |
# Split documents | |
chunks = text_splitter.split_documents(documents) | |
if not chunks: | |
chatbot.append( | |
gr.ChatMessage( | |
role="assistant", content="Failed to split documents into chunks." | |
) | |
) | |
return chatbot | |
# Create Vectorstore | |
state.vectorstore = InMemoryVectorStore.from_documents( | |
documents=chunks, | |
embedding=embed_model, | |
) | |
retriever = state.vectorstore.as_retriever() | |
# Build RAG Chain | |
state.rag_chain = ( | |
{"context": retriever, "question": RunnablePassthrough()} | |
| rag_prompt | |
| llm | |
| StrOutputParser() | |
) | |
# Final display | |
chatbot.append( | |
gr.ChatMessage( | |
role="assistant", | |
content="**Uploaded Files:**\n" | |
+ "\n".join(file_summaries) | |
+ "\n\nβ Ready to chat!", | |
) | |
) | |
return chatbot | |
def user_message( | |
text_prompt: str, chatbot: List[Union[gr.ChatMessage, dict]] | |
) -> Tuple[str, List[Union[gr.ChatMessage, dict]]]: | |
"""Add user's text input to conversation.""" | |
if text_prompt.strip(): | |
chatbot.append(gr.ChatMessage(role="user", content=text_prompt)) | |
return "", chatbot | |
def process_query( | |
chatbot: List[Union[gr.ChatMessage, dict]], | |
) -> List[Union[gr.ChatMessage, dict]]: | |
"""Process user's query through RAG pipeline.""" | |
prompt = get_last_user_message(chatbot) | |
if not prompt: | |
chatbot.append( | |
gr.ChatMessage(role="assistant", content="Please type a question first.") | |
) | |
return chatbot | |
if state.rag_chain is None: | |
chatbot.append( | |
gr.ChatMessage(role="assistant", content="Please upload documents first.") | |
) | |
return chatbot | |
chatbot.append(gr.ChatMessage(role="assistant", content="Thinking...")) | |
try: | |
response = state.rag_chain.invoke(prompt) | |
chatbot[-1].content = response | |
except Exception as e: | |
chatbot[-1].content = f"Error: {str(e)}" | |
return chatbot | |
def reset_app( | |
chatbot: List[Union[gr.ChatMessage, dict]], | |
) -> List[Union[gr.ChatMessage, dict]]: | |
"""Reset application state.""" | |
state.vectorstore = None | |
state.rag_chain = None | |
return [ | |
gr.ChatMessage( | |
role="assistant", content="App reset! Upload new documents to start." | |
) | |
] | |
# ========== UI Layout ========== | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.HTML(TITLE) | |
chatbot = gr.Chatbot( | |
label="Llama 4 RAG", | |
type="messages", | |
bubble_full_width=False, | |
avatar_images=AVATAR_IMAGES, | |
scale=2, | |
height=350, | |
) | |
with gr.Row(equal_height=True): | |
text_prompt = gr.Textbox( | |
placeholder="Ask a question...", show_label=False, autofocus=True, scale=28 | |
) | |
send_button = gr.Button( | |
value="Send", | |
variant="primary", | |
scale=1, | |
min_width=80, | |
) | |
upload_button = gr.UploadButton( | |
label="Upload", | |
file_count="multiple", | |
file_types=TEXT_EXTENSIONS, | |
scale=1, | |
min_width=80, | |
) | |
reset_button = gr.Button( | |
value="Reset", | |
variant="stop", | |
scale=1, | |
min_width=80, | |
) | |
send_button.click( | |
fn=user_message, | |
inputs=[text_prompt, chatbot], | |
outputs=[text_prompt, chatbot], | |
queue=False, | |
).then(fn=process_query, inputs=[chatbot], outputs=[chatbot]) | |
text_prompt.submit( | |
fn=user_message, | |
inputs=[text_prompt, chatbot], | |
outputs=[text_prompt, chatbot], | |
queue=False, | |
).then(fn=process_query, inputs=[chatbot], outputs=[chatbot]) | |
upload_button.upload( | |
fn=upload_files, inputs=[upload_button, chatbot], outputs=[chatbot], queue=False | |
) | |
reset_button.click(fn=reset_app, inputs=[chatbot], outputs=[chatbot], queue=False) | |
demo.queue().launch() | |