Spaces:
Sleeping
Sleeping
import io | |
import json | |
import fitz | |
import streamlit as st | |
import torch | |
from PIL import Image, ImageGrab | |
from transformers import pipeline | |
# --- Configuration and Setup --- | |
DEVICE = 0 if torch.cuda.is_available() else -1 | |
st.set_page_config( | |
page_title="Invoice AI | by Arif Dogan", | |
page_icon="🧾", | |
layout="wide", | |
initial_sidebar_state="collapsed", | |
) | |
# --- Styling --- | |
st.markdown( | |
""" | |
<style> | |
.stApp {max-width: 1200px; margin: 0 auto} | |
.stButton>button {background-color: #4CAF50; color: white; border-radius: 5px;} | |
.stProgress>div>div {background-color: #4CAF50} | |
footer {visibility: hidden} | |
.high {color: #4CAF50; font-weight: bold} | |
.medium {color: #FFA726; font-weight: bold} | |
.low {color: #EF5350; font-weight: bold} | |
div[data-testid="stToolbar"] {visibility: hidden; height: 0} | |
[data-testid="stExpanderContent"] {background-color: rgba(67, 76, 94, 0.5);} | |
.stTextInput>div>div {background-color: rgba(67, 76, 94, 0.5)} | |
</style> | |
""", | |
unsafe_allow_html=True, | |
) | |
# --- Functions --- | |
def load_model(): | |
return pipeline( | |
"document-question-answering", | |
model="faisalraza/layoutlm-invoices", | |
device=DEVICE, | |
) | |
def process_pdf(pdf_file): | |
pdf_content = pdf_file.read() | |
pdf_stream = io.BytesIO(pdf_content) | |
try: | |
with fitz.open(stream=pdf_stream, filetype="pdf") as pdf_document: | |
if pdf_document.page_count > 0: | |
page = pdf_document[0] | |
pix = page.get_pixmap(matrix=fitz.Matrix(300 / 72, 300 / 72)) | |
img_data = pix.tobytes("png") | |
return Image.open(io.BytesIO(img_data)), pdf_document.page_count | |
else: | |
raise ValueError("PDF has no pages") | |
except Exception as e: | |
raise e | |
finally: | |
pdf_stream.close() | |
def process_image(uploaded_file): | |
uploaded_file.seek(0) | |
if uploaded_file.type == "application/pdf": | |
return process_pdf(uploaded_file) | |
return Image.open(uploaded_file), 1 | |
def get_clipboard_image(): | |
try: | |
img = ImageGrab.grabclipboard() | |
return (img, 1) if isinstance(img, Image.Image) else (None, 0) | |
except Exception: | |
return None, 0 | |
def prepare_export_data(extracted_info, format_type): | |
if format_type == "JSON": | |
return json.dumps( | |
{field: data["value"] for field, data in extracted_info.items()}, indent=2 | |
) | |
elif format_type == "CSV": | |
header = ",".join(extracted_info.keys()) | |
values = ",".join(f'"{data["value"]}"' for data in extracted_info.values()) | |
return f"{header}\n{values}" | |
else: # TXT | |
return "\n".join( | |
f"{field}: {data['value']}" for field, data in extracted_info.items() | |
) | |
def extract_information(model, image, questions, progress_bar, status_text): | |
extracted_info = {} | |
for idx, question in enumerate(questions): | |
try: | |
# Update progress bar and status text | |
progress_bar.progress((idx + 1) / len(questions)) | |
status_text.text(f"Processing: {question} ({idx + 1}/{len(questions)})") | |
response = model(image=image, question=question) | |
if ( | |
response and response[0].get("answer", "").strip() | |
): # Check for non-empty answer | |
answer = response[0]["answer"] | |
confidence = response[0]["score"] | |
if confidence > 0.1: | |
field = ( | |
question.replace("What is the ", "").replace("?", "").title() | |
) | |
extracted_info[field] = {"value": answer, "confidence": confidence} | |
except Exception: | |
continue # Handle potential errors during model processing | |
return extracted_info | |
# --- Initialization --- | |
if "processed_image" not in st.session_state: | |
st.session_state.processed_image = None | |
if "extracted_info" not in st.session_state: | |
st.session_state.extracted_info = {} | |
# --- UI Layout --- | |
st.markdown( | |
""" | |
<div style='text-align: center; padding: 1rem;'> | |
<h1>🧾 Invoice AI Extractor</h1> | |
<p style='font-size: 1.2em; color: #999;'>Powered by LayoutLM</p> | |
</div> | |
""", | |
unsafe_allow_html=True, | |
) | |
model = load_model() | |
col1, col2 = st.columns([2, 1]) | |
with col1: | |
uploaded_file = st.file_uploader( | |
"Drop invoice (PDF, JPG, PNG)", type=["pdf", "jpg", "jpeg", "png"] | |
) | |
with col2: | |
st.write("Or paste from clipboard (Ctrl/Cmd + V)") | |
check_clipboard = st.button("📎 Check Clipboard") | |
# --- Image Processing Logic --- | |
if uploaded_file: | |
try: | |
image, _ = process_image(uploaded_file) | |
st.session_state.processed_image = image | |
st.session_state.extracted_info = {} # Reset on new upload | |
except Exception as e: | |
st.error(f"Error processing file: {e}") | |
elif check_clipboard: | |
clipboard_image, _ = get_clipboard_image() | |
if clipboard_image: | |
st.session_state.processed_image = clipboard_image | |
st.session_state.extracted_info = {} | |
st.success("Image loaded from clipboard") | |
else: | |
st.warning("No image found in clipboard") | |
# --- Display and Information Extraction --- | |
if st.session_state.processed_image: | |
try: | |
image = st.session_state.processed_image.convert("RGB") | |
col1, col2 = st.columns([1, 1]) | |
with col1: | |
st.image(image, caption="Document", use_container_width=True) | |
with col2: | |
st.markdown("### 📊 Extracted Information") | |
if not st.session_state.extracted_info: | |
questions = [ | |
"What is the invoice number?", | |
"What is the invoice date?", | |
"What is the total amount?", | |
"What is the company name?", | |
"What is the due date?", | |
"What is the tax amount?", | |
] | |
# Create progress bar and status text elements | |
progress_bar = st.progress(0) | |
status_text = st.empty() | |
st.session_state.extracted_info = extract_information( | |
model, image, questions, progress_bar, status_text | |
) | |
# Clear status text after completion | |
status_text.empty() | |
if st.session_state.extracted_info: | |
for field, data in st.session_state.extracted_info.items(): | |
conf_col, val_col = st.columns([1, 4]) | |
with val_col: | |
st.text_input( | |
field, data["value"], disabled=True, key=f"input_{field}" | |
) # added key | |
with conf_col: | |
confidence = data["confidence"] | |
css_class = ( | |
"high" | |
if confidence > 0.7 | |
else "medium" | |
if confidence > 0.4 | |
else "low" | |
) | |
st.markdown( | |
f"<p class='{css_class}'>{confidence:.1%}</p>", | |
unsafe_allow_html=True, | |
) | |
st.markdown("### 📥 Export") | |
export_format = st.selectbox("Format", ["JSON", "CSV", "TXT"]) | |
export_data = prepare_export_data( | |
st.session_state.extracted_info, export_format | |
) | |
file_extension = export_format.lower() | |
st.download_button( | |
"Download", | |
export_data, | |
file_name=f"invoice_data.{file_extension}", | |
mime=f"text/{file_extension}", | |
) | |
else: | |
st.warning( | |
"Could not extract information. Please ensure the document is clear." | |
) | |
except Exception as e: | |
st.error(f"Error during processing: {e}") | |
# --- Footer --- | |
st.markdown("---") | |
st.markdown( | |
""" | |
<div style='text-align: center'> | |
<p>Created by <a href='https://github.com/doganarif' target='_blank'>Arif Dogan</a> | | |
<a href='https://huggingface.co/arifdogan' target='_blank'>🤗 Hugging Face</a></p> | |
</div> | |
""", | |
unsafe_allow_html=True, | |
) | |