HenryStephen commited on
Commit
e8f04b1
·
1 Parent(s): 5eb1b80

Update pipeline progress bar

Browse files
Files changed (1) hide show
  1. RepoPipeline.py +16 -4
RepoPipeline.py CHANGED
@@ -280,7 +280,7 @@ class RepoPipeline(Pipeline):
280
  if not text_sets \
281
  else torch.cat([self.encode(text, max_length) for text in text_sets], dim=0)
282
 
283
- def _forward(self, extracted_infos: List, max_length=512) -> List:
284
  """
285
  The method for "forward" period.
286
  :param extracted_infos: the information of repositories.
@@ -289,8 +289,9 @@ class RepoPipeline(Pipeline):
289
  """
290
  model_outputs = []
291
  # The number of repository.
292
- num_repos = len(extracted_infos)
293
- with tqdm(total=num_repos) as progress_bar:
 
294
  # For each repository
295
  for repo_info in extracted_infos:
296
  repo_name = repo_info["name"]
@@ -307,12 +308,18 @@ class RepoPipeline(Pipeline):
307
  code_embeddings = self.generate_embeddings(repo_info["codes"], max_length)
308
  info["code_embeddings"] = code_embeddings.cpu().numpy()
309
  info["mean_code_embedding"] = torch.mean(code_embeddings, dim=0, keepdim=True).cpu().numpy()
 
 
 
310
 
311
  # Doc embeddings
312
  tqdm.write(f"[*] Generating doc embeddings for {repo_name}")
313
  doc_embeddings = self.generate_embeddings(repo_info["docs"], max_length)
314
  info["doc_embeddings"] = doc_embeddings.cpu().numpy()
315
  info["mean_doc_embedding"] = torch.mean(doc_embeddings, dim=0, keepdim=True).cpu().numpy()
 
 
 
316
 
317
  # Requirement embeddings
318
  tqdm.write(f"[*] Generating requirement embeddings for {repo_name}")
@@ -320,12 +327,18 @@ class RepoPipeline(Pipeline):
320
  info["requirement_embeddings"] = requirement_embeddings.cpu().numpy()
321
  info["mean_requirement_embedding"] = torch.mean(requirement_embeddings, dim=0,
322
  keepdim=True).cpu().numpy()
 
 
 
323
 
324
  # Readme embeddings
325
  tqdm.write(f"[*] Generating readme embeddings for {repo_name}")
326
  readme_embeddings = self.generate_embeddings(repo_info["readmes"], max_length)
327
  info["readme_embeddings"] = readme_embeddings.cpu().numpy()
328
  info["mean_readme_embedding"] = torch.mean(readme_embeddings, dim=0, keepdim=True).cpu().numpy()
 
 
 
329
 
330
  # Repo-level mean embedding
331
  info["mean_repo_embedding"] = np.concatenate([
@@ -345,7 +358,6 @@ class RepoPipeline(Pipeline):
345
  info["mean_readme_embedding_shape"] = info["mean_readme_embedding"].shape
346
  info["mean_repo_embedding_shape"] = info["mean_repo_embedding"].shape
347
 
348
- progress_bar.update(1)
349
  model_outputs.append(info)
350
 
351
  return model_outputs
 
280
  if not text_sets \
281
  else torch.cat([self.encode(text, max_length) for text in text_sets], dim=0)
282
 
283
+ def _forward(self, extracted_infos: List, max_length=512, st_progress=None) -> List:
284
  """
285
  The method for "forward" period.
286
  :param extracted_infos: the information of repositories.
 
289
  """
290
  model_outputs = []
291
  # The number of repository.
292
+ num_texts = sum(
293
+ len(x["codes"]) + len(x["docs"] + len(x["requirements"]) + len(x["readmes"])) for x in extracted_infos)
294
+ with tqdm(total=num_texts) as progress_bar:
295
  # For each repository
296
  for repo_info in extracted_infos:
297
  repo_name = repo_info["name"]
 
308
  code_embeddings = self.generate_embeddings(repo_info["codes"], max_length)
309
  info["code_embeddings"] = code_embeddings.cpu().numpy()
310
  info["mean_code_embedding"] = torch.mean(code_embeddings, dim=0, keepdim=True).cpu().numpy()
311
+ progress_bar.update(len(repo_info["codes"]))
312
+ if st_progress:
313
+ st_progress.progress(progress_bar.n / progress_bar.total)
314
 
315
  # Doc embeddings
316
  tqdm.write(f"[*] Generating doc embeddings for {repo_name}")
317
  doc_embeddings = self.generate_embeddings(repo_info["docs"], max_length)
318
  info["doc_embeddings"] = doc_embeddings.cpu().numpy()
319
  info["mean_doc_embedding"] = torch.mean(doc_embeddings, dim=0, keepdim=True).cpu().numpy()
320
+ progress_bar.update(len(repo_info["docs"]))
321
+ if st_progress:
322
+ st_progress.progress(progress_bar.n / progress_bar.total)
323
 
324
  # Requirement embeddings
325
  tqdm.write(f"[*] Generating requirement embeddings for {repo_name}")
 
327
  info["requirement_embeddings"] = requirement_embeddings.cpu().numpy()
328
  info["mean_requirement_embedding"] = torch.mean(requirement_embeddings, dim=0,
329
  keepdim=True).cpu().numpy()
330
+ progress_bar.update(len(repo_info["requirements"]))
331
+ if st_progress:
332
+ st_progress.progress(progress_bar.n / progress_bar.total)
333
 
334
  # Readme embeddings
335
  tqdm.write(f"[*] Generating readme embeddings for {repo_name}")
336
  readme_embeddings = self.generate_embeddings(repo_info["readmes"], max_length)
337
  info["readme_embeddings"] = readme_embeddings.cpu().numpy()
338
  info["mean_readme_embedding"] = torch.mean(readme_embeddings, dim=0, keepdim=True).cpu().numpy()
339
+ progress_bar.update(len(repo_info["readmes"]))
340
+ if st_progress:
341
+ st_progress.progress(progress_bar.n / progress_bar.total)
342
 
343
  # Repo-level mean embedding
344
  info["mean_repo_embedding"] = np.concatenate([
 
358
  info["mean_readme_embedding_shape"] = info["mean_readme_embedding"].shape
359
  info["mean_repo_embedding_shape"] = info["mean_repo_embedding"].shape
360
 
 
361
  model_outputs.append(info)
362
 
363
  return model_outputs