Implement local OCR and batch processing CLI flag
Implemented optical character recognition (OCR) in the image_to_anki function to vastly enhance performance. Additionally, allowed batch processing of images via explicitly specified batch size in command-line arguments
This commit is contained in:
		| @ -1,6 +1,7 @@ | |||||||
| import base64 | import base64 | ||||||
| from typing import Any | from typing import Any, Optional | ||||||
|  |  | ||||||
|  | import pytesseract | ||||||
| import requests | import requests | ||||||
| from PIL import Image | from PIL import Image | ||||||
| from io import BytesIO | from io import BytesIO | ||||||
| @ -50,8 +51,18 @@ def crop_image_to_left_side(image: Image, crop_width) -> Image: | |||||||
| # Resize the image and get base64 string | # Resize the image and get base64 string | ||||||
| # resized_image = resize_image(image_path, 1024, 512) | # resized_image = resize_image(image_path, 1024, 512) | ||||||
|  |  | ||||||
| def image_to_anki(image_paths: str | list[str]) -> tuple[str | None, Any]: |  | ||||||
|  | # Function to perform OCR | ||||||
|  | def ocr(image: Image, lang: Optional[str] = 'eng') -> str: | ||||||
|  |     text = pytesseract.image_to_string(image, lang=lang) | ||||||
|  |     return text | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def image_to_anki(image_paths: str | list[str], do_ocr: bool = False, lang: Optional[str] = None) -> tuple[ | ||||||
|  |     str | None, Any]: | ||||||
|     images = [] |     images = [] | ||||||
|  |     ocr_results = [] | ||||||
|  |  | ||||||
|     if isinstance(image_paths, str): |     if isinstance(image_paths, str): | ||||||
|         image_paths = [image_paths] |         image_paths = [image_paths] | ||||||
|     for image_path in image_paths: |     for image_path in image_paths: | ||||||
| @ -62,11 +73,41 @@ def image_to_anki(image_paths: str | list[str]) -> tuple[str | None, Any]: | |||||||
|         # exit(1) |         # exit(1) | ||||||
|         base64_image = encode_image(cropped_image) |         base64_image = encode_image(cropped_image) | ||||||
|         images.append(base64_image) |         images.append(base64_image) | ||||||
|  |         if do_ocr: | ||||||
|  |             original_image = Image.open(image_path) | ||||||
|  |             print("doing local ocr...", end='') | ||||||
|  |             ocr_text = ocr(original_image, lang) | ||||||
|  |             print(f" done. local ocr resulted in {len(ocr_text)} characters.") | ||||||
|  |             # print(ocr_text)  # or save it somewhere, or add it to your payload for further processing | ||||||
|  |             ocr_results.append(ocr_text) | ||||||
|     # print(resized_image.size) |     # print(resized_image.size) | ||||||
|  |  | ||||||
|  |  | ||||||
|     # exit(1) |     # exit(1) | ||||||
|  |  | ||||||
|  |     # generate image payload | ||||||
|  |  | ||||||
|  |     image_msgs = [] | ||||||
|  |  | ||||||
|  |     for i, base64_image in enumerate(images): | ||||||
|  |         image_payload = { | ||||||
|  |             "type": "image_url", | ||||||
|  |             "image_url": { | ||||||
|  |                 "url": f"data:image/jpeg;base64,{base64_image}", | ||||||
|  |                 "detail": "high" | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |         if do_ocr: | ||||||
|  |             ocr_payload = { | ||||||
|  |                 "type": "text", | ||||||
|  |                 "text": "Here are OCR results for the following page. These might be flawed. Use them to improve your " | ||||||
|  |                         "performance:\n " + | ||||||
|  |                         ocr_results[i] | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |             image_msgs.append(ocr_payload) | ||||||
|  |  | ||||||
|  |         image_msgs.append(image_payload) | ||||||
|  |  | ||||||
|     headers = { |     headers = { | ||||||
|         "Content-Type": "application/json", |         "Content-Type": "application/json", | ||||||
|         "Authorization": f"Bearer {api_key}" |         "Authorization": f"Bearer {api_key}" | ||||||
| @ -97,18 +138,11 @@ def image_to_anki(image_paths: str | list[str]) -> tuple[str | None, Any]: | |||||||
|                                #         "url": f"data:image/jpeg;base64,{base64_image}" |                                #         "url": f"data:image/jpeg;base64,{base64_image}" | ||||||
|                                #     } |                                #     } | ||||||
|                                # } |                                # } | ||||||
|                 ] + [{ |                            ] + image_msgs | ||||||
|                     "type": "image_url", |  | ||||||
|                     "image_url": { |  | ||||||
|                         "url": f"data:image/jpeg;base64,{base64_image}", |  | ||||||
|                         "detail": "high" |  | ||||||
|                     } |  | ||||||
|                 } |  | ||||||
|                     for base64_image in images] |  | ||||||
|             } |             } | ||||||
|         ], |         ], | ||||||
|         "max_tokens": 600 * len(images),  # in general, around 350 tokens per page, so around double to be safe |         "max_tokens": 600 * len(images),  # in general, around 350 tokens per page, so around double to be safe | ||||||
|         "temperature": 0.0, |         "temperature": 0.2, | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload) |     response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload) | ||||||
| @ -133,11 +167,12 @@ def test(): | |||||||
|     # image_path = 'tmp.jpg' |     # image_path = 'tmp.jpg' | ||||||
|  |  | ||||||
|     image_path = [ |     image_path = [ | ||||||
|         './.img/dict.pdf_7.png', |         # './.img/dict.pdf_7.png', | ||||||
|         './.img/dict.pdf_8.png', |         # './.img/dict.pdf_8.png', | ||||||
|  |         './.img/dict.pdf_103.png', | ||||||
|     ] |     ] | ||||||
|  |  | ||||||
|     text, meta = image_to_anki(image_path) |     text, meta = image_to_anki(image_path, do_ocr=False, lang='eng+chi_sim') | ||||||
|  |  | ||||||
|     print(text) |     print(text) | ||||||
|  |  | ||||||
| @ -148,10 +183,9 @@ def test(): | |||||||
|     print( |     print( | ||||||
|         f'approx. cost: 0.0075$ per picture, {usage["prompt_tokens"] * 0.01 / 1000}$ for prompt tokens and {usage["completion_tokens"] * 0.01 / 1000}$ for completion tokens') |         f'approx. cost: 0.0075$ per picture, {usage["prompt_tokens"] * 0.01 / 1000}$ for prompt tokens and {usage["completion_tokens"] * 0.01 / 1000}$ for completion tokens') | ||||||
|  |  | ||||||
|     cost_this = usage["prompt_tokens"] * 0.01 / 1000 + usage["completion_tokens"] * 0.01 / 1000 + 0.0075 |     cost_this = usage["prompt_tokens"] * 0.01 / 1000 + usage["completion_tokens"] * 0.03 / 1000 # + 0.0075 | ||||||
|     print(f'this page: {cost_this}$') |     print(f'this page: {cost_this}$') | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|     test() |     test() | ||||||
|  | |||||||
							
								
								
									
										11
									
								
								main.py
									
									
									
									
									
								
							
							
						
						
									
										11
									
								
								main.py
									
									
									
									
									
								
							| @ -18,6 +18,8 @@ def main(): | |||||||
|     parser.add_argument('--pages', type=str, required=True, help='Specify pages to parse in format <num>-<num>') |     parser.add_argument('--pages', type=str, required=True, help='Specify pages to parse in format <num>-<num>') | ||||||
|     parser.add_argument('--output-file', type=str, default='out.md', help='Specify output file') |     parser.add_argument('--output-file', type=str, default='out.md', help='Specify output file') | ||||||
|     parser.add_argument('--images-path', type=str, default='./.img/', help='Specify output file') |     parser.add_argument('--images-path', type=str, default='./.img/', help='Specify output file') | ||||||
|  |     parser.add_argument('--ocr', type=str, default=None, help='If present, send ocr=true to the image_to_anki method, and give the string value to the lang parameter') | ||||||
|  |     parser.add_argument('--batch-size', type=int, default=3, help='Decide how many pages are processed in parallel') | ||||||
|     parser.add_argument('pdf_file', type=str, help='Specify PDF file name') |     parser.add_argument('pdf_file', type=str, help='Specify PDF file name') | ||||||
|  |  | ||||||
|     args = parser.parse_args() |     args = parser.parse_args() | ||||||
| @ -62,11 +64,12 @@ def main(): | |||||||
|  |  | ||||||
|     break_outer = False |     break_outer = False | ||||||
|  |  | ||||||
|     for i in range(len(paths) // IMGS_PER_REQUEST + 1): |     for i in range(len(paths) // args.batch_size + 1):  # the batch size argument is used here | ||||||
|         # print(i) |         # print(i) | ||||||
|  |  | ||||||
|         # collect images |         # collect images | ||||||
|         while True: |         while True: | ||||||
|             to_process = paths[i * IMGS_PER_REQUEST:i * IMGS_PER_REQUEST + IMGS_PER_REQUEST] |             to_process = paths[i * args.batch_size:i * args.batch_size + args.batch_size]  # the batch size argument is used here | ||||||
|             # print(to_process) |             # print(to_process) | ||||||
|             if len(to_process) == 0: |             if len(to_process) == 0: | ||||||
|                 # skip if remaining list is empty (e.g. if 4 pages at package size 2) |                 # skip if remaining list is empty (e.g. if 4 pages at package size 2) | ||||||
| @ -74,7 +77,9 @@ def main(): | |||||||
|  |  | ||||||
|             print(f'processing {len(to_process)} image{"s" if len(to_process) != 1 else ""}') |             print(f'processing {len(to_process)} image{"s" if len(to_process) != 1 else ""}') | ||||||
|  |  | ||||||
|             cards, meta = dict_to_anki.image_to_anki(to_process) |             ocr = True if args.ocr else False  # set OCR to True if --ocr parameter is present | ||||||
|  |  | ||||||
|  |             cards, meta = dict_to_anki.image_to_anki(to_process, do_ocr=ocr, lang=args.ocr) | ||||||
|  |  | ||||||
|             if not cards: |             if not cards: | ||||||
|                 print("Error processing! Response: " + meta) |                 print("Error processing! Response: " + meta) | ||||||
|  | |||||||
							
								
								
									
										28
									
								
								poetry.lock
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										28
									
								
								poetry.lock
									
									
									
										generated
									
									
									
								
							| @ -269,6 +269,17 @@ typing-extensions = ">=4.7,<5" | |||||||
| [package.extras] | [package.extras] | ||||||
| datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"] | datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"] | ||||||
|  |  | ||||||
|  | [[package]] | ||||||
|  | name = "packaging" | ||||||
|  | version = "23.2" | ||||||
|  | description = "Core utilities for Python packages" | ||||||
|  | optional = false | ||||||
|  | python-versions = ">=3.7" | ||||||
|  | files = [ | ||||||
|  |     {file = "packaging-23.2-py3-none-any.whl", hash = "sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7"}, | ||||||
|  |     {file = "packaging-23.2.tar.gz", hash = "sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5"}, | ||||||
|  | ] | ||||||
|  |  | ||||||
| [[package]] | [[package]] | ||||||
| name = "pdf2image" | name = "pdf2image" | ||||||
| version = "1.17.0" | version = "1.17.0" | ||||||
| @ -478,6 +489,21 @@ files = [ | |||||||
| [package.dependencies] | [package.dependencies] | ||||||
| typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" | typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" | ||||||
|  |  | ||||||
|  | [[package]] | ||||||
|  | name = "pytesseract" | ||||||
|  | version = "0.3.10" | ||||||
|  | description = "Python-tesseract is a python wrapper for Google's Tesseract-OCR" | ||||||
|  | optional = false | ||||||
|  | python-versions = ">=3.7" | ||||||
|  | files = [ | ||||||
|  |     {file = "pytesseract-0.3.10-py3-none-any.whl", hash = "sha256:8f22cc98f765bf13517ead0c70effedb46c153540d25783e04014f28b55a5fc6"}, | ||||||
|  |     {file = "pytesseract-0.3.10.tar.gz", hash = "sha256:f1c3a8b0f07fd01a1085d451f5b8315be6eec1d5577a6796d46dc7a62bd4120f"}, | ||||||
|  | ] | ||||||
|  |  | ||||||
|  | [package.dependencies] | ||||||
|  | packaging = ">=21.3" | ||||||
|  | Pillow = ">=8.0.0" | ||||||
|  |  | ||||||
| [[package]] | [[package]] | ||||||
| name = "requests" | name = "requests" | ||||||
| version = "2.31.0" | version = "2.31.0" | ||||||
| @ -561,4 +587,4 @@ zstd = ["zstandard (>=0.18.0)"] | |||||||
| [metadata] | [metadata] | ||||||
| lock-version = "2.0" | lock-version = "2.0" | ||||||
| python-versions = "^3.10" | python-versions = "^3.10" | ||||||
| content-hash = "1df31140161c62d430257e30b1ebbff75524b5614888dfc7809f90d5f09a5737" | content-hash = "07e8002d23153d51441fa4c4a70af0d6022d127f2c6c9c900cb194741e9bbe6c" | ||||||
|  | |||||||
| @ -12,6 +12,7 @@ openai = "^1.10.0" | |||||||
| requests = "^2.31.0" | requests = "^2.31.0" | ||||||
| pillow = "^10.2.0" | pillow = "^10.2.0" | ||||||
| pdf2image = "^1.17.0" | pdf2image = "^1.17.0" | ||||||
|  | pytesseract = "^0.3.10" | ||||||
|  |  | ||||||
|  |  | ||||||
| [build-system] | [build-system] | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user