LPX55 commited on
Commit
75827f2
·
1 Parent(s): bd72bff

minor: debugging prints

Browse files
Files changed (1) hide show
  1. app.py +5 -1
app.py CHANGED
@@ -4,7 +4,7 @@ import gradio as gr
4
  import torch
5
  import logging
6
  from diffusers import DiffusionPipeline
7
- from transformers import LlamaForCausalLM, PreTrainedTokenizerFast
8
  from transformer_hidream_image import HiDreamImageTransformer2DModel
9
  from pipeline_hidream_image import HiDreamImagePipeline
10
  from schedulers.fm_solvers_unipc import FlowUniPCMultistepScheduler
@@ -12,6 +12,8 @@ from schedulers.flash_flow_match import FlashFlowMatchEulerDiscreteScheduler
12
 
13
  import subprocess
14
 
 
 
15
  try:
16
  print(subprocess.check_output(["nvcc", "--version"]).decode("utf-8"))
17
  except:
@@ -32,6 +34,7 @@ RESOLUTION_OPTIONS = [
32
  "1248 × 832 (Landscape)",
33
  "832 × 1248 (Portrait)"
34
  ]
 
35
 
36
  MODEL_PREFIX = "azaneko"
37
  LLAMA_MODEL_NAME = "hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4"
@@ -69,6 +72,7 @@ pipe = HiDreamImagePipeline.from_pretrained(
69
  tokenizer_4=tokenizer_4,
70
  text_encoder_4=text_encoder_4,
71
  torch_dtype=torch.bfloat16,
 
72
  )
73
  pipe.transformer = transformer
74
  log_vram("✅ Pipeline loaded!")
 
4
  import torch
5
  import logging
6
  from diffusers import DiffusionPipeline
7
+ from transformers import LlamaForCausalLM, PreTrainedTokenizerFast, BitsAndBytesConfig
8
  from transformer_hidream_image import HiDreamImageTransformer2DModel
9
  from pipeline_hidream_image import HiDreamImagePipeline
10
  from schedulers.fm_solvers_unipc import FlowUniPCMultistepScheduler
 
12
 
13
  import subprocess
14
 
15
+ print(f"Is CUDA available: {torch.cuda.is_available()}")
16
+ print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
17
  try:
18
  print(subprocess.check_output(["nvcc", "--version"]).decode("utf-8"))
19
  except:
 
34
  "1248 × 832 (Landscape)",
35
  "832 × 1248 (Portrait)"
36
  ]
37
+ quantization_config = BitsAndBytesConfig(load_in_4bit=True)
38
 
39
  MODEL_PREFIX = "azaneko"
40
  LLAMA_MODEL_NAME = "hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4"
 
72
  tokenizer_4=tokenizer_4,
73
  text_encoder_4=text_encoder_4,
74
  torch_dtype=torch.bfloat16,
75
+ quantization_config=quantization_config
76
  )
77
  pipe.transformer = transformer
78
  log_vram("✅ Pipeline loaded!")