Tugrul tomer-nv commited on
Commit
b7432dd
·
verified ·
1 Parent(s): d973526

_prepare_generation_config bugfix (failed due to version update in transformers) (#14)

Browse files

- _prepare_generation_config bugfix (failed due to version update in transformers) (cb7dc37ce9ae28726b17f79619d03519ab9551db)


Co-authored-by: Tomer Ronen <tomer-nv@users.noreply.huggingface.co>

Files changed (1) hide show
  1. modeling_decilm.py +5 -2
modeling_decilm.py CHANGED
@@ -802,10 +802,13 @@ class DeciLMPreTrainedModel(PreTrainedModel):
802
  module.weight.data[module.padding_idx].zero_()
803
 
804
  def _prepare_generation_config(
805
- self, generation_config: Optional[GenerationConfig], **kwargs: dict
 
 
 
806
  ) -> tuple[GenerationConfig, dict]:
807
  # DeciLM-specific code
808
- generation_config, model_kwargs = super()._prepare_generation_config(generation_config, **kwargs)
809
  generation_config.cache_implementation = "variable"
810
  NEED_SETUP_CACHE_CLASSES_MAPPING["variable"] = VariableCache
811
  return generation_config, model_kwargs
 
802
  module.weight.data[module.padding_idx].zero_()
803
 
804
  def _prepare_generation_config(
805
+ self,
806
+ generation_config: Optional[GenerationConfig],
807
+ *args,
808
+ **kwargs,
809
  ) -> tuple[GenerationConfig, dict]:
810
  # DeciLM-specific code
811
+ generation_config, model_kwargs = super()._prepare_generation_config(generation_config, *args, **kwargs)
812
  generation_config.cache_implementation = "variable"
813
  NEED_SETUP_CACHE_CLASSES_MAPPING["variable"] = VariableCache
814
  return generation_config, model_kwargs