Henry65 commited on
Commit
b0707cb
·
1 Parent(s): e34a465

Update RepoPipeline.py

Browse files
Files changed (1) hide show
  1. RepoPipeline.py +12 -8
RepoPipeline.py CHANGED
@@ -126,8 +126,8 @@ def extract_information(repos, headers=None):
126
  )
127
  except SyntaxError as e:
128
  tqdm.write(f"[-] SyntaxError in {member.name}, skipping: \n{e}")
 
129
  elif (member.name.endswith("README.md") or member.name.endswith("README.rst")) and member.isfile():
130
- # 3. Extracting readme.
131
  try:
132
  file_content = tar.extractfile(member).read().decode("utf-8")
133
  # extract readme
@@ -140,8 +140,8 @@ def extract_information(repos, headers=None):
140
  )
141
  except SyntaxError as e:
142
  tqdm.write(f"[-] SyntaxError in {member.name}, skipping: \n{e}")
 
143
  elif member.name.endswith("requirements.txt") and member.isfile():
144
- # 4. Extracting requirements.
145
  try:
146
  lines = tar.extractfile(member).readlines().decode("utf-8")
147
  # extract readme
@@ -290,25 +290,26 @@ class RepoPipeline(Pipeline):
290
  tqdm.write(f"[*] Generating code embeddings for {repo_name}")
291
  code_embeddings = self.generate_embeddings(repo_info["codes"], max_length)
292
  info["code_embeddings"] = code_embeddings.cpu().numpy()
293
- info["mean_code_embedding"] = torch.mean(code_embeddings, dim=0).cpu().numpy()
294
 
295
  # Doc embeddings
296
  tqdm.write(f"[*] Generating doc embeddings for {repo_name}")
297
  doc_embeddings = self.generate_embeddings(repo_info["docs"], max_length)
298
  info["doc_embeddings"] = doc_embeddings.cpu().numpy()
299
- info["mean_doc_embedding"] = torch.mean(doc_embeddings, dim=0).cpu().numpy()
300
 
301
  # Requirement embeddings
302
  tqdm.write(f"[*] Generating requirement embeddings for {repo_name}")
303
  requirement_embeddings = self.generate_embeddings(repo_info["requirements"], max_length)
304
  info["requirement_embeddings"] = requirement_embeddings.cpu().numpy()
305
- info["mean_requirement_embedding"] = torch.mean(requirement_embeddings, dim=0).cpu().numpy()
 
306
 
307
  # Readme embeddings
308
  tqdm.write(f"[*] Generating readme embeddings for {repo_name}")
309
  readme_embeddings = self.generate_embeddings(repo_info["readmes"], max_length)
310
  info["readme_embeddings"] = readme_embeddings.cpu().numpy()
311
- info["mean_readme_embedding"] = torch.mean(readme_embeddings, dim=0).cpu().numpy()
312
 
313
  # Repo-level mean embedding
314
  info["mean_repo_embedding"] = np.concatenate([
@@ -316,13 +317,16 @@ class RepoPipeline(Pipeline):
316
  info["mean_doc_embedding"],
317
  info["mean_requirement_embedding"],
318
  info["mean_readme_embedding"]
319
- ], axis=0)
320
 
321
- # TODO Remove test
322
  info["code_embeddings_shape"] = info["code_embeddings"].shape
 
323
  info["doc_embeddings_shape"] = info["doc_embeddings"].shape
 
324
  info["requirement_embeddings_shape"] = info["requirement_embeddings"].shape
 
325
  info["readme_embeddings_shape"] = info["readme_embeddings"].shape
 
326
  info["mean_repo_embedding_shape"] = info["mean_repo_embedding"].shape
327
 
328
  progress_bar.update(1)
 
126
  )
127
  except SyntaxError as e:
128
  tqdm.write(f"[-] SyntaxError in {member.name}, skipping: \n{e}")
129
+ # 3. Extracting readme.
130
  elif (member.name.endswith("README.md") or member.name.endswith("README.rst")) and member.isfile():
 
131
  try:
132
  file_content = tar.extractfile(member).read().decode("utf-8")
133
  # extract readme
 
140
  )
141
  except SyntaxError as e:
142
  tqdm.write(f"[-] SyntaxError in {member.name}, skipping: \n{e}")
143
+ # 4. Extracting requirements.
144
  elif member.name.endswith("requirements.txt") and member.isfile():
 
145
  try:
146
  lines = tar.extractfile(member).readlines().decode("utf-8")
147
  # extract readme
 
290
  tqdm.write(f"[*] Generating code embeddings for {repo_name}")
291
  code_embeddings = self.generate_embeddings(repo_info["codes"], max_length)
292
  info["code_embeddings"] = code_embeddings.cpu().numpy()
293
+ info["mean_code_embedding"] = torch.mean(code_embeddings, dim=0, keepdim=True).cpu().numpy()
294
 
295
  # Doc embeddings
296
  tqdm.write(f"[*] Generating doc embeddings for {repo_name}")
297
  doc_embeddings = self.generate_embeddings(repo_info["docs"], max_length)
298
  info["doc_embeddings"] = doc_embeddings.cpu().numpy()
299
+ info["mean_doc_embedding"] = torch.mean(doc_embeddings, dim=0, keepdim=True).cpu().numpy()
300
 
301
  # Requirement embeddings
302
  tqdm.write(f"[*] Generating requirement embeddings for {repo_name}")
303
  requirement_embeddings = self.generate_embeddings(repo_info["requirements"], max_length)
304
  info["requirement_embeddings"] = requirement_embeddings.cpu().numpy()
305
+ info["mean_requirement_embedding"] = torch.mean(requirement_embeddings, dim=0,
306
+ keepdim=True).cpu().numpy()
307
 
308
  # Readme embeddings
309
  tqdm.write(f"[*] Generating readme embeddings for {repo_name}")
310
  readme_embeddings = self.generate_embeddings(repo_info["readmes"], max_length)
311
  info["readme_embeddings"] = readme_embeddings.cpu().numpy()
312
+ info["mean_readme_embedding"] = torch.mean(readme_embeddings, dim=0, keepdim=True).cpu().numpy()
313
 
314
  # Repo-level mean embedding
315
  info["mean_repo_embedding"] = np.concatenate([
 
317
  info["mean_doc_embedding"],
318
  info["mean_requirement_embedding"],
319
  info["mean_readme_embedding"]
320
+ ], axis=0).reshape(1, -1)
321
 
 
322
  info["code_embeddings_shape"] = info["code_embeddings"].shape
323
+ info["mean_code_embedding_shape"] = info["mean_code_embedding"].shape
324
  info["doc_embeddings_shape"] = info["doc_embeddings"].shape
325
+ info["mean_doc_embedding_shape"] = info["mean_doc_embedding"].shape
326
  info["requirement_embeddings_shape"] = info["requirement_embeddings"].shape
327
+ info["mean_requirement_embedding_shape"] = info["mean_requirement_embedding"].shape
328
  info["readme_embeddings_shape"] = info["readme_embeddings"].shape
329
+ info["mean_readme_embedding_shape"] = info["mean_readme_embedding"].shape
330
  info["mean_repo_embedding_shape"] = info["mean_repo_embedding"].shape
331
 
332
  progress_bar.update(1)