from transformers import AutoConfig, AutoModel from PIL import Image import torch import llava from datasets import load_dataset from nvila_lite_2b_dev.tokenizer_utils import tokenize_conversation import json from torch.utils.data import Dataset, default_collate from llava.mm_utils import process_image, process_images import llava import tqdm import torch.nn.functional as F import random DEFAULT_IMAGE_TOKEN = '' class EasyDataset(Dataset): def __init__( self, dataset, config, tokenizer, device='cuda', dtype=torch.float16 ): super().__init__() self.dataset = dataset self.config = config self.device = device self.dtype = dtype self.tokenizer = tokenizer def __len__(self): # return len(self.data_list) return self.n_samples def __getitem__(self, index): image = self.dataset[index]['image'] conversation = json.loads(self.dataset[index]['conversation'])[:2] images = process_image(image, self.config, None, enable_dynamic_res=True) conversation[0]["value"] = conversation[0]["value"].replace( DEFAULT_IMAGE_TOKEN, f"{DEFAULT_IMAGE_TOKEN}\n" * images.shape[0] ) input_ids = tokenize_conversation(conversation, self.tokenizer).unsqueeze(0) return [image for image in images], input_ids def main(): model_path = "Efficient-Large-Model/nvila_lite_2b_dev" config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) # print(config) device1 = torch.device("cuda:0") device2 = torch.device("cuda:1") model_hf = AutoModel.from_config(config, trust_remote_code=True,device='cuda:0').to(device1) model_vila = llava.load('Efficient-Large-Model/NVILA-Lite-2B', device='cuda:0').to(device2) parameter_names = list(dict(model_hf.named_parameters())) parameter_names_select = random.sample(parameter_names, 10) grad_diff = {} for name in parameter_names_select: grad_diff[name] = {"Grad_L1": [], "Grad_L2": []} config = model_hf.config config.image_processor = model_hf.vision_tower.image_processor dataset = load_dataset('Yirany/UniMM-Chat', split='train') image_text_dataset = EasyDataset(dataset, config, model_hf.tokenizer) results = {"L1_diff": [], "L2_diff": [], "Cosine_similarity": []} for item in tqdm.tqdm(image_text_dataset): media = {} media1 = {} media2 = {} media['image'], input_ids = item media1['image'] = [image.to(device1).half() for image in media['image']] media2['image'] = [image.to(device2).half() for image in media['image']] input_ids1 = input_ids.to(device1) labels1 = torch.randint(0, len(model_hf.tokenizer), input_ids1.shape, dtype=input_ids.dtype).to(device1) output1 = model_hf(input_ids=input_ids1, media=media1, labels=labels1) input_ids2 = input_ids.to(device2) labels2 = torch.randint(0, len(model_hf.tokenizer), input_ids2.shape, dtype=input_ids.dtype).to(device2) output2 = model_vila(input_ids=input_ids2, media=media2, labels=labels2) logits1 = output1.logits logits2 = output2.logits logits2 = logits2.to(device1) l1_diff = torch.nn.functional.l1_loss(logits1, logits2).item() l2_diff = torch.nn.functional.mse_loss(logits1, logits2).item() cosine_sim = F.cosine_similarity(logits1, logits2, dim=-1).mean().item() results["L1_diff"].append(l1_diff) results["L2_diff"].append(l2_diff) results["Cosine_similarity"].append(cosine_sim) loss1 = output1.loss loss2 = output2.loss loss1.backward(retain_graph=True) loss2.backward(retain_graph=True) for name in parameter_names_select: param1 = dict(model_hf.named_parameters())[name].grad param2 = dict(model_vila.named_parameters())[name].grad grad_l1 = F.l1_loss(param1, param2.to(device1)).item() grad_l2 = F.mse_loss(param1, param2.to(device1)).item() grad_diff[name]["Grad_L1"].append(grad_l1) grad_diff[name]["Grad_L2"].append(grad_l2) del param1, param2 del output1, output2, logits1, logits2, input_ids, input_ids1, input_ids2, media, media1, media2, labels1, labels2 del loss1, loss2 torch.cuda.empty_cache() model_hf.zero_grad() model_vila.zero_grad() if len(results["L1_diff"])>100: break for name in parameter_names_select: grad_diff[name] = {key: sum(values) / len(values) for key, values in grad_diff[name].items()} final_results = {key: sum(values) / len(values) for key, values in results.items()} for key, value in final_results.items(): print(f"{key}: {value:.6f}") for name in parameter_names_select: for key, value in grad_diff[name].items(): print(f"{name} {key}: {value:.6f}") if __name__ == "__main__": main()