Valentin Buchner commited on
Commit
c951815
·
1 Parent(s): dd1b905

remove genception to reduce chance of conflict

Browse files
genception/evaluation.py DELETED
@@ -1,101 +0,0 @@
1
- import os
2
- import json
3
- import pickle
4
- import numpy as np
5
- import argparse
6
- from genception.utils import find_files
7
-
8
-
9
- def read_all_pkl(folder_path: str) -> dict:
10
- """
11
- Read all the pickle files in the given folder path
12
-
13
- Args:
14
- folder_path: str: The path to the folder
15
-
16
- Returns:
17
- dict: The dictionary containing the file path as key and the pickle file content as value
18
- """
19
- result_dict = dict()
20
- file_list = find_files(folder_path, {".pkl"})
21
- for file_path in file_list:
22
- with open(file_path, "rb") as file:
23
- result_dict[file_path] = pickle.load(file)
24
- return result_dict
25
-
26
-
27
- def integrated_decay_area(scores: list[float]) -> float:
28
- """
29
- Calculate the Integrated Decay Area (IDA) for the given scores
30
-
31
- Args:
32
- scores: list[float]: The list of scores
33
-
34
- Returns:
35
- float: The IDA score
36
- """
37
- total_area = 0
38
-
39
- for i, score in enumerate(scores):
40
- total_area += (i + 1) * score
41
-
42
- max_possible_area = sum(range(1, len(scores) + 1))
43
- ida = total_area / max_possible_area if max_possible_area else 0
44
- return ida
45
-
46
-
47
- def gc_score(folder_path: str, n_iter: int = None) -> tuple[float, list[float]]:
48
- """
49
- Calculate the GC@T score for the given folder path
50
-
51
- Args:
52
- folder_path: str: The path to the folder
53
- n_iter: int: The number of iterations to consider for GC@T score
54
-
55
- Returns:
56
- tuple[float, list[float]]: The GC@T score and the list of GC scores for each file
57
- """
58
- test_data = read_all_pkl(folder_path)
59
- all_gc_scores = []
60
- for _, value in test_data.items():
61
- sim_score = value["cosine_similarities"][1:]
62
- if n_iter is None:
63
- _gc = integrated_decay_area(sim_score)
64
- else:
65
- if len(value["cosine_similarities"]) >= n_iter:
66
- _gc = integrated_decay_area(sim_score[:n_iter])
67
- else:
68
- continue
69
- all_gc_scores.append(_gc)
70
- return np.mean(all_gc_scores), all_gc_scores
71
-
72
-
73
- def main():
74
- parser = argparse.ArgumentParser()
75
- parser.add_argument(
76
- "--results_path",
77
- type=str,
78
- help="Path to the folder containing the pickle files",
79
- required=True,
80
- )
81
- parser.add_argument(
82
- "--t",
83
- type=int,
84
- help="Number of iterations to consider for GC@T score",
85
- required=True,
86
- )
87
- args = parser.parse_args()
88
-
89
- # calculate GC@T score and save in results directory
90
- gc, all_gc_scores = gc_score(args.results_path, args.t)
91
- result = {
92
- "GC Score": gc,
93
- "All GC Scores": all_gc_scores,
94
- }
95
- results_path = os.path.join(args.results_path, f"GC@{str(args.t)}.json")
96
- with open(results_path, "w") as file:
97
- json.dump(result, file)
98
-
99
-
100
- if __name__ == "__main__":
101
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
genception/example_script.sh DELETED
@@ -1,8 +0,0 @@
1
- # run experiment with gpt4v on examples dataset
2
- python genception/experiment.py --model gpt4v --dataset datasets/examples
3
-
4
-
5
- # Calculate GC@T evaluation metric
6
- python genception/evaluation.py --results_path datasets/examples/results_gpt4v --t 1
7
- python genception/evaluation.py --results_path datasets/examples/results_gpt4v --t 3
8
- python genception/evaluation.py --results_path datasets/examples/results_gpt4v --t 5
 
 
 
 
 
 
 
 
 
genception/experiment.py DELETED
@@ -1,373 +0,0 @@
1
- import os
2
- import torch
3
- import base64
4
- import pickle
5
- import requests
6
- import argparse
7
- import nltk
8
- from nltk.tokenize import word_tokenize
9
- from functools import partial
10
- from transformers import ViTImageProcessor, ViTModel
11
- from transformers import AutoProcessor, LlavaForConditionalGeneration
12
- from sklearn.metrics.pairwise import cosine_similarity
13
- from PIL import Image
14
- import logging
15
- from tqdm import tqdm
16
- from openai import OpenAI
17
- from mplug_owl2.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
18
- from mplug_owl2.conversation import conv_templates
19
- from mplug_owl2.model.builder import load_pretrained_model
20
- from mplug_owl2.mm_utils import (
21
- process_images,
22
- tokenizer_image_token,
23
- get_model_name_from_path,
24
- KeywordsStoppingCriteria,
25
- )
26
- from genception.utils import find_files
27
-
28
- logging.basicConfig(level=logging.INFO)
29
- client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
30
- api_key = client.api_key
31
- nltk.download("punkt")
32
- device = "cuda" if torch.cuda.is_available() else "cpu"
33
- torch.backends.cudnn.enabled = False
34
-
35
- # VIT model
36
- vit_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
37
- vit_model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
38
-
39
-
40
- def image_embedding(image_file: str) -> list[float]:
41
- """
42
- Generates an image embedding using a vit model
43
-
44
- Args:
45
- image_file: str: The path to the image file
46
-
47
- Returns:
48
- list[float]: The image embedding
49
- """
50
- image = Image.open(image_file).convert("RGB")
51
- inputs = vit_processor(images=image, return_tensors="pt")
52
- outputs = vit_model(**inputs)
53
- return outputs.last_hidden_state.tolist()[0][0]
54
-
55
-
56
- def save_image_from_url(url: str, filename: str):
57
- """
58
- Save an image from a given URL to a file
59
-
60
- Args:
61
- url: str: The URL of the image
62
- filename: str: The name of the file to save the image to
63
- """
64
- response = requests.get(url)
65
- if response.status_code == 200:
66
- with open(filename, "wb") as file:
67
- file.write(response.content)
68
- else:
69
- logging.warning(
70
- f"Failed to download image. Status code: {response.status_code}"
71
- )
72
-
73
-
74
- def find_image_files(folder_path: str) -> list[str]:
75
- image_extensions = {".jpg", ".png"}
76
- return find_files(folder_path, image_extensions)
77
-
78
-
79
- def count_words(text):
80
- words = word_tokenize(text)
81
- return len(words)
82
-
83
-
84
- def encode_image_os(image_path: str):
85
- image = Image.open(image_path).convert("RGB")
86
- return image
87
-
88
-
89
- def encode_image_gpt4v(image_path: str):
90
- with open(image_path, "rb") as image_file:
91
- return base64.b64encode(image_file.read()).decode("utf-8")
92
-
93
-
94
- def generate_xt(
95
- image_desc: str, output_folder: str, i: int, file_name: str, file_extension: str
96
- ) -> str:
97
- """
98
- Generate an image based on a description using dall-e and save it to a file
99
-
100
- Args:
101
- image_desc: str: The description of the image
102
- output_folder: str: The path to the folder to save the image to
103
- i: int: The iteration number
104
- file_name: str: The name of the file
105
- file_extension: str: The extension of the file
106
-
107
- Returns:
108
- str: The path to the saved image file
109
- """
110
- response = client.images.generate(
111
- model="dall-e-3",
112
- prompt="Generate an image that fully and precisely reflects this description: {}".format(
113
- image_desc
114
- ),
115
- size="1024x1024",
116
- quality="standard",
117
- n=1,
118
- )
119
- new_image_filename = os.path.join(
120
- output_folder, f"{file_name}_{i}.{file_extension}"
121
- )
122
- save_image_from_url(response.data[0].url, new_image_filename)
123
- return new_image_filename
124
-
125
-
126
- def get_desc_mPLUG(image, image_processor, lmm_model, tokenizer, prompt):
127
- """
128
- Given an image, generate a description using the mPLUG model
129
-
130
- Args:
131
- image: Image: The image to describe
132
- image_processor: callable: The image processor
133
- lmm_model: The language model
134
- tokenizer: The tokenizer
135
- prompt: str: The prompt for the model
136
-
137
- Returns:
138
- str: The description of the image
139
- """
140
- conv = conv_templates["mplug_owl2"].copy()
141
- max_edge = max(image.size)
142
- image = image.resize((max_edge, max_edge))
143
- image_tensor = process_images([image], image_processor)
144
- image_tensor = image_tensor.to(lmm_model.device, dtype=torch.float16)
145
-
146
- inp = DEFAULT_IMAGE_TOKEN + prompt
147
- conv.append_message(conv.roles[0], inp)
148
- conv.append_message(conv.roles[1], None)
149
- prompt = conv.get_prompt()
150
-
151
- input_ids = (
152
- tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
153
- .unsqueeze(0)
154
- .to(lmm_model.device)
155
- )
156
- stop_str = conv.sep2
157
- keywords = [stop_str]
158
- stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
159
- attention_mask = torch.ones_like(input_ids, dtype=torch.long)
160
-
161
- temperature = 0.001
162
- max_new_tokens = 512
163
-
164
- with torch.inference_mode():
165
- output_ids = lmm_model.generate(
166
- input_ids,
167
- images=image_tensor,
168
- do_sample=True,
169
- temperature=temperature,
170
- max_new_tokens=max_new_tokens,
171
- stopping_criteria=[stopping_criteria],
172
- attention_mask=attention_mask,
173
- )
174
-
175
- image_desc = tokenizer.decode(
176
- output_ids[0, input_ids.shape[1] :], skip_special_tokens=True
177
- ).strip()
178
- return image_desc
179
-
180
-
181
- def get_desc_llava(image, lmm_processor, lmm_model, prompt):
182
- """
183
- Given an image, generate a description using the llava model
184
-
185
- Args:
186
- image: Image: The image to describe
187
- lmm_processor: callable: The language model processor
188
- lmm_model: The language model
189
- prompt: str: The prompt for the model
190
-
191
- Returns:
192
- str: The description of the image
193
- """
194
- inputs = lmm_processor(text=prompt, images=image, return_tensors="pt").to(device)
195
- outputs = lmm_model.generate(**inputs, max_new_tokens=512, do_sample=False)
196
- answer = lmm_processor.batch_decode(outputs, skip_special_tokens=True)[0]
197
- image_desc = answer.split("ASSISTANT:")[1].strip()
198
- return image_desc
199
-
200
-
201
- def get_desc_gpt4v(image, prompt):
202
- """
203
- Given an image, generate a description using the gpt-4-vision model
204
-
205
- Args:
206
- image: Image: The image to describe
207
- prompt: str: The prompt for the model
208
-
209
- Returns:
210
- str: The description of the image
211
- """
212
- payload = {
213
- "model": "gpt-4-vision-preview",
214
- "messages": [
215
- {
216
- "role": "user",
217
- "content": [
218
- {
219
- "type": "text",
220
- "text": prompt,
221
- },
222
- {
223
- "type": "image_url",
224
- "image_url": {"url": f"data:image/jpeg;base64,{image}"},
225
- },
226
- ],
227
- }
228
- ],
229
- "max_tokens": 512,
230
- "temperature": 0,
231
- }
232
-
233
- headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
234
-
235
- response = requests.post(
236
- "https://api.openai.com/v1/chat/completions", headers=headers, json=payload
237
- )
238
- image_desc = response.json()["choices"][0]["message"]["content"]
239
- return image_desc
240
-
241
-
242
- def test_sample(
243
- seed_image: str,
244
- n_iteration: int,
245
- output_folder: str,
246
- get_desc_function: callable,
247
- encode_image_function: callable,
248
- ):
249
- """
250
- Iteratively generates T (n_iterations) descriptions and images based on the seed image
251
-
252
- Args:
253
- seed_image: str: The path to the seed image
254
- n_iteration: int: The number of iterations to perform
255
- output_folder: str: The path to the folder to save the results
256
- get_desc_function: callable: The function to generate the description
257
- encode_image_function: callable: The function to encode the image
258
- """
259
- list_of_desc = []
260
- list_of_image = []
261
- list_of_image_embedding = [image_embedding(seed_image)]
262
- list_of_cos_sim = [1.0]
263
-
264
- current_image_path = seed_image
265
- current_image_name = os.path.basename(current_image_path)
266
- file_name, file_extension = current_image_name.split(".")
267
- logging.debug(f"Image: {current_image_path}")
268
- pkl_file = os.path.join(output_folder, f"{file_name}_result.pkl")
269
- if os.path.exists(pkl_file):
270
- logging.info("Results already exist, skipping")
271
- return None
272
-
273
- for i in range(n_iteration):
274
- # Encode the current image and get the description
275
- image = encode_image_function(current_image_path)
276
- image_desc = get_desc_function(image)
277
- list_of_desc.append(image_desc)
278
- logging.debug(image_desc)
279
-
280
- # generate X^t, append image and embedding
281
- new_image_filename = generate_xt(
282
- image_desc, output_folder, i, file_name, file_extension
283
- )
284
- list_of_image.append(new_image_filename)
285
- list_of_image_embedding.append(image_embedding(new_image_filename))
286
-
287
- # Calculate Cosine Sim to original image
288
- similarity = cosine_similarity(
289
- [list_of_image_embedding[0]], [list_of_image_embedding[-1]]
290
- )[0][0]
291
- list_of_cos_sim.append(similarity)
292
- logging.info(f"({count_words(image_desc)}, {round(similarity,2)})")
293
-
294
- # Save checkpoint to avoid losing results
295
- data_to_save = {
296
- "descriptions": list_of_desc,
297
- "images": list_of_image,
298
- "image_embeddings": list_of_image_embedding,
299
- "cosine_similarities": list_of_cos_sim,
300
- }
301
- with open(pkl_file, "wb") as file:
302
- pickle.dump(data_to_save, file)
303
-
304
- # Update current_image_path for the next iteration
305
- current_image_path = new_image_filename
306
-
307
- return None
308
-
309
-
310
- def main():
311
- parser = argparse.ArgumentParser()
312
- parser.add_argument("--dataset", type=str, default="mme_data/color")
313
- parser.add_argument("--model", type=str, default="llava7b")
314
- parser.add_argument("--n_iter", type=int, default=5)
315
- args = parser.parse_args()
316
-
317
- logging.info(args)
318
-
319
- prompt = "Please write a clear, precise, detailed, and concise description of all elements in the image. Focus on accurately depicting various aspects, including but not limited to the colors, shapes, positions, styles, texts and the relationships between different objects and subjects in the image. Your description should be thorough enough to guide a professional in recreating this image solely based on your textual representation. Remember, only include descriptive texts that directly pertain to the contents of the image. You must complete the description using less than 500 words."
320
-
321
- if "llava" in args.model:
322
- lmm_model = LlavaForConditionalGeneration.from_pretrained(
323
- f"llava-hf/llava-1.5-{args.model[5:]}-hf", load_in_8bit=True
324
- )
325
- lmm_processor = AutoProcessor.from_pretrained(
326
- f"llava-hf/llava-1.5-{args.model[5:]}-hf"
327
- )
328
- prompt = f"<image>\nUSER: {prompt}\nASSISTANT:"
329
- get_desc_function = partial(get_desc_llava, lmm_processor, lmm_model, prompt)
330
- encode_image_function = encode_image_os
331
- elif args.model == "mPLUG":
332
- model_path = "MAGAer13/mplug-owl2-llama2-7b"
333
- model_name = get_model_name_from_path(model_path)
334
- tokenizer, lmm_model, image_processor, _ = load_pretrained_model(
335
- model_path,
336
- None,
337
- model_name,
338
- load_8bit=False,
339
- load_4bit=False,
340
- device=device,
341
- )
342
- tokenizer.pad_token_id = tokenizer.eos_token_id
343
- tokenizer.pad_token = tokenizer.eos_token
344
- get_desc_function = partial(
345
- get_desc_mPLUG, image_processor, lmm_model, tokenizer, prompt
346
- )
347
- encode_image_function = encode_image_os
348
- elif args.model == "gpt4v":
349
- get_desc_function = partial(get_desc_gpt4v, prompt=prompt)
350
- encode_image_function = encode_image_gpt4v
351
-
352
- output_folder = os.path.join(args.dataset, f"results_{args.model}")
353
- os.makedirs(output_folder, exist_ok=True)
354
-
355
- logging.debug("Loaded model. Entered main loop.")
356
- for img_file in tqdm(find_image_files(args.dataset)):
357
- try:
358
- logging.info(img_file)
359
- test_sample(
360
- seed_image=img_file,
361
- n_iteration=args.n_iter,
362
- output_folder=output_folder,
363
- get_desc_function=get_desc_function,
364
- encode_image_function=encode_image_function,
365
- )
366
- except Exception as e:
367
- logging.warning("caught error:")
368
- logging.warning(e)
369
- continue
370
-
371
-
372
- if __name__ == "__main__":
373
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
genception/utils.py DELETED
@@ -1,25 +0,0 @@
1
- import os
2
-
3
-
4
- def find_files(folder_path: str, file_extensions: dict) -> list[str]:
5
- """
6
- Find all files with the given extensions in the given folder path
7
-
8
- Args:
9
- folder_path: str: The path to the folder
10
- file_extensions: dict: The file extensions to look for
11
-
12
- Returns:
13
- list[str]: The list of file paths
14
- """
15
- file_paths = []
16
-
17
- for file in os.listdir(folder_path):
18
- if (
19
- os.path.isfile(os.path.join(folder_path, file))
20
- and os.path.splitext(file)[1].lower() in file_extensions
21
- ):
22
- absolute_path = os.path.abspath(os.path.join(folder_path, file))
23
- file_paths.append(absolute_path)
24
-
25
- return file_paths