Implement local OCR and batch processing CLI flag
Implemented optical character recognition (OCR) in the image_to_anki function to vastly enhance performance. Additionally, allowed batch processing of images via explicitly specified batch size in command-line arguments
This commit is contained in:
parent
3d09cbbef8
commit
d9eb6f1c64
@ -1,6 +1,7 @@
|
|||||||
import base64
|
import base64
|
||||||
from typing import Any
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import pytesseract
|
||||||
import requests
|
import requests
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
@ -50,8 +51,18 @@ def crop_image_to_left_side(image: Image, crop_width) -> Image:
|
|||||||
# Resize the image and get base64 string
|
# Resize the image and get base64 string
|
||||||
# resized_image = resize_image(image_path, 1024, 512)
|
# resized_image = resize_image(image_path, 1024, 512)
|
||||||
|
|
||||||
def image_to_anki(image_paths: str | list[str]) -> tuple[str | None, Any]:
|
|
||||||
|
# Function to perform OCR
|
||||||
|
def ocr(image: Image, lang: Optional[str] = 'eng') -> str:
|
||||||
|
text = pytesseract.image_to_string(image, lang=lang)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def image_to_anki(image_paths: str | list[str], do_ocr: bool = False, lang: Optional[str] = None) -> tuple[
|
||||||
|
str | None, Any]:
|
||||||
images = []
|
images = []
|
||||||
|
ocr_results = []
|
||||||
|
|
||||||
if isinstance(image_paths, str):
|
if isinstance(image_paths, str):
|
||||||
image_paths = [image_paths]
|
image_paths = [image_paths]
|
||||||
for image_path in image_paths:
|
for image_path in image_paths:
|
||||||
@ -62,11 +73,41 @@ def image_to_anki(image_paths: str | list[str]) -> tuple[str | None, Any]:
|
|||||||
# exit(1)
|
# exit(1)
|
||||||
base64_image = encode_image(cropped_image)
|
base64_image = encode_image(cropped_image)
|
||||||
images.append(base64_image)
|
images.append(base64_image)
|
||||||
|
if do_ocr:
|
||||||
|
original_image = Image.open(image_path)
|
||||||
|
print("doing local ocr...", end='')
|
||||||
|
ocr_text = ocr(original_image, lang)
|
||||||
|
print(f" done. local ocr resulted in {len(ocr_text)} characters.")
|
||||||
|
# print(ocr_text) # or save it somewhere, or add it to your payload for further processing
|
||||||
|
ocr_results.append(ocr_text)
|
||||||
# print(resized_image.size)
|
# print(resized_image.size)
|
||||||
|
|
||||||
|
|
||||||
# exit(1)
|
# exit(1)
|
||||||
|
|
||||||
|
# generate image payload
|
||||||
|
|
||||||
|
image_msgs = []
|
||||||
|
|
||||||
|
for i, base64_image in enumerate(images):
|
||||||
|
image_payload = {
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": f"data:image/jpeg;base64,{base64_image}",
|
||||||
|
"detail": "high"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if do_ocr:
|
||||||
|
ocr_payload = {
|
||||||
|
"type": "text",
|
||||||
|
"text": "Here are OCR results for the following page. These might be flawed. Use them to improve your "
|
||||||
|
"performance:\n " +
|
||||||
|
ocr_results[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
image_msgs.append(ocr_payload)
|
||||||
|
|
||||||
|
image_msgs.append(image_payload)
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"Authorization": f"Bearer {api_key}"
|
"Authorization": f"Bearer {api_key}"
|
||||||
@ -87,28 +128,21 @@ def image_to_anki(image_paths: str | list[str]) -> tuple[str | None, Any]:
|
|||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": [
|
"content": [
|
||||||
{
|
{
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": "Transform this image into Anki cards."
|
"text": "Transform this image into Anki cards."
|
||||||
},
|
},
|
||||||
# {
|
# {
|
||||||
# "type": "image_url",
|
# "type": "image_url",
|
||||||
# "image_url": {
|
# "image_url": {
|
||||||
# "url": f"data:image/jpeg;base64,{base64_image}"
|
# "url": f"data:image/jpeg;base64,{base64_image}"
|
||||||
# }
|
# }
|
||||||
# }
|
# }
|
||||||
] + [{
|
] + image_msgs
|
||||||
"type": "image_url",
|
|
||||||
"image_url": {
|
|
||||||
"url": f"data:image/jpeg;base64,{base64_image}",
|
|
||||||
"detail": "high"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for base64_image in images]
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"max_tokens": 600 * len(images), # in general, around 350 tokens per page, so around double to be safe
|
"max_tokens": 600 * len(images), # in general, around 350 tokens per page, so around double to be safe
|
||||||
"temperature": 0.0,
|
"temperature": 0.2,
|
||||||
}
|
}
|
||||||
|
|
||||||
response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
|
response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
|
||||||
@ -133,11 +167,12 @@ def test():
|
|||||||
# image_path = 'tmp.jpg'
|
# image_path = 'tmp.jpg'
|
||||||
|
|
||||||
image_path = [
|
image_path = [
|
||||||
'./.img/dict.pdf_7.png',
|
# './.img/dict.pdf_7.png',
|
||||||
'./.img/dict.pdf_8.png',
|
# './.img/dict.pdf_8.png',
|
||||||
|
'./.img/dict.pdf_103.png',
|
||||||
]
|
]
|
||||||
|
|
||||||
text, meta = image_to_anki(image_path)
|
text, meta = image_to_anki(image_path, do_ocr=False, lang='eng+chi_sim')
|
||||||
|
|
||||||
print(text)
|
print(text)
|
||||||
|
|
||||||
@ -148,10 +183,9 @@ def test():
|
|||||||
print(
|
print(
|
||||||
f'approx. cost: 0.0075$ per picture, {usage["prompt_tokens"] * 0.01 / 1000}$ for prompt tokens and {usage["completion_tokens"] * 0.01 / 1000}$ for completion tokens')
|
f'approx. cost: 0.0075$ per picture, {usage["prompt_tokens"] * 0.01 / 1000}$ for prompt tokens and {usage["completion_tokens"] * 0.01 / 1000}$ for completion tokens')
|
||||||
|
|
||||||
cost_this = usage["prompt_tokens"] * 0.01 / 1000 + usage["completion_tokens"] * 0.01 / 1000 + 0.0075
|
cost_this = usage["prompt_tokens"] * 0.01 / 1000 + usage["completion_tokens"] * 0.03 / 1000 # + 0.0075
|
||||||
print(f'this page: {cost_this}$')
|
print(f'this page: {cost_this}$')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test()
|
test()
|
||||||
|
11
main.py
11
main.py
@ -18,6 +18,8 @@ def main():
|
|||||||
parser.add_argument('--pages', type=str, required=True, help='Specify pages to parse in format <num>-<num>')
|
parser.add_argument('--pages', type=str, required=True, help='Specify pages to parse in format <num>-<num>')
|
||||||
parser.add_argument('--output-file', type=str, default='out.md', help='Specify output file')
|
parser.add_argument('--output-file', type=str, default='out.md', help='Specify output file')
|
||||||
parser.add_argument('--images-path', type=str, default='./.img/', help='Specify output file')
|
parser.add_argument('--images-path', type=str, default='./.img/', help='Specify output file')
|
||||||
|
parser.add_argument('--ocr', type=str, default=None, help='If present, send ocr=true to the image_to_anki method, and give the string value to the lang parameter')
|
||||||
|
parser.add_argument('--batch-size', type=int, default=3, help='Decide how many pages are processed in parallel')
|
||||||
parser.add_argument('pdf_file', type=str, help='Specify PDF file name')
|
parser.add_argument('pdf_file', type=str, help='Specify PDF file name')
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@ -62,11 +64,12 @@ def main():
|
|||||||
|
|
||||||
break_outer = False
|
break_outer = False
|
||||||
|
|
||||||
for i in range(len(paths) // IMGS_PER_REQUEST + 1):
|
for i in range(len(paths) // args.batch_size + 1): # the batch size argument is used here
|
||||||
# print(i)
|
# print(i)
|
||||||
|
|
||||||
# collect images
|
# collect images
|
||||||
while True:
|
while True:
|
||||||
to_process = paths[i * IMGS_PER_REQUEST:i * IMGS_PER_REQUEST + IMGS_PER_REQUEST]
|
to_process = paths[i * args.batch_size:i * args.batch_size + args.batch_size] # the batch size argument is used here
|
||||||
# print(to_process)
|
# print(to_process)
|
||||||
if len(to_process) == 0:
|
if len(to_process) == 0:
|
||||||
# skip if remaining list is empty (e.g. if 4 pages at package size 2)
|
# skip if remaining list is empty (e.g. if 4 pages at package size 2)
|
||||||
@ -74,7 +77,9 @@ def main():
|
|||||||
|
|
||||||
print(f'processing {len(to_process)} image{"s" if len(to_process) != 1 else ""}')
|
print(f'processing {len(to_process)} image{"s" if len(to_process) != 1 else ""}')
|
||||||
|
|
||||||
cards, meta = dict_to_anki.image_to_anki(to_process)
|
ocr = True if args.ocr else False # set OCR to True if --ocr parameter is present
|
||||||
|
|
||||||
|
cards, meta = dict_to_anki.image_to_anki(to_process, do_ocr=ocr, lang=args.ocr)
|
||||||
|
|
||||||
if not cards:
|
if not cards:
|
||||||
print("Error processing! Response: " + meta)
|
print("Error processing! Response: " + meta)
|
||||||
|
28
poetry.lock
generated
28
poetry.lock
generated
@ -269,6 +269,17 @@ typing-extensions = ">=4.7,<5"
|
|||||||
[package.extras]
|
[package.extras]
|
||||||
datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"]
|
datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "packaging"
|
||||||
|
version = "23.2"
|
||||||
|
description = "Core utilities for Python packages"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "packaging-23.2-py3-none-any.whl", hash = "sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7"},
|
||||||
|
{file = "packaging-23.2.tar.gz", hash = "sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5"},
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pdf2image"
|
name = "pdf2image"
|
||||||
version = "1.17.0"
|
version = "1.17.0"
|
||||||
@ -478,6 +489,21 @@ files = [
|
|||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0"
|
typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pytesseract"
|
||||||
|
version = "0.3.10"
|
||||||
|
description = "Python-tesseract is a python wrapper for Google's Tesseract-OCR"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "pytesseract-0.3.10-py3-none-any.whl", hash = "sha256:8f22cc98f765bf13517ead0c70effedb46c153540d25783e04014f28b55a5fc6"},
|
||||||
|
{file = "pytesseract-0.3.10.tar.gz", hash = "sha256:f1c3a8b0f07fd01a1085d451f5b8315be6eec1d5577a6796d46dc7a62bd4120f"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
packaging = ">=21.3"
|
||||||
|
Pillow = ">=8.0.0"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "requests"
|
name = "requests"
|
||||||
version = "2.31.0"
|
version = "2.31.0"
|
||||||
@ -561,4 +587,4 @@ zstd = ["zstandard (>=0.18.0)"]
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.10"
|
python-versions = "^3.10"
|
||||||
content-hash = "1df31140161c62d430257e30b1ebbff75524b5614888dfc7809f90d5f09a5737"
|
content-hash = "07e8002d23153d51441fa4c4a70af0d6022d127f2c6c9c900cb194741e9bbe6c"
|
||||||
|
@ -12,6 +12,7 @@ openai = "^1.10.0"
|
|||||||
requests = "^2.31.0"
|
requests = "^2.31.0"
|
||||||
pillow = "^10.2.0"
|
pillow = "^10.2.0"
|
||||||
pdf2image = "^1.17.0"
|
pdf2image = "^1.17.0"
|
||||||
|
pytesseract = "^0.3.10"
|
||||||
|
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
|
Loading…
Reference in New Issue
Block a user