omkar56's picture
Update main.py
f486b06
raw
history blame
10.9 kB
# import os
# os.system("sudo apt-get install xclip")
# import nltk
# from fastapi import FastAPI, File, Request, UploadFile, Body, Depends, HTTPException
# from fastapi.security.api_key import APIKeyHeader
# from typing import Optional, Annotated
# from fastapi.encoders import jsonable_encoder
# from PIL import Image
# from io import BytesIO
# import pytesseract
# from nltk.tokenize import sent_tokenize
# from transformers import MarianMTModel, MarianTokenizer
# nltk.download('punkt')
# API_KEY = os.environ.get("API_KEY")
# app = FastAPI()
# api_key_header = APIKeyHeader(name="api_key", auto_error=False)
# def get_api_key(api_key: Optional[str] = Depends(api_key_header)):
# if api_key is None or api_key != API_KEY:
# raise HTTPException(status_code=401, detail="Unauthorized access")
# return api_key
# # Image path
# img_dir = "./data"
# # Get tesseract language list
# choices = os.popen('tesseract --list-langs').read().split('\n')[1:-1]
# # Convert tesseract language list to pytesseract language
# def ocr_lang(lang_list):
# lang_str = ""
# lang_len = len(lang_list)
# if lang_len == 1:
# return lang_list[0]
# else:
# for i in range(lang_len):
# lang_list.insert(lang_len - i, "+")
# lang_str = "".join(lang_list[:-1])
# return lang_str
# # ocr tesseract
# def ocr_tesseract(img, languages):
# print("[img]", img)
# print("[languages]", languages)
# ocr_str = pytesseract.image_to_string(img, lang=ocr_lang(languages))
# return ocr_str
# @app.post("/api/ocr", response_model=dict)
# async def ocr(
# api_key: str = Depends(get_api_key),
# image: UploadFile = File(...),
# # languages: list = Body(["eng"])
# ):
# try:
# content = await image.read()
# image = Image.open(BytesIO(content))
# print("[image]",image)
# if hasattr(pytesseract, "image_to_string"):
# print("Image to string function is available")
# # print(pytesseract.image_to_string(image, lang = 'eng'))
# text = ocr_tesseract(image, ['eng'])
# else:
# print("Image to string function is not available")
# # text = pytesseract.image_to_string(image, lang="+".join(languages))
# except Exception as e:
# return {"error": str(e)}, 500
# return {"ImageText": "text"}
# @app.post("/api/translate", response_model=dict)
# async def translate(
# api_key: str = Depends(get_api_key),
# text: str = Body(...),
# src: str = "en",
# trg: str = "zh",
# ):
# if api_key != API_KEY:
# return {"error": "Invalid API key"}, 401
# tokenizer, model = get_model(src, trg)
# translated_text = ""
# for sentence in sent_tokenize(text):
# translated_sub = model.generate(**tokenizer(sentence, return_tensors="pt"))[0]
# translated_text += tokenizer.decode(translated_sub, skip_special_tokens=True) + "\n"
# return jsonable_encoder({"translated_text": translated_text})
# def get_model(src: str, trg: str):
# model_name = f"Helsinki-NLP/opus-mt-{src}-{trg}"
# tokenizer = MarianTokenizer.from_pretrained(model_name)
# model = MarianMTModel.from_pretrained(model_name)
# return tokenizer, model
# OCR Translate v0.2
import os
os.system("sudo apt-get install xclip")
# import gradio as gr
import nltk
import pyclip
import pytesseract
from nltk.tokenize import sent_tokenize
from transformers import MarianMTModel, MarianTokenizer
# Added below code
from fastapi import FastAPI, File, Request, UploadFile, Body, Depends, HTTPException
from fastapi.security.api_key import APIKeyHeader
from typing import Optional, Annotated
from fastapi.encoders import jsonable_encoder
from PIL import Image
from io import BytesIO
API_KEY = os.environ.get("API_KEY")
app = FastAPI()
api_key_header = APIKeyHeader(name="api_key", auto_error=False)
def get_api_key(api_key: Optional[str] = Depends(api_key_header)):
if api_key is None or api_key != API_KEY:
raise HTTPException(status_code=401, detail="Unauthorized access")
return api_key
@app.post("/api/ocr", response_model=dict)
async def ocr(
api_key: str = Depends(get_api_key),
image: UploadFile = File(...),
# languages: list = Body(["eng"])
):
try:
content = await image.read()
image = Image.open(BytesIO(content))
print("[image]",image)
if hasattr(pytesseract, "image_to_string"):
print("Image to string function is available")
# print(pytesseract.image_to_string(image, lang = 'eng'))
text = ocr_tesseract(image, ['eng'])
else:
print("Image to string function is not available")
# text = pytesseract.image_to_string(image, lang="+".join(languages))
except Exception as e:
return {"error": str(e)}, 500
return {"ImageText": "text"}
nltk.download('punkt')
OCR_TR_DESCRIPTION = '''# OCR Translate v0.2
<div id="content_align">OCR translation system based on Tesseract</div>'''
# Image path
img_dir = "./data"
# Get tesseract language list
choices = os.popen('tesseract --list-langs').read().split('\n')[1:-1]
# Translation model selection
def model_choice(src="en", trg="zh"):
# https://huggingface.co/Helsinki-NLP/opus-mt-zh-en
# https://huggingface.co/Helsinki-NLP/opus-mt-en-zh
model_name = f"Helsinki-NLP/opus-mt-{src}-{trg}" # Model name
tokenizer = MarianTokenizer.from_pretrained(model_name) # tokenizer
model = MarianMTModel.from_pretrained(model_name) # Model
return tokenizer, model
# Convert tesseract language list to pytesseract language
def ocr_lang(lang_list):
lang_str = ""
lang_len = len(lang_list)
if lang_len == 1:
return lang_list[0]
else:
for i in range(lang_len):
lang_list.insert(lang_len - i, "+")
lang_str = "".join(lang_list[:-1])
return lang_str
# ocr tesseract
def ocr_tesseract(img, languages):
print("[img]", img)
print("[languages]", languages)
ocr_str = pytesseract.image_to_string(img, lang=ocr_lang(languages))
return ocr_str
# Clear
def clear_content():
return None
# copy to clipboard
def cp_text(input_text):
# sudo apt-get install xclip
try:
pyclip.copy(input_text)
except Exception as e:
print("sudo apt-get install xclip")
print(e)
# clear clipboard
def cp_clear():
pyclip.clear()
# translate
def translate(input_text, inputs_transStyle):
# reference:https://huggingface.co/docs/transformers/model_doc/marian
if input_text is None or input_text == "":
return "System prompt: There is no content to translate!"
# Select translation model
trans_src, trans_trg = inputs_transStyle.split("-")[0], inputs_transStyle.split("-")[1]
tokenizer, model = model_choice(trans_src, trans_trg)
translate_text = ""
input_text_list = input_text.split("\n\n")
translate_text_list_tmp = []
for i in range(len(input_text_list)):
if input_text_list[i] != "":
translate_text_list_tmp.append(input_text_list[i])
for i in range(len(translate_text_list_tmp)):
translated_sub = model.generate(
**tokenizer(sent_tokenize(translate_text_list_tmp[i]), return_tensors="pt", truncation=True, padding=True))
tgt_text_sub = [tokenizer.decode(t, skip_special_tokens=True) for t in translated_sub]
translate_text_sub = "".join(tgt_text_sub)
translate_text = translate_text + "\n\n" + translate_text_sub
return translate_text[2:]
# def main():
# with gr.Blocks(css='style.css') as ocr_tr:
# gr.Markdown(OCR_TR_DESCRIPTION)
# # -------------- OCR text extraction --------------
# with gr.Box():
# with gr.Row():
# gr.Markdown("### Step 01: Text Extraction")
# with gr.Row():
# with gr.Column():
# with gr.Row():
# inputs_img = gr.Image(image_mode="RGB", source="upload", type="pil", label="image")
# with gr.Row():
# inputs_lang = gr.CheckboxGroup(choices=["chi_sim", "eng"],
# type="value",
# value=['eng'],
# label='language')
# with gr.Row():
# clear_img_btn = gr.Button('Clear')
# ocr_btn = gr.Button(value='OCR Extraction', variant="primary")
# with gr.Column():
# with gr.Row():
# outputs_text = gr.Textbox(label="Extract content", lines=20)
# with gr.Row():
# inputs_transStyle = gr.Radio(choices=["zh-en", "en-zh"],
# type="value",
# value="zh-en",
# label='translation mode')
# with gr.Row():
# clear_text_btn = gr.Button('Clear')
# translate_btn = gr.Button(value='Translate', variant="primary")
# with gr.Row():
# example_list = [["./data/test.png", ["eng"]], ["./data/test02.png", ["eng"]],
# ["./data/test03.png", ["chi_sim"]]]
# gr.Examples(example_list, [inputs_img, inputs_lang], outputs_text, ocr_tesseract, cache_examples=False)
# # -------------- translate --------------
# with gr.Box():
# with gr.Row():
# gr.Markdown("### Step 02: Translation")
# with gr.Row():
# outputs_tr_text = gr.Textbox(label="Translate Content", lines=20)
# with gr.Row():
# cp_clear_btn = gr.Button(value='Clear Clipboard')
# cp_btn = gr.Button(value='Copy to clipboard', variant="primary")
# # ---------------------- OCR Tesseract ----------------------
# ocr_btn.click(fn=ocr_tesseract, inputs=[inputs_img, inputs_lang], outputs=[
# outputs_text,])
# clear_img_btn.click(fn=clear_content, inputs=[], outputs=[inputs_img])
# # ---------------------- translate ----------------------
# translate_btn.click(fn=translate, inputs=[outputs_text, inputs_transStyle], outputs=[outputs_tr_text])
# clear_text_btn.click(fn=clear_content, inputs=[], outputs=[outputs_text])
# # ---------------------- copy to clipboard ----------------------
# cp_btn.click(fn=cp_text, inputs=[outputs_tr_text], outputs=[])
# cp_clear_btn.click(fn=cp_clear, inputs=[], outputs=[])
# ocr_tr.launch(inbrowser=True)
# if __name__ == '__main__':
# main()