duzx16 commited on
Commit
4ec3b65
·
1 Parent(s): 7025474

Add output_attentions for ChatGLMForSequenceClassification

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +5 -8
modeling_chatglm.py CHANGED
@@ -21,7 +21,7 @@ from transformers.modeling_outputs import (
21
  SequenceClassifierOutputWithPast,
22
  )
23
  from transformers.modeling_utils import PreTrainedModel
24
- from transformers.utils import logging
25
  from transformers.generation.logits_process import LogitsProcessor
26
  from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
27
 
@@ -29,7 +29,7 @@ from .configuration_chatglm import ChatGLMConfig
29
 
30
  # flags required to enable jit fusion kernels
31
 
32
- if sys.platform != 'darwin':
33
  torch._C._jit_set_profiling_mode(False)
34
  torch._C._jit_set_profiling_executor(False)
35
  torch._C._jit_override_can_fuse_on_cpu(True)
@@ -40,12 +40,6 @@ logger = logging.get_logger(__name__)
40
  _CHECKPOINT_FOR_DOC = "THUDM/ChatGLM"
41
  _CONFIG_FOR_DOC = "ChatGLMConfig"
42
 
43
- CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
44
- "THUDM/chatglm3-6b",
45
- # See all ChatGLM models at https://huggingface.co/models?filter=chatglm
46
- ]
47
-
48
-
49
  def default_init(cls, *args, **kwargs):
50
  return cls(*args, **kwargs)
51
 
@@ -740,6 +734,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
740
  past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
741
  inputs_embeds: Optional[torch.Tensor] = None,
742
  use_cache: Optional[bool] = None,
 
743
  output_hidden_states: Optional[bool] = None,
744
  return_dict: Optional[bool] = None,
745
  ):
@@ -1145,6 +1140,7 @@ class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
1145
  inputs_embeds: Optional[torch.LongTensor] = None,
1146
  labels: Optional[torch.LongTensor] = None,
1147
  use_cache: Optional[bool] = None,
 
1148
  output_hidden_states: Optional[bool] = None,
1149
  return_dict: Optional[bool] = None,
1150
  ) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutputWithPast]:
@@ -1158,6 +1154,7 @@ class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
1158
  past_key_values=past_key_values,
1159
  inputs_embeds=inputs_embeds,
1160
  use_cache=use_cache,
 
1161
  output_hidden_states=output_hidden_states,
1162
  return_dict=return_dict,
1163
  )
 
21
  SequenceClassifierOutputWithPast,
22
  )
23
  from transformers.modeling_utils import PreTrainedModel
24
+ from transformers.utils import logging, is_torch_npu_available
25
  from transformers.generation.logits_process import LogitsProcessor
26
  from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
27
 
 
29
 
30
  # flags required to enable jit fusion kernels
31
 
32
+ if sys.platform != 'darwin' and not is_torch_npu_available():
33
  torch._C._jit_set_profiling_mode(False)
34
  torch._C._jit_set_profiling_executor(False)
35
  torch._C._jit_override_can_fuse_on_cpu(True)
 
40
  _CHECKPOINT_FOR_DOC = "THUDM/ChatGLM"
41
  _CONFIG_FOR_DOC = "ChatGLMConfig"
42
 
 
 
 
 
 
 
43
  def default_init(cls, *args, **kwargs):
44
  return cls(*args, **kwargs)
45
 
 
734
  past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
735
  inputs_embeds: Optional[torch.Tensor] = None,
736
  use_cache: Optional[bool] = None,
737
+ output_attentions: Optional[bool] = None,
738
  output_hidden_states: Optional[bool] = None,
739
  return_dict: Optional[bool] = None,
740
  ):
 
1140
  inputs_embeds: Optional[torch.LongTensor] = None,
1141
  labels: Optional[torch.LongTensor] = None,
1142
  use_cache: Optional[bool] = None,
1143
+ output_attentions: Optional[bool] = None,
1144
  output_hidden_states: Optional[bool] = None,
1145
  return_dict: Optional[bool] = None,
1146
  ) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutputWithPast]:
 
1154
  past_key_values=past_key_values,
1155
  inputs_embeds=inputs_embeds,
1156
  use_cache=use_cache,
1157
+ output_attentions=output_attentions,
1158
  output_hidden_states=output_hidden_states,
1159
  return_dict=return_dict,
1160
  )