duzx16
commited on
Commit
·
4ec3b65
1
Parent(s):
7025474
Add output_attentions for ChatGLMForSequenceClassification
Browse files- modeling_chatglm.py +5 -8
modeling_chatglm.py
CHANGED
@@ -21,7 +21,7 @@ from transformers.modeling_outputs import (
|
|
21 |
SequenceClassifierOutputWithPast,
|
22 |
)
|
23 |
from transformers.modeling_utils import PreTrainedModel
|
24 |
-
from transformers.utils import logging
|
25 |
from transformers.generation.logits_process import LogitsProcessor
|
26 |
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
|
27 |
|
@@ -29,7 +29,7 @@ from .configuration_chatglm import ChatGLMConfig
|
|
29 |
|
30 |
# flags required to enable jit fusion kernels
|
31 |
|
32 |
-
if sys.platform != 'darwin':
|
33 |
torch._C._jit_set_profiling_mode(False)
|
34 |
torch._C._jit_set_profiling_executor(False)
|
35 |
torch._C._jit_override_can_fuse_on_cpu(True)
|
@@ -40,12 +40,6 @@ logger = logging.get_logger(__name__)
|
|
40 |
_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM"
|
41 |
_CONFIG_FOR_DOC = "ChatGLMConfig"
|
42 |
|
43 |
-
CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
44 |
-
"THUDM/chatglm3-6b",
|
45 |
-
# See all ChatGLM models at https://huggingface.co/models?filter=chatglm
|
46 |
-
]
|
47 |
-
|
48 |
-
|
49 |
def default_init(cls, *args, **kwargs):
|
50 |
return cls(*args, **kwargs)
|
51 |
|
@@ -740,6 +734,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
740 |
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
741 |
inputs_embeds: Optional[torch.Tensor] = None,
|
742 |
use_cache: Optional[bool] = None,
|
|
|
743 |
output_hidden_states: Optional[bool] = None,
|
744 |
return_dict: Optional[bool] = None,
|
745 |
):
|
@@ -1145,6 +1140,7 @@ class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
|
|
1145 |
inputs_embeds: Optional[torch.LongTensor] = None,
|
1146 |
labels: Optional[torch.LongTensor] = None,
|
1147 |
use_cache: Optional[bool] = None,
|
|
|
1148 |
output_hidden_states: Optional[bool] = None,
|
1149 |
return_dict: Optional[bool] = None,
|
1150 |
) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutputWithPast]:
|
@@ -1158,6 +1154,7 @@ class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
|
|
1158 |
past_key_values=past_key_values,
|
1159 |
inputs_embeds=inputs_embeds,
|
1160 |
use_cache=use_cache,
|
|
|
1161 |
output_hidden_states=output_hidden_states,
|
1162 |
return_dict=return_dict,
|
1163 |
)
|
|
|
21 |
SequenceClassifierOutputWithPast,
|
22 |
)
|
23 |
from transformers.modeling_utils import PreTrainedModel
|
24 |
+
from transformers.utils import logging, is_torch_npu_available
|
25 |
from transformers.generation.logits_process import LogitsProcessor
|
26 |
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
|
27 |
|
|
|
29 |
|
30 |
# flags required to enable jit fusion kernels
|
31 |
|
32 |
+
if sys.platform != 'darwin' and not is_torch_npu_available():
|
33 |
torch._C._jit_set_profiling_mode(False)
|
34 |
torch._C._jit_set_profiling_executor(False)
|
35 |
torch._C._jit_override_can_fuse_on_cpu(True)
|
|
|
40 |
_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM"
|
41 |
_CONFIG_FOR_DOC = "ChatGLMConfig"
|
42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
def default_init(cls, *args, **kwargs):
|
44 |
return cls(*args, **kwargs)
|
45 |
|
|
|
734 |
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
735 |
inputs_embeds: Optional[torch.Tensor] = None,
|
736 |
use_cache: Optional[bool] = None,
|
737 |
+
output_attentions: Optional[bool] = None,
|
738 |
output_hidden_states: Optional[bool] = None,
|
739 |
return_dict: Optional[bool] = None,
|
740 |
):
|
|
|
1140 |
inputs_embeds: Optional[torch.LongTensor] = None,
|
1141 |
labels: Optional[torch.LongTensor] = None,
|
1142 |
use_cache: Optional[bool] = None,
|
1143 |
+
output_attentions: Optional[bool] = None,
|
1144 |
output_hidden_states: Optional[bool] = None,
|
1145 |
return_dict: Optional[bool] = None,
|
1146 |
) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutputWithPast]:
|
|
|
1154 |
past_key_values=past_key_values,
|
1155 |
inputs_embeds=inputs_embeds,
|
1156 |
use_cache=use_cache,
|
1157 |
+
output_attentions=output_attentions,
|
1158 |
output_hidden_states=output_hidden_states,
|
1159 |
return_dict=return_dict,
|
1160 |
)
|