|
from transformers import AutoConfig, AutoModel |
|
from PIL import Image |
|
import torch |
|
import llava |
|
from datasets import load_dataset |
|
from nvila_lite_2b_dev.tokenizer_utils import tokenize_conversation |
|
import json |
|
from torch.utils.data import Dataset, default_collate |
|
from llava.mm_utils import process_image, process_images |
|
import llava |
|
import tqdm |
|
import torch.nn.functional as F |
|
import random |
|
|
|
DEFAULT_IMAGE_TOKEN = '<image>' |
|
|
|
class EasyDataset(Dataset): |
|
def __init__( |
|
self, |
|
dataset, |
|
config, |
|
tokenizer, |
|
device='cuda', |
|
dtype=torch.float16 |
|
): |
|
super().__init__() |
|
self.dataset = dataset |
|
self.config = config |
|
self.device = device |
|
self.dtype = dtype |
|
self.tokenizer = tokenizer |
|
|
|
def __len__(self): |
|
|
|
return self.n_samples |
|
|
|
def __getitem__(self, index): |
|
image = self.dataset[index]['image'] |
|
conversation = json.loads(self.dataset[index]['conversation'])[:2] |
|
images = process_image(image, self.config, None, enable_dynamic_res=True) |
|
conversation[0]["value"] = conversation[0]["value"].replace( |
|
DEFAULT_IMAGE_TOKEN, f"{DEFAULT_IMAGE_TOKEN}\n" * images.shape[0] |
|
) |
|
input_ids = tokenize_conversation(conversation, self.tokenizer).unsqueeze(0) |
|
return [image for image in images], input_ids |
|
|
|
def main(): |
|
model_path = "Efficient-Large-Model/nvila_lite_2b_dev" |
|
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) |
|
|
|
device1 = torch.device("cuda:0") |
|
device2 = torch.device("cuda:1") |
|
model_hf = AutoModel.from_config(config, trust_remote_code=True,device='cuda:0').to(device1) |
|
model_vila = llava.load('Efficient-Large-Model/NVILA-Lite-2B', device='cuda:0').to(device2) |
|
|
|
parameter_names = list(dict(model_hf.named_parameters())) |
|
parameter_names_select = random.sample(parameter_names, 10) |
|
grad_diff = {} |
|
for name in parameter_names_select: |
|
grad_diff[name] = {"Grad_L1": [], "Grad_L2": []} |
|
|
|
config = model_hf.config |
|
config.image_processor = model_hf.vision_tower.image_processor |
|
dataset = load_dataset('Yirany/UniMM-Chat', split='train') |
|
image_text_dataset = EasyDataset(dataset, config, model_hf.tokenizer) |
|
results = {"L1_diff": [], "L2_diff": [], "Cosine_similarity": []} |
|
|
|
for item in tqdm.tqdm(image_text_dataset): |
|
media = {} |
|
media1 = {} |
|
media2 = {} |
|
media['image'], input_ids = item |
|
media1['image'] = [image.to(device1).half() for image in media['image']] |
|
media2['image'] = [image.to(device2).half() for image in media['image']] |
|
|
|
input_ids1 = input_ids.to(device1) |
|
labels1 = torch.randint(0, len(model_hf.tokenizer), input_ids1.shape, dtype=input_ids.dtype).to(device1) |
|
output1 = model_hf(input_ids=input_ids1, media=media1, labels=labels1) |
|
|
|
input_ids2 = input_ids.to(device2) |
|
labels2 = torch.randint(0, len(model_hf.tokenizer), input_ids2.shape, dtype=input_ids.dtype).to(device2) |
|
output2 = model_vila(input_ids=input_ids2, media=media2, labels=labels2) |
|
|
|
logits1 = output1.logits |
|
logits2 = output2.logits |
|
|
|
logits2 = logits2.to(device1) |
|
l1_diff = torch.nn.functional.l1_loss(logits1, logits2).item() |
|
l2_diff = torch.nn.functional.mse_loss(logits1, logits2).item() |
|
cosine_sim = F.cosine_similarity(logits1, logits2, dim=-1).mean().item() |
|
|
|
results["L1_diff"].append(l1_diff) |
|
results["L2_diff"].append(l2_diff) |
|
results["Cosine_similarity"].append(cosine_sim) |
|
|
|
loss1 = output1.loss |
|
loss2 = output2.loss |
|
loss1.backward(retain_graph=True) |
|
loss2.backward(retain_graph=True) |
|
|
|
for name in parameter_names_select: |
|
param1 = dict(model_hf.named_parameters())[name].grad |
|
param2 = dict(model_vila.named_parameters())[name].grad |
|
grad_l1 = F.l1_loss(param1, param2.to(device1)).item() |
|
grad_l2 = F.mse_loss(param1, param2.to(device1)).item() |
|
|
|
grad_diff[name]["Grad_L1"].append(grad_l1) |
|
grad_diff[name]["Grad_L2"].append(grad_l2) |
|
del param1, param2 |
|
|
|
del output1, output2, logits1, logits2, input_ids, input_ids1, input_ids2, media, media1, media2, labels1, labels2 |
|
del loss1, loss2 |
|
torch.cuda.empty_cache() |
|
|
|
model_hf.zero_grad() |
|
model_vila.zero_grad() |
|
|
|
if len(results["L1_diff"])>100: |
|
break |
|
|
|
for name in parameter_names_select: |
|
grad_diff[name] = {key: sum(values) / len(values) for key, values in grad_diff[name].items()} |
|
final_results = {key: sum(values) / len(values) for key, values in results.items()} |
|
|
|
for key, value in final_results.items(): |
|
print(f"{key}: {value:.6f}") |
|
|
|
for name in parameter_names_select: |
|
for key, value in grad_diff[name].items(): |
|
print(f"{name} {key}: {value:.6f}") |
|
if __name__ == "__main__": |
|
main() |
|
|