AustingDong commited on
Commit
8b5e432
·
1 Parent(s): ee8653b

add evaluate

Browse files
Files changed (4) hide show
  1. app.py +1 -3
  2. evaluate/__init__.py +0 -0
  3. evaluate/evaluate.py +66 -0
  4. evaluate/questions.py +73 -0
app.py CHANGED
@@ -64,9 +64,7 @@ def multimodal_understanding(model_type,
64
  torch.cuda.ipc_collect()
65
 
66
  # set seed
67
- torch.manual_seed(seed)
68
- np.random.seed(seed)
69
- torch.cuda.manual_seed(seed) if torch.cuda.is_available() else None
70
 
71
  input_text_decoded = ""
72
  answer = ""
 
64
  torch.cuda.ipc_collect()
65
 
66
  # set seed
67
+ set_seed(model_seed=seed)
 
 
68
 
69
  input_text_decoded = ""
70
  answer = ""
evaluate/__init__.py ADDED
File without changes
evaluate/evaluate.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ from demo.model_utils import *
6
+ from evaluate.questions import questions
7
+
8
+ def set_seed(model_seed = 42):
9
+ torch.manual_seed(model_seed)
10
+ np.random.seed(model_seed)
11
+ torch.cuda.manual_seed(model_seed) if torch.cuda.is_available() else None
12
+
13
+ def evaluate(model_type, num_eval = 10):
14
+ for eval_idx in range(num_eval):
15
+ set_seed(np.random.randint(0, 1000))
16
+ model_utils, vl_gpt, tokenizer = None, None, None
17
+
18
+ if model_type.split('-')[0] == "Janus":
19
+ model_utils = Janus_Utils()
20
+ vl_gpt, tokenizer = model_utils.init_Janus(model_type.split('-')[-1])
21
+
22
+ elif model_type.split('-')[0] == "LLaVA":
23
+ model_utils = LLaVA_Utils()
24
+ version = model_type.split('-')[1]
25
+ vl_gpt, tokenizer = model_utils.init_LLaVA(version=version)
26
+
27
+ elif model_type.split('-')[0] == "ChartGemma":
28
+ model_utils = ChartGemma_Utils()
29
+ vl_gpt, tokenizer = model_utils.init_ChartGemma()
30
+
31
+ for question in questions:
32
+ chart_type = question[0]
33
+ q = question[1]
34
+ img_path = question[2]
35
+ image = np.array(Image.open(img_path).convert("RGB"))
36
+
37
+ prepare_inputs = model_utils.prepare_inputs(q, image)
38
+ temperature = 0.9
39
+ top_p = 0.1
40
+
41
+ if model_type.split('-')[0] == "Janus":
42
+ inputs_embeds = model_utils.generate_inputs_embeddings(prepare_inputs)
43
+ outputs = model_utils.generate_outputs(inputs_embeds, prepare_inputs, temperature, top_p)
44
+ else:
45
+ outputs = model_utils.generate_outputs(prepare_inputs, temperature, top_p)
46
+
47
+ sequences = outputs.sequences.cpu().tolist()
48
+ answer = tokenizer.decode(sequences[0], skip_special_tokens=True)
49
+
50
+ RESULTS_ROOT = "./evaluate/results"
51
+ FILES_ROOT = f"{RESULTS_ROOT}/{model_type}/{eval_idx}"
52
+ os.makedirs(FILES_ROOT, exist_ok=True)
53
+
54
+ with open(f"{FILES_ROOT}/{chart_type}.txt", "w") as f:
55
+ f.write(answer)
56
+ f.close()
57
+
58
+
59
+
60
+ if __name__ == '__main__':
61
+
62
+ # models = ["ChartGemma", "Janus-Pro-1B", "Janus-Pro-7B", "LLaVA-1.5-7B"]
63
+ # models = ["ChartGemma", "Janus-Pro-1B"]
64
+ models = ["Janus-Pro-7B", "LLaVA-1.5-7B"]
65
+ for model_type in models:
66
+ evaluate(model_type=model_type, num_eval=10)
evaluate/questions.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ questions=[
2
+ [
3
+ "LineChart",
4
+ "What was the price of a barrel of oil in February 2020?",
5
+ "images/LineChart.png"
6
+ ],
7
+
8
+ [
9
+ "BarChart",
10
+ "What is the average internet speed in Japan?",
11
+ "images/BarChart.png"
12
+ ],
13
+
14
+ [
15
+ "StackedBar",
16
+ "What is the cost of peanuts in Seoul?",
17
+ "images/StackedBar.png"
18
+ ],
19
+
20
+ [
21
+ "100%StackedBar",
22
+ "Which country has the lowest proportion of Gold medals?",
23
+ "images/Stacked100.png"
24
+ ],
25
+
26
+ [
27
+ "PieChart",
28
+ "What is the approximate global smartphone market share of Samsung?",
29
+ "images/PieChart.png"
30
+ ],
31
+
32
+ [
33
+ "Histogram",
34
+ "What distance have customers traveled in the taxi the most?",
35
+ "images/Histogram.png"
36
+ ],
37
+
38
+ [
39
+ "Scatterplot",
40
+ "True/False: There is a negative linear relationship between the height and the weight of the 85 males.",
41
+ "images/Scatterplot.png"
42
+ ],
43
+
44
+ [
45
+ "AreaChart",
46
+ "What was the average price of pount of coffee beans in October 2019?",
47
+ "images/AreaChart.png"
48
+ ],
49
+
50
+ [
51
+ "StackedArea",
52
+ "What was the ratio of girls named 'Isla' to girls named 'Amelia' in 2012 in the UK?",
53
+ "images/StackedArea.png"
54
+ ],
55
+
56
+ [
57
+ "BubbleChart",
58
+ "Which city's metro system has the largest number of stations?",
59
+ "images/BubbleChart.png"
60
+ ],
61
+
62
+ [
63
+ "Choropleth",
64
+ "True/False: In 2020, the unemployment rate for Washington (WA) was higher than that of Wisconsin (WI).",
65
+ "images/Choropleth_New.png"
66
+ ],
67
+
68
+ [
69
+ "TreeMap",
70
+ "True/False: eBay is nested in the Software category.",
71
+ "images/TreeMap.png"
72
+ ]
73
+ ]