Sonofica / utils /japanese_ocr.py
janmayjay's picture
Add application file
39a7537
import re
from pathlib import Path
import jaconv
import torch
from PIL import Image
from loguru import logger
from transformers import ViTImageProcessor, AutoTokenizer, VisionEncoderDecoderModel, GenerationMixin
class MangaOcrModel(VisionEncoderDecoderModel, GenerationMixin):
pass
class MangaOcr:
def __init__(self, pretrained_model_name_or_path="kha-white/manga-ocr-base", force_cpu=False):
logger.info(f"Loading OCR model from {pretrained_model_name_or_path}")
self.processor = ViTImageProcessor.from_pretrained(pretrained_model_name_or_path)
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
self.model = MangaOcrModel.from_pretrained(pretrained_model_name_or_path)
if not force_cpu and torch.cuda.is_available():
logger.info("Using CUDA")
self.model.cuda()
elif not force_cpu and torch.backends.mps.is_available():
logger.info("Using MPS")
self.model.to("mps")
else:
logger.info("Using CPU")
logger.info("OCR ready")
def __call__(self, img_or_path):
img = img_or_path.convert("L").convert("RGB")
x = self._preprocess(img)
x = self.model.generate(x[None].to(self.model.device), max_length=300)[0].cpu()
x = self.tokenizer.decode(x, skip_special_tokens=True)
x = post_process(x)
return x
def _preprocess(self, img):
pixel_values = self.processor(img, return_tensors="pt").pixel_values
return pixel_values.squeeze()
def post_process(text):
text = "".join(text.split())
text = text.replace("…", "...")
text = re.sub("[・.]{2,}", lambda x: (x.end() - x.start()) * ".", text)
text = jaconv.h2z(text, ascii=True, digit=True)
return text