File size: 2,013 Bytes
c199313
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# Adapted from https://github.com/openai/jukebox

from enum import Enum

import torch.distributed as dist


class ReduceOp(Enum):
    SUM = 0,
    PRODUCT = 1,
    MIN = 2,
    MAX = 3

    def ToDistOp(self):
        return {
            self.SUM: dist.ReduceOp.SUM,
            self.PRODUCT: dist.ReduceOp.PRODUCT,
            self.MIN: dist.ReduceOp.MIN,
            self.MAX: dist.ReduceOp.MAX
        }[self]


def is_available():
    return dist.is_initialized()


def get_rank():
    if is_available():
        return _get_rank()
    else:
        return 0


def get_world_size():
    if is_available():
        return _get_world_size()
    else:
        return 1


def barrier():
    if is_available():
        return _barrier()
    # else: do nothing


def all_gather(tensor_list, tensor):
    if is_available():
        return _all_gather(tensor_list, tensor)
    else:
        tensor_list[0] = tensor


def all_reduce(tensor, op=ReduceOp.SUM):
    if is_available():
        return _all_reduce(tensor, op)
    # else: do nothing


def reduce(tensor, dst, op=ReduceOp.SUM):
    if is_available():
        return _reduce(tensor, dst, op)
    # else: do nothing


def broadcast(tensor, src):
    if is_available():
        return _broadcast(tensor, src)
    # else: do nothing


def init_process_group(backend, init_method):
    if is_available():
        return _init_process_group(backend, init_method)
    # else: do nothing


def _get_rank():
    return dist.get_rank()


def _barrier():
    return dist.barrier()


def _get_world_size():
    return dist.get_world_size()


def _all_gather(tensor_list, tensor):
    return dist.all_gather(tensor_list, tensor)


def _all_reduce(tensor, op):
    return dist.all_reduce(tensor, op.ToDistOp())


def _reduce(tensor, dst, op):
    return dist.reduce(tensor, dst, op.ToDistOp())


def _broadcast(tensor, src):
    return dist.broadcast(tensor, src)


def _init_process_group(backend, init_method):
    return dist.init_process_group(backend, init_method)