Implement local OCR and batch processing CLI flag

Implemented optical character recognition (OCR) in the image_to_anki function to vastly enhance performance. Additionally, allowed batch processing of images via explicitly specified batch size in command-line arguments
This commit is contained in:
Yandrik 2024-02-05 09:47:49 +01:00
parent 3d09cbbef8
commit d9eb6f1c64
4 changed files with 97 additions and 31 deletions

View File

@ -1,6 +1,7 @@
import base64 import base64
from typing import Any from typing import Any, Optional
import pytesseract
import requests import requests
from PIL import Image from PIL import Image
from io import BytesIO from io import BytesIO
@ -50,8 +51,18 @@ def crop_image_to_left_side(image: Image, crop_width) -> Image:
# Resize the image and get base64 string # Resize the image and get base64 string
# resized_image = resize_image(image_path, 1024, 512) # resized_image = resize_image(image_path, 1024, 512)
def image_to_anki(image_paths: str | list[str]) -> tuple[str | None, Any]:
# Function to perform OCR
def ocr(image: Image, lang: Optional[str] = 'eng') -> str:
text = pytesseract.image_to_string(image, lang=lang)
return text
def image_to_anki(image_paths: str | list[str], do_ocr: bool = False, lang: Optional[str] = None) -> tuple[
str | None, Any]:
images = [] images = []
ocr_results = []
if isinstance(image_paths, str): if isinstance(image_paths, str):
image_paths = [image_paths] image_paths = [image_paths]
for image_path in image_paths: for image_path in image_paths:
@ -62,11 +73,41 @@ def image_to_anki(image_paths: str | list[str]) -> tuple[str | None, Any]:
# exit(1) # exit(1)
base64_image = encode_image(cropped_image) base64_image = encode_image(cropped_image)
images.append(base64_image) images.append(base64_image)
if do_ocr:
original_image = Image.open(image_path)
print("doing local ocr...", end='')
ocr_text = ocr(original_image, lang)
print(f" done. local ocr resulted in {len(ocr_text)} characters.")
# print(ocr_text) # or save it somewhere, or add it to your payload for further processing
ocr_results.append(ocr_text)
# print(resized_image.size) # print(resized_image.size)
# exit(1) # exit(1)
# generate image payload
image_msgs = []
for i, base64_image in enumerate(images):
image_payload = {
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}",
"detail": "high"
}
}
if do_ocr:
ocr_payload = {
"type": "text",
"text": "Here are OCR results for the following page. These might be flawed. Use them to improve your "
"performance:\n " +
ocr_results[i]
}
image_msgs.append(ocr_payload)
image_msgs.append(image_payload)
headers = { headers = {
"Content-Type": "application/json", "Content-Type": "application/json",
"Authorization": f"Bearer {api_key}" "Authorization": f"Bearer {api_key}"
@ -97,18 +138,11 @@ def image_to_anki(image_paths: str | list[str]) -> tuple[str | None, Any]:
# "url": f"data:image/jpeg;base64,{base64_image}" # "url": f"data:image/jpeg;base64,{base64_image}"
# } # }
# } # }
] + [{ ] + image_msgs
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}",
"detail": "high"
}
}
for base64_image in images]
} }
], ],
"max_tokens": 600 * len(images), # in general, around 350 tokens per page, so around double to be safe "max_tokens": 600 * len(images), # in general, around 350 tokens per page, so around double to be safe
"temperature": 0.0, "temperature": 0.2,
} }
response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload) response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
@ -133,11 +167,12 @@ def test():
# image_path = 'tmp.jpg' # image_path = 'tmp.jpg'
image_path = [ image_path = [
'./.img/dict.pdf_7.png', # './.img/dict.pdf_7.png',
'./.img/dict.pdf_8.png', # './.img/dict.pdf_8.png',
'./.img/dict.pdf_103.png',
] ]
text, meta = image_to_anki(image_path) text, meta = image_to_anki(image_path, do_ocr=False, lang='eng+chi_sim')
print(text) print(text)
@ -148,10 +183,9 @@ def test():
print( print(
f'approx. cost: 0.0075$ per picture, {usage["prompt_tokens"] * 0.01 / 1000}$ for prompt tokens and {usage["completion_tokens"] * 0.01 / 1000}$ for completion tokens') f'approx. cost: 0.0075$ per picture, {usage["prompt_tokens"] * 0.01 / 1000}$ for prompt tokens and {usage["completion_tokens"] * 0.01 / 1000}$ for completion tokens')
cost_this = usage["prompt_tokens"] * 0.01 / 1000 + usage["completion_tokens"] * 0.01 / 1000 + 0.0075 cost_this = usage["prompt_tokens"] * 0.01 / 1000 + usage["completion_tokens"] * 0.03 / 1000 # + 0.0075
print(f'this page: {cost_this}$') print(f'this page: {cost_this}$')
if __name__ == '__main__': if __name__ == '__main__':
test() test()

11
main.py
View File

@ -18,6 +18,8 @@ def main():
parser.add_argument('--pages', type=str, required=True, help='Specify pages to parse in format <num>-<num>') parser.add_argument('--pages', type=str, required=True, help='Specify pages to parse in format <num>-<num>')
parser.add_argument('--output-file', type=str, default='out.md', help='Specify output file') parser.add_argument('--output-file', type=str, default='out.md', help='Specify output file')
parser.add_argument('--images-path', type=str, default='./.img/', help='Specify output file') parser.add_argument('--images-path', type=str, default='./.img/', help='Specify output file')
parser.add_argument('--ocr', type=str, default=None, help='If present, send ocr=true to the image_to_anki method, and give the string value to the lang parameter')
parser.add_argument('--batch-size', type=int, default=3, help='Decide how many pages are processed in parallel')
parser.add_argument('pdf_file', type=str, help='Specify PDF file name') parser.add_argument('pdf_file', type=str, help='Specify PDF file name')
args = parser.parse_args() args = parser.parse_args()
@ -62,11 +64,12 @@ def main():
break_outer = False break_outer = False
for i in range(len(paths) // IMGS_PER_REQUEST + 1): for i in range(len(paths) // args.batch_size + 1): # the batch size argument is used here
# print(i) # print(i)
# collect images # collect images
while True: while True:
to_process = paths[i * IMGS_PER_REQUEST:i * IMGS_PER_REQUEST + IMGS_PER_REQUEST] to_process = paths[i * args.batch_size:i * args.batch_size + args.batch_size] # the batch size argument is used here
# print(to_process) # print(to_process)
if len(to_process) == 0: if len(to_process) == 0:
# skip if remaining list is empty (e.g. if 4 pages at package size 2) # skip if remaining list is empty (e.g. if 4 pages at package size 2)
@ -74,7 +77,9 @@ def main():
print(f'processing {len(to_process)} image{"s" if len(to_process) != 1 else ""}') print(f'processing {len(to_process)} image{"s" if len(to_process) != 1 else ""}')
cards, meta = dict_to_anki.image_to_anki(to_process) ocr = True if args.ocr else False # set OCR to True if --ocr parameter is present
cards, meta = dict_to_anki.image_to_anki(to_process, do_ocr=ocr, lang=args.ocr)
if not cards: if not cards:
print("Error processing! Response: " + meta) print("Error processing! Response: " + meta)

28
poetry.lock generated
View File

@ -269,6 +269,17 @@ typing-extensions = ">=4.7,<5"
[package.extras] [package.extras]
datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"] datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"]
[[package]]
name = "packaging"
version = "23.2"
description = "Core utilities for Python packages"
optional = false
python-versions = ">=3.7"
files = [
{file = "packaging-23.2-py3-none-any.whl", hash = "sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7"},
{file = "packaging-23.2.tar.gz", hash = "sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5"},
]
[[package]] [[package]]
name = "pdf2image" name = "pdf2image"
version = "1.17.0" version = "1.17.0"
@ -478,6 +489,21 @@ files = [
[package.dependencies] [package.dependencies]
typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0"
[[package]]
name = "pytesseract"
version = "0.3.10"
description = "Python-tesseract is a python wrapper for Google's Tesseract-OCR"
optional = false
python-versions = ">=3.7"
files = [
{file = "pytesseract-0.3.10-py3-none-any.whl", hash = "sha256:8f22cc98f765bf13517ead0c70effedb46c153540d25783e04014f28b55a5fc6"},
{file = "pytesseract-0.3.10.tar.gz", hash = "sha256:f1c3a8b0f07fd01a1085d451f5b8315be6eec1d5577a6796d46dc7a62bd4120f"},
]
[package.dependencies]
packaging = ">=21.3"
Pillow = ">=8.0.0"
[[package]] [[package]]
name = "requests" name = "requests"
version = "2.31.0" version = "2.31.0"
@ -561,4 +587,4 @@ zstd = ["zstandard (>=0.18.0)"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.10" python-versions = "^3.10"
content-hash = "1df31140161c62d430257e30b1ebbff75524b5614888dfc7809f90d5f09a5737" content-hash = "07e8002d23153d51441fa4c4a70af0d6022d127f2c6c9c900cb194741e9bbe6c"

View File

@ -12,6 +12,7 @@ openai = "^1.10.0"
requests = "^2.31.0" requests = "^2.31.0"
pillow = "^10.2.0" pillow = "^10.2.0"
pdf2image = "^1.17.0" pdf2image = "^1.17.0"
pytesseract = "^0.3.10"
[build-system] [build-system]