Text-to-Image
Diffusers
Safetensors
tolgacangoz commited on
Commit
760c120
·
verified ·
1 Parent(s): 17cf7e7

Upload anytext.py

Browse files
Files changed (1) hide show
  1. text_embedding_module/anytext.py +8 -13
text_embedding_module/anytext.py CHANGED
@@ -35,6 +35,7 @@ import torch
35
  import torch.nn.functional as F
36
  from bert_tokenizer import BasicTokenizer
37
  from easydict import EasyDict as edict
 
38
  from frozen_clip_embedder_t3 import FrozenCLIPEmbedderT3
39
  from ocr_recog.RecModel import RecModel
40
  from PIL import Image, ImageDraw, ImageFont
@@ -271,18 +272,12 @@ def crop_image(src_img, mask):
271
 
272
  def create_predictor(model_dir=None, model_lang="ch", device="cpu", use_fp16=False):
273
  if model_dir is None or not os.path.exists(model_dir):
274
- try:
275
- # Use the repo id from which the pipeline was loaded
276
- model_dir = hf_hub_download(
277
- repo_id="tolgacangoz/anytext",
278
- filename="text_embedding_module/OCR/ppv3_rec.pth",
279
- local_dir=".cache/diffusers",
280
- local_dir_use_symlinks=True
281
- )
282
- except Exception as e:
283
- raise ValueError(f"Could not download the model file: {e}")
284
-
285
- if model_dir is not None and not os.path.exists(model_dir):
286
  raise ValueError("not find model file path {}".format(model_dir))
287
 
288
  if model_lang == "ch":
@@ -476,7 +471,7 @@ class TextEmbeddingModule(nn.Module):
476
  args["rec_image_shape"] = "3, 48, 320"
477
  args["rec_batch_num"] = 6
478
  args["rec_char_dict_path"] = "./text_embedding_module/OCR/ppocr_keys_v1.txt"
479
- args["use_fp16"] = self.use_fp16
480
  self.embedding_manager.recog = TextRecognizer(args, self.text_predictor)
481
 
482
  @torch.no_grad()
 
35
  import torch.nn.functional as F
36
  from bert_tokenizer import BasicTokenizer
37
  from easydict import EasyDict as edict
38
+ from diffusers.utils.constants import HF_MODULES_CACHE
39
  from frozen_clip_embedder_t3 import FrozenCLIPEmbedderT3
40
  from ocr_recog.RecModel import RecModel
41
  from PIL import Image, ImageDraw, ImageFont
 
272
 
273
  def create_predictor(model_dir=None, model_lang="ch", device="cpu", use_fp16=False):
274
  if model_dir is None or not os.path.exists(model_dir):
275
+ model_dir = hf_hub_download(
276
+ repo_id="tolgacangoz/anytext",
277
+ filename="text_embedding_module/OCR/ppv3_rec.pth",
278
+ cache_dir=HF_MODULES_CACHE
279
+ )
280
+ if not os.path.exists(model_dir):
 
 
 
 
 
 
281
  raise ValueError("not find model file path {}".format(model_dir))
282
 
283
  if model_lang == "ch":
 
471
  args["rec_image_shape"] = "3, 48, 320"
472
  args["rec_batch_num"] = 6
473
  args["rec_char_dict_path"] = "./text_embedding_module/OCR/ppocr_keys_v1.txt"
474
+ args["use_fp16"] = use_fp16
475
  self.embedding_manager.recog = TextRecognizer(args, self.text_predictor)
476
 
477
  @torch.no_grad()