LPX55's picture
fix: structruring
f7aaef1
raw
history blame
3.24 kB
import torch
from transformers import LlamaForCausalLM, PreTrainedTokenizerFast
from pipeline_hidream_image import HiDreamImagePipeline
from transformer_hidream_image import HiDreamImageTransformer2DModel
from schedulers.fm_solvers_unipc import FlowUniPCMultistepScheduler
from schedulers.flash_flow_match import FlashFlowMatchEulerDiscreteScheduler
MODEL_PREFIX = "azaneko"
LLAMA_MODEL_NAME = "hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4"
# Model configurations
MODEL_CONFIGS = {
"dev": {
"path": f"{MODEL_PREFIX}/HiDream-I1-Dev-nf4",
"guidance_scale": 0.0,
"num_inference_steps": 28,
"shift": 6.0,
"scheduler": FlashFlowMatchEulerDiscreteScheduler
},
"full": {
"path": f"{MODEL_PREFIX}/HiDream-I1-Full-nf4",
"guidance_scale": 5.0,
"num_inference_steps": 50,
"shift": 3.0,
"scheduler": FlowUniPCMultistepScheduler
},
"fast": {
"path": f"{MODEL_PREFIX}/HiDream-I1-Fast-nf4",
"guidance_scale": 0.0,
"num_inference_steps": 16,
"shift": 3.0,
"scheduler": FlashFlowMatchEulerDiscreteScheduler
}
}
def log_vram(msg: str):
print(f"{msg} (used {torch.cuda.memory_allocated() / 1024**2:.2f} MB VRAM)\n")
def load_models(model_type: str):
config = MODEL_CONFIGS[model_type]
tokenizer_4 = PreTrainedTokenizerFast.from_pretrained(LLAMA_MODEL_NAME)
log_vram("βœ… Tokenizer loaded!")
text_encoder_4 = LlamaForCausalLM.from_pretrained(
LLAMA_MODEL_NAME,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
torch_dtype=torch.bfloat16,
device_map="auto",
)
log_vram("βœ… Text encoder loaded!")
transformer = HiDreamImageTransformer2DModel.from_pretrained(
config["path"],
subfolder="transformer",
torch_dtype=torch.bfloat16
)
log_vram("βœ… Transformer loaded!")
pipe = HiDreamImagePipeline.from_pretrained(
config["path"],
scheduler=FlowUniPCMultistepScheduler(num_train_timesteps=1000, shift=config["shift"], use_dynamic_shifting=False),
tokenizer_4=tokenizer_4,
text_encoder_4=text_encoder_4,
torch_dtype=torch.bfloat16,
)
pipe.transformer = transformer
log_vram("βœ… Pipeline loaded!")
pipe.enable_sequential_cpu_offload()
return pipe, config
@torch.inference_mode()
def generate_image(pipe: HiDreamImagePipeline, model_type: str, prompt: str, resolution: tuple[int, int], seed: int):
# Get configuration for current model
config = MODEL_CONFIGS[model_type]
guidance_scale = config["guidance_scale"]
num_inference_steps = config["num_inference_steps"]
# Parse resolution
width, height = resolution
# Handle seed
if seed == -1:
seed = torch.randint(0, 1000000, (1,)).item()
generator = torch.Generator("cuda").manual_seed(seed)
images = pipe(
prompt,
height=height,
width=width,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
num_images_per_prompt=1,
generator=generator
).images
return images[0], seed