Henry65 commited on
Commit
e34a465
·
1 Parent(s): 066b297

Update RepoPipeline.py

Browse files
Files changed (1) hide show
  1. RepoPipeline.py +154 -28
RepoPipeline.py CHANGED
@@ -2,14 +2,20 @@ from typing import Dict, Any, List
2
 
3
  import ast
4
  import tarfile
5
- from ast import AsyncFunctionDef, ClassDef, FunctionDef, Module
6
  import torch
7
  import requests
 
 
8
  from transformers import Pipeline
9
  from tqdm.auto import tqdm
10
 
11
 
12
  def extract_code_and_docs(text: str):
 
 
 
 
 
13
  code_set = set()
14
  docs_set = set()
15
  root = ast.parse(text)
@@ -28,7 +34,33 @@ def extract_code_and_docs(text: str):
28
  return code_set, docs_set
29
 
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  def get_metadata(repo_name, headers=None):
 
 
 
 
 
 
32
  api_url = f"https://api.github.com/repos/{repo_name}"
33
  tqdm.write(f"[+] Getting metadata for {repo_name}")
34
  try:
@@ -41,9 +73,15 @@ def get_metadata(repo_name, headers=None):
41
 
42
 
43
  def extract_information(repos, headers=None):
 
 
 
 
 
 
44
  extracted_infos = []
45
  for repo_name in tqdm(repos, disable=len(repos) <= 1):
46
- # Get metadata
47
  metadata = get_metadata(repo_name, headers=headers)
48
  repo_info = {
49
  "name": repo_name,
@@ -60,7 +98,7 @@ def extract_information(repos, headers=None):
60
  if metadata.get("license"):
61
  repo_info["license"] = metadata["license"]["spdx_id"]
62
 
63
- # Download repo tarball bytes
64
  download_url = f"https://api.github.com/repos/{repo_name}/tarball"
65
  tqdm.write(f"[+] Downloading {repo_name}")
66
  try:
@@ -70,24 +108,51 @@ def extract_information(repos, headers=None):
70
  tqdm.write(f"[-] Failed to download {repo_name}: {e}")
71
  continue
72
 
73
- # Extract python files and parse them
74
  tqdm.write(f"[+] Extracting {repo_name} info")
75
  with tarfile.open(fileobj=response.raw, mode="r|gz") as tar:
76
  for member in tar:
77
- if (member.name.endswith(".py") and member.isfile()) is False:
78
- continue
79
- try:
80
- file_content = tar.extractfile(member).read().decode("utf-8")
81
- code_set, docs_set = extract_code_and_docs(file_content)
82
-
83
- repo_info["codes"].update(code_set)
84
- repo_info["docs"].update(docs_set)
85
- except UnicodeDecodeError as e:
86
- tqdm.write(
87
- f"[-] UnicodeDecodeError in {member.name}, skipping: \n{e}"
88
- )
89
- except SyntaxError as e:
90
- tqdm.write(f"[-] SyntaxError in {member.name}, skipping: \n{e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
  extracted_infos.append(repo_info)
93
 
@@ -95,11 +160,20 @@ def extract_information(repos, headers=None):
95
 
96
 
97
  class RepoPipeline(Pipeline):
 
 
 
98
 
99
  def __init__(self, github_token=None, *args, **kwargs):
 
 
 
 
 
 
100
  super().__init__(*args, **kwargs)
101
 
102
- # Github token
103
  self.github_token = github_token
104
  if self.github_token:
105
  print("[+] GitHub token set!")
@@ -111,36 +185,56 @@ class RepoPipeline(Pipeline):
111
  )
112
 
113
  def _sanitize_parameters(self, **pipeline_parameters):
 
 
 
 
 
 
114
  preprocess_parameters = {}
115
  if "github_token" in pipeline_parameters:
116
  preprocess_parameters["github_token"] = pipeline_parameters["github_token"]
117
 
 
118
  forward_parameters = {}
119
  if "max_length" in pipeline_parameters:
120
  forward_parameters["max_length"] = pipeline_parameters["max_length"]
121
 
 
122
  postprocess_parameters = {}
123
  return preprocess_parameters, forward_parameters, postprocess_parameters
124
 
125
  def preprocess(self, input_: Any, github_token=None) -> List:
126
- # Making input to list format
 
 
 
 
 
 
127
  if isinstance(input_, str):
128
  input_ = [input_]
129
 
130
- # Building token
131
  headers = {"Accept": "application/vnd.github+json"}
132
  token = github_token or self.github_token
133
  if token:
134
  headers["Authorization"] = f"Bearer {token}"
135
 
136
- # Getting repositories' information: input_ means series of repositories
137
  extracted_infos = extract_information(input_, headers=headers)
138
-
139
  return extracted_infos
140
 
141
  def encode(self, text, max_length):
 
 
 
 
 
 
142
  assert max_length < 1024
143
 
 
144
  tokenizer = self.tokenizer
145
  tokens = (
146
  [tokenizer.cls_token, "<encoder-only>", tokenizer.sep_token]
@@ -149,20 +243,36 @@ class RepoPipeline(Pipeline):
149
  )
150
  tokens_id = tokenizer.convert_tokens_to_ids(tokens)
151
  source_ids = torch.tensor([tokens_id]).to(self.device)
152
-
153
  token_embeddings = self.model(source_ids)[0]
 
 
154
  sentence_embeddings = token_embeddings.mean(dim=1)
155
 
156
  return sentence_embeddings
157
 
158
  def generate_embeddings(self, text_sets, max_length):
 
 
 
 
 
 
159
  assert max_length < 1024
 
 
160
  return torch.zeros((1, 768), device=self.device) \
161
- if text_sets is None or len(text_sets) == 0 \
162
  else torch.cat([self.encode(text, max_length) for text in text_sets], dim=0)
163
 
164
  def _forward(self, extracted_infos: List, max_length=512) -> List:
 
 
 
 
 
 
165
  model_outputs = []
 
166
  num_repos = len(extracted_infos)
167
  with tqdm(total=num_repos) as progress_bar:
168
  # For each repository
@@ -194,14 +304,26 @@ class RepoPipeline(Pipeline):
194
  info["requirement_embeddings"] = requirement_embeddings.cpu().numpy()
195
  info["mean_requirement_embedding"] = torch.mean(requirement_embeddings, dim=0).cpu().numpy()
196
 
197
- # Requirement embeddings
198
  tqdm.write(f"[*] Generating readme embeddings for {repo_name}")
199
  readme_embeddings = self.generate_embeddings(repo_info["readmes"], max_length)
200
  info["readme_embeddings"] = readme_embeddings.cpu().numpy()
201
  info["mean_readme_embedding"] = torch.mean(readme_embeddings, dim=0).cpu().numpy()
202
 
 
 
 
 
 
 
 
 
 
203
  info["code_embeddings_shape"] = info["code_embeddings"].shape
204
  info["doc_embeddings_shape"] = info["doc_embeddings"].shape
 
 
 
205
 
206
  progress_bar.update(1)
207
  model_outputs.append(info)
@@ -209,6 +331,10 @@ class RepoPipeline(Pipeline):
209
  return model_outputs
210
 
211
  def postprocess(self, model_outputs: List, **postprocess_parameters: Dict) -> List:
 
 
 
 
 
 
212
  return model_outputs
213
-
214
-
 
2
 
3
  import ast
4
  import tarfile
 
5
  import torch
6
  import requests
7
+ import numpy as np
8
+ from ast import AsyncFunctionDef, ClassDef, FunctionDef, Module
9
  from transformers import Pipeline
10
  from tqdm.auto import tqdm
11
 
12
 
13
  def extract_code_and_docs(text: str):
14
+ """
15
+ The method for extracting codes and docs in text.
16
+ :param text: python file.
17
+ :return: codes and docs set.
18
+ """
19
  code_set = set()
20
  docs_set = set()
21
  root = ast.parse(text)
 
34
  return code_set, docs_set
35
 
36
 
37
+ def extract_requirements(lines):
38
+ """
39
+ The method for extracting requirements.
40
+ :param lines: requirements.
41
+ :return: requirement libraries.
42
+ """
43
+ requirements_set = set()
44
+ for line in lines:
45
+ try:
46
+ if line != "\n":
47
+ if " == " in line:
48
+ splitLine = line.split(" == ")
49
+ else:
50
+ splitLine = line.split("==")
51
+ requirements_set.add(splitLine[0])
52
+ except:
53
+ pass
54
+ return requirements_set
55
+
56
+
57
  def get_metadata(repo_name, headers=None):
58
+ """
59
+ The method for getting metadata of repository from github_api.
60
+ :param repo_name: repository name.
61
+ :param headers: request headers.
62
+ :return: response json.
63
+ """
64
  api_url = f"https://api.github.com/repos/{repo_name}"
65
  tqdm.write(f"[+] Getting metadata for {repo_name}")
66
  try:
 
73
 
74
 
75
  def extract_information(repos, headers=None):
76
+ """
77
+ The method for extracting repositories information.
78
+ :param repos: repositories.
79
+ :param headers: request header.
80
+ :return: a list for representing the information of each repository.
81
+ """
82
  extracted_infos = []
83
  for repo_name in tqdm(repos, disable=len(repos) <= 1):
84
+ # 1. Extracting metadata.
85
  metadata = get_metadata(repo_name, headers=headers)
86
  repo_info = {
87
  "name": repo_name,
 
98
  if metadata.get("license"):
99
  repo_info["license"] = metadata["license"]["spdx_id"]
100
 
101
+ # Download repo tarball bytes ---- Download repository.
102
  download_url = f"https://api.github.com/repos/{repo_name}/tarball"
103
  tqdm.write(f"[+] Downloading {repo_name}")
104
  try:
 
108
  tqdm.write(f"[-] Failed to download {repo_name}: {e}")
109
  continue
110
 
111
+ # Extract repository files and parse them
112
  tqdm.write(f"[+] Extracting {repo_name} info")
113
  with tarfile.open(fileobj=response.raw, mode="r|gz") as tar:
114
  for member in tar:
115
+ # 2. Extracting codes and docs.
116
+ if member.name.endswith(".py") and member.isfile():
117
+ try:
118
+ file_content = tar.extractfile(member).read().decode("utf-8")
119
+ # extract_code_and_docs
120
+ code_set, docs_set = extract_code_and_docs(file_content)
121
+ repo_info["codes"].update(code_set)
122
+ repo_info["docs"].update(docs_set)
123
+ except UnicodeDecodeError as e:
124
+ tqdm.write(
125
+ f"[-] UnicodeDecodeError in {member.name}, skipping: \n{e}"
126
+ )
127
+ except SyntaxError as e:
128
+ tqdm.write(f"[-] SyntaxError in {member.name}, skipping: \n{e}")
129
+ elif (member.name.endswith("README.md") or member.name.endswith("README.rst")) and member.isfile():
130
+ # 3. Extracting readme.
131
+ try:
132
+ file_content = tar.extractfile(member).read().decode("utf-8")
133
+ # extract readme
134
+ readmes_set = set()
135
+ readmes_set.add(file_content)
136
+ repo_info["readmes"].update(readmes_set)
137
+ except UnicodeDecodeError as e:
138
+ tqdm.write(
139
+ f"[-] UnicodeDecodeError in {member.name}, skipping: \n{e}"
140
+ )
141
+ except SyntaxError as e:
142
+ tqdm.write(f"[-] SyntaxError in {member.name}, skipping: \n{e}")
143
+ elif member.name.endswith("requirements.txt") and member.isfile():
144
+ # 4. Extracting requirements.
145
+ try:
146
+ lines = tar.extractfile(member).readlines().decode("utf-8")
147
+ # extract readme
148
+ requirements_set = extract_requirements(lines)
149
+ repo_info["requirements"].update(requirements_set)
150
+ except UnicodeDecodeError as e:
151
+ tqdm.write(
152
+ f"[-] UnicodeDecodeError in {member.name}, skipping: \n{e}"
153
+ )
154
+ except SyntaxError as e:
155
+ tqdm.write(f"[-] SyntaxError in {member.name}, skipping: \n{e}")
156
 
157
  extracted_infos.append(repo_info)
158
 
 
160
 
161
 
162
  class RepoPipeline(Pipeline):
163
+ """
164
+ A custom pipeline for generating series of embeddings of a repository.
165
+ """
166
 
167
  def __init__(self, github_token=None, *args, **kwargs):
168
+ """
169
+ The initial method for pipeline.
170
+ :param github_token: github_token
171
+ :param args: args
172
+ :param kwargs: kwargs
173
+ """
174
  super().__init__(*args, **kwargs)
175
 
176
+ # Getting github token
177
  self.github_token = github_token
178
  if self.github_token:
179
  print("[+] GitHub token set!")
 
185
  )
186
 
187
  def _sanitize_parameters(self, **pipeline_parameters):
188
+ """
189
+ The method for splitting parameters.
190
+ :param pipeline_parameters: parameters
191
+ :return: different parameters of different periods.
192
+ """
193
+ # The parameters of "preprocess" period.
194
  preprocess_parameters = {}
195
  if "github_token" in pipeline_parameters:
196
  preprocess_parameters["github_token"] = pipeline_parameters["github_token"]
197
 
198
+ # The parameters of "forward" period.
199
  forward_parameters = {}
200
  if "max_length" in pipeline_parameters:
201
  forward_parameters["max_length"] = pipeline_parameters["max_length"]
202
 
203
+ # The parameters of "postprocess" period.
204
  postprocess_parameters = {}
205
  return preprocess_parameters, forward_parameters, postprocess_parameters
206
 
207
  def preprocess(self, input_: Any, github_token=None) -> List:
208
+ """
209
+ The method for "preprocess" period.
210
+ :param input_: the input.
211
+ :param github_token: github_token.
212
+ :return: a list about repository information.
213
+ """
214
+ # Making input to list format.
215
  if isinstance(input_, str):
216
  input_ = [input_]
217
 
218
+ # Building headers.
219
  headers = {"Accept": "application/vnd.github+json"}
220
  token = github_token or self.github_token
221
  if token:
222
  headers["Authorization"] = f"Bearer {token}"
223
 
224
+ # Getting repositories' information: input_ means series of repositories (can be only one repository).
225
  extracted_infos = extract_information(input_, headers=headers)
 
226
  return extracted_infos
227
 
228
  def encode(self, text, max_length):
229
+ """
230
+ The method for encoding the text to embedding by using UniXcoder.
231
+ :param text: text.
232
+ :param max_length: the max length.
233
+ :return: the embedding of text.
234
+ """
235
  assert max_length < 1024
236
 
237
+ # Getting the tokenizer.
238
  tokenizer = self.tokenizer
239
  tokens = (
240
  [tokenizer.cls_token, "<encoder-only>", tokenizer.sep_token]
 
243
  )
244
  tokens_id = tokenizer.convert_tokens_to_ids(tokens)
245
  source_ids = torch.tensor([tokens_id]).to(self.device)
 
246
  token_embeddings = self.model(source_ids)[0]
247
+
248
+ # Getting the text embedding.
249
  sentence_embeddings = token_embeddings.mean(dim=1)
250
 
251
  return sentence_embeddings
252
 
253
  def generate_embeddings(self, text_sets, max_length):
254
+ """
255
+ The method for generating embeddings of a text set.
256
+ :param text_sets: text set.
257
+ :param max_length: max length.
258
+ :return: the embeddings of text set.
259
+ """
260
  assert max_length < 1024
261
+
262
+ # Concat the embeddings of each sentence/text in vertical dimension.
263
  return torch.zeros((1, 768), device=self.device) \
264
+ if not text_sets \
265
  else torch.cat([self.encode(text, max_length) for text in text_sets], dim=0)
266
 
267
  def _forward(self, extracted_infos: List, max_length=512) -> List:
268
+ """
269
+ The method for "forward" period.
270
+ :param extracted_infos: the information of repositories.
271
+ :param max_length: max length.
272
+ :return: the output of this pipeline.
273
+ """
274
  model_outputs = []
275
+ # The number of repository.
276
  num_repos = len(extracted_infos)
277
  with tqdm(total=num_repos) as progress_bar:
278
  # For each repository
 
304
  info["requirement_embeddings"] = requirement_embeddings.cpu().numpy()
305
  info["mean_requirement_embedding"] = torch.mean(requirement_embeddings, dim=0).cpu().numpy()
306
 
307
+ # Readme embeddings
308
  tqdm.write(f"[*] Generating readme embeddings for {repo_name}")
309
  readme_embeddings = self.generate_embeddings(repo_info["readmes"], max_length)
310
  info["readme_embeddings"] = readme_embeddings.cpu().numpy()
311
  info["mean_readme_embedding"] = torch.mean(readme_embeddings, dim=0).cpu().numpy()
312
 
313
+ # Repo-level mean embedding
314
+ info["mean_repo_embedding"] = np.concatenate([
315
+ info["mean_code_embedding"],
316
+ info["mean_doc_embedding"],
317
+ info["mean_requirement_embedding"],
318
+ info["mean_readme_embedding"]
319
+ ], axis=0)
320
+
321
+ # TODO Remove test
322
  info["code_embeddings_shape"] = info["code_embeddings"].shape
323
  info["doc_embeddings_shape"] = info["doc_embeddings"].shape
324
+ info["requirement_embeddings_shape"] = info["requirement_embeddings"].shape
325
+ info["readme_embeddings_shape"] = info["readme_embeddings"].shape
326
+ info["mean_repo_embedding_shape"] = info["mean_repo_embedding"].shape
327
 
328
  progress_bar.update(1)
329
  model_outputs.append(info)
 
331
  return model_outputs
332
 
333
  def postprocess(self, model_outputs: List, **postprocess_parameters: Dict) -> List:
334
+ """
335
+ The method for "postprocess" period.
336
+ :param model_outputs: the output of this pipeline.
337
+ :param postprocess_parameters: the parameters of "postprocess" period.
338
+ :return: model output.
339
+ """
340
  return model_outputs