import os import torch import base64 import pickle import requests import argparse import nltk from nltk.tokenize import word_tokenize from functools import partial from transformers import ViTImageProcessor, ViTModel from transformers import AutoProcessor, LlavaForConditionalGeneration from sklearn.metrics.pairwise import cosine_similarity from PIL import Image import logging from tqdm import tqdm from openai import OpenAI from mplug_owl2.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN from mplug_owl2.conversation import conv_templates from mplug_owl2.model.builder import load_pretrained_model from mplug_owl2.mm_utils import ( process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria, ) from genception.utils import find_files logging.basicConfig(level=logging.INFO) client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) api_key = client.api_key nltk.download("punkt") device = "cuda" if torch.cuda.is_available() else "cpu" torch.backends.cudnn.enabled = False # VIT model vit_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k") vit_model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k") def image_embedding(image_file: str) -> list[float]: """ Generates an image embedding using a vit model Args: image_file: str: The path to the image file Returns: list[float]: The image embedding """ image = Image.open(image_file).convert("RGB") inputs = vit_processor(images=image, return_tensors="pt") outputs = vit_model(**inputs) return outputs.last_hidden_state.tolist()[0][0] def save_image_from_url(url: str, filename: str): """ Save an image from a given URL to a file Args: url: str: The URL of the image filename: str: The name of the file to save the image to """ response = requests.get(url) if response.status_code == 200: with open(filename, "wb") as file: file.write(response.content) else: logging.warning( f"Failed to download image. Status code: {response.status_code}" ) def find_image_files(folder_path: str) -> list[str]: image_extensions = {".jpg", ".png"} return find_files(folder_path, image_extensions) def count_words(text): words = word_tokenize(text) return len(words) def encode_image_os(image_path: str): image = Image.open(image_path).convert("RGB") return image def encode_image_gpt4v(image_path: str): with open(image_path, "rb") as image_file: return base64.b64encode(image_file.read()).decode("utf-8") def generate_xt( image_desc: str, output_folder: str, i: int, file_name: str, file_extension: str ) -> str: """ Generate an image based on a description using dall-e and save it to a file Args: image_desc: str: The description of the image output_folder: str: The path to the folder to save the image to i: int: The iteration number file_name: str: The name of the file file_extension: str: The extension of the file Returns: str: The path to the saved image file """ response = client.images.generate( model="dall-e-3", prompt="Generate an image that fully and precisely reflects this description: {}".format( image_desc ), size="1024x1024", quality="standard", n=1, ) new_image_filename = os.path.join( output_folder, f"{file_name}_{i}.{file_extension}" ) save_image_from_url(response.data[0].url, new_image_filename) return new_image_filename def get_desc_mPLUG(image, image_processor, lmm_model, tokenizer, prompt): """ Given an image, generate a description using the mPLUG model Args: image: Image: The image to describe image_processor: callable: The image processor lmm_model: The language model tokenizer: The tokenizer prompt: str: The prompt for the model Returns: str: The description of the image """ conv = conv_templates["mplug_owl2"].copy() max_edge = max(image.size) image = image.resize((max_edge, max_edge)) image_tensor = process_images([image], image_processor) image_tensor = image_tensor.to(lmm_model.device, dtype=torch.float16) inp = DEFAULT_IMAGE_TOKEN + prompt conv.append_message(conv.roles[0], inp) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() input_ids = ( tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") .unsqueeze(0) .to(lmm_model.device) ) stop_str = conv.sep2 keywords = [stop_str] stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) attention_mask = torch.ones_like(input_ids, dtype=torch.long) temperature = 0.001 max_new_tokens = 512 with torch.inference_mode(): output_ids = lmm_model.generate( input_ids, images=image_tensor, do_sample=True, temperature=temperature, max_new_tokens=max_new_tokens, stopping_criteria=[stopping_criteria], attention_mask=attention_mask, ) image_desc = tokenizer.decode( output_ids[0, input_ids.shape[1] :], skip_special_tokens=True ).strip() return image_desc def get_desc_llava(image, lmm_processor, lmm_model, prompt): """ Given an image, generate a description using the llava model Args: image: Image: The image to describe lmm_processor: callable: The language model processor lmm_model: The language model prompt: str: The prompt for the model Returns: str: The description of the image """ inputs = lmm_processor(text=prompt, images=image, return_tensors="pt").to(device) outputs = lmm_model.generate(**inputs, max_new_tokens=512, do_sample=False) answer = lmm_processor.batch_decode(outputs, skip_special_tokens=True)[0] image_desc = answer.split("ASSISTANT:")[1].strip() return image_desc def get_desc_gpt4v(image, prompt): """ Given an image, generate a description using the gpt-4-vision model Args: image: Image: The image to describe prompt: str: The prompt for the model Returns: str: The description of the image """ payload = { "model": "gpt-4-vision-preview", "messages": [ { "role": "user", "content": [ { "type": "text", "text": prompt, }, { "type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image}"}, }, ], } ], "max_tokens": 512, "temperature": 0, } headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} response = requests.post( "https://api.openai.com/v1/chat/completions", headers=headers, json=payload ) image_desc = response.json()["choices"][0]["message"]["content"] return image_desc def test_sample( seed_image: str, n_iteration: int, output_folder: str, get_desc_function: callable, encode_image_function: callable, ): """ Iteratively generates T (n_iterations) descriptions and images based on the seed image Args: seed_image: str: The path to the seed image n_iteration: int: The number of iterations to perform output_folder: str: The path to the folder to save the results get_desc_function: callable: The function to generate the description encode_image_function: callable: The function to encode the image """ list_of_desc = [] list_of_image = [] list_of_image_embedding = [image_embedding(seed_image)] list_of_cos_sim = [1.0] current_image_path = seed_image current_image_name = os.path.basename(current_image_path) file_name, file_extension = current_image_name.split(".") logging.debug(f"Image: {current_image_path}") pkl_file = os.path.join(output_folder, f"{file_name}_result.pkl") if os.path.exists(pkl_file): logging.info("Results already exist, skipping") return None for i in range(n_iteration): # Encode the current image and get the description image = encode_image_function(current_image_path) image_desc = get_desc_function(image) list_of_desc.append(image_desc) logging.debug(image_desc) # generate X^t, append image and embedding new_image_filename = generate_xt( image_desc, output_folder, i, file_name, file_extension ) list_of_image.append(new_image_filename) list_of_image_embedding.append(image_embedding(new_image_filename)) # Calculate Cosine Sim to original image similarity = cosine_similarity( [list_of_image_embedding[0]], [list_of_image_embedding[-1]] )[0][0] list_of_cos_sim.append(similarity) logging.info(f"({count_words(image_desc)}, {round(similarity,2)})") # Save checkpoint to avoid losing results data_to_save = { "descriptions": list_of_desc, "images": list_of_image, "image_embeddings": list_of_image_embedding, "cosine_similarities": list_of_cos_sim, } with open(pkl_file, "wb") as file: pickle.dump(data_to_save, file) # Update current_image_path for the next iteration current_image_path = new_image_filename return None def main(): parser = argparse.ArgumentParser() parser.add_argument("--dataset", type=str, default="mme_data/color") parser.add_argument("--model", type=str, default="llava7b") parser.add_argument("--n_iter", type=int, default=5) args = parser.parse_args() logging.info(args) prompt = "Please write a clear, precise, detailed, and concise description of all elements in the image. Focus on accurately depicting various aspects, including but not limited to the colors, shapes, positions, styles, texts and the relationships between different objects and subjects in the image. Your description should be thorough enough to guide a professional in recreating this image solely based on your textual representation. Remember, only include descriptive texts that directly pertain to the contents of the image. You must complete the description using less than 500 words." if "llava" in args.model: lmm_model = LlavaForConditionalGeneration.from_pretrained( f"llava-hf/llava-1.5-{args.model[5:]}-hf", load_in_8bit=True ) lmm_processor = AutoProcessor.from_pretrained( f"llava-hf/llava-1.5-{args.model[5:]}-hf" ) prompt = f"\nUSER: {prompt}\nASSISTANT:" get_desc_function = partial(get_desc_llava, lmm_processor, lmm_model, prompt) encode_image_function = encode_image_os elif args.model == "mPLUG": model_path = "MAGAer13/mplug-owl2-llama2-7b" model_name = get_model_name_from_path(model_path) tokenizer, lmm_model, image_processor, _ = load_pretrained_model( model_path, None, model_name, load_8bit=False, load_4bit=False, device=device, ) tokenizer.pad_token_id = tokenizer.eos_token_id tokenizer.pad_token = tokenizer.eos_token get_desc_function = partial( get_desc_mPLUG, image_processor, lmm_model, tokenizer, prompt ) encode_image_function = encode_image_os elif args.model == "gpt4v": get_desc_function = partial(get_desc_gpt4v, prompt=prompt) encode_image_function = encode_image_gpt4v output_folder = os.path.join(args.dataset, f"results_{args.model}") os.makedirs(output_folder, exist_ok=True) logging.debug("Loaded model. Entered main loop.") for img_file in tqdm(find_image_files(args.dataset)): try: logging.info(img_file) test_sample( seed_image=img_file, n_iteration=args.n_iter, output_folder=output_folder, get_desc_function=get_desc_function, encode_image_function=encode_image_function, ) except Exception as e: logging.warning("caught error:") logging.warning(e) continue if __name__ == "__main__": main()