AustingDong
commited on
Commit
·
6d117d1
1
Parent(s):
7e57874
customed a loss (useless)
Browse files- demo/visualization.py +11 -0
- evaluate/evaluate.py +7 -9
- evaluate/questions.py +12 -12
demo/visualization.py
CHANGED
@@ -406,6 +406,16 @@ class VisualizationChartGemma(Visualization):
|
|
406 |
super().__init__(model, register=True)
|
407 |
self._modify_layers()
|
408 |
self._register_hooks_activations()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
409 |
|
410 |
def forward_backward(self, inputs, focus, start_idx, target_token_idx, visual_method="softmax"):
|
411 |
outputs_raw = self.model(**inputs, output_hidden_states=True)
|
@@ -421,6 +431,7 @@ class VisualizationChartGemma(Visualization):
|
|
421 |
print("logits shape:", outputs_raw.logits.shape)
|
422 |
if target_token_idx == -1:
|
423 |
loss = outputs_raw.logits.max(dim=-1).values.sum()
|
|
|
424 |
else:
|
425 |
loss = outputs_raw.logits.max(dim=-1).values[0, start_idx + target_token_idx]
|
426 |
loss.backward()
|
|
|
406 |
super().__init__(model, register=True)
|
407 |
self._modify_layers()
|
408 |
self._register_hooks_activations()
|
409 |
+
|
410 |
+
# def custom_loss(self, start_idx, input_ids, logits):
|
411 |
+
# Q = logits.shape[1]
|
412 |
+
# loss = 0
|
413 |
+
# q = 0
|
414 |
+
# while start_idx + q < Q - 1:
|
415 |
+
# loss += F.cross_entropy(logits[0, start_idx + q], input_ids[0, start_idx + q + 1])
|
416 |
+
# q += 1
|
417 |
+
# return loss
|
418 |
+
|
419 |
|
420 |
def forward_backward(self, inputs, focus, start_idx, target_token_idx, visual_method="softmax"):
|
421 |
outputs_raw = self.model(**inputs, output_hidden_states=True)
|
|
|
431 |
print("logits shape:", outputs_raw.logits.shape)
|
432 |
if target_token_idx == -1:
|
433 |
loss = outputs_raw.logits.max(dim=-1).values.sum()
|
434 |
+
# loss = self.custom_loss(start_idx, inputs['input_ids'], outputs_raw.logits)
|
435 |
else:
|
436 |
loss = outputs_raw.logits.max(dim=-1).values[0, start_idx + target_token_idx]
|
437 |
loss.backward()
|
evaluate/evaluate.py
CHANGED
@@ -7,9 +7,9 @@ from openai import OpenAI
|
|
7 |
from demo.model_utils import *
|
8 |
from evaluate.questions import questions
|
9 |
|
10 |
-
def set_seed(model_seed =
|
11 |
torch.manual_seed(model_seed)
|
12 |
-
np.random.seed(model_seed)
|
13 |
torch.cuda.manual_seed(model_seed) if torch.cuda.is_available() else None
|
14 |
|
15 |
def clean():
|
@@ -52,7 +52,7 @@ def evaluate(model_type, num_eval = 10):
|
|
52 |
client = OpenAI(api_key=os.environ["GEMINI_HCI_API_KEY"],
|
53 |
base_url="https://generativelanguage.googleapis.com/v1beta/openai/")
|
54 |
|
55 |
-
for question in questions:
|
56 |
chart_type = question[0]
|
57 |
q = question[1]
|
58 |
img_path = question[2]
|
@@ -104,8 +104,8 @@ def evaluate(model_type, num_eval = 10):
|
|
104 |
|
105 |
else:
|
106 |
prepare_inputs = model_utils.prepare_inputs(q, image)
|
107 |
-
temperature = 0.
|
108 |
-
top_p = 0.
|
109 |
|
110 |
if model_type.split('-')[0] == "Janus":
|
111 |
inputs_embeds = model_utils.generate_inputs_embeddings(prepare_inputs)
|
@@ -120,7 +120,7 @@ def evaluate(model_type, num_eval = 10):
|
|
120 |
FILES_ROOT = f"{RESULTS_ROOT}/{model_type}/{eval_idx}"
|
121 |
os.makedirs(FILES_ROOT, exist_ok=True)
|
122 |
|
123 |
-
with open(f"{FILES_ROOT}/{chart_type}.txt", "w") as f:
|
124 |
f.write(answer)
|
125 |
f.close()
|
126 |
|
@@ -129,8 +129,6 @@ def evaluate(model_type, num_eval = 10):
|
|
129 |
if __name__ == '__main__':
|
130 |
|
131 |
# models = ["ChartGemma", "Janus-Pro-1B", "Janus-Pro-7B", "LLaVA-1.5-7B", "GPT-4o", "Gemini-2.0-flash"]
|
132 |
-
|
133 |
-
# models = ["Janus-Pro-7B", "LLaVA-1.5-7B"]
|
134 |
-
models = ["GPT-4o", "Gemini-2.0-flash"]
|
135 |
for model_type in models:
|
136 |
evaluate(model_type=model_type, num_eval=10)
|
|
|
7 |
from demo.model_utils import *
|
8 |
from evaluate.questions import questions
|
9 |
|
10 |
+
def set_seed(model_seed = 70):
|
11 |
torch.manual_seed(model_seed)
|
12 |
+
# np.random.seed(model_seed)
|
13 |
torch.cuda.manual_seed(model_seed) if torch.cuda.is_available() else None
|
14 |
|
15 |
def clean():
|
|
|
52 |
client = OpenAI(api_key=os.environ["GEMINI_HCI_API_KEY"],
|
53 |
base_url="https://generativelanguage.googleapis.com/v1beta/openai/")
|
54 |
|
55 |
+
for question_idx, question in enumerate(questions):
|
56 |
chart_type = question[0]
|
57 |
q = question[1]
|
58 |
img_path = question[2]
|
|
|
104 |
|
105 |
else:
|
106 |
prepare_inputs = model_utils.prepare_inputs(q, image)
|
107 |
+
temperature = 0.1
|
108 |
+
top_p = 0.95
|
109 |
|
110 |
if model_type.split('-')[0] == "Janus":
|
111 |
inputs_embeds = model_utils.generate_inputs_embeddings(prepare_inputs)
|
|
|
120 |
FILES_ROOT = f"{RESULTS_ROOT}/{model_type}/{eval_idx}"
|
121 |
os.makedirs(FILES_ROOT, exist_ok=True)
|
122 |
|
123 |
+
with open(f"{FILES_ROOT}/Q{question_idx + 1}-{chart_type}.txt", "w") as f:
|
124 |
f.write(answer)
|
125 |
f.close()
|
126 |
|
|
|
129 |
if __name__ == '__main__':
|
130 |
|
131 |
# models = ["ChartGemma", "Janus-Pro-1B", "Janus-Pro-7B", "LLaVA-1.5-7B", "GPT-4o", "Gemini-2.0-flash"]
|
132 |
+
models = ["Janus-Pro-7B"]
|
|
|
|
|
133 |
for model_type in models:
|
134 |
evaluate(model_type=model_type, num_eval=10)
|
evaluate/questions.py
CHANGED
@@ -2,72 +2,72 @@ questions=[
|
|
2 |
[
|
3 |
"LineChart",
|
4 |
"What was the price of a barrel of oil in February 2020?",
|
5 |
-
"images/LineChart.png"
|
6 |
],
|
7 |
|
8 |
[
|
9 |
"BarChart",
|
10 |
"What is the average internet speed in Japan?",
|
11 |
-
"images/BarChart.png"
|
12 |
],
|
13 |
|
14 |
[
|
15 |
"StackedBar",
|
16 |
"What is the cost of peanuts in Seoul?",
|
17 |
-
"images/StackedBar.png"
|
18 |
],
|
19 |
|
20 |
[
|
21 |
"100%StackedBar",
|
22 |
"Which country has the lowest proportion of Gold medals?",
|
23 |
-
"images/Stacked100.png"
|
24 |
],
|
25 |
|
26 |
[
|
27 |
"PieChart",
|
28 |
"What is the approximate global smartphone market share of Samsung?",
|
29 |
-
"images/PieChart.png"
|
30 |
],
|
31 |
|
32 |
[
|
33 |
"Histogram",
|
34 |
"What distance have customers traveled in the taxi the most?",
|
35 |
-
"images/Histogram.png"
|
36 |
],
|
37 |
|
38 |
[
|
39 |
"Scatterplot",
|
40 |
"True/False: There is a negative linear relationship between the height and the weight of the 85 males.",
|
41 |
-
"images/Scatterplot.png"
|
42 |
],
|
43 |
|
44 |
[
|
45 |
"AreaChart",
|
46 |
"What was the average price of pount of coffee beans in October 2019?",
|
47 |
-
"images/AreaChart.png"
|
48 |
],
|
49 |
|
50 |
[
|
51 |
"StackedArea",
|
52 |
"What was the ratio of girls named 'Isla' to girls named 'Amelia' in 2012 in the UK?",
|
53 |
-
"images/StackedArea.png"
|
54 |
],
|
55 |
|
56 |
[
|
57 |
"BubbleChart",
|
58 |
"Which city's metro system has the largest number of stations?",
|
59 |
-
"images/BubbleChart.png"
|
60 |
],
|
61 |
|
62 |
[
|
63 |
"Choropleth",
|
64 |
"True/False: In 2020, the unemployment rate for Washington (WA) was higher than that of Wisconsin (WI).",
|
65 |
-
"images/Choropleth_New.png"
|
66 |
],
|
67 |
|
68 |
[
|
69 |
"TreeMap",
|
70 |
"True/False: eBay is nested in the Software category.",
|
71 |
-
"images/TreeMap.png"
|
72 |
]
|
73 |
]
|
|
|
2 |
[
|
3 |
"LineChart",
|
4 |
"What was the price of a barrel of oil in February 2020?",
|
5 |
+
"images/mini-VLAT/LineChart.png"
|
6 |
],
|
7 |
|
8 |
[
|
9 |
"BarChart",
|
10 |
"What is the average internet speed in Japan?",
|
11 |
+
"images/mini-VLAT/BarChart.png"
|
12 |
],
|
13 |
|
14 |
[
|
15 |
"StackedBar",
|
16 |
"What is the cost of peanuts in Seoul?",
|
17 |
+
"images/mini-VLAT/StackedBar.png"
|
18 |
],
|
19 |
|
20 |
[
|
21 |
"100%StackedBar",
|
22 |
"Which country has the lowest proportion of Gold medals?",
|
23 |
+
"images/mini-VLAT/Stacked100.png"
|
24 |
],
|
25 |
|
26 |
[
|
27 |
"PieChart",
|
28 |
"What is the approximate global smartphone market share of Samsung?",
|
29 |
+
"images/mini-VLAT/PieChart.png"
|
30 |
],
|
31 |
|
32 |
[
|
33 |
"Histogram",
|
34 |
"What distance have customers traveled in the taxi the most?",
|
35 |
+
"images/mini-VLAT/Histogram.png"
|
36 |
],
|
37 |
|
38 |
[
|
39 |
"Scatterplot",
|
40 |
"True/False: There is a negative linear relationship between the height and the weight of the 85 males.",
|
41 |
+
"images/mini-VLAT/Scatterplot.png"
|
42 |
],
|
43 |
|
44 |
[
|
45 |
"AreaChart",
|
46 |
"What was the average price of pount of coffee beans in October 2019?",
|
47 |
+
"images/mini-VLAT/AreaChart.png"
|
48 |
],
|
49 |
|
50 |
[
|
51 |
"StackedArea",
|
52 |
"What was the ratio of girls named 'Isla' to girls named 'Amelia' in 2012 in the UK?",
|
53 |
+
"images/mini-VLAT/StackedArea.png"
|
54 |
],
|
55 |
|
56 |
[
|
57 |
"BubbleChart",
|
58 |
"Which city's metro system has the largest number of stations?",
|
59 |
+
"images/mini-VLAT/BubbleChart.png"
|
60 |
],
|
61 |
|
62 |
[
|
63 |
"Choropleth",
|
64 |
"True/False: In 2020, the unemployment rate for Washington (WA) was higher than that of Wisconsin (WI).",
|
65 |
+
"images/mini-VLAT/Choropleth_New.png"
|
66 |
],
|
67 |
|
68 |
[
|
69 |
"TreeMap",
|
70 |
"True/False: eBay is nested in the Software category.",
|
71 |
+
"images/mini-VLAT/TreeMap.png"
|
72 |
]
|
73 |
]
|