| import ast | |
| import importlib | |
| import os | |
| from typing import Optional, Sequence | |
| class DeleteSpecificNodes(ast.NodeTransformer): | |
| def __init__(self, nodes_to_remove: list[ast.AST]): | |
| self.nodes_to_remove = nodes_to_remove | |
| def visit(self, node: ast.AST) -> Optional[ast.AST]: | |
| if node in self.nodes_to_remove: | |
| return None | |
| return super().visit(node) | |
| def convert_to_relative_import(module_name: str, original_parent_module_name: Optional[str]) -> str: | |
| parts = module_name.split('.') | |
| if parts[-1] == original_parent_module_name: | |
| return '.' | |
| return '.' + parts[-1] | |
| def find_module_file(module_name: str) -> str: | |
| if not module_name: | |
| raise ValueError(f'Invalid input: module_name={module_name!r}') | |
| module = importlib.import_module(module_name) | |
| module_file = module.__file__ | |
| if module_file is None: | |
| raise ValueError(f'Could not find file for module: {module_name}') | |
| return module_file | |
| def _flatten_import(node: ast.ImportFrom, flatten_imports_prefix: Sequence[str]) -> bool: | |
| """Returns True if import should be flattened. | |
| Checks whether the node starts the same as any of the imports in | |
| flatten_imports_prefix. | |
| """ | |
| for import_prefix in flatten_imports_prefix: | |
| if node.module is not None and node.module.startswith(import_prefix): | |
| return True | |
| return False | |
| def _remove_import(node: ast.ImportFrom, remove_imports_prefix: Sequence[str]) -> bool: | |
| """Returns True if import should be removed. | |
| Checks whether the node starts the same as any of the imports in | |
| remove_imports_prefix. | |
| """ | |
| for import_prefix in remove_imports_prefix: | |
| if node.module is not None and node.module.startswith(import_prefix): | |
| return True | |
| return False | |
| def process_file(file_path: str, folder_path: str, flatten_imports_prefix: Sequence[str], remove_imports_prefix: Sequence[str]) -> list[str]: | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| source = f.read() | |
| parent_module_name = None | |
| if os.path.basename(file_path) == '__init__.py': | |
| parent_module_name = os.path.basename(os.path.dirname(file_path)) | |
| tree = ast.parse(source) | |
| new_files_to_process = [] | |
| nodes_to_remove = [] | |
| for node in ast.walk(tree): | |
| if isinstance(node, ast.ImportFrom) and node.module is not None and _remove_import(node, remove_imports_prefix): | |
| nodes_to_remove.append(node) | |
| elif isinstance(node, ast.ImportFrom) and node.module is not None and _flatten_import(node, flatten_imports_prefix): | |
| module_path = find_module_file(node.module) | |
| node.module = convert_to_relative_import(node.module, parent_module_name) | |
| new_files_to_process.append(module_path) | |
| elif isinstance(node, ast.ClassDef) and node.name.startswith('Composer'): | |
| nodes_to_remove.append(node) | |
| elif isinstance(node, ast.Assign) and len(node.targets) == 1 and isinstance(node.targets[0], ast.Name) and (node.targets[0].id == '__all__'): | |
| nodes_to_remove.append(node) | |
| transformer = DeleteSpecificNodes(nodes_to_remove) | |
| new_tree = transformer.visit(tree) | |
| new_filename = os.path.basename(file_path) | |
| if new_filename == '__init__.py': | |
| new_filename = file_path.split('/')[-2] + '.py' | |
| new_file_path = os.path.join(folder_path, new_filename) | |
| with open(new_file_path, 'w', encoding='utf-8') as f: | |
| assert new_tree is not None | |
| f.write(ast.unparse(new_tree)) | |
| return new_files_to_process | |
| def edit_files_for_hf_compatibility(folder: str, flatten_imports_prefix: Sequence[str]=('llmfoundry',), remove_imports_prefix: Sequence[str]=('composer', 'omegaconf', 'llmfoundry.metrics')) -> None: | |
| """Edit files to be compatible with Hugging Face Hub. | |
| Args: | |
| folder (str): The folder to process. | |
| flatten_imports_prefix (Sequence[str], optional): Sequence of prefixes to flatten. Defaults to ('llmfoundry',). | |
| remove_imports_prefix (Sequence[str], optional): Sequence of prefixes to remove. Takes precedence over flattening. | |
| Defaults to ('composer', 'omegaconf', 'llmfoundry.metrics'). | |
| """ | |
| files_to_process = [os.path.join(folder, filename) for filename in os.listdir(folder) if filename.endswith('.py')] | |
| files_processed_and_queued = set(files_to_process) | |
| while len(files_to_process) > 0: | |
| to_process = files_to_process.pop() | |
| if os.path.isfile(to_process) and to_process.endswith('.py'): | |
| to_add = process_file(to_process, folder, flatten_imports_prefix, remove_imports_prefix) | |
| for file in to_add: | |
| if file not in files_processed_and_queued: | |
| files_to_process.append(file) | |
| files_processed_and_queued.add(file) | 
