Update RepoPipeline.py
Browse files- RepoPipeline.py +2 -5
RepoPipeline.py
CHANGED
@@ -122,13 +122,12 @@ class RepoPipeline(Pipeline):
|
|
122 |
postprocess_parameters = {}
|
123 |
return preprocess_parameters, forward_parameters, postprocess_parameters
|
124 |
|
125 |
-
def preprocess(self, input_: Any,
|
126 |
# Making input to list format
|
127 |
if isinstance(input_, str):
|
128 |
input_ = [input_]
|
129 |
|
130 |
# Building token
|
131 |
-
github_token = preprocess_parameters["github_token"]
|
132 |
headers = {"Accept": "application/vnd.github+json"}
|
133 |
token = github_token or self.github_token
|
134 |
if token:
|
@@ -162,9 +161,7 @@ class RepoPipeline(Pipeline):
|
|
162 |
if text_sets is None or len(text_sets) == 0 \
|
163 |
else torch.zeros((1, 768), device=self.device)
|
164 |
|
165 |
-
def _forward(self, extracted_infos: List,
|
166 |
-
max_length = 512 if forward_parameters["max_length"] is None else forward_parameters["max_length"]
|
167 |
-
|
168 |
model_outputs = []
|
169 |
num_repos = len(extracted_infos)
|
170 |
with tqdm(total=num_repos) as progress_bar:
|
|
|
122 |
postprocess_parameters = {}
|
123 |
return preprocess_parameters, forward_parameters, postprocess_parameters
|
124 |
|
125 |
+
def preprocess(self, input_: Any, github_token=None) -> List:
|
126 |
# Making input to list format
|
127 |
if isinstance(input_, str):
|
128 |
input_ = [input_]
|
129 |
|
130 |
# Building token
|
|
|
131 |
headers = {"Accept": "application/vnd.github+json"}
|
132 |
token = github_token or self.github_token
|
133 |
if token:
|
|
|
161 |
if text_sets is None or len(text_sets) == 0 \
|
162 |
else torch.zeros((1, 768), device=self.device)
|
163 |
|
164 |
+
def _forward(self, extracted_infos: List, max_length) -> List:
|
|
|
|
|
165 |
model_outputs = []
|
166 |
num_repos = len(extracted_infos)
|
167 |
with tqdm(total=num_repos) as progress_bar:
|