|
import os
|
|
import json
|
|
import streamlit as st
|
|
from groq import Groq
|
|
from PIL import Image, UnidentifiedImageError, ExifTags
|
|
import requests
|
|
from io import BytesIO
|
|
from transformers import pipeline
|
|
from final_captioner import generate_final_caption
|
|
import hashlib
|
|
|
|
|
|
st.title("PicSamvaad : Image Conversational Chatbot")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
os.environ["GROQ_API_KEY"] = GROQ_API_KEY
|
|
|
|
client = Groq()
|
|
|
|
|
|
with st.sidebar:
|
|
st.header("Upload Image or Enter URL")
|
|
|
|
uploaded_file = st.file_uploader(
|
|
"Upload an image to chat...", type=["jpg", "jpeg", "png"]
|
|
)
|
|
url = st.text_input("Or enter a valid image URL...")
|
|
|
|
image = None
|
|
error_message = None
|
|
|
|
|
|
def correct_image_orientation(img):
|
|
try:
|
|
for orientation in ExifTags.TAGS.keys():
|
|
if ExifTags.TAGS[orientation] == "Orientation":
|
|
break
|
|
exif = img._getexif()
|
|
if exif is not None:
|
|
orientation = exif[orientation]
|
|
if orientation == 3:
|
|
img = img.rotate(180, expand=True)
|
|
elif orientation == 6:
|
|
img = img.rotate(270, expand=True)
|
|
elif orientation == 8:
|
|
img = img.rotate(90, expand=True)
|
|
except (AttributeError, KeyError, IndexError):
|
|
pass
|
|
return img
|
|
|
|
|
|
def get_image_hash(image):
|
|
|
|
img_bytes = image.tobytes()
|
|
return hashlib.md5(img_bytes).hexdigest()
|
|
|
|
|
|
|
|
if "last_uploaded_hash" not in st.session_state:
|
|
st.session_state.last_uploaded_hash = None
|
|
|
|
if uploaded_file is not None:
|
|
image = Image.open(uploaded_file)
|
|
image_hash = get_image_hash(image)
|
|
|
|
if st.session_state.last_uploaded_hash != image_hash:
|
|
st.session_state.chat_history = []
|
|
st.session_state.last_uploaded_hash = image_hash
|
|
|
|
image = correct_image_orientation(image)
|
|
st.image(image, caption="Uploaded Image.", use_column_width=True)
|
|
|
|
elif url:
|
|
try:
|
|
response = requests.get(url)
|
|
response.raise_for_status()
|
|
image = Image.open(BytesIO(response.content))
|
|
image_hash = get_image_hash(image)
|
|
|
|
if st.session_state.last_uploaded_hash != image_hash:
|
|
st.session_state.chat_history = []
|
|
st.session_state.last_uploaded_hash = (
|
|
image_hash
|
|
)
|
|
|
|
image = correct_image_orientation(image)
|
|
st.image(image, caption="Image from URL.", use_column_width=True)
|
|
except (requests.exceptions.RequestException, UnidentifiedImageError) as e:
|
|
image = None
|
|
error_message = "Error: The provided URL is invalid or the image could not be loaded. Sometimes some image URLs don't work. We suggest you upload the downloaded image instead ;)"
|
|
|
|
caption = ""
|
|
if image is not None:
|
|
caption += generate_final_caption(image)
|
|
st.write("ChatBot : " + caption)
|
|
|
|
|
|
if error_message:
|
|
st.error(error_message)
|
|
|
|
|
|
if "chat_history" not in st.session_state:
|
|
st.session_state.chat_history = []
|
|
|
|
|
|
for message in st.session_state.chat_history:
|
|
with st.chat_message(message["role"]):
|
|
st.markdown(message["content"])
|
|
|
|
|
|
user_prompt = st.chat_input("Ask the Chatbot about the image...")
|
|
|
|
if user_prompt:
|
|
st.chat_message("user").markdown(user_prompt)
|
|
st.session_state.chat_history.append({"role": "user", "content": user_prompt})
|
|
|
|
|
|
messages = [
|
|
{
|
|
"role": "system",
|
|
"content": "You are a helpful, accurate image conversational assistant. You don't hallucinate, and your answers are very precise and have a positive approach.The caption of the image is: "
|
|
+ caption,
|
|
},
|
|
*st.session_state.chat_history,
|
|
]
|
|
|
|
response = client.chat.completions.create(
|
|
model="llama-3.1-8b-instant", messages=messages
|
|
)
|
|
|
|
assistant_response = response.choices[0].message.content
|
|
st.session_state.chat_history.append(
|
|
{"role": "assistant", "content": assistant_response}
|
|
)
|
|
|
|
|
|
with st.chat_message("assistant"):
|
|
st.markdown(assistant_response)
|
|
|