Spaces:
Running
Running
## Setup | |
# Import the necessary Libraries | |
import json | |
import uuid | |
import os | |
from groq import Groq | |
import gradio as gr | |
from langchain_community.embeddings.sentence_transformer import ( | |
SentenceTransformerEmbeddings | |
) | |
from langchain_community.vectorstores import Chroma | |
from huggingface_hub import CommitScheduler | |
from pathlib import Path | |
# Create Client | |
os.environ['GROQ_API_KEY'] = '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'; | |
client = Groq( | |
api_key=os.environ.get("GROQ_API_KEY"), | |
) | |
chat_completion = client.chat.completions.create( | |
messages=[ | |
{ | |
"role": "user", | |
"content": "Explain the importance of fast language models", | |
} | |
], | |
model="llama3-8b-8192", | |
) | |
# Define the embedding model and the vectorstore | |
embedding_model = SentenceTransformerEmbeddings(model_name='thenlper/gte-large') | |
collection_name = 'reports_collection' | |
vectorstore_persisted = Chroma( | |
collection_name=collection_name, | |
persist_directory='./reports_db', | |
embedding_function=embedding_model | |
) | |
# Prepare the logging functionality | |
log_file = Path("logs/") / f"data_{uuid.uuid4()}.json" | |
log_folder = log_file.parent | |
scheduler = CommitScheduler( | |
repo_id="reports-qna", | |
repo_type="dataset", | |
folder_path=log_folder, | |
path_in_repo="data", | |
every=2 | |
) | |
# Define the Q&A system message | |
qna_system_message = """ | |
You are an assistant to a Hospital. Your task is to summarize and provide relevant information to the Medical Diagnosis question based on the provided context. | |
User input will include the necessary context for you to answer their questions. This context will begin with the token: ###Context. | |
The context contains references to specific portions of documents relevant to the user's query, along with page number from the report. | |
The source for the context will begin with the token ###Page | |
When crafting your response: | |
1. Select only context relevant to answer the question. | |
2. Include the source links in your response. | |
3. User questions will begin with the token: ###Question. | |
4. If the question is irrelevant or if the context is empty - "Sorry, this is out of my knowledge base" | |
Please adhere to the following guidelines: | |
- Your response should only be about the question asked and nothing else. | |
- Answer only using the context provided. | |
- Do not mention anything about the context in your final answer. | |
- If the answer is not found in the context, it is very very important for you to respond with "Sorry, this is out of my knowledge base" | |
- If NO CONTEXT is provided, it is very important for you to respond with "Sorry, this is out of my knowledge base" | |
Here is an example of how to structure your response: | |
Answer: | |
[Answer] | |
Sourced from Medical Diagnosis PDF, Page No: | |
[Page number] | |
Example: | |
Answer: Sorry, this is out of my knowledge base if the user query is not relevant to the context. | |
""" | |
# Define the user message template | |
# Create a message template | |
qna_user_message_template = """ | |
###Context | |
Here are some documents and their page number that are relevant to the question mentioned below. | |
{context} | |
###Question | |
{question} | |
""" | |
# Define the predict function that runs when 'Submit' is clicked or when a API request is made | |
def predict(user_input): | |
filter = "dataset/MedicalDiagnosisManuals/The_Merck_Manual_of_Diagnosis_and_Therapy_2011 - 19th Edn........pdf" | |
relevant_document_chunks = vectorstore_persisted.similarity_search(user_input, k=5, filter={"source":filter}) | |
context_list = [d.page_content + "\n ###This information is taken from the PDF in the Page NO: " + str(d.metadata['page']) + "\n\n " for d in relevant_document_chunks] | |
context_for_query = ".".join(context_list) + "this is all thhe context I have" | |
prompt = [ | |
{'role':'system', 'content': qna_system_message}, | |
{'role': 'user', 'content': qna_user_message_template.format( | |
context=context_for_query, | |
question=user_input | |
) | |
} | |
] | |
try: | |
response = client.chat.completions.create( | |
model='llama3-8b-8192', | |
messages=prompt, | |
temperature=0 | |
) | |
prediction = response.choices[0].message.content | |
except Exception as e: | |
prediction = str(e) | |
# While the prediction is made, log both the inputs and outputs to a local log file | |
# While writing to the log file, ensure that the commit scheduler is locked to avoid parallel | |
# access | |
with scheduler.lock: | |
with log_file.open("a") as f: | |
f.write(json.dumps( | |
{ | |
'user_input': user_input, | |
'retrieved_context': context_for_query, | |
'model_response': prediction | |
} | |
)) | |
f.write("\n") | |
return prediction | |
# Set-up the Gradio UI | |
# Add text box. | |
textbox = gr.Textbox(placeholder="Enter your query here", lines=6) | |
# Create the interface | |
demo = gr.Interface( | |
inputs=[textbox], fn=predict, outputs="text", | |
title="Medical Report", | |
description="This web API presents an interface to ask questions on the medical reports ", | |
concurrency_limit=16 | |
) | |
demo.queue() | |
demo.launch() | |