File size: 1,854 Bytes
39a7537 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import re
from pathlib import Path
import jaconv
import torch
from PIL import Image
from loguru import logger
from transformers import ViTImageProcessor, AutoTokenizer, VisionEncoderDecoderModel, GenerationMixin
class MangaOcrModel(VisionEncoderDecoderModel, GenerationMixin):
pass
class MangaOcr:
def __init__(self, pretrained_model_name_or_path="kha-white/manga-ocr-base", force_cpu=False):
logger.info(f"Loading OCR model from {pretrained_model_name_or_path}")
self.processor = ViTImageProcessor.from_pretrained(pretrained_model_name_or_path)
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
self.model = MangaOcrModel.from_pretrained(pretrained_model_name_or_path)
if not force_cpu and torch.cuda.is_available():
logger.info("Using CUDA")
self.model.cuda()
elif not force_cpu and torch.backends.mps.is_available():
logger.info("Using MPS")
self.model.to("mps")
else:
logger.info("Using CPU")
logger.info("OCR ready")
def __call__(self, img_or_path):
img = img_or_path.convert("L").convert("RGB")
x = self._preprocess(img)
x = self.model.generate(x[None].to(self.model.device), max_length=300)[0].cpu()
x = self.tokenizer.decode(x, skip_special_tokens=True)
x = post_process(x)
return x
def _preprocess(self, img):
pixel_values = self.processor(img, return_tensors="pt").pixel_values
return pixel_values.squeeze()
def post_process(text):
text = "".join(text.split())
text = text.replace("…", "...")
text = re.sub("[・.]{2,}", lambda x: (x.end() - x.start()) * ".", text)
text = jaconv.h2z(text, ascii=True, digit=True)
return text
|