import subprocess import sys import time import argparse import os import datetime def main(): parser = argparse.ArgumentParser(description="Download checkpoint pair from remote runpod machine with unique filenames.") parser.add_argument("--remote-ip", required=True, help="Remote machine IP address") parser.add_argument("--remote-port", required=True, type=int, help="Remote SSH port") parser.add_argument("--remote-user", required=True, help="Username for remote SSH") parser.add_argument("--remote-base-path", default="/workspace", help="Directory on remote machine containing checkpoints and lock file") parser.add_argument("--local-dest", required=True, help="Local directory where checkpoints should be saved") parser.add_argument("--rsa-key", required=True, help="Path to your RSA private key for authentication") parser.add_argument("--poll-interval", type=float, default=10, help="Polling interval in seconds") args = parser.parse_args() # Construct remote file paths. remote_checkpoint_a = f"{args.remote_base_path}/grads_a.safetensors" remote_checkpoint_b = f"{args.remote_base_path}/grads_b.safetensors" remote_inv_log_scalars = f"{args.remote_base_path}/log_scalars.safetensors" remote_thresholds = f"{args.remote_base_path}/thresholds.safetensors" remote_lock_file = f"{args.remote_base_path}/safetensors.lock" print("Starting remote checkpoint monitor...") while True: # Check if the lock file exists on the remote machine. if remote_file_exists(args.remote_user, args.remote_ip, args.remote_port, remote_lock_file, args.rsa_key): print("New checkpoints detected. Downloading...") # Generate unique filenames for each model. local_checkpoint_a = get_unique_filename(args.local_dest, "grads_a") local_checkpoint_b = get_unique_filename(args.local_dest, "grads_b") local_inv_log_scalars = get_unique_filename(args.local_dest, "log_scalars") local_thresholds = get_unique_filename(args.local_dest, "thresholds") try: # Download both checkpoints with the unique filenames. download_file(args.remote_user, args.remote_ip, args.remote_port, remote_checkpoint_a, local_checkpoint_a, args.rsa_key) download_file(args.remote_user, args.remote_ip, args.remote_port, remote_checkpoint_b, local_checkpoint_b, args.rsa_key) download_file(args.remote_user, args.remote_ip, args.remote_port, remote_inv_log_scalars, local_inv_log_scalars, args.rsa_key) # download_file(args.remote_user, args.remote_ip, args.remote_port, remote_thresholds, local_thresholds, args.rsa_key) except subprocess.CalledProcessError as e: print(f"Download error: {e}") time.sleep(args.poll_interval) continue # After successful download, delete only the lock file on the remote side. try: while not delete_remote_lock(args.remote_user, args.remote_ip, args.remote_port, remote_lock_file, args.rsa_key): continue print("Download complete. Checkpoints saved as:") print(f" {local_checkpoint_a}") print(f" {local_checkpoint_b}") print("Remote lock file deleted.") except subprocess.CalledProcessError as e: print(f"Error deleting remote lock file: {e}") else: print("No checkpoints found.") time.sleep(args.poll_interval) def remote_file_exists(remote_user, remote_host, remote_port, remote_path, rsa_key, timeout=10): """Check if a file exists on the remote machine.""" cmd = [ "ssh", "-i", rsa_key, "-p", str(remote_port), "-o", "StrictHostKeyChecking=no", "-o", "UserKnownHostsFile=/dev/null", f"{remote_user}@{remote_host}", f"test -f {remote_path}" ] try: result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=timeout) if result.stdout: print("stdout", result.stdout.decode("utf-8"), end="") if result.stderr: print("stderr", result.stderr.decode("utf-8"), file=sys.stderr, end="") return result.returncode == 0 except subprocess.TimeoutExpired: print(f"TimeoutExpired: SSH command to check {remote_path} on {remote_host} timed out after {timeout} seconds.") return False def download_file(remote_user, remote_host, remote_port, remote_file, local_file, rsa_key, timeout=1200): """Download a file from the remote machine using scp and save it with a specific name.""" cmd = [ "scp", "-i", rsa_key, "-P", str(remote_port), "-o", "StrictHostKeyChecking=no", "-o", "UserKnownHostsFile=/dev/null", f"{remote_user}@{remote_host}:{remote_file}", str(local_file) ] try: result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=timeout) if result.stdout: print("stdout", result.stdout.decode("utf-8"), end="") if result.stderr: print("stderr", result.stderr.decode("utf-8"), file=sys.stderr, end="") return result.returncode == 0 except subprocess.TimeoutExpired: print(f"TimeoutExpired: SSH command to download {remote_file} on {remote_host} timed out after {timeout} seconds.") return False def delete_remote_lock(remote_user, remote_host, remote_port, remote_lock_file, rsa_key, timeout=10): """Delete the lock file on the remote machine.""" cmd = [ "ssh", "-i", rsa_key, "-p", str(remote_port), "-o", "StrictHostKeyChecking=no", "-o", "UserKnownHostsFile=/dev/null", f"{remote_user}@{remote_host}", f"rm -f {remote_lock_file}" ] try: result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=timeout) if result.stdout: print("stdout", result.stdout.decode("utf-8"), end="") if result.stderr: print("stderr", result.stderr.decode("utf-8"), file=sys.stderr, end="") return result.returncode == 0 except subprocess.TimeoutExpired: print(f"TimeoutExpired: SSH command to delete {remote_lock_file} on {remote_host} timed out after {timeout} seconds.") return False def get_unique_filename(local_dest, base_name): """Generate a unique filename with a timestamp and return the full path.""" timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S") filename = f"{base_name}_{timestamp}.safetensors" return os.path.join(local_dest, filename) if __name__ == "__main__": main()