Henry65 commited on
Commit
ac219c3
·
1 Parent(s): e42969a

Update RepoPipeline.py

Browse files
Files changed (1) hide show
  1. RepoPipeline.py +29 -155
RepoPipeline.py CHANGED
@@ -2,20 +2,14 @@ from typing import Dict, Any, List
2
 
3
  import ast
4
  import tarfile
 
5
  import torch
6
  import requests
7
- import numpy as np
8
- from ast import AsyncFunctionDef, ClassDef, FunctionDef, Module
9
  from transformers import Pipeline
10
  from tqdm.auto import tqdm
11
 
12
 
13
  def extract_code_and_docs(text: str):
14
- """
15
- The method for extracting codes and docs in text.
16
- :param text: python file.
17
- :return: codes and docs set.
18
- """
19
  code_set = set()
20
  docs_set = set()
21
  root = ast.parse(text)
@@ -34,33 +28,7 @@ def extract_code_and_docs(text: str):
34
  return code_set, docs_set
35
 
36
 
37
- def extract_requirements(lines):
38
- """
39
- The method for extracting requirements.
40
- :param lines: requirements.
41
- :return: requirement libraries.
42
- """
43
- requirements_set = set()
44
- for line in lines:
45
- try:
46
- if line != "\n":
47
- if " == " in line:
48
- splitLine = line.split(" == ")
49
- else:
50
- splitLine = line.split("==")
51
- requirements_set.add(splitLine[0])
52
- except:
53
- pass
54
- return requirements_set
55
-
56
-
57
  def get_metadata(repo_name, headers=None):
58
- """
59
- The method for getting metadata of repository from github_api.
60
- :param repo_name: repository name.
61
- :param headers: request headers.
62
- :return: response json.
63
- """
64
  api_url = f"https://api.github.com/repos/{repo_name}"
65
  tqdm.write(f"[+] Getting metadata for {repo_name}")
66
  try:
@@ -73,15 +41,9 @@ def get_metadata(repo_name, headers=None):
73
 
74
 
75
  def extract_information(repos, headers=None):
76
- """
77
- The method for extracting repositories information.
78
- :param repos: repositories.
79
- :param headers: request header.
80
- :return: a list for representing the information of each repository.
81
- """
82
  extracted_infos = []
83
  for repo_name in tqdm(repos, disable=len(repos) <= 1):
84
- # 1. Extracting metadata.
85
  metadata = get_metadata(repo_name, headers=headers)
86
  repo_info = {
87
  "name": repo_name,
@@ -98,7 +60,7 @@ def extract_information(repos, headers=None):
98
  if metadata.get("license"):
99
  repo_info["license"] = metadata["license"]["spdx_id"]
100
 
101
- # Download repo tarball bytes ---- Download repository.
102
  download_url = f"https://api.github.com/repos/{repo_name}/tarball"
103
  tqdm.write(f"[+] Downloading {repo_name}")
104
  try:
@@ -108,51 +70,24 @@ def extract_information(repos, headers=None):
108
  tqdm.write(f"[-] Failed to download {repo_name}: {e}")
109
  continue
110
 
111
- # Extract repository files and parse them
112
  tqdm.write(f"[+] Extracting {repo_name} info")
113
  with tarfile.open(fileobj=response.raw, mode="r|gz") as tar:
114
  for member in tar:
115
- # 2. Extracting codes and docs.
116
- if member.name.endswith(".py") and member.isfile():
117
- try:
118
- file_content = tar.extractfile(member).read().decode("utf-8")
119
- # extract_code_and_docs
120
- code_set, docs_set = extract_code_and_docs(file_content)
121
- repo_info["codes"].update(code_set)
122
- repo_info["docs"].update(docs_set)
123
- except UnicodeDecodeError as e:
124
- tqdm.write(
125
- f"[-] UnicodeDecodeError in {member.name}, skipping: \n{e}"
126
- )
127
- except SyntaxError as e:
128
- tqdm.write(f"[-] SyntaxError in {member.name}, skipping: \n{e}")
129
- # 3. Extracting readme.
130
- elif (member.name == "README.md" or member.name == "README.rst") and member.isfile():
131
- try:
132
- file_content = tar.extractfile(member).read().decode("utf-8")
133
- # extract readme
134
- readmes_set = set()
135
- readmes_set.add(file_content)
136
- repo_info["readmes"].update(readmes_set)
137
- except UnicodeDecodeError as e:
138
- tqdm.write(
139
- f"[-] UnicodeDecodeError in {member.name}, skipping: \n{e}"
140
- )
141
- except SyntaxError as e:
142
- tqdm.write(f"[-] SyntaxError in {member.name}, skipping: \n{e}")
143
- # 4. Extracting requirements.
144
- elif member.name == "requirements.txt" and member.isfile():
145
- try:
146
- lines = tar.extractfile(member).readlines().decode("utf-8")
147
- # extract readme
148
- requirements_set = extract_requirements(lines)
149
- repo_info["requirements"].update(requirements_set)
150
- except UnicodeDecodeError as e:
151
- tqdm.write(
152
- f"[-] UnicodeDecodeError in {member.name}, skipping: \n{e}"
153
- )
154
- except SyntaxError as e:
155
- tqdm.write(f"[-] SyntaxError in {member.name}, skipping: \n{e}")
156
 
157
  extracted_infos.append(repo_info)
158
 
@@ -160,20 +95,11 @@ def extract_information(repos, headers=None):
160
 
161
 
162
  class RepoPipeline(Pipeline):
163
- """
164
- A custom pipeline for generating series of embeddings of a repository.
165
- """
166
 
167
  def __init__(self, github_token=None, *args, **kwargs):
168
- """
169
- The initial method for pipeline.
170
- :param github_token: github_token
171
- :param args: args
172
- :param kwargs: kwargs
173
- """
174
  super().__init__(*args, **kwargs)
175
 
176
- # Getting github token
177
  self.github_token = github_token
178
  if self.github_token:
179
  print("[+] GitHub token set!")
@@ -185,56 +111,36 @@ class RepoPipeline(Pipeline):
185
  )
186
 
187
  def _sanitize_parameters(self, **pipeline_parameters):
188
- """
189
- The method for splitting parameters.
190
- :param pipeline_parameters: parameters
191
- :return: different parameters of different periods.
192
- """
193
- # The parameters of "preprocess" period.
194
  preprocess_parameters = {}
195
  if "github_token" in pipeline_parameters:
196
  preprocess_parameters["github_token"] = pipeline_parameters["github_token"]
197
 
198
- # The parameters of "forward" period.
199
  forward_parameters = {}
200
  if "max_length" in pipeline_parameters:
201
  forward_parameters["max_length"] = pipeline_parameters["max_length"]
202
 
203
- # The parameters of "postprocess" period.
204
  postprocess_parameters = {}
205
  return preprocess_parameters, forward_parameters, postprocess_parameters
206
 
207
  def preprocess(self, input_: Any, github_token=None) -> List:
208
- """
209
- The method for "preprocess" period.
210
- :param input_: the input.
211
- :param github_token: github_token.
212
- :return: a list about repository information.
213
- """
214
- # Making input to list format.
215
  if isinstance(input_, str):
216
  input_ = [input_]
217
 
218
- # Building headers.
219
  headers = {"Accept": "application/vnd.github+json"}
220
  token = github_token or self.github_token
221
  if token:
222
  headers["Authorization"] = f"Bearer {token}"
223
 
224
- # Getting repositories' information: input_ means series of repositories (can be only one repository).
225
  extracted_infos = extract_information(input_, headers=headers)
 
226
  return extracted_infos
227
 
228
  def encode(self, text, max_length):
229
- """
230
- The method for encoding the text to embedding by using UniXcoder.
231
- :param text: text.
232
- :param max_length: the max length.
233
- :return: the embedding of text.
234
- """
235
  assert max_length < 1024
236
 
237
- # Getting the tokenizer.
238
  tokenizer = self.tokenizer
239
  tokens = (
240
  [tokenizer.cls_token, "<encoder-only>", tokenizer.sep_token]
@@ -243,36 +149,20 @@ class RepoPipeline(Pipeline):
243
  )
244
  tokens_id = tokenizer.convert_tokens_to_ids(tokens)
245
  source_ids = torch.tensor([tokens_id]).to(self.device)
246
- token_embeddings = self.model(source_ids)[0]
247
 
248
- # Getting the text embedding.
249
  sentence_embeddings = token_embeddings.mean(dim=1)
250
 
251
  return sentence_embeddings
252
 
253
  def generate_embeddings(self, text_sets, max_length):
254
- """
255
- The method for generating embeddings of a text set.
256
- :param text_sets: text set.
257
- :param max_length: max length.
258
- :return: the embeddings of text set.
259
- """
260
  assert max_length < 1024
261
-
262
- # Concat the embeddings of each sentence/text in vertical dimension.
263
  return torch.zeros((1, 768), device=self.device) \
264
- if not text_sets \
265
  else torch.cat([self.encode(text, max_length) for text in text_sets], dim=0)
266
 
267
  def _forward(self, extracted_infos: List, max_length=512) -> List:
268
- """
269
- The method for "forward" period.
270
- :param extracted_infos: the information of repositories.
271
- :param max_length: max length.
272
- :return: the output of this pipeline.
273
- """
274
  model_outputs = []
275
- # The number of repository.
276
  num_repos = len(extracted_infos)
277
  with tqdm(total=num_repos) as progress_bar:
278
  # For each repository
@@ -304,26 +194,14 @@ class RepoPipeline(Pipeline):
304
  info["requirement_embeddings"] = requirement_embeddings.cpu().numpy()
305
  info["mean_requirement_embedding"] = torch.mean(requirement_embeddings, dim=0).cpu().numpy()
306
 
307
- # Readme embeddings
308
  tqdm.write(f"[*] Generating readme embeddings for {repo_name}")
309
  readme_embeddings = self.generate_embeddings(repo_info["readmes"], max_length)
310
  info["readme_embeddings"] = readme_embeddings.cpu().numpy()
311
  info["mean_readme_embedding"] = torch.mean(readme_embeddings, dim=0).cpu().numpy()
312
 
313
- # Repo-level mean embedding
314
- info["mean_repo_embedding"] = np.concatenate([
315
- info["mean_code_embedding"],
316
- info["mean_doc_embedding"],
317
- info["mean_requirement_embedding"],
318
- info["mean_readme_embedding"]
319
- ], axis=0)
320
-
321
- # TODO Remove test
322
  info["code_embeddings_shape"] = info["code_embeddings"].shape
323
- info["doc_embeddings_shape"] = info["doc_embeddings"].shape
324
- info["requirement_embeddings_shape"] = info["requirement_embeddings"].shape
325
- info["readme_embeddings_shape"] = info["readme_embeddings"].shape
326
- info["mean_repo_embedding_shape"] = info["mean_repo_embedding"].shape
327
 
328
  progress_bar.update(1)
329
  model_outputs.append(info)
@@ -331,10 +209,6 @@ class RepoPipeline(Pipeline):
331
  return model_outputs
332
 
333
  def postprocess(self, model_outputs: List, **postprocess_parameters: Dict) -> List:
334
- """
335
- The method for "postprocess" period.
336
- :param model_outputs: the output of this pipeline.
337
- :param postprocess_parameters: the parameters of "postprocess" period.
338
- :return: model output.
339
- """
340
  return model_outputs
 
 
 
2
 
3
  import ast
4
  import tarfile
5
+ from ast import AsyncFunctionDef, ClassDef, FunctionDef, Module
6
  import torch
7
  import requests
 
 
8
  from transformers import Pipeline
9
  from tqdm.auto import tqdm
10
 
11
 
12
  def extract_code_and_docs(text: str):
 
 
 
 
 
13
  code_set = set()
14
  docs_set = set()
15
  root = ast.parse(text)
 
28
  return code_set, docs_set
29
 
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  def get_metadata(repo_name, headers=None):
 
 
 
 
 
 
32
  api_url = f"https://api.github.com/repos/{repo_name}"
33
  tqdm.write(f"[+] Getting metadata for {repo_name}")
34
  try:
 
41
 
42
 
43
  def extract_information(repos, headers=None):
 
 
 
 
 
 
44
  extracted_infos = []
45
  for repo_name in tqdm(repos, disable=len(repos) <= 1):
46
+ # Get metadata
47
  metadata = get_metadata(repo_name, headers=headers)
48
  repo_info = {
49
  "name": repo_name,
 
60
  if metadata.get("license"):
61
  repo_info["license"] = metadata["license"]["spdx_id"]
62
 
63
+ # Download repo tarball bytes
64
  download_url = f"https://api.github.com/repos/{repo_name}/tarball"
65
  tqdm.write(f"[+] Downloading {repo_name}")
66
  try:
 
70
  tqdm.write(f"[-] Failed to download {repo_name}: {e}")
71
  continue
72
 
73
+ # Extract python files and parse them
74
  tqdm.write(f"[+] Extracting {repo_name} info")
75
  with tarfile.open(fileobj=response.raw, mode="r|gz") as tar:
76
  for member in tar:
77
+ if (member.name.endswith(".py") and member.isfile()) is False:
78
+ continue
79
+ try:
80
+ file_content = tar.extractfile(member).read().decode("utf-8")
81
+ code_set, docs_set = extract_code_and_docs(file_content)
82
+
83
+ repo_info["codes"].update(code_set)
84
+ repo_info["docs"].update(docs_set)
85
+ except UnicodeDecodeError as e:
86
+ tqdm.write(
87
+ f"[-] UnicodeDecodeError in {member.name}, skipping: \n{e}"
88
+ )
89
+ except SyntaxError as e:
90
+ tqdm.write(f"[-] SyntaxError in {member.name}, skipping: \n{e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
  extracted_infos.append(repo_info)
93
 
 
95
 
96
 
97
  class RepoPipeline(Pipeline):
 
 
 
98
 
99
  def __init__(self, github_token=None, *args, **kwargs):
 
 
 
 
 
 
100
  super().__init__(*args, **kwargs)
101
 
102
+ # Github token
103
  self.github_token = github_token
104
  if self.github_token:
105
  print("[+] GitHub token set!")
 
111
  )
112
 
113
  def _sanitize_parameters(self, **pipeline_parameters):
 
 
 
 
 
 
114
  preprocess_parameters = {}
115
  if "github_token" in pipeline_parameters:
116
  preprocess_parameters["github_token"] = pipeline_parameters["github_token"]
117
 
 
118
  forward_parameters = {}
119
  if "max_length" in pipeline_parameters:
120
  forward_parameters["max_length"] = pipeline_parameters["max_length"]
121
 
 
122
  postprocess_parameters = {}
123
  return preprocess_parameters, forward_parameters, postprocess_parameters
124
 
125
  def preprocess(self, input_: Any, github_token=None) -> List:
126
+ # Making input to list format
 
 
 
 
 
 
127
  if isinstance(input_, str):
128
  input_ = [input_]
129
 
130
+ # Building token
131
  headers = {"Accept": "application/vnd.github+json"}
132
  token = github_token or self.github_token
133
  if token:
134
  headers["Authorization"] = f"Bearer {token}"
135
 
136
+ # Getting repositories' information: input_ means series of repositories
137
  extracted_infos = extract_information(input_, headers=headers)
138
+
139
  return extracted_infos
140
 
141
  def encode(self, text, max_length):
 
 
 
 
 
 
142
  assert max_length < 1024
143
 
 
144
  tokenizer = self.tokenizer
145
  tokens = (
146
  [tokenizer.cls_token, "<encoder-only>", tokenizer.sep_token]
 
149
  )
150
  tokens_id = tokenizer.convert_tokens_to_ids(tokens)
151
  source_ids = torch.tensor([tokens_id]).to(self.device)
 
152
 
153
+ token_embeddings = self.model(source_ids)[0]
154
  sentence_embeddings = token_embeddings.mean(dim=1)
155
 
156
  return sentence_embeddings
157
 
158
  def generate_embeddings(self, text_sets, max_length):
 
 
 
 
 
 
159
  assert max_length < 1024
 
 
160
  return torch.zeros((1, 768), device=self.device) \
161
+ if text_sets is None or len(text_sets) == 0 \
162
  else torch.cat([self.encode(text, max_length) for text in text_sets], dim=0)
163
 
164
  def _forward(self, extracted_infos: List, max_length=512) -> List:
 
 
 
 
 
 
165
  model_outputs = []
 
166
  num_repos = len(extracted_infos)
167
  with tqdm(total=num_repos) as progress_bar:
168
  # For each repository
 
194
  info["requirement_embeddings"] = requirement_embeddings.cpu().numpy()
195
  info["mean_requirement_embedding"] = torch.mean(requirement_embeddings, dim=0).cpu().numpy()
196
 
197
+ # Requirement embeddings
198
  tqdm.write(f"[*] Generating readme embeddings for {repo_name}")
199
  readme_embeddings = self.generate_embeddings(repo_info["readmes"], max_length)
200
  info["readme_embeddings"] = readme_embeddings.cpu().numpy()
201
  info["mean_readme_embedding"] = torch.mean(readme_embeddings, dim=0).cpu().numpy()
202
 
 
 
 
 
 
 
 
 
 
203
  info["code_embeddings_shape"] = info["code_embeddings"].shape
204
+ info["doc_embeddings_shape"] = info["doc_embeddings_shape"].shape
 
 
 
205
 
206
  progress_bar.update(1)
207
  model_outputs.append(info)
 
209
  return model_outputs
210
 
211
  def postprocess(self, model_outputs: List, **postprocess_parameters: Dict) -> List:
 
 
 
 
 
 
212
  return model_outputs
213
+
214
+