import io import json import fitz import streamlit as st import torch from PIL import Image, ImageGrab from transformers import pipeline # --- Configuration and Setup --- DEVICE = 0 if torch.cuda.is_available() else -1 st.set_page_config( page_title="Invoice AI | by Arif Dogan", page_icon="๐งพ", layout="wide", initial_sidebar_state="collapsed", ) # --- Styling --- st.markdown( """ """, unsafe_allow_html=True, ) # --- Functions --- @st.cache_resource def load_model(): return pipeline( "document-question-answering", model="faisalraza/layoutlm-invoices", device=DEVICE, ) def process_pdf(pdf_file): pdf_content = pdf_file.read() pdf_stream = io.BytesIO(pdf_content) try: with fitz.open(stream=pdf_stream, filetype="pdf") as pdf_document: if pdf_document.page_count > 0: page = pdf_document[0] pix = page.get_pixmap(matrix=fitz.Matrix(300 / 72, 300 / 72)) img_data = pix.tobytes("png") return Image.open(io.BytesIO(img_data)), pdf_document.page_count else: raise ValueError("PDF has no pages") except Exception as e: raise e finally: pdf_stream.close() def process_image(uploaded_file): uploaded_file.seek(0) if uploaded_file.type == "application/pdf": return process_pdf(uploaded_file) return Image.open(uploaded_file), 1 def get_clipboard_image(): try: img = ImageGrab.grabclipboard() return (img, 1) if isinstance(img, Image.Image) else (None, 0) except Exception: return None, 0 def prepare_export_data(extracted_info, format_type): if format_type == "JSON": return json.dumps( {field: data["value"] for field, data in extracted_info.items()}, indent=2 ) elif format_type == "CSV": header = ",".join(extracted_info.keys()) values = ",".join(f'"{data["value"]}"' for data in extracted_info.values()) return f"{header}\n{values}" else: # TXT return "\n".join( f"{field}: {data['value']}" for field, data in extracted_info.items() ) def extract_information(model, image, questions, progress_bar, status_text): extracted_info = {} for idx, question in enumerate(questions): try: # Update progress bar and status text progress_bar.progress((idx + 1) / len(questions)) status_text.text(f"Processing: {question} ({idx + 1}/{len(questions)})") response = model(image=image, question=question) if ( response and response[0].get("answer", "").strip() ): # Check for non-empty answer answer = response[0]["answer"] confidence = response[0]["score"] if confidence > 0.1: field = ( question.replace("What is the ", "").replace("?", "").title() ) extracted_info[field] = {"value": answer, "confidence": confidence} except Exception: continue # Handle potential errors during model processing return extracted_info # --- Initialization --- if "processed_image" not in st.session_state: st.session_state.processed_image = None if "extracted_info" not in st.session_state: st.session_state.extracted_info = {} # --- UI Layout --- st.markdown( """
Powered by LayoutLM
{confidence:.1%}
", unsafe_allow_html=True, ) st.markdown("### ๐ฅ Export") export_format = st.selectbox("Format", ["JSON", "CSV", "TXT"]) export_data = prepare_export_data( st.session_state.extracted_info, export_format ) file_extension = export_format.lower() st.download_button( "Download", export_data, file_name=f"invoice_data.{file_extension}", mime=f"text/{file_extension}", ) else: st.warning( "Could not extract information. Please ensure the document is clear." ) except Exception as e: st.error(f"Error during processing: {e}") # --- Footer --- st.markdown("---") st.markdown( """Created by Arif Dogan | ๐ค Hugging Face