Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 | |
import numpy as np | |
import pickle | |
import re | |
import sys | |
from datetime import datetime | |
def log(msg, accelerator=None, all_threads=False): | |
if accelerator is not None and all_threads: | |
report_proc = f" ({accelerator.process_index+1}/{accelerator.num_processes})" | |
else: | |
report_proc = "" | |
if accelerator is None or accelerator.is_main_process or all_threads: | |
sys.stderr.write(str(datetime.now()) + report_proc + ": " + msg + '\n') | |
def _same_line_log(msg, len_to_del=0): | |
"""if sys.stderr.isatty(): | |
if len_to_del > 0: | |
sys.stderr.write("\b" * len_to_del) | |
new_len = len(msg) | |
sys.stderr.write(msg) | |
sys.stderr.flush() | |
return new_len | |
else:""" | |
log(msg) | |
def debug(msg): | |
pass | |
### log("\n(DEBUG) " + msg) | |
def maybe_convert(value): | |
try: | |
return int(value) | |
except (ValueError, TypeError): | |
try: | |
return float(value) | |
except (ValueError, TypeError): | |
return value | |
def get_changed_config(conf, args): | |
arg_dict = args.to_dict() | |
for kwarg in arg_dict: | |
if hasattr(conf, kwarg) and arg_dict[kwarg] is not None: | |
setattr(conf, kwarg, maybe_convert(arg_dict[kwarg])) | |
return conf | |
class SameLineLogger: | |
def __init__(self, epoch_len, epoch_num, data_state): | |
self.epoch_len = epoch_len | |
self.epoch_num = epoch_num | |
self.start_global_step = epoch_len * data_state.epoch_idx + data_state.elem_idx | |
self.totalx = epoch_len * epoch_num | |
self.log_after = [] | |
self.log_len = 0 | |
self.start_time = datetime.now() | |
def line_start(self): | |
_same_line_log(str(datetime.now()) + ": training batches ") | |
def step(self, global_batch_idx, epoch_batch_idx, epoch_idx, loss, lr, grad): | |
passed_time = datetime.now() - self.start_time | |
time_per_batch = passed_time / (global_batch_idx - self.start_global_step) | |
prediction = time_per_batch * (self.totalx - global_batch_idx) | |
msg = f"{epoch_batch_idx} / {self.epoch_len}, epoch {epoch_idx + 1} / {self.epoch_num}, loss={loss}, avg {time_per_batch}/iter, {prediction} to finish, LR={lr:.2e}, grad={grad:.2e} " | |
new_len = _same_line_log(msg, self.log_len) | |
self.log_len = new_len | |
def line_break(self): | |
sys.stderr.write("\n") | |
class CmdlineArgs: | |
def __init__(self, | |
description, | |
pos_arg_list=None, | |
pos_arg_types=None, | |
kw_arg_dict=None, | |
input_args=None): | |
self.description = description | |
self.raw_pos_arg_list = pos_arg_list if pos_arg_list is not None else [] | |
self.raw_pos_arg_types = pos_arg_types \ | |
if pos_arg_types is not None \ | |
else [None] * len(self.raw_pos_arg_list) | |
self.kw_arg_dict_with_defaults = kw_arg_dict if kw_arg_dict is not None else {} | |
kw_vals, cmdline_values = self._to_kwargs(sys.argv[1:] if input_args is None else input_args) | |
self._maybe_help(cmdline_values) | |
self._handle_positional_args(cmdline_values) | |
self._handle_keyword_args(kw_vals) | |
def _to_kwargs(arg_list): | |
key_args = dict(raw_entry.lstrip("-").split("=") for raw_entry in arg_list if "=" in raw_entry) | |
filtered_arg_list = [arg for arg in arg_list if "=" not in arg] | |
return key_args, filtered_arg_list | |
def _handle_keyword_args(self, kw_vals): | |
for kw in self.kw_arg_dict_with_defaults: | |
if kw in kw_vals: | |
val = self._convert_kw(kw_vals, kw) | |
del kw_vals[kw] | |
else: | |
val = self.kw_arg_dict_with_defaults[kw] | |
setattr(self, kw, val) | |
if kw_vals: | |
extra_keys = ", ".join(kw_vals.keys()) | |
msg = f"command-line keyword arguments '{extra_keys}' are not recognized." | |
self._help_message_and_die(extra=msg) | |
def _convert_kw(self, kw_vals, kw): | |
if self.kw_arg_dict_with_defaults[kw] is None: | |
return kw_vals[kw] | |
else: | |
this_typ = type(self.kw_arg_dict_with_defaults[kw]) | |
try: | |
return this_typ(kw_vals[kw]) | |
except ValueError: | |
self._help_message_and_die(extra=f"could not convert '{kw_vals[kw]}' to '{this_typ}'") | |
def _sanity_check_pos_args(self, cmdline_values): | |
cmdline_len = len(cmdline_values) | |
if cmdline_len < len(self.raw_pos_arg_list): | |
self._help_message_and_die( | |
extra=f"positional arguments missing: {', '.join(self.raw_pos_arg_list[cmdline_len:])}") | |
if cmdline_len > len(self.raw_pos_arg_list): | |
self._help_message_and_die( | |
extra=f"superfluous positional arguments: {', '.join(cmdline_values[len(self.raw_pos_arg_list):])}") | |
def _handle_positional_args(self, cmdline_values): | |
self._sanity_check_pos_args(cmdline_values) | |
for arg, val, typ in zip(self.raw_pos_arg_list, cmdline_values, self.raw_pos_arg_types): | |
try: | |
val = val if typ is None else typ(val) | |
except ValueError: | |
self._help_message_and_die(extra=f"could not convert '{val}' to '{typ}'") | |
setattr(self, arg, val) | |
def _maybe_help(self, cmdline_values): | |
if len(cmdline_values) == 1 and cmdline_values[0] in {"--help", "-h", "-?"}: | |
self._help_message_and_die() | |
def _help_message_and_die(self, extra=None): | |
sys.stderr.write("Help message: " + self.description + "\n") | |
if self.raw_pos_arg_list: | |
args_descr = ", ".join([f"'{arg}' ({typ.__name__ if typ is not None else 'any'})" | |
for arg, typ in zip(self.raw_pos_arg_list, self.raw_pos_arg_types)]) | |
sys.stderr.write(f"Positional arguments: {args_descr}\n") | |
if self.kw_arg_dict_with_defaults: | |
kw_descr = ", ".join([f"'{kw}' (default: {val})" | |
for kw, val in self.kw_arg_dict_with_defaults.items()]) | |
sys.stderr.write(f"Keyword arguments: {kw_descr}\n") | |
if extra is not None: | |
sys.stderr.write("Error: " + extra + "\n") | |
sys.stderr.write("\n") | |
sys.exit(-1) | |
def to_dict(self): | |
return {k: v for k, v in self.__dict__.items() | |
if k not in {'description', 'raw_pos_arg_list', 'raw_pos_arg_types', 'kw_arg_dict_with_defaults'}} | |
def __str__(self): | |
return str(self.to_dict()) | |
def __repr__(self): | |
return self.__str__() | |
if __name__ == "__main__": | |
for dname in sys.argv[1:]: | |
d = np.load(dname + "/custom_checkpoint_1.pkl", allow_pickle=True) | |
p = pickle.loads(d['custom_checkpoint_1/data.pkl']) | |
print(dname, p) | |