AustingDong
commited on
Commit
·
8b5e432
1
Parent(s):
ee8653b
add evaluate
Browse files- app.py +1 -3
- evaluate/__init__.py +0 -0
- evaluate/evaluate.py +66 -0
- evaluate/questions.py +73 -0
app.py
CHANGED
@@ -64,9 +64,7 @@ def multimodal_understanding(model_type,
|
|
64 |
torch.cuda.ipc_collect()
|
65 |
|
66 |
# set seed
|
67 |
-
|
68 |
-
np.random.seed(seed)
|
69 |
-
torch.cuda.manual_seed(seed) if torch.cuda.is_available() else None
|
70 |
|
71 |
input_text_decoded = ""
|
72 |
answer = ""
|
|
|
64 |
torch.cuda.ipc_collect()
|
65 |
|
66 |
# set seed
|
67 |
+
set_seed(model_seed=seed)
|
|
|
|
|
68 |
|
69 |
input_text_decoded = ""
|
70 |
answer = ""
|
evaluate/__init__.py
ADDED
File without changes
|
evaluate/evaluate.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
from demo.model_utils import *
|
6 |
+
from evaluate.questions import questions
|
7 |
+
|
8 |
+
def set_seed(model_seed = 42):
|
9 |
+
torch.manual_seed(model_seed)
|
10 |
+
np.random.seed(model_seed)
|
11 |
+
torch.cuda.manual_seed(model_seed) if torch.cuda.is_available() else None
|
12 |
+
|
13 |
+
def evaluate(model_type, num_eval = 10):
|
14 |
+
for eval_idx in range(num_eval):
|
15 |
+
set_seed(np.random.randint(0, 1000))
|
16 |
+
model_utils, vl_gpt, tokenizer = None, None, None
|
17 |
+
|
18 |
+
if model_type.split('-')[0] == "Janus":
|
19 |
+
model_utils = Janus_Utils()
|
20 |
+
vl_gpt, tokenizer = model_utils.init_Janus(model_type.split('-')[-1])
|
21 |
+
|
22 |
+
elif model_type.split('-')[0] == "LLaVA":
|
23 |
+
model_utils = LLaVA_Utils()
|
24 |
+
version = model_type.split('-')[1]
|
25 |
+
vl_gpt, tokenizer = model_utils.init_LLaVA(version=version)
|
26 |
+
|
27 |
+
elif model_type.split('-')[0] == "ChartGemma":
|
28 |
+
model_utils = ChartGemma_Utils()
|
29 |
+
vl_gpt, tokenizer = model_utils.init_ChartGemma()
|
30 |
+
|
31 |
+
for question in questions:
|
32 |
+
chart_type = question[0]
|
33 |
+
q = question[1]
|
34 |
+
img_path = question[2]
|
35 |
+
image = np.array(Image.open(img_path).convert("RGB"))
|
36 |
+
|
37 |
+
prepare_inputs = model_utils.prepare_inputs(q, image)
|
38 |
+
temperature = 0.9
|
39 |
+
top_p = 0.1
|
40 |
+
|
41 |
+
if model_type.split('-')[0] == "Janus":
|
42 |
+
inputs_embeds = model_utils.generate_inputs_embeddings(prepare_inputs)
|
43 |
+
outputs = model_utils.generate_outputs(inputs_embeds, prepare_inputs, temperature, top_p)
|
44 |
+
else:
|
45 |
+
outputs = model_utils.generate_outputs(prepare_inputs, temperature, top_p)
|
46 |
+
|
47 |
+
sequences = outputs.sequences.cpu().tolist()
|
48 |
+
answer = tokenizer.decode(sequences[0], skip_special_tokens=True)
|
49 |
+
|
50 |
+
RESULTS_ROOT = "./evaluate/results"
|
51 |
+
FILES_ROOT = f"{RESULTS_ROOT}/{model_type}/{eval_idx}"
|
52 |
+
os.makedirs(FILES_ROOT, exist_ok=True)
|
53 |
+
|
54 |
+
with open(f"{FILES_ROOT}/{chart_type}.txt", "w") as f:
|
55 |
+
f.write(answer)
|
56 |
+
f.close()
|
57 |
+
|
58 |
+
|
59 |
+
|
60 |
+
if __name__ == '__main__':
|
61 |
+
|
62 |
+
# models = ["ChartGemma", "Janus-Pro-1B", "Janus-Pro-7B", "LLaVA-1.5-7B"]
|
63 |
+
# models = ["ChartGemma", "Janus-Pro-1B"]
|
64 |
+
models = ["Janus-Pro-7B", "LLaVA-1.5-7B"]
|
65 |
+
for model_type in models:
|
66 |
+
evaluate(model_type=model_type, num_eval=10)
|
evaluate/questions.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
questions=[
|
2 |
+
[
|
3 |
+
"LineChart",
|
4 |
+
"What was the price of a barrel of oil in February 2020?",
|
5 |
+
"images/LineChart.png"
|
6 |
+
],
|
7 |
+
|
8 |
+
[
|
9 |
+
"BarChart",
|
10 |
+
"What is the average internet speed in Japan?",
|
11 |
+
"images/BarChart.png"
|
12 |
+
],
|
13 |
+
|
14 |
+
[
|
15 |
+
"StackedBar",
|
16 |
+
"What is the cost of peanuts in Seoul?",
|
17 |
+
"images/StackedBar.png"
|
18 |
+
],
|
19 |
+
|
20 |
+
[
|
21 |
+
"100%StackedBar",
|
22 |
+
"Which country has the lowest proportion of Gold medals?",
|
23 |
+
"images/Stacked100.png"
|
24 |
+
],
|
25 |
+
|
26 |
+
[
|
27 |
+
"PieChart",
|
28 |
+
"What is the approximate global smartphone market share of Samsung?",
|
29 |
+
"images/PieChart.png"
|
30 |
+
],
|
31 |
+
|
32 |
+
[
|
33 |
+
"Histogram",
|
34 |
+
"What distance have customers traveled in the taxi the most?",
|
35 |
+
"images/Histogram.png"
|
36 |
+
],
|
37 |
+
|
38 |
+
[
|
39 |
+
"Scatterplot",
|
40 |
+
"True/False: There is a negative linear relationship between the height and the weight of the 85 males.",
|
41 |
+
"images/Scatterplot.png"
|
42 |
+
],
|
43 |
+
|
44 |
+
[
|
45 |
+
"AreaChart",
|
46 |
+
"What was the average price of pount of coffee beans in October 2019?",
|
47 |
+
"images/AreaChart.png"
|
48 |
+
],
|
49 |
+
|
50 |
+
[
|
51 |
+
"StackedArea",
|
52 |
+
"What was the ratio of girls named 'Isla' to girls named 'Amelia' in 2012 in the UK?",
|
53 |
+
"images/StackedArea.png"
|
54 |
+
],
|
55 |
+
|
56 |
+
[
|
57 |
+
"BubbleChart",
|
58 |
+
"Which city's metro system has the largest number of stations?",
|
59 |
+
"images/BubbleChart.png"
|
60 |
+
],
|
61 |
+
|
62 |
+
[
|
63 |
+
"Choropleth",
|
64 |
+
"True/False: In 2020, the unemployment rate for Washington (WA) was higher than that of Wisconsin (WI).",
|
65 |
+
"images/Choropleth_New.png"
|
66 |
+
],
|
67 |
+
|
68 |
+
[
|
69 |
+
"TreeMap",
|
70 |
+
"True/False: eBay is nested in the Software category.",
|
71 |
+
"images/TreeMap.png"
|
72 |
+
]
|
73 |
+
]
|