AustingDong commited on
Commit
6d117d1
·
1 Parent(s): 7e57874

customed a loss (useless)

Browse files
demo/visualization.py CHANGED
@@ -406,6 +406,16 @@ class VisualizationChartGemma(Visualization):
406
  super().__init__(model, register=True)
407
  self._modify_layers()
408
  self._register_hooks_activations()
 
 
 
 
 
 
 
 
 
 
409
 
410
  def forward_backward(self, inputs, focus, start_idx, target_token_idx, visual_method="softmax"):
411
  outputs_raw = self.model(**inputs, output_hidden_states=True)
@@ -421,6 +431,7 @@ class VisualizationChartGemma(Visualization):
421
  print("logits shape:", outputs_raw.logits.shape)
422
  if target_token_idx == -1:
423
  loss = outputs_raw.logits.max(dim=-1).values.sum()
 
424
  else:
425
  loss = outputs_raw.logits.max(dim=-1).values[0, start_idx + target_token_idx]
426
  loss.backward()
 
406
  super().__init__(model, register=True)
407
  self._modify_layers()
408
  self._register_hooks_activations()
409
+
410
+ # def custom_loss(self, start_idx, input_ids, logits):
411
+ # Q = logits.shape[1]
412
+ # loss = 0
413
+ # q = 0
414
+ # while start_idx + q < Q - 1:
415
+ # loss += F.cross_entropy(logits[0, start_idx + q], input_ids[0, start_idx + q + 1])
416
+ # q += 1
417
+ # return loss
418
+
419
 
420
  def forward_backward(self, inputs, focus, start_idx, target_token_idx, visual_method="softmax"):
421
  outputs_raw = self.model(**inputs, output_hidden_states=True)
 
431
  print("logits shape:", outputs_raw.logits.shape)
432
  if target_token_idx == -1:
433
  loss = outputs_raw.logits.max(dim=-1).values.sum()
434
+ # loss = self.custom_loss(start_idx, inputs['input_ids'], outputs_raw.logits)
435
  else:
436
  loss = outputs_raw.logits.max(dim=-1).values[0, start_idx + target_token_idx]
437
  loss.backward()
evaluate/evaluate.py CHANGED
@@ -7,9 +7,9 @@ from openai import OpenAI
7
  from demo.model_utils import *
8
  from evaluate.questions import questions
9
 
10
- def set_seed(model_seed = 42):
11
  torch.manual_seed(model_seed)
12
- np.random.seed(model_seed)
13
  torch.cuda.manual_seed(model_seed) if torch.cuda.is_available() else None
14
 
15
  def clean():
@@ -52,7 +52,7 @@ def evaluate(model_type, num_eval = 10):
52
  client = OpenAI(api_key=os.environ["GEMINI_HCI_API_KEY"],
53
  base_url="https://generativelanguage.googleapis.com/v1beta/openai/")
54
 
55
- for question in questions:
56
  chart_type = question[0]
57
  q = question[1]
58
  img_path = question[2]
@@ -104,8 +104,8 @@ def evaluate(model_type, num_eval = 10):
104
 
105
  else:
106
  prepare_inputs = model_utils.prepare_inputs(q, image)
107
- temperature = 0.9
108
- top_p = 0.1
109
 
110
  if model_type.split('-')[0] == "Janus":
111
  inputs_embeds = model_utils.generate_inputs_embeddings(prepare_inputs)
@@ -120,7 +120,7 @@ def evaluate(model_type, num_eval = 10):
120
  FILES_ROOT = f"{RESULTS_ROOT}/{model_type}/{eval_idx}"
121
  os.makedirs(FILES_ROOT, exist_ok=True)
122
 
123
- with open(f"{FILES_ROOT}/{chart_type}.txt", "w") as f:
124
  f.write(answer)
125
  f.close()
126
 
@@ -129,8 +129,6 @@ def evaluate(model_type, num_eval = 10):
129
  if __name__ == '__main__':
130
 
131
  # models = ["ChartGemma", "Janus-Pro-1B", "Janus-Pro-7B", "LLaVA-1.5-7B", "GPT-4o", "Gemini-2.0-flash"]
132
- # models = ["ChartGemma", "Janus-Pro-1B"]
133
- # models = ["Janus-Pro-7B", "LLaVA-1.5-7B"]
134
- models = ["GPT-4o", "Gemini-2.0-flash"]
135
  for model_type in models:
136
  evaluate(model_type=model_type, num_eval=10)
 
7
  from demo.model_utils import *
8
  from evaluate.questions import questions
9
 
10
+ def set_seed(model_seed = 70):
11
  torch.manual_seed(model_seed)
12
+ # np.random.seed(model_seed)
13
  torch.cuda.manual_seed(model_seed) if torch.cuda.is_available() else None
14
 
15
  def clean():
 
52
  client = OpenAI(api_key=os.environ["GEMINI_HCI_API_KEY"],
53
  base_url="https://generativelanguage.googleapis.com/v1beta/openai/")
54
 
55
+ for question_idx, question in enumerate(questions):
56
  chart_type = question[0]
57
  q = question[1]
58
  img_path = question[2]
 
104
 
105
  else:
106
  prepare_inputs = model_utils.prepare_inputs(q, image)
107
+ temperature = 0.1
108
+ top_p = 0.95
109
 
110
  if model_type.split('-')[0] == "Janus":
111
  inputs_embeds = model_utils.generate_inputs_embeddings(prepare_inputs)
 
120
  FILES_ROOT = f"{RESULTS_ROOT}/{model_type}/{eval_idx}"
121
  os.makedirs(FILES_ROOT, exist_ok=True)
122
 
123
+ with open(f"{FILES_ROOT}/Q{question_idx + 1}-{chart_type}.txt", "w") as f:
124
  f.write(answer)
125
  f.close()
126
 
 
129
  if __name__ == '__main__':
130
 
131
  # models = ["ChartGemma", "Janus-Pro-1B", "Janus-Pro-7B", "LLaVA-1.5-7B", "GPT-4o", "Gemini-2.0-flash"]
132
+ models = ["Janus-Pro-7B"]
 
 
133
  for model_type in models:
134
  evaluate(model_type=model_type, num_eval=10)
evaluate/questions.py CHANGED
@@ -2,72 +2,72 @@ questions=[
2
  [
3
  "LineChart",
4
  "What was the price of a barrel of oil in February 2020?",
5
- "images/LineChart.png"
6
  ],
7
 
8
  [
9
  "BarChart",
10
  "What is the average internet speed in Japan?",
11
- "images/BarChart.png"
12
  ],
13
 
14
  [
15
  "StackedBar",
16
  "What is the cost of peanuts in Seoul?",
17
- "images/StackedBar.png"
18
  ],
19
 
20
  [
21
  "100%StackedBar",
22
  "Which country has the lowest proportion of Gold medals?",
23
- "images/Stacked100.png"
24
  ],
25
 
26
  [
27
  "PieChart",
28
  "What is the approximate global smartphone market share of Samsung?",
29
- "images/PieChart.png"
30
  ],
31
 
32
  [
33
  "Histogram",
34
  "What distance have customers traveled in the taxi the most?",
35
- "images/Histogram.png"
36
  ],
37
 
38
  [
39
  "Scatterplot",
40
  "True/False: There is a negative linear relationship between the height and the weight of the 85 males.",
41
- "images/Scatterplot.png"
42
  ],
43
 
44
  [
45
  "AreaChart",
46
  "What was the average price of pount of coffee beans in October 2019?",
47
- "images/AreaChart.png"
48
  ],
49
 
50
  [
51
  "StackedArea",
52
  "What was the ratio of girls named 'Isla' to girls named 'Amelia' in 2012 in the UK?",
53
- "images/StackedArea.png"
54
  ],
55
 
56
  [
57
  "BubbleChart",
58
  "Which city's metro system has the largest number of stations?",
59
- "images/BubbleChart.png"
60
  ],
61
 
62
  [
63
  "Choropleth",
64
  "True/False: In 2020, the unemployment rate for Washington (WA) was higher than that of Wisconsin (WI).",
65
- "images/Choropleth_New.png"
66
  ],
67
 
68
  [
69
  "TreeMap",
70
  "True/False: eBay is nested in the Software category.",
71
- "images/TreeMap.png"
72
  ]
73
  ]
 
2
  [
3
  "LineChart",
4
  "What was the price of a barrel of oil in February 2020?",
5
+ "images/mini-VLAT/LineChart.png"
6
  ],
7
 
8
  [
9
  "BarChart",
10
  "What is the average internet speed in Japan?",
11
+ "images/mini-VLAT/BarChart.png"
12
  ],
13
 
14
  [
15
  "StackedBar",
16
  "What is the cost of peanuts in Seoul?",
17
+ "images/mini-VLAT/StackedBar.png"
18
  ],
19
 
20
  [
21
  "100%StackedBar",
22
  "Which country has the lowest proportion of Gold medals?",
23
+ "images/mini-VLAT/Stacked100.png"
24
  ],
25
 
26
  [
27
  "PieChart",
28
  "What is the approximate global smartphone market share of Samsung?",
29
+ "images/mini-VLAT/PieChart.png"
30
  ],
31
 
32
  [
33
  "Histogram",
34
  "What distance have customers traveled in the taxi the most?",
35
+ "images/mini-VLAT/Histogram.png"
36
  ],
37
 
38
  [
39
  "Scatterplot",
40
  "True/False: There is a negative linear relationship between the height and the weight of the 85 males.",
41
+ "images/mini-VLAT/Scatterplot.png"
42
  ],
43
 
44
  [
45
  "AreaChart",
46
  "What was the average price of pount of coffee beans in October 2019?",
47
+ "images/mini-VLAT/AreaChart.png"
48
  ],
49
 
50
  [
51
  "StackedArea",
52
  "What was the ratio of girls named 'Isla' to girls named 'Amelia' in 2012 in the UK?",
53
+ "images/mini-VLAT/StackedArea.png"
54
  ],
55
 
56
  [
57
  "BubbleChart",
58
  "Which city's metro system has the largest number of stations?",
59
+ "images/mini-VLAT/BubbleChart.png"
60
  ],
61
 
62
  [
63
  "Choropleth",
64
  "True/False: In 2020, the unemployment rate for Washington (WA) was higher than that of Wisconsin (WI).",
65
+ "images/mini-VLAT/Choropleth_New.png"
66
  ],
67
 
68
  [
69
  "TreeMap",
70
  "True/False: eBay is nested in the Software category.",
71
+ "images/mini-VLAT/TreeMap.png"
72
  ]
73
  ]