Ahmadzei's picture
added 3 more tables for large emb model
5fa1a76
You can also extract and load the state_dict of the fp32 weights:
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
model = model.cpu()
model.load_state_dict(state_dict)
Offline
DeepSpeed provides a zero_to_fp32.py script at the top-level of the checkpoint folder for extracting weights at any point.