Commit
·
e8f04b1
1
Parent(s):
5eb1b80
Update pipeline progress bar
Browse files- RepoPipeline.py +16 -4
RepoPipeline.py
CHANGED
@@ -280,7 +280,7 @@ class RepoPipeline(Pipeline):
|
|
280 |
if not text_sets \
|
281 |
else torch.cat([self.encode(text, max_length) for text in text_sets], dim=0)
|
282 |
|
283 |
-
def _forward(self, extracted_infos: List, max_length=512) -> List:
|
284 |
"""
|
285 |
The method for "forward" period.
|
286 |
:param extracted_infos: the information of repositories.
|
@@ -289,8 +289,9 @@ class RepoPipeline(Pipeline):
|
|
289 |
"""
|
290 |
model_outputs = []
|
291 |
# The number of repository.
|
292 |
-
|
293 |
-
|
|
|
294 |
# For each repository
|
295 |
for repo_info in extracted_infos:
|
296 |
repo_name = repo_info["name"]
|
@@ -307,12 +308,18 @@ class RepoPipeline(Pipeline):
|
|
307 |
code_embeddings = self.generate_embeddings(repo_info["codes"], max_length)
|
308 |
info["code_embeddings"] = code_embeddings.cpu().numpy()
|
309 |
info["mean_code_embedding"] = torch.mean(code_embeddings, dim=0, keepdim=True).cpu().numpy()
|
|
|
|
|
|
|
310 |
|
311 |
# Doc embeddings
|
312 |
tqdm.write(f"[*] Generating doc embeddings for {repo_name}")
|
313 |
doc_embeddings = self.generate_embeddings(repo_info["docs"], max_length)
|
314 |
info["doc_embeddings"] = doc_embeddings.cpu().numpy()
|
315 |
info["mean_doc_embedding"] = torch.mean(doc_embeddings, dim=0, keepdim=True).cpu().numpy()
|
|
|
|
|
|
|
316 |
|
317 |
# Requirement embeddings
|
318 |
tqdm.write(f"[*] Generating requirement embeddings for {repo_name}")
|
@@ -320,12 +327,18 @@ class RepoPipeline(Pipeline):
|
|
320 |
info["requirement_embeddings"] = requirement_embeddings.cpu().numpy()
|
321 |
info["mean_requirement_embedding"] = torch.mean(requirement_embeddings, dim=0,
|
322 |
keepdim=True).cpu().numpy()
|
|
|
|
|
|
|
323 |
|
324 |
# Readme embeddings
|
325 |
tqdm.write(f"[*] Generating readme embeddings for {repo_name}")
|
326 |
readme_embeddings = self.generate_embeddings(repo_info["readmes"], max_length)
|
327 |
info["readme_embeddings"] = readme_embeddings.cpu().numpy()
|
328 |
info["mean_readme_embedding"] = torch.mean(readme_embeddings, dim=0, keepdim=True).cpu().numpy()
|
|
|
|
|
|
|
329 |
|
330 |
# Repo-level mean embedding
|
331 |
info["mean_repo_embedding"] = np.concatenate([
|
@@ -345,7 +358,6 @@ class RepoPipeline(Pipeline):
|
|
345 |
info["mean_readme_embedding_shape"] = info["mean_readme_embedding"].shape
|
346 |
info["mean_repo_embedding_shape"] = info["mean_repo_embedding"].shape
|
347 |
|
348 |
-
progress_bar.update(1)
|
349 |
model_outputs.append(info)
|
350 |
|
351 |
return model_outputs
|
|
|
280 |
if not text_sets \
|
281 |
else torch.cat([self.encode(text, max_length) for text in text_sets], dim=0)
|
282 |
|
283 |
+
def _forward(self, extracted_infos: List, max_length=512, st_progress=None) -> List:
|
284 |
"""
|
285 |
The method for "forward" period.
|
286 |
:param extracted_infos: the information of repositories.
|
|
|
289 |
"""
|
290 |
model_outputs = []
|
291 |
# The number of repository.
|
292 |
+
num_texts = sum(
|
293 |
+
len(x["codes"]) + len(x["docs"] + len(x["requirements"]) + len(x["readmes"])) for x in extracted_infos)
|
294 |
+
with tqdm(total=num_texts) as progress_bar:
|
295 |
# For each repository
|
296 |
for repo_info in extracted_infos:
|
297 |
repo_name = repo_info["name"]
|
|
|
308 |
code_embeddings = self.generate_embeddings(repo_info["codes"], max_length)
|
309 |
info["code_embeddings"] = code_embeddings.cpu().numpy()
|
310 |
info["mean_code_embedding"] = torch.mean(code_embeddings, dim=0, keepdim=True).cpu().numpy()
|
311 |
+
progress_bar.update(len(repo_info["codes"]))
|
312 |
+
if st_progress:
|
313 |
+
st_progress.progress(progress_bar.n / progress_bar.total)
|
314 |
|
315 |
# Doc embeddings
|
316 |
tqdm.write(f"[*] Generating doc embeddings for {repo_name}")
|
317 |
doc_embeddings = self.generate_embeddings(repo_info["docs"], max_length)
|
318 |
info["doc_embeddings"] = doc_embeddings.cpu().numpy()
|
319 |
info["mean_doc_embedding"] = torch.mean(doc_embeddings, dim=0, keepdim=True).cpu().numpy()
|
320 |
+
progress_bar.update(len(repo_info["docs"]))
|
321 |
+
if st_progress:
|
322 |
+
st_progress.progress(progress_bar.n / progress_bar.total)
|
323 |
|
324 |
# Requirement embeddings
|
325 |
tqdm.write(f"[*] Generating requirement embeddings for {repo_name}")
|
|
|
327 |
info["requirement_embeddings"] = requirement_embeddings.cpu().numpy()
|
328 |
info["mean_requirement_embedding"] = torch.mean(requirement_embeddings, dim=0,
|
329 |
keepdim=True).cpu().numpy()
|
330 |
+
progress_bar.update(len(repo_info["requirements"]))
|
331 |
+
if st_progress:
|
332 |
+
st_progress.progress(progress_bar.n / progress_bar.total)
|
333 |
|
334 |
# Readme embeddings
|
335 |
tqdm.write(f"[*] Generating readme embeddings for {repo_name}")
|
336 |
readme_embeddings = self.generate_embeddings(repo_info["readmes"], max_length)
|
337 |
info["readme_embeddings"] = readme_embeddings.cpu().numpy()
|
338 |
info["mean_readme_embedding"] = torch.mean(readme_embeddings, dim=0, keepdim=True).cpu().numpy()
|
339 |
+
progress_bar.update(len(repo_info["readmes"]))
|
340 |
+
if st_progress:
|
341 |
+
st_progress.progress(progress_bar.n / progress_bar.total)
|
342 |
|
343 |
# Repo-level mean embedding
|
344 |
info["mean_repo_embedding"] = np.concatenate([
|
|
|
358 |
info["mean_readme_embedding_shape"] = info["mean_readme_embedding"].shape
|
359 |
info["mean_repo_embedding_shape"] = info["mean_repo_embedding"].shape
|
360 |
|
|
|
361 |
model_outputs.append(info)
|
362 |
|
363 |
return model_outputs
|