Skip to content

Commit

Permalink
feat: improve JSON OCR generation script
Browse files Browse the repository at this point in the history
  • Loading branch information
raphael0202 committed Oct 17, 2022
1 parent 92b01d9 commit 603e355
Showing 1 changed file with 53 additions and 23 deletions.
76 changes: 53 additions & 23 deletions scripts/ocr/run_ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,16 @@

import argparse
import base64
import glob
import gzip
import json
import os
import pathlib
import sys
import time
from datetime import datetime
from typing import List, Optional

import orjson
import requests

API_KEY = os.environ.get("CLOUD_VISION_API_KEY")
Expand Down Expand Up @@ -119,7 +120,7 @@ def run_ocr_on_image_paths(image_paths: List[pathlib.Path], override: bool = Fal
print(image_paths)
return [], True

r_json = r.json()
r_json = orjson.loads(r.content)
responses = r_json["responses"]
return (
[(images_content[i][0], responses[i]) for i in range(len(images_content))],
Expand All @@ -135,15 +136,20 @@ def dump_ocr(
for image_path, response in responses:
json_path = image_path.with_suffix(".json.gz")

with gzip.open(str(json_path), "wt") as f:
with gzip.open(str(json_path), "wb") as f:
# print("Dumping OCR JSON to {}".format(json_path))
json.dump({"responses": [response]}, f)
f.write(orjson.dumps({"responses": [response]}))

if performed_request and sleep:
time.sleep(sleep)


def add_missing_ocr(sleep: float):
def add_to_seen_set(seen_path: pathlib.Path, item: str):
with seen_path.open("a", encoding="utf-8") as f:
f.write("{}\n".format(item))


def add_missing_ocr(sleep: float, seen_path: pathlib.Path):
total = 0
missing = 0
json_error = 0
Expand All @@ -152,17 +158,43 @@ def add_missing_ocr(sleep: float):
empty_images = 0
expired = 0

for i, image_path in enumerate(BASE_IMAGE_DIR.glob("**/*.jpg")):
with seen_path.open("r", encoding="utf-8") as f:
seen_set = set(map(str.strip, f))

for i, image_path_str in enumerate(
glob.iglob("{}/**/*.jpg".format(BASE_IMAGE_DIR))
):
if i % 10000 == 0:
print(
"scanned: {}, total: {}, missing: {}, json_error: {}, ocr_error: {}, empty images: {}, valid: {}, "
"expired: {}".format(
i,
total,
missing,
json_error,
ocr_error,
empty_images,
valid,
expired,
)
)

image_path = pathlib.Path(image_path_str)
if not image_path.stem.isdigit():
continue

if image_path_str in seen_set:
continue

image_size = image_path.stat().st_size

if not image_size:
empty_images += 1
add_to_seen_set(seen_path, image_path_str)
continue

if image_size >= 10485760:
add_to_seen_set(seen_path, image_path_str)
continue

json_path = image_path.with_suffix(".json.gz")
Expand All @@ -175,47 +207,45 @@ def add_missing_ocr(sleep: float):

missing += 1
dump_ocr([image_path], sleep=sleep, override=False)
add_to_seen_set(seen_path, image_path_str)
continue

modification_datetime = datetime.fromtimestamp(json_path.stat().st_mtime)
if modification_datetime < MAXIMUM_MODIFICATION_DATETIME:
dump_ocr([image_path], sleep=sleep, override=True)
expired += 1
dump_ocr([image_path], sleep=sleep, override=True)
add_to_seen_set(seen_path, image_path_str)
continue

has_json_error = False
with gzip.open(str(json_path), "rt", encoding="utf-8") as f:
with gzip.open(str(json_path), "rb") as f:
try:
data = json.load(f)
except json.JSONDecodeError:
data = orjson.loads(f.read())
except orjson.JSONDecodeError:
has_json_error = True

if has_json_error:
dump_ocr([image_path], sleep=sleep, override=True)
json_error += 1
dump_ocr([image_path], sleep=sleep, override=True)
add_to_seen_set(seen_path, image_path_str)
continue

has_error = False
for response in data["responses"]:
if "error" in response:
ocr_error += 1
has_error = True
dump_ocr([image_path], sleep=sleep, override=True)

if not has_error:
if has_error:
ocr_error += 1
dump_ocr([image_path], sleep=sleep, override=True)
add_to_seen_set(seen_path, image_path_str)
else:
valid += 1

if i % 1000 == 0:
print(
"total: {}, missing: {}, json_error: {}, ocr_error: {}, empty images: {}, valid: {}, "
"expired: {}".format(
total, missing, json_error, ocr_error, empty_images, valid, expired
)
)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--sleep", type=float, default=1.0)
parser.add_argument("--seen-path", type=pathlib.Path, required=True)
args = parser.parse_args()
add_missing_ocr(sleep=args.sleep)
add_missing_ocr(sleep=args.sleep, seen_path=args.seen_path)

0 comments on commit 603e355

Please sign in to comment.