Alic-Li commited on
Commit
b2e0455
·
verified ·
1 Parent(s): a28b463

Update world RWKV CPU

Browse files
infer/__init__.py ADDED
File without changes
infer/rwkv/__init__.py ADDED
File without changes
infer/rwkv/cuda/gemm_fp16_cublas.cpp ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <cublas_v2.h>
2
+ #include <cuda.h>
3
+ #include <cuda_fp16.h>
4
+ #include <cuda_runtime.h>
5
+ #include <torch/extension.h>
6
+ #include <c10/cuda/CUDAGuard.h>
7
+ #include <ATen/cuda/CUDAContext.h>
8
+
9
+ #define CUBLAS_CHECK(condition) \
10
+ for (cublasStatus_t _cublas_check_status = (condition); \
11
+ _cublas_check_status != CUBLAS_STATUS_SUCCESS;) \
12
+ throw std::runtime_error("cuBLAS error " + \
13
+ std::to_string(_cublas_check_status) + " at " + \
14
+ std::to_string(__LINE__));
15
+
16
+ #define CUDA_CHECK(condition) \
17
+ for (cudaError_t _cuda_check_status = (condition); \
18
+ _cuda_check_status != cudaSuccess;) \
19
+ throw std::runtime_error( \
20
+ "CUDA error " + std::string(cudaGetErrorString(_cuda_check_status)) + \
21
+ " at " + std::to_string(__LINE__));
22
+
23
+ /*
24
+ NOTE: blas gemm is column-major by default, but we need row-major output.
25
+ The data of row-major, transposed matrix is exactly the same as the
26
+ column-major, non-transposed matrix, and C = A * B ---> C^T = B^T * A^T
27
+ */
28
+ void gemm_fp16_cublas(torch::Tensor a, torch::Tensor b, torch::Tensor c) {
29
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
30
+ const auto cuda_data_type = CUDA_R_16F;
31
+ const auto cuda_c_data_type =
32
+ c.dtype() == torch::kFloat32 ? CUDA_R_32F : CUDA_R_16F;
33
+ const auto compute_type = CUDA_R_32F;
34
+ const float sp_alpha = 1.f;
35
+ // swap a and b, and use CUBLAS_OP_N. see the notes above
36
+ std::swap(a, b);
37
+ const cublasOperation_t cublas_trans_a = CUBLAS_OP_N;
38
+ const cublasOperation_t cublas_trans_b = CUBLAS_OP_N;
39
+ // m = (B^T).size(0) = B.size(1), and = A.size(1) after swap,
40
+ // negative axis is used because of the existence of batch matmul.
41
+ const int m = a.size(-1);
42
+ const int k = a.size(-2);
43
+ const int n = b.size(-2);
44
+ const int cublas_lda = m;
45
+ const int cublas_ldb = k;
46
+ const int cublas_ldc = m;
47
+ cublasHandle_t cublas_handle = at::cuda::getCurrentCUDABlasHandle();
48
+
49
+ #if CUDA_VERSION >= 11000
50
+ cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
51
+ #else
52
+ cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
53
+ #endif
54
+ const float sp_beta = 0.f;
55
+ if (a.sizes().size() == 2 && b.sizes().size() == 2) {
56
+ CUBLAS_CHECK(cublasGemmEx(
57
+ cublas_handle, cublas_trans_a, cublas_trans_b, m, n, k, &sp_alpha,
58
+ a.data_ptr(), cuda_data_type, cublas_lda, b.data_ptr(), cuda_data_type,
59
+ cublas_ldb, &sp_beta, c.data_ptr(), cuda_c_data_type, cublas_ldc,
60
+ compute_type, algo));
61
+ } else {
62
+ // batch matmul
63
+ assert(a.sizes().size() == 3 && b.sizes().size() == 3);
64
+
65
+ const long long int cublas_stride_a = m * k;
66
+ const long long int cublas_stride_b = k * n;
67
+ const long long int cublas_stride_c = m * n;
68
+ CUBLAS_CHECK(cublasGemmStridedBatchedEx(
69
+ cublas_handle, cublas_trans_a, cublas_trans_b, m,
70
+ n, k, &sp_alpha, a.data_ptr(), cuda_data_type, cublas_lda,
71
+ cublas_stride_a, b.data_ptr(), cuda_data_type, cublas_ldb, cublas_stride_b,
72
+ &sp_beta, c.data_ptr(), cuda_c_data_type, cublas_ldc, cublas_stride_c,
73
+ a.size(0), compute_type, algo));
74
+ }
75
+ }
infer/rwkv/cuda/operators.cu ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <stdio.h>
2
+ #include <assert.h>
3
+ #include "ATen/ATen.h"
4
+ #include <cuda_fp16.h>
5
+ #define MIN_VALUE (-1e38)
6
+ typedef at::Half fp16;
7
+ __half *cast(fp16 *ptr) {
8
+ return reinterpret_cast<__half *>(ptr);
9
+ }
10
+
11
+ template <typename F>
12
+ __global__ void kernel_wkv_forward(const int B, const int T, const int C,
13
+ const float *__restrict__ const _w, const float *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
14
+ F *__restrict__ const _y, float *__restrict__ const _aa, float *__restrict__ const _bb, float *__restrict__ const _pp) {
15
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
16
+ const int _b = idx / C;
17
+ const int _c = idx % C;
18
+ const int _offset = _b * T * C + _c;
19
+ const int _state_offset = _b * C + _c;
20
+
21
+ float u = _u[_c];
22
+ float w = _w[_c];
23
+ const F *__restrict__ const k = _k + _offset;
24
+ const F *__restrict__ const v = _v + _offset;
25
+ F *__restrict__ const y = _y + _offset;
26
+
27
+ float aa = _aa[_state_offset];
28
+ float bb = _bb[_state_offset];
29
+ float pp = _pp[_state_offset];
30
+ for (int i = 0; i < T; i++) {
31
+ const int ii = i * C;
32
+ const float kk = float(k[ii]);
33
+ const float vv = float(v[ii]);
34
+ float ww = u + kk;
35
+ float p = max(pp, ww);
36
+ float e1 = exp(pp - p);
37
+ float e2 = exp(ww - p);
38
+ y[ii] = F((e1 * aa + e2 * vv) / (e1 * bb + e2));
39
+ ww = w + pp;
40
+ p = max(ww, kk);
41
+ e1 = exp(ww - p);
42
+ e2 = exp(kk - p);
43
+ aa = e1 * aa + e2 * vv;
44
+ bb = e1 * bb + e2;
45
+ pp = p;
46
+ }
47
+ _aa[_state_offset] = aa;
48
+ _bb[_state_offset] = bb;
49
+ _pp[_state_offset] = pp;
50
+ }
51
+
52
+ template <typename F>
53
+ void cuda_wkv_forward(int B, int T, int C, float *w, float *u, F *k, F *v, F *y, float *aa, float *bb, float *pp) {
54
+ dim3 threadsPerBlock( min(C, 32) );
55
+ assert(B * C % threadsPerBlock.x == 0);
56
+ dim3 numBlocks(B * C / threadsPerBlock.x);
57
+ kernel_wkv_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, aa, bb, pp);
58
+ }
59
+
60
+ template void cuda_wkv_forward<fp16>(
61
+ int B, int T, int C,
62
+ float *w, float *u, fp16 *k, fp16 *v, fp16 *y,
63
+ float *aa, float *bb, float *pp);
64
+ template void cuda_wkv_forward<float>(
65
+ int B, int T, int C,
66
+ float *w, float *u, float *k, float *v, float *y,
67
+ float *aa, float *bb, float *pp);
68
+
69
+ __global__ void kernel_mm_seq_fp32i8(
70
+ const int B, const int N, const int M,
71
+ const float *__restrict__ const x, const int x_stride,
72
+ const uint8_t *__restrict__ const w, const int w_stride,
73
+ const float *__restrict__ const mx,
74
+ const float *__restrict__ const rx,
75
+ const float *__restrict__ const my,
76
+ const float *__restrict__ const ry,
77
+ float *__restrict__ const y, const int y_stride) {
78
+
79
+ const int i = blockIdx.x * blockDim.x + threadIdx.x;
80
+ const int k = blockIdx.y * blockDim.y + threadIdx.y;
81
+
82
+ if (i < B && k < M) {
83
+ float y_local = 0;
84
+ for (int j = 0; j < N; ++j) {
85
+ y_local += x[i * x_stride + j] * (
86
+ (float(w[j * w_stride + k]) + 0.5f)
87
+ * rx[k] * ry[j] + mx[k] + my[j]
88
+ );
89
+ }
90
+ y[i * y_stride + k] = y_local;
91
+ }
92
+ }
93
+
94
+ template <typename F>
95
+ void cuda_mm8_seq(int B, int N, int M,
96
+ F *x, int x_stride,
97
+ uint8_t *w, int w_stride,
98
+ F *mx, F *rx,
99
+ F *my, F *ry,
100
+ F *y, int y_stride);
101
+
102
+ template <>
103
+ void cuda_mm8_seq<float>(int B, int N, int M,
104
+ float *x, int x_stride,
105
+ uint8_t *w, int w_stride,
106
+ float *mx, float *rx,
107
+ float *my, float *ry,
108
+ float *y, int y_stride) {
109
+ dim3 blockSize(1, 128);
110
+ dim3 gridSize((B + blockSize.x - 1) / blockSize.x, (M + blockSize.y - 1) / blockSize.y);
111
+ kernel_mm_seq_fp32i8<<<gridSize, blockSize>>>(
112
+ B, N, M, x, x_stride, w, w_stride,
113
+ mx, rx, my, ry, y, y_stride);
114
+ }
115
+
116
+ __global__ void kernel_mm_seq_fp16i8(
117
+ const int B, const int N, const int M,
118
+ const __half *__restrict__ const x, const int x_stride,
119
+ const uint8_t *__restrict__ const w, const int w_stride,
120
+ const __half *__restrict__ const mx,
121
+ const __half *__restrict__ const rx,
122
+ const __half *__restrict__ const my,
123
+ const __half *__restrict__ const ry,
124
+ __half *__restrict__ const y, const int y_stride) {
125
+
126
+ const int i = blockIdx.x * blockDim.x + threadIdx.x;
127
+ const int k = blockIdx.y * blockDim.y + threadIdx.y;
128
+
129
+ if (i < B && k < M) {
130
+ float y_local = 0;
131
+ for (int j = 0; j < N; ++j) {
132
+ y_local += __half2float(x[i * x_stride + j]) * (
133
+ (float(w[j * w_stride + k]) + 0.5f)
134
+ * __half2float(rx[k]) * __half2float(ry[j])
135
+ + __half2float(mx[k]) + __half2float(my[j])
136
+ );
137
+ }
138
+ y[i * y_stride + k] = __float2half(y_local);
139
+ }
140
+ }
141
+
142
+ template <>
143
+ void cuda_mm8_seq<fp16>(int B, int N, int M,
144
+ fp16 *x, int x_stride,
145
+ uint8_t *w, int w_stride,
146
+ fp16 *mx, fp16 *rx,
147
+ fp16 *my, fp16 *ry,
148
+ fp16 *y, int y_stride) {
149
+ dim3 blockSize(1, 128);
150
+ dim3 gridSize((B + blockSize.x - 1) / blockSize.x, (M + blockSize.y - 1) / blockSize.y);
151
+ kernel_mm_seq_fp16i8<<<gridSize, blockSize>>>(
152
+ B, N, M, cast(x), x_stride, w, w_stride,
153
+ cast(mx), cast(rx), cast(my), cast(ry), cast(y), y_stride);
154
+ }
155
+
156
+ #define MM8_ONE_JSPLIT 24
157
+ #define MM8_ONE_TILE 1024
158
+
159
+ __global__ void kernel_mm_one_fp32i8(
160
+ const int N, const int M,
161
+ const float *__restrict__ const x,
162
+ const uint8_t *__restrict__ const w, const int w_stride,
163
+ const float *__restrict__ const mx,
164
+ const float *__restrict__ const rx,
165
+ const float *__restrict__ const my,
166
+ const float *__restrict__ const ry,
167
+ float *__restrict__ const y) {
168
+
169
+ const int k = blockIdx.y * blockDim.y + threadIdx.y;
170
+ const int j0 = min(N, blockIdx.x * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT));
171
+ const int j1 = min(N, (blockIdx.x + 1) * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT));
172
+
173
+ if (k < M) {
174
+ float y_local = 0;
175
+ for (int j = j0; j < j1; ++j) {
176
+ y_local += x[j] * (
177
+ (float(w[j * w_stride + k]) + 0.5f)
178
+ * rx[k] * ry[j] + mx[k] + my[j]
179
+ );
180
+ }
181
+ atomicAdd(&y[k], y_local);
182
+ }
183
+ }
184
+
185
+ template <typename F>
186
+ void cuda_mm8_one(int N, int M,
187
+ F *x,
188
+ uint8_t *w, int w_stride,
189
+ F *mx, F *rx,
190
+ F *my, F *ry,
191
+ float *y);
192
+
193
+ template <>
194
+ void cuda_mm8_one<float>(int N, int M,
195
+ float *x,
196
+ uint8_t *w, int w_stride,
197
+ float *mx, float *rx,
198
+ float *my, float *ry,
199
+ float *y) {
200
+ dim3 blockSize(1, MM8_ONE_TILE);
201
+ dim3 gridSize(MM8_ONE_JSPLIT, (M + blockSize.y - 1) / blockSize.y);
202
+ kernel_mm_one_fp32i8<<<gridSize, blockSize>>>(
203
+ N, M, x, w, w_stride,
204
+ mx, rx, my, ry, y);
205
+ }
206
+
207
+ __global__ void kernel_mm_one_fp16i8(
208
+ const int N, const int M,
209
+ const __half *__restrict__ const x,
210
+ const uint8_t *__restrict__ const w, const int w_stride,
211
+ const __half *__restrict__ const mx,
212
+ const __half *__restrict__ const rx,
213
+ const __half *__restrict__ const my,
214
+ const __half *__restrict__ const ry,
215
+ float *__restrict__ const y) {
216
+
217
+ const int k = blockIdx.y * blockDim.y + threadIdx.y;
218
+ const int j0 = min(N, blockIdx.x * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT));
219
+ const int j1 = min(N, (blockIdx.x + 1) * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT));
220
+
221
+ if (k < M) {
222
+ float y_local = 0;
223
+ for (int j = j0; j < j1; ++j) {
224
+ y_local += __half2float(x[j]) * (
225
+ (float(w[j * w_stride + k]) + 0.5f)
226
+ * __half2float(rx[k]) * __half2float(ry[j])
227
+ + __half2float(mx[k]) + __half2float(my[j])
228
+ );
229
+ }
230
+ atomicAdd(&y[k], y_local);
231
+ }
232
+ }
233
+
234
+ template <>
235
+ void cuda_mm8_one<fp16>(int N, int M,
236
+ fp16 *x,
237
+ uint8_t *w, int w_stride,
238
+ fp16 *mx, fp16 *rx,
239
+ fp16 *my, fp16 *ry,
240
+ float *y) {
241
+ dim3 blockSize(1, MM8_ONE_TILE);
242
+ dim3 gridSize(MM8_ONE_JSPLIT, (M + blockSize.y - 1) / blockSize.y);
243
+ kernel_mm_one_fp16i8<<<gridSize, blockSize>>>(
244
+ N, M, cast(x), w, w_stride,
245
+ cast(mx), cast(rx), cast(my), cast(ry), y);
246
+ }
infer/rwkv/cuda/rwkv5.cu ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <stdio.h>
2
+ #include <assert.h>
3
+ #include "ATen/ATen.h"
4
+ typedef at::BFloat16 bf16;
5
+ typedef at::Half fp16;
6
+ typedef float fp32;
7
+
8
+ template <typename F>
9
+ __global__ void kernel_forward(const int B, const int T, const int C, const int H, float *__restrict__ _state,
10
+ const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u,
11
+ F *__restrict__ const _y)
12
+ {
13
+ const int b = blockIdx.x / H;
14
+ const int h = blockIdx.x % H;
15
+ const int i = threadIdx.x;
16
+ _w += h*_N_;
17
+ _u += h*_N_;
18
+ _state += h*_N_*_N_ + i*_N_; // wrong if B > 1 !!!
19
+
20
+ __shared__ float r[_N_], k[_N_], u[_N_], w[_N_];
21
+
22
+ float state[_N_];
23
+ #pragma unroll
24
+ for (int j = 0; j < _N_; j++)
25
+ state[j] = _state[j];
26
+
27
+ __syncthreads();
28
+ u[i] = float(_u[i]);
29
+ w[i] = _w[i];
30
+ __syncthreads();
31
+
32
+ for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C)
33
+ {
34
+ __syncthreads();
35
+ r[i] = float(_r[t]);
36
+ k[i] = float(_k[t]);
37
+ __syncthreads();
38
+
39
+ const float v = float(_v[t]);
40
+ float y = 0;
41
+
42
+ #pragma unroll
43
+ for (int j = 0; j < _N_; j+=4)
44
+ {
45
+ const float4& r_ = (float4&)(r[j]);
46
+ const float4& k_ = (float4&)(k[j]);
47
+ const float4& w_ = (float4&)(w[j]);
48
+ const float4& u_ = (float4&)(u[j]);
49
+ float4& s = (float4&)(state[j]);
50
+ float4 x;
51
+
52
+ x.x = k_.x * v;
53
+ x.y = k_.y * v;
54
+ x.z = k_.z * v;
55
+ x.w = k_.w * v;
56
+
57
+ y += r_.x * (u_.x * x.x + s.x);
58
+ y += r_.y * (u_.y * x.y + s.y);
59
+ y += r_.z * (u_.z * x.z + s.z);
60
+ y += r_.w * (u_.w * x.w + s.w);
61
+
62
+ s.x = s.x * w_.x + x.x;
63
+ s.y = s.y * w_.y + x.y;
64
+ s.z = s.z * w_.z + x.z;
65
+ s.w = s.w * w_.w + x.w;
66
+ }
67
+ _y[t] = F(y);
68
+ }
69
+ #pragma unroll
70
+ for (int j = 0; j < _N_; j++)
71
+ _state[j] = state[j];
72
+ }
73
+
74
+ void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y)
75
+ {
76
+ assert(H*_N_ == C);
77
+ kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
78
+ }
79
+ void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y)
80
+ {
81
+ assert(H*_N_ == C);
82
+ kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
83
+ }
84
+ void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y)
85
+ {
86
+ assert(H*_N_ == C);
87
+ kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
88
+ }
infer/rwkv/cuda/rwkv5_op.cpp ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include "ATen/ATen.h"
3
+ #include <c10/cuda/CUDAGuard.h>
4
+ typedef at::BFloat16 bf16;
5
+ typedef at::Half fp16;
6
+ typedef float fp32;
7
+
8
+ void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y);
9
+ void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y);
10
+ void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y);
11
+
12
+ void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
13
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
14
+ cuda_forward_bf16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>());
15
+ }
16
+ void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
17
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
18
+ cuda_forward_fp16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp16>(), k.data_ptr<fp16>(), v.data_ptr<fp16>(), w.data_ptr<float>(), u.data_ptr<fp16>(), y.data_ptr<fp16>());
19
+ }
20
+ void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
21
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
22
+ cuda_forward_fp32(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp32>(), k.data_ptr<fp32>(), v.data_ptr<fp32>(), w.data_ptr<float>(), u.data_ptr<fp32>(), y.data_ptr<fp32>());
23
+ }
24
+
25
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
26
+ m.def("forward_bf16", &forward_bf16, "rwkv5 forward_bf16");
27
+ m.def("forward_fp16", &forward_fp16, "rwkv5 forward_fp16");
28
+ m.def("forward_fp32", &forward_fp32, "rwkv5 forward_fp32");
29
+ }
30
+ TORCH_LIBRARY(rwkv5, m) {
31
+ m.def("forward_bf16", forward_bf16);
32
+ m.def("forward_fp16", forward_fp16);
33
+ m.def("forward_fp32", forward_fp32);
34
+ }
infer/rwkv/cuda/rwkv6.cu ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <stdio.h>
2
+ #include <assert.h>
3
+ #include "ATen/ATen.h"
4
+ typedef at::BFloat16 bf16;
5
+ typedef at::Half fp16;
6
+ typedef float fp32;
7
+
8
+ template <typename F>
9
+ __global__ void kernel_forward(const int B, const int T, const int C, const int H, float *__restrict__ _state,
10
+ const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u,
11
+ F *__restrict__ const _y)
12
+ {
13
+ const int b = blockIdx.x / H;
14
+ const int h = blockIdx.x % H;
15
+ const int i = threadIdx.x;
16
+ _u += h*_N_;
17
+ _state += h*_N_*_N_ + i*_N_; // wrong if B > 1 !!!
18
+
19
+ __shared__ float r[_N_], k[_N_], u[_N_], w[_N_];
20
+
21
+ float state[_N_];
22
+ #pragma unroll
23
+ for (int j = 0; j < _N_; j++)
24
+ state[j] = _state[j];
25
+
26
+ __syncthreads();
27
+ u[i] = float(_u[i]);
28
+ __syncthreads();
29
+
30
+ for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C)
31
+ {
32
+ __syncthreads();
33
+ w[i] = _w[t];
34
+ r[i] = float(_r[t]);
35
+ k[i] = float(_k[t]);
36
+ __syncthreads();
37
+
38
+ const float v = float(_v[t]);
39
+ float y = 0;
40
+
41
+ #pragma unroll
42
+ for (int j = 0; j < _N_; j+=4)
43
+ {
44
+ const float4& r_ = (float4&)(r[j]);
45
+ const float4& k_ = (float4&)(k[j]);
46
+ const float4& w_ = (float4&)(w[j]);
47
+ const float4& u_ = (float4&)(u[j]);
48
+ float4& s = (float4&)(state[j]);
49
+ float4 x;
50
+
51
+ x.x = k_.x * v;
52
+ x.y = k_.y * v;
53
+ x.z = k_.z * v;
54
+ x.w = k_.w * v;
55
+
56
+ y += r_.x * (u_.x * x.x + s.x);
57
+ y += r_.y * (u_.y * x.y + s.y);
58
+ y += r_.z * (u_.z * x.z + s.z);
59
+ y += r_.w * (u_.w * x.w + s.w);
60
+
61
+ s.x = s.x * w_.x + x.x;
62
+ s.y = s.y * w_.y + x.y;
63
+ s.z = s.z * w_.z + x.z;
64
+ s.w = s.w * w_.w + x.w;
65
+ }
66
+ _y[t] = F(y);
67
+ }
68
+ #pragma unroll
69
+ for (int j = 0; j < _N_; j++)
70
+ _state[j] = state[j];
71
+ }
72
+
73
+ void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y)
74
+ {
75
+ assert(H*_N_ == C);
76
+ kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
77
+ }
78
+ void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y)
79
+ {
80
+ assert(H*_N_ == C);
81
+ kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
82
+ }
83
+ void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y)
84
+ {
85
+ assert(H*_N_ == C);
86
+ kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
87
+ }
infer/rwkv/cuda/rwkv6_op.cpp ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include "ATen/ATen.h"
3
+ #include <c10/cuda/CUDAGuard.h>
4
+ typedef at::BFloat16 bf16;
5
+ typedef at::Half fp16;
6
+ typedef float fp32;
7
+
8
+ void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y);
9
+ void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y);
10
+ void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y);
11
+
12
+ void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
13
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
14
+ cuda_forward_bf16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>());
15
+ }
16
+ void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
17
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
18
+ cuda_forward_fp16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp16>(), k.data_ptr<fp16>(), v.data_ptr<fp16>(), w.data_ptr<float>(), u.data_ptr<fp16>(), y.data_ptr<fp16>());
19
+ }
20
+ void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
21
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
22
+ cuda_forward_fp32(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp32>(), k.data_ptr<fp32>(), v.data_ptr<fp32>(), w.data_ptr<float>(), u.data_ptr<fp32>(), y.data_ptr<fp32>());
23
+ }
24
+
25
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
26
+ m.def("forward_bf16", &forward_bf16, "rwkv6 forward_bf16");
27
+ m.def("forward_fp16", &forward_fp16, "rwkv6 forward_fp16");
28
+ m.def("forward_fp32", &forward_fp32, "rwkv6 forward_fp32");
29
+ }
30
+ TORCH_LIBRARY(rwkv6, m) {
31
+ m.def("forward_bf16", forward_bf16);
32
+ m.def("forward_fp16", forward_fp16);
33
+ m.def("forward_fp32", forward_fp32);
34
+ }
infer/rwkv/cuda/rwkv7.cu ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <stdio.h>
2
+ #include <assert.h>
3
+ #include "ATen/ATen.h"
4
+
5
+ typedef at::Half fp16;
6
+ typedef at::BFloat16 bf16;
7
+ typedef float fp32;
8
+
9
+ template <typename F>
10
+ __global__ void kernel_forward(const int B, const int T, const int C, const int H,
11
+ float *__restrict__ _state, const F *__restrict__ const _r, const F *__restrict__ const _w, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _a, const F *__restrict__ const _b,
12
+ F *__restrict__ const _y)
13
+ {
14
+ const int e = blockIdx.x / H;
15
+ const int h = blockIdx.x % H;
16
+ const int i = threadIdx.x;
17
+ _state += h*_N_*_N_ + i*_N_; // wrong if B > 1 !!!
18
+
19
+ float state[_N_];
20
+ #pragma unroll
21
+ for (int j = 0; j < _N_; j++)
22
+ state[j] = _state[j];
23
+
24
+ __shared__ float r[_N_], k[_N_], w[_N_], a[_N_], b[_N_];
25
+
26
+ for (int _t = 0; _t < T; _t++)
27
+ {
28
+ const int t = e*T*C + h*_N_ + i + _t * C;
29
+ __syncthreads();
30
+ r[i] = float(_r[t]);
31
+ w[i] = __expf(-__expf(float(_w[t])));
32
+ k[i] = float(_k[t]);
33
+ a[i] = float(_a[t]);
34
+ b[i] = float(_b[t]);
35
+ __syncthreads();
36
+
37
+ float sa = 0;
38
+ #pragma unroll
39
+ for (int j = 0; j < _N_; j++)
40
+ {
41
+ sa += a[j] * state[j];
42
+ }
43
+
44
+ float vv = float(_v[t]);
45
+ float y = 0;
46
+ #pragma unroll
47
+ for (int j = 0; j < _N_; j++)
48
+ {
49
+ float& s = state[j];
50
+ s = s * w[j] + k[j] * vv + sa * b[j];
51
+ y += s * r[j];
52
+ }
53
+ _y[t] = F(y);
54
+ }
55
+ #pragma unroll
56
+ for (int j = 0; j < _N_; j++)
57
+ _state[j] = state[j];
58
+ }
59
+
60
+ void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16* w, bf16 *k, bf16 *v, bf16 *a, bf16 *b, bf16 *y)
61
+ {
62
+ assert(H*_N_ == C);
63
+ assert(B == 1); // only for B=1
64
+ kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, w, k, v, a, b, y);
65
+ }
66
+ void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16* w, fp16 *k, fp16 *v, fp16 *a, fp16 *b, fp16 *y)
67
+ {
68
+ assert(H*_N_ == C);
69
+ assert(B == 1); // only for B=1
70
+ kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, w, k, v, a, b, y);
71
+ }
72
+ void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32* w, fp32 *k, fp32 *v, fp32 *a, fp32 *b, fp32 *y)
73
+ {
74
+ assert(H*_N_ == C);
75
+ assert(B == 1); // only for B=1
76
+ kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, w, k, v, a, b, y);
77
+ }
infer/rwkv/cuda/rwkv7_op.cpp ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include "ATen/ATen.h"
3
+
4
+ typedef at::Half fp16;
5
+ typedef at::BFloat16 bf16;
6
+ typedef float fp32;
7
+
8
+ void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *w, bf16 *k, bf16 *v, bf16 *a, bf16 *b, bf16 *y);
9
+ void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *w, fp16 *k, fp16 *v, fp16 *a, fp16 *b, fp16 *y);
10
+ void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *w, fp32 *k, fp32 *v, fp32 *a, fp32 *b, fp32 *y);
11
+
12
+ void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &y) {
13
+ cuda_forward_bf16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<bf16>(), w.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), a.data_ptr<bf16>(), b.data_ptr<bf16>(), y.data_ptr<bf16>());
14
+ }
15
+ void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &y) {
16
+ cuda_forward_fp16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp16>(), w.data_ptr<fp16>(), k.data_ptr<fp16>(), v.data_ptr<fp16>(), a.data_ptr<fp16>(), b.data_ptr<fp16>(), y.data_ptr<fp16>());
17
+ }
18
+ void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &y) {
19
+ cuda_forward_fp32(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp32>(), w.data_ptr<fp32>(), k.data_ptr<fp32>(), v.data_ptr<fp32>(), a.data_ptr<fp32>(), b.data_ptr<fp32>(), y.data_ptr<fp32>());
20
+ }
21
+
22
+ TORCH_LIBRARY(wkv7s, m) {
23
+ m.def("forward_bf16", forward_bf16);
24
+ m.def("forward_fp16", forward_fp16);
25
+ m.def("forward_fp32", forward_fp32);
26
+ }
infer/rwkv/cuda/wrapper.cpp ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include "ATen/ATen.h"
3
+ #include <iostream>
4
+ #include <c10/cuda/CUDAGuard.h>
5
+
6
+ typedef at::Half fp16;
7
+
8
+ template <typename F>
9
+ void cuda_wkv_forward(int B, int T, int C,
10
+ float *w, float *u, F *k, F *v, F *y,
11
+ float *aa, float *bb, float *pp);
12
+ template <typename F>
13
+ void cuda_mm8_seq(int B, int N, int M,
14
+ F *x, int x_stride,
15
+ uint8_t *w, int w_stride,
16
+ F *mx, F *rx,
17
+ F *my, F *ry,
18
+ F *y, int y_stride);
19
+ template <typename F>
20
+ void cuda_mm8_one(int N, int M,
21
+ F *x,
22
+ uint8_t *w, int w_stride,
23
+ F *mx, F *rx,
24
+ F *my, F *ry,
25
+ float *y);
26
+
27
+ void wkv_forward(int64_t B, int64_t T, int64_t C,
28
+ torch::Tensor &w, torch::Tensor &u,
29
+ torch::Tensor &k, torch::Tensor &v, torch::Tensor &y,
30
+ torch::Tensor &aa, torch::Tensor &bb, torch::Tensor &pp) {
31
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(w));
32
+ switch (k.scalar_type()) {
33
+ case c10::ScalarType::Half:
34
+ cuda_wkv_forward(B, T, C,
35
+ w.data_ptr<float>(), u.data_ptr<float>(),
36
+ k.data_ptr<fp16>(), v.data_ptr<fp16>(), y.data_ptr<fp16>(),
37
+ aa.data_ptr<float>(), bb.data_ptr<float>(), pp.data_ptr<float>());
38
+ break;
39
+ case c10::ScalarType::Float:
40
+ cuda_wkv_forward(B, T, C,
41
+ w.data_ptr<float>(), u.data_ptr<float>(),
42
+ k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>(),
43
+ aa.data_ptr<float>(), bb.data_ptr<float>(), pp.data_ptr<float>());
44
+ break;
45
+ default:
46
+ assert(false && "Only FP16 and FP32 are currently supported");
47
+ }
48
+ }
49
+
50
+ void mm8_seq(int64_t B, int64_t N, int64_t M,
51
+ torch::Tensor &x, torch::Tensor &w,
52
+ torch::Tensor &mx, torch::Tensor &rx,
53
+ torch::Tensor &my, torch::Tensor &ry,
54
+ torch::Tensor &y) {
55
+ assert(x.stride(1) == 1);
56
+ assert(w.stride(1) == 1);
57
+ assert(mx.stride(0) == 1 && rx.stride(0) == 1);
58
+ assert(my.stride(0) == 1 && ry.stride(0) == 1);
59
+ assert(y.stride(1) == 1);
60
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(w));
61
+ switch (x.scalar_type()) {
62
+ case c10::ScalarType::Half:
63
+ cuda_mm8_seq(
64
+ B, N, M,
65
+ x.data_ptr<fp16>(), x.stride(0),
66
+ w.data_ptr<uint8_t>(), w.stride(0),
67
+ mx.data_ptr<fp16>(), rx.data_ptr<fp16>(),
68
+ my.data_ptr<fp16>(), ry.data_ptr<fp16>(),
69
+ y.data_ptr<fp16>(), y.stride(0));
70
+ break;
71
+ case c10::ScalarType::Float:
72
+ cuda_mm8_seq(
73
+ B, N, M,
74
+ x.data_ptr<float>(), x.stride(0),
75
+ w.data_ptr<uint8_t>(), w.stride(0),
76
+ mx.data_ptr<float>(), rx.data_ptr<float>(),
77
+ my.data_ptr<float>(), ry.data_ptr<float>(),
78
+ y.data_ptr<float>(), y.stride(0));
79
+ break;
80
+ default:
81
+ assert(false && "Only FP16 and FP32 are currently supported");
82
+ }
83
+ }
84
+ void mm8_one(int64_t N, int64_t M,
85
+ torch::Tensor &x, torch::Tensor &w,
86
+ torch::Tensor &mx, torch::Tensor &rx,
87
+ torch::Tensor &my, torch::Tensor &ry,
88
+ torch::Tensor &y) {
89
+ assert(x.stride(0) == 1);
90
+ assert(w.stride(1) == 1);
91
+ assert(mx.stride(0) == 1 && rx.stride(0) == 1);
92
+ assert(my.stride(0) == 1 && ry.stride(0) == 1);
93
+ assert(y.stride(0) == 1);
94
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(w));
95
+ switch (x.scalar_type()) {
96
+ case c10::ScalarType::Half:
97
+ cuda_mm8_one(
98
+ N, M,
99
+ x.data_ptr<fp16>(),
100
+ w.data_ptr<uint8_t>(), w.stride(0),
101
+ mx.data_ptr<fp16>(), rx.data_ptr<fp16>(),
102
+ my.data_ptr<fp16>(), ry.data_ptr<fp16>(),
103
+ y.data_ptr<float>());
104
+ break;
105
+ case c10::ScalarType::Float:
106
+ cuda_mm8_one(
107
+ N, M,
108
+ x.data_ptr<float>(),
109
+ w.data_ptr<uint8_t>(), w.stride(0),
110
+ mx.data_ptr<float>(), rx.data_ptr<float>(),
111
+ my.data_ptr<float>(), ry.data_ptr<float>(),
112
+ y.data_ptr<float>());
113
+ break;
114
+ default:
115
+ assert(false && "Only FP16 and FP32 are currently supported");
116
+ }
117
+ }
118
+
119
+ using torch::Tensor;
120
+
121
+ #ifndef DISABLE_CUBLAS_GEMM
122
+ void gemm_fp16_cublas(Tensor a, Tensor b, Tensor c);
123
+ #endif
124
+
125
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
126
+ m.def("wkv_forward", &wkv_forward, "wkv forward");
127
+ m.def("mm8_seq", &mm8_seq, "mm8 seq");
128
+ m.def("mm8_one", &mm8_one, "mm8 one");
129
+ #ifndef DISABLE_CUBLAS_GEMM
130
+ m.def("gemm_fp16_cublas", &gemm_fp16_cublas, "gemv fp16 cublas");
131
+ #endif
132
+ }
133
+
134
+ TORCH_LIBRARY(rwkv, m) {
135
+ m.def("wkv_forward", wkv_forward);
136
+ m.def("mm8_seq", mm8_seq);
137
+ m.def("mm8_one", mm8_one);
138
+ #ifndef DISABLE_CUBLAS_GEMM
139
+ m.def("gemm_fp16_cublas", gemm_fp16_cublas);
140
+ #endif
141
+ }
infer/rwkv/model.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ########################################################################################################
2
+ # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
3
+ ########################################################################################################
4
+
5
+ from typing import Optional
6
+ import types, gc, os, time, re, math
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.nn import functional as F
10
+ torch.backends.cudnn.benchmark = True
11
+ torch.backends.cudnn.allow_tf32 = True
12
+ torch.backends.cuda.matmul.allow_tf32 = True
13
+ current_path = os.path.dirname(os.path.abspath(__file__))
14
+
15
+ ########################################################################################################
16
+
17
+ if os.environ.get('RWKV_JIT_ON') != '0':
18
+ os.environ["RWKV_JIT_ON"] = '1'
19
+ MyModule = torch.jit.ScriptModule
20
+ MyFunction = torch.jit.script_method
21
+ MyStatic = torch.jit.script
22
+ else:
23
+ MyModule = torch.nn.Module
24
+ def __nop(ob):
25
+ return ob
26
+ MyFunction = __nop
27
+ MyStatic = __nop
28
+
29
+ if os.environ.get('RWKV_CUDA_ON') == '1':
30
+ from torch.utils.cpp_extension import load
31
+ try:
32
+ load(
33
+ name=f"wkv_cuda",
34
+ sources=[f"{current_path}/cuda/wrapper.cpp", f"{current_path}/cuda/operators.cu", f"{current_path}/cuda/gemm_fp16_cublas.cpp"],
35
+ verbose=True,
36
+ extra_ldflags=["cublas.lib" if os.name == "nt" else ""],
37
+ extra_cuda_cflags=["--use_fast_math", "-O3", "--extra-device-vectorization"],
38
+ is_python_module=False)
39
+ DISABLE_CUBLAS_GEMM = False
40
+ except:
41
+ print("Failed to build cuBLAS matmul, falling back to torch.matmul. Small model with fp16 will overflow.")
42
+ load(
43
+ name=f"wkv_cuda",
44
+ sources=[f"{current_path}/cuda/wrapper.cpp", f"{current_path}/cuda/operators.cu"],
45
+ verbose=True,
46
+ extra_cuda_cflags=["--use_fast_math", "-O3", "--extra-device-vectorization"],
47
+ extra_cflags=["-DDISABLE_CUBLAS_GEMM"],
48
+ is_python_module=False)
49
+ DISABLE_CUBLAS_GEMM = True
50
+
51
+ @MyStatic
52
+ def cuda_wkv(T: int, C: int, w, u, k, v, aa, bb, pp):
53
+ assert 1 * C % min(C, 32) == 0
54
+ assert k.dtype == v.dtype == torch.float16 or k.dtype == v.dtype == torch.float32
55
+ assert w.dtype == u.dtype == aa.dtype == bb.dtype == pp.dtype == torch.float32
56
+ w = w.contiguous()
57
+ u = u.contiguous()
58
+ k = k.contiguous()
59
+ v = v.contiguous()
60
+ y = torch.empty((T, C), device=w.device, memory_format=torch.contiguous_format, dtype=k.dtype)
61
+ torch.ops.rwkv.wkv_forward(1, T, C, w, u, k, v, y, aa, bb, pp)
62
+ return y, aa, bb, pp
63
+ @MyStatic
64
+ def cuda_mm8_seq(B: int, N: int, M: int, x, w, mx, rx, my, ry):
65
+ assert x.dtype == mx.dtype == rx.dtype == my.dtype == ry.dtype
66
+ assert x.dtype == torch.float32 or x.dtype == torch.float16
67
+ assert w.dtype == torch.uint8
68
+ assert x.shape == (B, N)
69
+ assert w.shape == (N, M)
70
+ assert rx.shape == mx.shape == (M,)
71
+ assert ry.shape == my.shape == (N, 1)
72
+ y = torch.empty((B, M), device=w.device, dtype=x.dtype)
73
+ torch.ops.rwkv.mm8_seq(B, N, M, x, w, mx, rx, my, ry, y)
74
+ return y
75
+ @MyStatic
76
+ def cuda_mm8_one(N: int, M: int, x, w, mx, rx, my, ry):
77
+ assert x.dtype == mx.dtype == rx.dtype == my.dtype == ry.dtype
78
+ assert x.dtype == torch.float32 or x.dtype == torch.float16
79
+ assert w.dtype == torch.uint8
80
+ assert x.shape == (N,)
81
+ assert w.shape == (N, M)
82
+ assert rx.shape == mx.shape == (M,)
83
+ assert ry.shape == my.shape == (N, 1)
84
+ y = torch.zeros((M,), device=w.device, dtype=torch.float32)
85
+ torch.ops.rwkv.mm8_one(N, M, x, w, mx, rx, my, ry, y)
86
+ return y.to(dtype=x.dtype)
87
+ else:
88
+ os.environ["RWKV_CUDA_ON"] = '0'
89
+
90
+
91
+ @MyStatic
92
+ def torch_mm8_seq(x, w, mx, rx, my, ry):
93
+ return x @ ((w.to(dtype=x.dtype) + 0.5) * ry * rx + my + mx)
94
+
95
+ @MyStatic
96
+ def torch_mm8_one(x, w, mx, rx, my, ry):
97
+ return x @ ((w.to(dtype=x.dtype) + 0.5) * ry * rx + my + mx)
98
+
99
+ if os.environ.get('RWKV_CUDA_ON') == '1':
100
+ @MyStatic
101
+ def mm8_seq(x, w, mx, rx, my, ry):
102
+ if w.device.type == 'cuda' and x.dtype == torch.float16:
103
+ B, N, M = x.shape[0], w.shape[0], w.shape[1]
104
+ return cuda_mm8_seq(B, N, M, x, w, mx, rx, my, ry)
105
+ else:
106
+ return torch_mm8_seq(x, w, mx, rx, my, ry)
107
+ @MyStatic
108
+ def mm8_one(x, w, mx, rx, my, ry):
109
+ if w.device.type == 'cuda':
110
+ N, M = w.shape[0], w.shape[1]
111
+ return cuda_mm8_one(N, M, x, w, mx, rx, my, ry)
112
+ else:
113
+ return torch_mm8_one(x, w, mx, rx, my, ry)
114
+ else:
115
+ @MyStatic
116
+ def mm8_seq(x, w, mx, rx, my, ry):
117
+ return torch_mm8_seq(x, w, mx, rx, my, ry)
118
+ @MyStatic
119
+ def mm8_one(x, w, mx, rx, my, ry):
120
+ return torch_mm8_one(x, w, mx, rx, my, ry)
121
+
122
+ def mm8(x: torch.Tensor, w: torch.Tensor, mx: torch.Tensor, rx: torch.Tensor, my: torch.Tensor, ry: torch.Tensor):
123
+ if len(x.shape) == 1:
124
+ return mm8_one(x, w, mx, rx, my, ry)
125
+ return mm8_seq(x, w, mx, rx, my, ry)
126
+
127
+ def matmul(a, b, mx: Optional[torch.Tensor]=None, rx: Optional[torch.Tensor]=None, my: Optional[torch.Tensor]=None, ry: Optional[torch.Tensor]=None, output_dtype: Optional[torch.dtype]=None) -> torch.Tensor:
128
+ if output_dtype is None:
129
+ output_dtype = a.dtype
130
+ if b.dtype in [torch.float16, torch.bfloat16, torch.float32]:
131
+ assert a.dtype == b.dtype
132
+ return matmul_float(a, b, output_dtype=output_dtype)
133
+ elif b.dtype == torch.uint8:
134
+ assert mx is not None
135
+ assert rx is not None
136
+ assert my is not None
137
+ assert ry is not None
138
+ return mm8(a, b, mx, rx, my, ry).to(output_dtype)
139
+ else:
140
+ raise ValueError("Unsupported dtype")
141
+
142
+
143
+ if os.environ.get('RWKV_CUDA_ON') == '1' and not DISABLE_CUBLAS_GEMM:
144
+ def matmul_float(a, b, output_dtype: Optional[torch.dtype]=None):
145
+ if output_dtype is None:
146
+ output_dtype = a.dtype
147
+ if a.dtype == b.dtype == torch.float16 and a.device.type == 'cuda':
148
+ if len(a.shape) == 1:
149
+ assert len(b.shape) == 2
150
+ c = torch.empty((b.shape[-1],), dtype=output_dtype, device=a.device)
151
+ a = a.unsqueeze(0)
152
+ else:
153
+ assert len(a.shape) == len(b.shape)
154
+ assert len(a.shape) == 2 or len(a.shape) == 3
155
+ # torch.empty((*a.shape[:-1], b.shape[-1])) doesn't work with jit
156
+ if len(a.shape) == 2:
157
+ c = torch.empty((a.shape[0], b.shape[-1]), dtype=output_dtype, device=a.device)
158
+ else:
159
+ c = torch.empty((a.shape[0], a.shape[1], b.shape[-1]), dtype=output_dtype, device=a.device)
160
+ torch.ops.rwkv.gemm_fp16_cublas(a, b, c)
161
+ return c
162
+ else:
163
+ return (a @ b).to(output_dtype)
164
+
165
+ else:
166
+ def matmul_float(a, b, output_dtype: Optional[torch.dtype]=None):
167
+ return (a @ b).to(output_dtype)
168
+
169
+
170
+ if os.environ.get('RWKV_DML_ON') == '1':
171
+ import torch_directml
172
+ print("PyTorch with DirectML Enabled")
173
+
174
+
175
+ print(f'\n### RWKV-7 "Goose" enabled ###\n')
176
+
177
+ torch.backends.cudnn.benchmark = True
178
+ torch.backends.cudnn.allow_tf32 = True
179
+ torch.backends.cuda.matmul.allow_tf32 = True
180
+ # torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
181
+ # torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True
182
+ torch._C._jit_set_autocast_mode(False)
183
+
184
+ MyModule = torch.jit.ScriptModule
185
+ MyFunction = torch.jit.script_method
186
+ MyStatic = torch.jit.script
187
+ from typing import List
188
+
189
+ DTYPE = None
190
+ DEVICE = None
191
+ HEAD_SIZE = 64
192
+
193
+ if os.environ.get('RWKV_CUDA_ON') == '1':
194
+ from torch.utils.cpp_extension import load
195
+ load(name="wkv7s", sources=[f"{current_path}/cuda/rwkv7_op.cpp", f"{current_path}/cuda/rwkv7.cu"], is_python_module=False,
196
+ verbose=True, extra_cuda_cflags=["-res-usage", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-D_N_={HEAD_SIZE}"])
197
+ class WKV_7(torch.autograd.Function):
198
+ @staticmethod
199
+ def forward(ctx, state, r, w, k, v, a, b):
200
+ with torch.no_grad():
201
+ T, C = r.size()
202
+ H = C // HEAD_SIZE
203
+ N = HEAD_SIZE
204
+ assert HEAD_SIZE == C // H
205
+ assert all(x.dtype == DTYPE for x in [r,w,k,v,a,b])
206
+ assert all(x.is_contiguous() for x in [r,w,k,v,a,b])
207
+ y = torch.empty((T, C), device=DEVICE, dtype=r.dtype, requires_grad=False, memory_format=torch.contiguous_format)
208
+
209
+ if DTYPE == torch.float16:
210
+ torch.ops.wkv7s.forward_fp16(1, T, C, H, state, r, w, k, v, a, b, y)
211
+ elif DTYPE == torch.bfloat16:
212
+ torch.ops.wkv7s.forward_bf16(1, T, C, H, state, r, w, k, v, a, b, y)
213
+ elif DTYPE == torch.float32:
214
+ torch.ops.wkv7s.forward_fp32(1, T, C, H, state, r, w, k, v, a, b, y)
215
+
216
+ return y
217
+ def RWKV7_OP(state, r, w, k, v, a, b):
218
+ return WKV_7.apply(state, r, w, k, v, a, b)
219
+
220
+ ########################################################################################################
221
+
222
+ class RWKV(MyModule):
223
+ def __init__(self, model, strategy):
224
+ global DTYPE, DEVICE
225
+ super().__init__()
226
+ self.eval()
227
+ args = types.SimpleNamespace()
228
+ self.args = args
229
+ # args.MODEL_NAME = model
230
+
231
+ # print(f'Loading {model} ({strategy})\n')
232
+
233
+ ss = strategy.split(' ')
234
+ DEVICE = ss[0]
235
+ if ss[1] == 'fp16':
236
+ DTYPE = torch.half
237
+ elif ss[1] == 'fp32':
238
+ DTYPE = torch.float32
239
+ elif ss[1] == 'bf16':
240
+ DTYPE = torch.bfloat16
241
+ else:
242
+ assert False, "currently rwkv7 strategy must be: cuda/cpu fp16/fp32/bf16"
243
+
244
+ # self.z = torch.load(args.MODEL_NAME + '.pth', map_location=DEVICE)
245
+ self.z = model
246
+
247
+ # for k,v in self.z.items():
248
+ # print(k, v.shape)
249
+ z = self.z
250
+
251
+ self.n_head, self.head_size = z['blocks.0.att.r_k'].shape
252
+ args.head_size = self.head_size
253
+ args.vocab_size, args.n_embd = z['emb.weight'].shape
254
+
255
+ args.n_layer = 0
256
+ keys = list(z.keys())
257
+ for k in keys:
258
+ layer_id = int(k.split('.')[1]) if ('blocks.' in k) else 0
259
+ args.n_layer = max(args.n_layer, layer_id+1)
260
+ if 'key.weight' in k or 'value.weight' in k or 'receptance.weight' in k or 'output.weight' in k or 'head.weight' in k:
261
+ z[k] = z[k].t()
262
+ z[k] = z[k].squeeze().to(dtype=DTYPE)
263
+ if k.endswith('att.r_k'): z[k] = z[k].flatten()
264
+
265
+ self.n_embd = args.n_embd
266
+ self.n_layer = args.n_layer
267
+
268
+ # z['emb.weight'] = F.layer_norm(z['emb.weight'], (args.n_embd,), weight=z['blocks.0.ln0.weight'], bias=z['blocks.0.ln0.bias'])
269
+ z['blocks.0.att.v0'] = z['blocks.0.att.a0'] # actually ignored
270
+ z['blocks.0.att.v1'] = z['blocks.0.att.a1'] # actually ignored
271
+ z['blocks.0.att.v2'] = z['blocks.0.att.a2'] # actually ignored
272
+
273
+ def forward(self, idx, state, full_output=False, sign=None):
274
+ if state == None:
275
+ state = [None for _ in range(self.args.n_layer * 3)]
276
+ for i in range(self.args.n_layer): # state: 0=att_x_prev 1=att_kv 2=ffn_x_prev
277
+ state[i*3+0] = torch.zeros(self.args.n_embd, dtype=DTYPE, requires_grad=False, device=DEVICE)
278
+ state[i*3+1] = torch.zeros((self.args.n_embd // self.args.head_size, self.args.head_size, self.args.head_size), dtype=torch.float, requires_grad=False, device=DEVICE)
279
+ state[i*3+2] = torch.zeros(self.args.n_embd, dtype=DTYPE, requires_grad=False, device=DEVICE)
280
+
281
+ x = self.z['emb.weight'][idx]
282
+ if isinstance(sign, torch.Tensor):
283
+ sign = sign.squeeze(0)
284
+ # sign = F.layer_norm(sign, (self.args.n_embd,), weight=self.z['blocks.0.ln0.weight'], bias=self.z['blocks.0.ln0.bias'])
285
+
286
+ x = torch.cat((sign,x.to(DEVICE)), dim=0)
287
+
288
+ x = F.layer_norm(x, (self.args.n_embd,), weight=self.z['blocks.0.ln0.weight'], bias=self.z['blocks.0.ln0.bias'])
289
+ # if isinstance(sign, torch.Tensor):
290
+ # print(x)
291
+
292
+ if type(idx) is list:
293
+ if len(idx) > 1:
294
+ return self.forward_seq(x, state, full_output)
295
+ else:
296
+ return self.forward_one(x[0], state)
297
+ else:
298
+ return self.forward_one(x, state)
299
+
300
+ @MyFunction
301
+ def forward_one(self, x, state:List[torch.Tensor]):
302
+ with torch.no_grad():
303
+ z = self.z
304
+ #x = z['emb.weight'][idx]
305
+
306
+ v_first = torch.empty_like(x)
307
+ for i in range(self.n_layer):
308
+ bbb = f'blocks.{i}.'
309
+ att = f'blocks.{i}.att.'
310
+ ffn = f'blocks.{i}.ffn.'
311
+
312
+ xx = F.layer_norm(x, (self.n_embd,), weight=z[bbb+'ln1.weight'], bias=z[bbb+'ln1.bias'])
313
+
314
+ xx, state[i*3+0], state[i*3+1], v_first = RWKV_x070_TMix_one(i, self.n_head, self.head_size, xx, state[i*3+0], v_first, state[i*3+1],
315
+ z[att+'x_r'], z[att+'x_w'], z[att+'x_k'], z[att+'x_v'], z[att+'x_a'], z[att+'x_g'],
316
+ z[att+'w0'], z[att+'w1'], z[att+'w2'], z[att+'a0'], z[att+'a1'], z[att+'a2'], z[att+'v0'], z[att+'v1'], z[att+'v2'],
317
+ z[att+'g1'], z[att+'g2'], z[att+'k_k'], z[att+'k_a'], z[att+'r_k'],
318
+ z[att+'receptance.weight'], z[att+'key.weight'], z[att+'value.weight'], z[att+'output.weight'],
319
+ z[att+'ln_x.weight'], z[att+'ln_x.bias'])
320
+ x = x + xx
321
+
322
+ xx = F.layer_norm(x, (self.n_embd,), weight=z[bbb+'ln2.weight'], bias=z[bbb+'ln2.bias'])
323
+
324
+ xx, state[i*3+2] = RWKV_x070_CMix_one(xx, state[i*3+2], z[ffn+'x_k'], z[ffn+'key.weight'], z[ffn+'value.weight'])
325
+ x = x + xx
326
+
327
+ # if math.isnan(torch.min(x).item()): print(idx, i)
328
+
329
+ x = F.layer_norm(x, (self.n_embd,), weight=z['ln_out.weight'], bias=z['ln_out.bias'])
330
+ x = x @ z['head.weight']
331
+ return x, state
332
+
333
+ @MyFunction
334
+ def forward_seq(self, x, state:List[torch.Tensor], full_output:bool=False):
335
+ with torch.no_grad():
336
+ z = self.z
337
+ #x = z['emb.weight'][idx]
338
+
339
+ v_first = torch.empty_like(x)
340
+ for i in range(self.n_layer):
341
+ bbb = f'blocks.{i}.'
342
+ att = f'blocks.{i}.att.'
343
+ ffn = f'blocks.{i}.ffn.'
344
+
345
+ xx = F.layer_norm(x, (self.n_embd,), weight=z[bbb+'ln1.weight'], bias=z[bbb+'ln1.bias'])
346
+
347
+ xx, state[i*3+0], state[i*3+1], v_first = RWKV_x070_TMix_seq(i, self.n_head, self.head_size, xx, state[i*3+0], v_first, state[i*3+1],
348
+ z[att+'x_r'], z[att+'x_w'], z[att+'x_k'], z[att+'x_v'], z[att+'x_a'], z[att+'x_g'],
349
+ z[att+'w0'], z[att+'w1'], z[att+'w2'], z[att+'a0'], z[att+'a1'], z[att+'a2'], z[att+'v0'], z[att+'v1'], z[att+'v2'],
350
+ z[att+'g1'], z[att+'g2'], z[att+'k_k'], z[att+'k_a'], z[att+'r_k'],
351
+ z[att+'receptance.weight'], z[att+'key.weight'], z[att+'value.weight'], z[att+'output.weight'],
352
+ z[att+'ln_x.weight'], z[att+'ln_x.bias'])
353
+ x = x + xx
354
+
355
+ xx = F.layer_norm(x, (self.n_embd,), weight=z[bbb+'ln2.weight'], bias=z[bbb+'ln2.bias'])
356
+
357
+ xx, state[i*3+2] = RWKV_x070_CMix_seq(xx, state[i*3+2], z[ffn+'x_k'], z[ffn+'key.weight'], z[ffn+'value.weight'])
358
+ x = x + xx
359
+
360
+ if not full_output: x = x[-1,:]
361
+ x = F.layer_norm(x, (self.n_embd,), weight=z['ln_out.weight'], bias=z['ln_out.bias'])
362
+ x = x @ z['head.weight']
363
+ return x, state
364
+
365
+ ########################################################################################################
366
+
367
+ @MyStatic
368
+ def RWKV_x070_TMix_one(layer_id: int, H:int, N:int, x, x_prev, v_first, state, x_r, x_w, x_k, x_v, x_a, x_g, w0, w1, w2, a0, a1, a2, v0, v1, v2, g1, g2, k_k, k_a, r_k, R_, K_, V_, O_, ln_w, ln_b):
369
+ xx = x_prev - x
370
+ xr, xw, xk, xv, xa, xg = x+xx*x_r, x+xx*x_w, x+xx*x_k, x+xx*x_v, x+xx*x_a, x+xx*x_g
371
+
372
+ r = xr @ R_
373
+ w = torch.tanh(xw @ w1) @ w2
374
+ k = xk @ K_
375
+ v = xv @ V_
376
+ a = torch.sigmoid(a0 + (xa @ a1) @ a2)
377
+ g = torch.sigmoid(xg @ g1) @ g2
378
+
379
+ kk = torch.nn.functional.normalize((k * k_k).view(H,N), dim=-1, p=2.0).view(H*N)
380
+ k = k * (1 + (a-1) * k_a)
381
+ if layer_id == 0: v_first = v
382
+ else: v = v + (v_first - v) * torch.sigmoid(v0 + (xv @ v1) @ v2)
383
+ w = torch.exp(-0.606531 * torch.sigmoid((w0 + w).float())) # 0.606531 = exp(-0.5)
384
+
385
+ vk = v.view(H,N,1) @ k.view(H,1,N)
386
+ ab = (-kk).view(H,N,1) @ (kk*a).view(H,1,N)
387
+ state = state * w.view(H,1,N) + state @ ab.float() + vk.float()
388
+ xx = (state.to(dtype=x.dtype) @ r.view(H,N,1))
389
+
390
+ xx = torch.nn.functional.group_norm(xx.view(1,H*N), num_groups=H, weight=ln_w, bias=ln_b, eps = 64e-5).view(H*N)
391
+ xx = xx + ((r * k * r_k).view(H,N).sum(dim=-1, keepdim=True) * v.view(H,N)).view(H*N)
392
+ return (xx * g) @ O_, x, state, v_first
393
+
394
+ if os.environ.get('RWKV_CUDA_ON') == '1':
395
+ @MyStatic
396
+ def RWKV_x070_TMix_seq(layer_id: int, H:int, N:int, x, x_prev, v_first, state, x_r, x_w, x_k, x_v, x_a, x_g, w0, w1, w2, a0, a1, a2, v0, v1, v2, g1, g2, k_k, k_a, r_k, R_, K_, V_, O_, ln_w, ln_b):
397
+ T = x.shape[0]
398
+ xx = torch.cat((x_prev.unsqueeze(0), x[:-1,:])) - x
399
+ xr, xw, xk, xv, xa, xg = x+xx*x_r, x+xx*x_w, x+xx*x_k, x+xx*x_v, x+xx*x_a, x+xx*x_g
400
+
401
+ r = xr @ R_
402
+ w = torch.tanh(xw @ w1) @ w2
403
+ k = xk @ K_
404
+ v = xv @ V_
405
+ a = torch.sigmoid(a0 + (xa @ a1) @ a2)
406
+ g = torch.sigmoid(xg @ g1) @ g2
407
+
408
+ kk = torch.nn.functional.normalize((k * k_k).view(T,H,N), dim=-1, p=2.0).view(T,H*N)
409
+ k = k * (1 + (a-1) * k_a)
410
+ if layer_id == 0: v_first = v
411
+ else: v = v + (v_first - v) * torch.sigmoid(v0 + (xv @ v1) @ v2)
412
+
413
+ w = -torch.nn.functional.softplus(-(w0 + w)) - 0.5
414
+ xx = RWKV7_OP(state, r, w, k, v, -kk, kk*a)
415
+
416
+ xx = torch.nn.functional.group_norm(xx.view(T,H*N), num_groups=H, weight=ln_w, bias=ln_b, eps = 64e-5).view(T,H*N)
417
+ xx = xx + ((r * k * r_k).view(T,H,N).sum(dim=-1, keepdim=True) * v.view(T,H,N)).view(T,H*N)
418
+ return (xx * g) @ O_, x[-1,:], state, v_first
419
+ else:
420
+ @MyStatic
421
+ def RWKV_x070_TMix_seq(layer_id: int, H:int, N:int, x, x_prev, v_first, state, x_r, x_w, x_k, x_v, x_a, x_g, w0, w1, w2, a0, a1, a2, v0, v1, v2, g1, g2, k_k, k_a, r_k, R_, K_, V_, O_, ln_w, ln_b):
422
+ T = x.shape[0]
423
+ xx = torch.cat((x_prev.unsqueeze(0), x[:-1,:])) - x
424
+ xr, xw, xk, xv, xa, xg = x+xx*x_r, x+xx*x_w, x+xx*x_k, x+xx*x_v, x+xx*x_a, x+xx*x_g
425
+
426
+ r = xr @ R_
427
+ w = torch.tanh(xw @ w1) @ w2
428
+ k = xk @ K_
429
+ v = xv @ V_
430
+ a = torch.sigmoid(a0 + (xa @ a1) @ a2)
431
+ g = torch.sigmoid(xg @ g1) @ g2
432
+
433
+ kk = torch.nn.functional.normalize((k * k_k).view(T,H,N), dim=-1, p=2.0).view(T,H*N)
434
+ k = k * (1 + (a-1) * k_a)
435
+ if layer_id == 0: v_first = v
436
+ else: v = v + (v_first - v) * torch.sigmoid(v0 + (xv @ v1) @ v2)
437
+
438
+ w = torch.exp(-0.606531 * torch.sigmoid((w0 + w).float())) # 0.606531 = exp(-0.5)
439
+ for t in range(T):
440
+ r_, w_, k_, v_, kk_, a_ = r[t], w[t], k[t], v[t], kk[t], a[t]
441
+ vk = v_.view(H,N,1) @ k_.view(H,1,N)
442
+ ab = (-kk_).view(H,N,1) @ (kk_*a_).view(H,1,N)
443
+ state = state * w_.view(H,1,N) + state @ ab.float() + vk.float()
444
+ xx[t] = (state.to(dtype=x.dtype) @ r_.view(H,N,1)).view(H*N)
445
+
446
+ xx = torch.nn.functional.group_norm(xx.view(T,H*N), num_groups=H, weight=ln_w, bias=ln_b, eps = 64e-5).view(T,H*N)
447
+ xx = xx + ((r * k * r_k).view(T,H,N).sum(dim=-1, keepdim=True) * v.view(T,H,N)).view(T,H*N)
448
+ return (xx * g) @ O_, x[-1,:], state, v_first
449
+
450
+ ########################################################################################################
451
+
452
+ @MyStatic
453
+ def RWKV_x070_CMix_one(x, x_prev, x_k, K_, V_):
454
+ xx = x_prev - x
455
+ k = x + xx * x_k
456
+ k = torch.relu(k @ K_) ** 2
457
+ return k @ V_, x
458
+
459
+ @MyStatic
460
+ def RWKV_x070_CMix_seq(x, x_prev, x_k, K_, V_):
461
+ xx = torch.cat((x_prev.unsqueeze(0), x[:-1,:])) - x
462
+ k = x + xx * x_k
463
+ k = torch.relu(k @ K_) ** 2
464
+ return k @ V_, x[-1,:]
465
+
466
+ ########################################################################################################
467
+
468
+
469
+
infer/rwkv/rwkv_tokenizer.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ########################################################################################################
2
+ # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
3
+ ########################################################################################################
4
+
5
+ class TRIE:
6
+ __slots__ = tuple("ch,to,values,front".split(","))
7
+ to:list
8
+ values:set
9
+ def __init__(self, front=None, ch=None):
10
+ self.ch = ch
11
+ self.to = [None for ch in range(256)]
12
+ self.values = set()
13
+ self.front = front
14
+
15
+ def __repr__(self):
16
+ fr = self
17
+ ret = []
18
+ while(fr!=None):
19
+ if(fr.ch!=None):
20
+ ret.append(fr.ch)
21
+ fr = fr.front
22
+ return "<TRIE %s %s>"%(ret[::-1], self.values)
23
+
24
+ def add(self, key:bytes, idx:int=0, val=None):
25
+ if(idx == len(key)):
26
+ if(val is None):
27
+ val = key
28
+ self.values.add(val)
29
+ return self
30
+ ch = key[idx]
31
+ if(self.to[ch] is None):
32
+ self.to[ch] = TRIE(front=self, ch=ch)
33
+ return self.to[ch].add(key, idx=idx+1, val=val)
34
+
35
+ def find_longest(self, key:bytes, idx:int=0):
36
+ u:TRIE = self
37
+ ch:int = key[idx]
38
+
39
+ while(u.to[ch] is not None):
40
+ u = u.to[ch]
41
+ idx += 1
42
+ if(u.values):
43
+ ret = idx, u, u.values
44
+ if(idx==len(key)):
45
+ break
46
+ ch = key[idx]
47
+ return ret
48
+
49
+ class TRIE_TOKENIZER():
50
+ def __init__(self, file_name):
51
+ self.idx2token = {}
52
+ sorted = [] # must be already sorted
53
+ with open(file_name, "r", encoding="utf-8") as f:
54
+ lines = f.readlines()
55
+ for l in lines:
56
+ idx = int(l[:l.index(' ')])
57
+ x = eval(l[l.index(' '):l.rindex(' ')])
58
+ x = x.encode("utf-8") if isinstance(x, str) else x
59
+ assert isinstance(x, bytes)
60
+ assert len(x) == int(l[l.rindex(' '):])
61
+ sorted += [x]
62
+ self.idx2token[idx] = x
63
+
64
+ self.token2idx = {}
65
+ for k,v in self.idx2token.items():
66
+ self.token2idx[v] = int(k)
67
+
68
+ self.root = TRIE()
69
+ for t, i in self.token2idx.items():
70
+ _ = self.root.add(t, val=(t, i))
71
+
72
+ def encodeBytes(self, src:bytes):
73
+ idx:int = 0
74
+ tokens = []
75
+ while (idx < len(src)):
76
+ _idx:int = idx
77
+ idx, _, values = self.root.find_longest(src, idx)
78
+ assert(idx != _idx)
79
+ _, token = next(iter(values))
80
+ tokens.append(token)
81
+ return tokens
82
+
83
+ def decodeBytes(self, tokens):
84
+ return b''.join(map(lambda i: self.idx2token[i], tokens))
85
+
86
+ def encode(self, src):
87
+ return self.encodeBytes(src.encode("utf-8"))
88
+
89
+ def decode(self, tokens):
90
+ try:
91
+ return self.decodeBytes(tokens).decode('utf-8')
92
+ except:
93
+ return '\ufffd' # bad utf-8
94
+
95
+ def printTokens(self, tokens):
96
+ for i in tokens:
97
+ s = self.idx2token[i]
98
+ try:
99
+ s = s.decode('utf-8')
100
+ except:
101
+ pass
102
+ print(f'{repr(s)}{i}', end=' ')
103
+ print()
infer/rwkv/rwkv_vocab_v20230424.txt ADDED
The diff for this file is too large to render. See raw diff
 
infer/rwkv/utils.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ########################################################################################################
2
+ # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
3
+ ########################################################################################################
4
+
5
+ import os, sys
6
+ import numpy as np
7
+ import torch
8
+ from torch.nn import functional as F
9
+
10
+ class PIPELINE_ARGS():
11
+ def __init__(self, temperature=1.0, top_p=0.85, top_k=0, alpha_frequency=0.2, alpha_presence=0.2, alpha_decay=0.996, token_ban=[], token_stop=[], chunk_len=256):
12
+ self.temperature = temperature
13
+ self.top_p = top_p
14
+ self.top_k = top_k
15
+ self.alpha_frequency = alpha_frequency # Frequency Penalty (as in GPT-3)
16
+ self.alpha_presence = alpha_presence # Presence Penalty (as in GPT-3)
17
+ self.alpha_decay = alpha_decay # gradually decay the penalty
18
+ self.token_ban = token_ban # ban the generation of some tokens
19
+ self.token_stop = token_stop # stop generation whenever you see any token here
20
+ self.chunk_len = chunk_len # split input into chunks to save VRAM (shorter -> slower)
21
+
22
+ class PIPELINE():
23
+ def __init__(self, model, WORD_NAME):
24
+ self.model = model
25
+ if WORD_NAME == 'cl100k_base':
26
+ import tiktoken
27
+ self.tokenizer = tiktoken.get_encoding(WORD_NAME)
28
+ elif WORD_NAME == 'rwkv_vocab_v20230424':
29
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
30
+ from rwkv_tokenizer import TRIE_TOKENIZER
31
+ self.tokenizer = TRIE_TOKENIZER(os.path.dirname(os.path.abspath(__file__)) + '/rwkv_vocab_v20230424.txt')
32
+ else:
33
+ from tokenizers import Tokenizer
34
+ self.tokenizer = Tokenizer.from_file(WORD_NAME)
35
+
36
+ def refine_context(self, context):
37
+ context = context.strip().split('\n')
38
+ for c in range(len(context)):
39
+ context[c] = context[c].strip().strip('\u3000').strip('\r')
40
+ context = list(filter(lambda c: c != '', context))
41
+ context = '\n' + ('\n'.join(context)).strip()
42
+ if context == '':
43
+ context = '\n'
44
+ return context
45
+
46
+ def encode(self, x):
47
+ if 'Tokenizer' in str(type(self.tokenizer)):
48
+ return self.tokenizer.encode(x).ids
49
+ else:
50
+ return self.tokenizer.encode(x)
51
+
52
+ def decode(self, x):
53
+ return self.tokenizer.decode(x)
54
+
55
+ def sample_logits(self, logits, temperature=1.0, top_p=0.85, top_k=0):
56
+ if temperature == 0:
57
+ temperature = 1.0
58
+ top_p = 0
59
+ probs = F.softmax(logits.float(), dim=-1)
60
+ top_k = int(top_k)
61
+ # 'privateuseone' is the type of custom devices like `torch_directml.device()`
62
+ if probs.device.type in ['cpu', 'privateuseone']:
63
+ probs = probs.cpu().numpy()
64
+ sorted_ids = np.argsort(probs)
65
+ sorted_probs = probs[sorted_ids][::-1]
66
+ cumulative_probs = np.cumsum(sorted_probs)
67
+ cutoff = float(sorted_probs[np.argmax(cumulative_probs >= top_p)])
68
+ probs[probs < cutoff] = 0
69
+ if top_k < len(probs) and top_k > 0:
70
+ probs[sorted_ids[:-top_k]] = 0
71
+ if temperature != 1.0:
72
+ probs = probs ** (1.0 / temperature)
73
+ probs = probs / np.sum(probs)
74
+ out = np.random.choice(a=len(probs), p=probs)
75
+ return int(out)
76
+ else:
77
+ sorted_ids = torch.argsort(probs)
78
+ sorted_probs = probs[sorted_ids]
79
+ sorted_probs = torch.flip(sorted_probs, dims=(0,))
80
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1).cpu().numpy()
81
+ cutoff = float(sorted_probs[np.argmax(cumulative_probs >= top_p)])
82
+ probs[probs < cutoff] = 0
83
+ if top_k < len(probs) and top_k > 0:
84
+ probs[sorted_ids[:-top_k]] = 0
85
+ if temperature != 1.0:
86
+ probs = probs ** (1.0 / temperature)
87
+ out = torch.multinomial(probs, num_samples=1)[0]
88
+ return int(out)
89
+
90
+ def generate(self, ctx, token_count=100, args=PIPELINE_ARGS(), callback=None, state=None, sign=None):
91
+ all_tokens = []
92
+ out_last = 0
93
+ out_str = ''
94
+ occurrence = {}
95
+ for i in range(token_count):
96
+
97
+ # forward & adjust prob.
98
+ tokens = self.encode(ctx) if i == 0 else [token]
99
+ while len(tokens) > 0:
100
+ out, state = self.model.forward(tokens[:args.chunk_len], state, sign=sign)
101
+ tokens = tokens[args.chunk_len:]
102
+
103
+ for n in args.token_ban:
104
+ out[n] = -float('inf')
105
+ for n in occurrence:
106
+ out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
107
+
108
+ # sampler
109
+ token = self.sample_logits(out, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k)
110
+
111
+ if token in args.token_stop:
112
+ break
113
+ all_tokens += [token]
114
+ for xxx in occurrence:
115
+ occurrence[xxx] *= args.alpha_decay
116
+
117
+ ttt = self.decode([token])
118
+
119
+ www = 1
120
+ if ttt in ' \t0123456789':
121
+ www = 0
122
+ # elif ttt in '\r\n,.;?!"\':+-*/=#@$%^&_`~|<>\\()[]{},。;“”:?!()【】':
123
+ # www = 0.5
124
+ if token not in occurrence:
125
+ occurrence[token] = www
126
+ else:
127
+ occurrence[token] += www
128
+ # print(occurrence) # debug
129
+
130
+ # output
131
+ tmp = self.decode(all_tokens[out_last:])
132
+ if '\ufffd' not in tmp: # is valid utf-8 string?
133
+ if callback:
134
+ callback(tmp)
135
+ out_str += tmp
136
+ out_last = i + 1
137
+ sign =None
138
+ return out_str, state
139
+
140
+ def prefill(self, ctx, token_count=100, args=PIPELINE_ARGS(), callback=None, state=None, sign=None):
141
+ tokens = self.encode(ctx)
142
+ out, state = self.model.forward(tokens[:args.chunk_len], state, sign=sign,full_output=True)
143
+ max_indices = torch.argmax(out, dim=-1)
144
+ token = self.sample_logits(out[19,:], temperature=args.temperature, top_p=args.top_p, top_k=args.top_k)
145
+ print(token, self.decode([token]),out[19,:])
146
+ print(state[0].view(-1))
147
+ print(self.decode(max_indices.tolist()))
148
+ return self.decode(max_indices.tolist())
infer/worldmodel.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+
4
+ import os, sys, torch, time
5
+ import numpy as np
6
+ np.set_printoptions(precision=4, suppress=True, linewidth=200)
7
+ import torch
8
+ print(torch.__version__)
9
+ print(torch.version.cuda)
10
+
11
+ # set these before import RWKV
12
+ # os.environ['RWKV_JIT_ON'] = '1'
13
+ # os.environ["RWKV_CUDA_ON"] = '1' # '1' to compile CUDA kernel (10x faster), requires c++ compiler & cuda libraries
14
+ from infer.rwkv.model import RWKV # pip install rwkv
15
+ from infer.rwkv.utils import PIPELINE, PIPELINE_ARGS
16
+
17
+
18
+ from world.world_encoder import WorldEncoder
19
+
20
+ class Worldinfer():
21
+ def __init__(self, model_path, encoder_type, encoder_path, strategy='cpu bf16', args=None):
22
+
23
+ ss = strategy.split(' ')
24
+ DEVICE = ss[0]
25
+ if ss[1] == 'fp16':
26
+ self.DTYPE = torch.half
27
+ elif ss[1] == 'fp32':
28
+ self.DTYPE = torch.float32
29
+ elif ss[1] == 'bf16':
30
+ self.DTYPE = torch.bfloat16
31
+ else:
32
+ assert False, "currently rwkv7 strategy must be: cuda/cpu fp16/fp32/bf16"
33
+
34
+ self.model_weight = torch.load(model_path + '.pth', map_location=DEVICE)
35
+ modality_dict = {}
36
+ for key, value in self.model_weight.items():
37
+ if 'emb.weight' in key:
38
+ _, n_embd = value.shape
39
+ if 'modality' in key:
40
+ k = key.replace('modality.world_encoder.', '')
41
+ modality_dict[k] = value
42
+ model = RWKV(model=self.model_weight, strategy=strategy)
43
+ self.pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
44
+
45
+ if args==None:
46
+ self.args = PIPELINE_ARGS(temperature = 1.0, top_p = 0.0, top_k=0, # top_k = 0 then ignore
47
+ alpha_frequency = 0.0,
48
+ alpha_presence = 0.0,
49
+ token_ban = [0], # ban the generation of some tokens
50
+ token_stop = [24], # stop generation whenever you see any token here
51
+ chunk_len = 256) # split input into chunks to save VRAM (shorter -> slower)
52
+ else:
53
+ self.args=args
54
+ print('RWKV finish!!!')
55
+
56
+ config = {
57
+ 'encoder_type': encoder_type,
58
+ 'encoder_path': encoder_path,
59
+ 'project_dim' : n_embd
60
+ }
61
+ self.modality = WorldEncoder(**config).to(DEVICE, torch.bfloat16)
62
+ self.modality.load_checkpoint(modality_dict)
63
+
64
+
65
+ def generate(self, text, modality='none', state=None):
66
+ if isinstance(modality, str):
67
+ y=None
68
+ else:
69
+ y = self.modality(modality).to(self.DTYPE)
70
+ result, state = self.pipeline.generate(text, token_count=500, args=self.args, callback=None, state=state, sign=y)
71
+ return result, state
72
+
73
+ # def prefill(self, text, modality='none', state=None):
74
+ # if isinstance(modality, str):
75
+ # y=None
76
+ # else:
77
+ # y = self.modality(modality).to(self.DTYPE)
78
+ # result, state = self.pipeline.forward(text, token_count=500, args=self.args, callback=None, state=state, sign=y)
79
+ # return result, state
world/__init__.py ADDED
File without changes
world/block.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from src.infctx_module import *
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+
7
+ from src.rwkv7.Channel_mix import RWKV_CMix_x070
8
+ from src.rwkv7.Time_mix import RWKV_Tmix_x070
9
+
10
+ class Block(nn.Module):
11
+ def __init__(self, args, layer_id):
12
+ super().__init__()
13
+ self.args = args
14
+ self.layer_id = layer_id
15
+
16
+ self.ln1 = nn.LayerNorm(args.n_embd)
17
+ self.ln2 = nn.LayerNorm(args.n_embd)
18
+
19
+ if self.layer_id == 0:
20
+ self.ln0 = nn.LayerNorm(args.n_embd)
21
+
22
+ self.att = RWKV_Tmix_x070(args, layer_id)
23
+ self.ffn = RWKV_CMix_x070(args, layer_id)
24
+
25
+
26
+ def forward(self, x, v_first):
27
+ if self.layer_id == 0:
28
+ x = self.ln0(x)
29
+
30
+ x_attn, v_first = self.att(self.ln1(x), v_first)
31
+ x = x + x_attn
32
+
33
+ x = x + self.ffn(self.ln2(x))
34
+ return x, v_first
world/cat.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from torch.nn import functional as F
4
+ def pad_mod(self, tensor_list, signal_list):
5
+ """
6
+ 对一个包含不同长度张量的列表进行填充,使所有张量的长度相同且为16的整数倍,并生成掩码。
7
+ 参数:
8
+ tensor_list (list of torch.Tensor): 输入的张量列表,每个张量形状为 [seq_len]。
9
+ pad_value (int, optional): 填充值,默认值为 0。
10
+ 返回:
11
+ padded_tensor (torch.Tensor): 填充后的张量,形状为 [batch_size, target_len]。
12
+ mask (torch.Tensor): 填充掩码,1 表示有效数据,0 表示填充部分。
13
+ """
14
+
15
+ modality_list = []
16
+ #max_len = max((token.size(0) + signal.size(1)) for token, signal in zip(tensor_list, modality_list))
17
+ max_len = 0
18
+ for token, signal in zip(tensor_list, signal_list):
19
+
20
+ modality = self.modality(signal)
21
+ if modality is False:
22
+ modality_list.append(False)
23
+ continue
24
+ modality_list.append(modality)
25
+ max_len = max(token.size(0) + modality.size(1), max_len)
26
+
27
+ # 计算目标长度(向上取整到16的整数倍)
28
+ target_len = ((max_len + 15) // 16 * 16)+1
29
+
30
+ if self.args.ctx_len is not None:
31
+ target_len = min(target_len, self.args.ctx_len+1)
32
+
33
+ masks = torch.zeros((len(tensor_list), target_len-1), dtype=torch.int32)
34
+ x = []
35
+ y = []
36
+ s = []
37
+ m = []
38
+ for token, signal, mask in zip(tensor_list, modality_list, masks):
39
+ if signal is False:
40
+ continue
41
+ pad_len = target_len-(token.size(0) + signal.size(1))
42
+
43
+ padded_token = F.pad(token, (0, pad_len), value=0)
44
+
45
+ x_token = padded_token[:-1]
46
+ y_token = F.pad(padded_token, (signal.size(1)-1, 0), value=0)
47
+
48
+ mask[signal.size(1) : -pad_len] = 1
49
+
50
+ s.append(signal)
51
+ x.append(x_token)
52
+ y.append(y_token)
53
+ m.append(mask)
54
+
55
+ y = torch.stack(y, dim=0)
56
+ m = torch.stack(m, dim=0).cuda()
57
+
58
+ return s, x, y, m
59
+
60
+
61
+
62
+ def mod_pad_text(self, signal_list, text_inputs, text_labels):
63
+ """
64
+ 对一个包含不同长度张量的列表进行填充,使所有张量的长度相同且为16的整数倍,并生成掩码。
65
+ 参数:
66
+ tensor_list (list of torch.Tensor): 输入的张量列表,每个张量形状为 [seq_len]。
67
+ pad_value (int, optional): 填充值,默认值为 0。
68
+ 返回:
69
+ padded_tensor (torch.Tensor): 填充后的张量,形状为 [batch_size, target_len]。
70
+ mask (torch.Tensor): 填充掩码,1 表示有效数据,0 表示填充部分。
71
+ """
72
+
73
+ modality_list = []
74
+ #max_len = max((token.size(0) + signal.size(1)) for token, signal in zip(tensor_list, modality_list))
75
+ max_len = 0
76
+ for i, (signal, token, label) in enumerate(zip(signal_list, text_inputs, text_labels)):
77
+
78
+ modality = self.modality(signal)
79
+ modality_list.append(modality)
80
+ mod_label = torch.full((modality.size(1),), -100, device='cuda')
81
+ text_labels[i] = torch.cat([mod_label, label])
82
+ max_len = max(token.size(0) + modality.size(1), max_len)
83
+
84
+ # 计算目标长度(向上取整到16的整数倍)
85
+ target_len = ((max_len + 15) // 16 * 16)+1
86
+
87
+ if self.args.ctx_len is not None:
88
+ target_len = min(target_len, self.args.ctx_len+1)
89
+
90
+
91
+ for i, (signal, token, label) in enumerate(zip(modality_list , text_inputs, text_labels)):
92
+ pad_len = target_len-(token.size(0) + signal.size(1))
93
+
94
+ text_inputs[i] = F.pad(token, (0, pad_len), value=0)[:-1]
95
+ text_labels[i] = F.pad(label, (0, pad_len), value=-100)[1:]
96
+
97
+
98
+ targets = torch.stack(text_labels, dim=0).cuda()
99
+
100
+ return modality_list, text_inputs, targets
101
+
102
+
103
+
104
+ def cat_tts(self, tensor_list, signal_list):
105
+ """
106
+ 对一个包含不同长度张量的列表进行填充,使所有张量的长度相同且为16的整数倍,并生成掩码。
107
+ 参数:
108
+ tensor_list (list of torch.Tensor): 输入的张量列表,每个张量形状为 [seq_len]。
109
+ pad_value (int, optional): 填充值,默认值为 0。
110
+ 返回:
111
+ padded_tensor (torch.Tensor): 填充后的张量,形状为 [batch_size, target_len]。
112
+ mask (torch.Tensor): 填充掩码,1 表示有效数据,0 表示填充部分。
113
+ """
114
+
115
+ modality_list = []
116
+ atokens = []
117
+ labels_list = [] #多模态拼接标签
118
+ #max_len = max((token.size(0) + signal.size(1)) for token, signal in zip(tensor_list, modality_list))
119
+ max_len = 0
120
+ for token, signal in zip(tensor_list, signal_list):
121
+ global_tokens, semantic_tokens = self.modality.world_encoder.encoder(signal)
122
+ # print(global_tokens.squeeze(0).squeeze(0), global_tokens.squeeze(0).squeeze(0)+8192)
123
+ global_tokens = global_tokens.squeeze(0).squeeze(0)+8194
124
+ # global_tokens = F.pad(global_tokens.squeeze(0).squeeze(0)+8194, (0, 1), value=8193)
125
+ semantic_tokens = F.pad(semantic_tokens.squeeze(0), (0, 1), value=8192)
126
+ audio_token = torch.cat([global_tokens,semantic_tokens])
127
+ mask_gt = torch.full_like(global_tokens, -100)
128
+ label = torch.cat([global_tokens-1,semantic_tokens-1])
129
+ # modality = self.modality.encoder(audio_token)
130
+
131
+ # if modality is False:
132
+ # modality_list.append(False)
133
+ # continue
134
+
135
+ mask_t = torch.full_like(token, -100)
136
+ label = torch.cat([mask_t,label])
137
+ atokens.append(audio_token)
138
+ # modality_list.append(modality)
139
+ labels_list.append(label)
140
+ max_len = max(label.size(0), max_len)
141
+
142
+ # 计算目标长度(向上取整到16的整数倍)
143
+ target_len = ((max_len + 15) // 16 * 16)+1
144
+
145
+ if self.args.ctx_len is not None:
146
+ target_len = min(target_len, self.args.ctx_len+1)
147
+
148
+ text_token = []
149
+ labels = []
150
+ mod_token = []
151
+
152
+ for token, atoken, mask in zip(tensor_list, atokens, labels_list):
153
+
154
+ pad_len = target_len-(token.size(0) + atoken.size(0))
155
+
156
+ padded_atoken = F.pad(atoken, (0, pad_len), value=8192)
157
+
158
+ atoken = padded_atoken[:-1]
159
+ mod = self.modality(atoken)
160
+ # padded_token = F.pad(signal, (0, 0, 0, pad_len), value=0)
161
+
162
+ # pad_mod = padded_token[:,:-1,:]
163
+
164
+ pad_mask = F.pad(mask, (0, pad_len), value=-100)[1:]
165
+
166
+ mod_token.append(mod)
167
+
168
+ labels.append(pad_mask)
169
+
170
+ labels = torch.stack(labels, dim=0)
171
+
172
+
173
+ return mod_token, tensor_list, labels
world/dataset.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ########################################################################################################
2
+ # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
3
+ ########################################################################################################
4
+ import torch.nn.functional as F
5
+
6
+ import json
7
+ import math
8
+ import random
9
+ import os
10
+ import sys
11
+ import numpy as np
12
+ import torch
13
+ import lightning as L
14
+ from torch.utils.data import Dataset
15
+ from torch.utils.data import DataLoader
16
+ from lightning_utilities.core.rank_zero import rank_zero_info
17
+ from infer.rwkv.utils import PIPELINE
18
+ pipeline = PIPELINE('rwkv', "rwkv_vocab_v20230424")
19
+ from PIL import Image
20
+
21
+ import pandas as pd
22
+ import librosa
23
+ import io
24
+ import soundfile as sf
25
+ # 读取parquet文件
26
+ from torchvision import transforms
27
+
28
+
29
+
30
+ transform = transforms.Compose([
31
+ transforms.Resize((512, 512)),
32
+ transforms.ToTensor() # 将图像转换为张量
33
+ ])
34
+
35
+ def process_conversation_text(conversations):
36
+ conversation_text = f"\x16"
37
+
38
+ for conv in conversations:
39
+ role = conv.get('from', '').lower()
40
+ content = conv.get('value', '')
41
+
42
+ if role == 'human':
43
+ conversation_text += f"User: {content}\x17"
44
+ elif role in ['assistant', 'gpt']:
45
+ conversation_text += f"Assistant: {content}\x17"
46
+
47
+ return conversation_text
48
+
49
+ def process_tokens(conversations):
50
+ # conversation_text = f"\x16"
51
+ inputs = []
52
+ labels = []
53
+ for conv in conversations:
54
+ role = conv.get('from', '').lower()
55
+ content = conv.get('value', '')
56
+
57
+ if role in ['human', 'user']:
58
+ question = f"\x16User: {content}\x17"
59
+ input = torch.tensor(pipeline.encode(question))
60
+ label = torch.full_like(input, -100)
61
+ elif role in ['assistant', 'gpt']:
62
+ answer = f"\x16Assistant: {content}\x17"
63
+ input= torch.tensor(pipeline.encode(answer))
64
+ label = input
65
+ inputs.append(input)
66
+ labels.append(label)
67
+ inputs =torch.cat(inputs)
68
+ labels =torch.cat(labels)
69
+ return inputs, labels
70
+
71
+ def bytes_to_audio(audio_bytes):
72
+ with io.BytesIO(audio_bytes) as buf:
73
+ # 使用 soundfile 读取音频数据
74
+ audio_array, sr = sf.read(buf)
75
+
76
+ # 确保是单声道
77
+ if len(audio_array.shape) > 1:
78
+ audio_array = audio_array.mean(axis=1)
79
+
80
+ # 确保是 float32 类型
81
+ audio_array = audio_array.astype(np.float32)
82
+
83
+ return {
84
+ 'array': audio_array,
85
+ 'sampling_rate': sr
86
+ }
87
+
88
+
89
+
90
+ def get_data_by_l_version(trainer: L.Trainer, args):
91
+ if L.__version__[0] == '2':
92
+ train_data = MyDataModule(args)
93
+ else:
94
+ raise ValueError(f"Unsupported PyTorch Lightning version: {L.__version__}")
95
+ return train_data
96
+
97
+ class GlobalIndexManager:
98
+ def __init__(self, rank=0, device_num=1, shuffle=True):
99
+ self.current_idx = 0
100
+ self.rank = rank
101
+ self.device_num = device_num
102
+ self.shuffle = shuffle
103
+
104
+ def get_next_idx(self, idx_t):
105
+ if self.shuffle:
106
+ idx = idx_t
107
+ else:
108
+ idx = self.current_idx * self.device_num + self.rank
109
+ self.current_idx += 1
110
+ return idx
111
+
112
+ class MyDataModule(L.LightningDataModule):
113
+ def __init__(self, args):
114
+ super().__init__()
115
+ self.args = args
116
+ self.train_data = None
117
+
118
+
119
+ def setup(self, stage=None):
120
+ self.train_data = MyDataset(self.args)
121
+ self.args.vocab_size = self.train_data.vocab_size
122
+ self.train_data.real_epoch = self.trainer.current_epoch
123
+ self.train_data.rank = self.trainer.global_rank
124
+ self.train_data.world_size = self.trainer.world_size
125
+ self.train_data.setup(self.trainer.global_rank, self.trainer.world_size,
126
+ int(self.args.devices), self.args.data_shuffle)
127
+
128
+ def train_dataloader(self):
129
+ # must set shuffle=False, persistent_workers=False (because worker is in another thread)
130
+ return DataLoader(
131
+ self.train_data,
132
+ shuffle=self.args.data_shuffle,
133
+ pin_memory=True,
134
+ batch_size=self.args.micro_bsz,
135
+ num_workers=1,
136
+ persistent_workers=False,
137
+ drop_last=True
138
+ )
139
+
140
+ class WorldDataset(Dataset):
141
+ def __init__(self, args, emb=None):
142
+ self.args = args
143
+ self.rank = 0
144
+ self.real_epoch = 0
145
+ self.world_size = 0
146
+ self.index_manager = None
147
+ self.emb = emb
148
+
149
+ if args.data_type =='wav':
150
+ import jsonlines
151
+
152
+ # 打开并读取 JSON 文件
153
+ #with open(f'{args.data_file}/answer.jsonl', 'r') as file:
154
+ with jsonlines.open(f'{args.data_file}/answer.jsonl') as file:
155
+ self.data = list(file)
156
+ elif args.data_type =='img':
157
+ import jsonlines
158
+
159
+ # 打开并读取 JSON 文件
160
+ #with open(f'{args.data_file}/answer.jsonl', 'r') as file:
161
+ with jsonlines.open(f'{args.data_file}/answer.jsonl') as file:
162
+ self.data = list(file)
163
+ elif args.data_type=='hf_img':
164
+ import jsonlines
165
+ # with open(f'{args.data_file}/chat.json', 'r', encoding='utf-8') as file:
166
+ # self.data = json.load(file)
167
+ with jsonlines.open(f'{args.data_file}/chat.jsonl') as file:
168
+ self.data = list(file)
169
+ elif args.data_type=='visual':
170
+ import jsonlines
171
+ # with open(f'{args.data_file}/chat.json', 'r', encoding='utf-8') as file:
172
+ # self.data = json.load(file)
173
+ with jsonlines.open(f'{args.data_file}/chat.jsonl') as file:
174
+ self.data = list(file)
175
+ elif args.data_type == 'visual-r1-cs':
176
+
177
+ llava_path = os.path.join(args.data_file, 'vision_r1_llava_cot_full.json')
178
+ mulberry_path = os.path.join(args.data_file, 'vision_r1_mulberry_sft_full.json')
179
+
180
+ import json
181
+ with open(f'{llava_path}', 'r', encoding='utf-8') as file:
182
+ llava_data = json.load(file)
183
+ with open(f'{mulberry_path}', 'r', encoding='utf-8') as file:
184
+ mulberry_data = json.load(file)
185
+
186
+ # 合并数据集并添加来源标识
187
+ for item in llava_data:
188
+ item['_source'] = 'llava_cot'
189
+ for item in mulberry_data:
190
+ item['_source'] = 'mulberry'
191
+
192
+ self.data = llava_data + mulberry_data
193
+
194
+ elif args.data_type =='hf' or args.data_type =='qa' or args.data_type =='cnqa' or args.data_type =='cnasr' or args.data_type =='tts':
195
+ from datasets import load_dataset, concatenate_datasets
196
+
197
+ def list_subdirectories(base_path):
198
+ return [
199
+ name for name in os.listdir(base_path)
200
+ if os.path.isdir(os.path.join(base_path, name)) and not name.startswith('.')
201
+ ]
202
+
203
+ datasets = []
204
+ files = list_subdirectories(args.data_file)
205
+ if not files:
206
+ datasets = load_dataset(args.data_file, split="train")
207
+ else:
208
+ for file in files:
209
+ dataset = load_dataset(f'{args.data_file}/{file}', split="train")
210
+ datasets.append(dataset)
211
+ datasets = concatenate_datasets(datasets)
212
+ self.data = datasets
213
+ print(len(datasets))
214
+
215
+ elif args.data_type == "jsonl":
216
+ import jsonlines
217
+
218
+ with jsonlines.open(args.data_file) as file:
219
+ self.data = list(file)
220
+
221
+ else:
222
+ self.data = pd.read_parquet(args.data_file)
223
+
224
+
225
+
226
+ def setup(self, rank, world_size, devices, shuffle):
227
+ self.rank = rank
228
+ self.world_size = world_size
229
+ self.index_manager = GlobalIndexManager(rank=rank, device_num=devices, shuffle=shuffle)
230
+
231
+ def __len__(self):
232
+ return self.args.epoch_steps * self.args.micro_bsz
233
+
234
+
235
+ def __getitem__(self, idx):
236
+ idx = self.index_manager.get_next_idx(idx_t=idx) if self.index_manager else idx
237
+ args = self.args
238
+ if args.data_type =='wav':
239
+
240
+ mod_name = self.data[idx]['file_name']
241
+ data_answer = self.data[idx]['answer']
242
+ mod_path = f'{args.data_file}/{mod_name}'
243
+ audio, sample_rate = librosa.load(mod_path, sr=16000) # sr=None 保持原采样率
244
+ #sign,_ = self.speech_encoder(audio)
245
+ sign = audio
246
+ token = torch.tensor(pipeline.encode(f'\x16Assistant: {data_answer}\x17'))
247
+ elif args.data_type =='hf':
248
+ sample = self.data[idx]
249
+ audio = sample['audio']
250
+ data_answer = sample['text'] #####caption
251
+ audio = librosa.resample(audio['array'],orig_sr= audio['sampling_rate'],target_sr= 16000) # sr=None 保持原采样率
252
+ sign = audio
253
+ token = torch.tensor(pipeline.encode(f'\x16Assistant: {data_answer}\x17'))
254
+ elif args.data_type =='tts':
255
+ sample = self.data[idx]
256
+ audio = sample['audio']
257
+ data_answer = sample['text'] #####caption
258
+ audio = librosa.resample(audio['array'],orig_sr= audio['sampling_rate'],target_sr= 16000) # sr=None 保持原采样率
259
+ sign = audio
260
+ token = torch.tensor(pipeline.encode(f'User: {data_answer}\x17Assistant:'))
261
+ elif args.data_type =='qa':
262
+ sample = self.data[idx]
263
+ # audio = sample['speech_cosy'][0]
264
+ # data_answer = sample['answer']
265
+
266
+ audio = sample['question_audio']
267
+ data_answer = sample['answer']
268
+ sign = librosa.resample(audio['array'],orig_sr= audio['sampling_rate'],target_sr= 16000) # sr=None 保持原采样率
269
+
270
+ token = torch.tensor(pipeline.encode(f'\x16Assistant: {data_answer}\x17'))
271
+ elif args.data_type =='cnqa':
272
+ sample = self.data[idx]
273
+ audio = sample['audio']
274
+ data_answer = sample['answer']
275
+ sign = librosa.resample(audio['array'],orig_sr= audio['sampling_rate'],target_sr= 16000) # sr=None 保持原采样率
276
+ token = torch.tensor(pipeline.encode(f'\x16Assistant: {data_answer}\x17'))
277
+ elif args.data_type =='cnasr':
278
+ sample = self.data[idx]
279
+ audio = sample['audio']
280
+ data_answer = sample['transcript']
281
+ sign = librosa.resample(audio['array'],orig_sr= audio['sampling_rate'],target_sr= 16000) # sr=None 保持原采样率
282
+ token = torch.tensor(pipeline.encode(f'\x16Assistant: {data_answer}\x17'))
283
+ elif args.data_type == "jsonl":
284
+ ctx_len = args.ctx_len
285
+ req_len = ctx_len + 1
286
+ ctx = self.data[idx]['text']
287
+ token = torch.tensor(pipeline.encode(ctx))
288
+ token_len = len(token)
289
+ pad_len = req_len - token_len
290
+
291
+ dix = F.pad(token, (0, pad_len), value=0)
292
+ x = dix[:-1]
293
+ y = dix[1:]
294
+ mask = torch.zeros(req_len - 1)
295
+ mask[:token_len - 1] = 1
296
+ return x, y, mask
297
+ elif args.data_type == "img":
298
+
299
+ mod_name = self.data[idx]['file_name']
300
+ data_answer = self.data[idx]['answer']
301
+ mod_path = f'{args.data_file}/{mod_name}'
302
+ token = torch.tensor(pipeline.encode(f'\n\nAssistant: {data_answer}\x17'))
303
+ image = Image.open(mod_path).convert('RGB')
304
+ sign = transform(image)
305
+ elif args.data_type == 'visual':
306
+
307
+ img_name = self.data[idx]['image']
308
+ conversation_text = self.data[idx]['conversations']
309
+
310
+ mod_path = f'{args.data_file}/images/{img_name}'
311
+ image = Image.open(mod_path).convert('RGB')
312
+ sign = image
313
+ text_tokens, text_labels = process_tokens(conversation_text)
314
+ return sign, text_tokens, text_labels
315
+ elif args.data_type== 'hf_img':
316
+
317
+ img_name = self.data[idx]['image']
318
+ conversation_text = self.data[idx]['conversations']
319
+ conversation_text = process_conversation_text(conversation_text)
320
+
321
+ mod_path = f'{args.data_file}/images/{img_name}'
322
+ token = torch.tensor(pipeline.encode(conversation_text))
323
+ image = Image.open(mod_path).convert('RGB')
324
+ sign = image
325
+ elif args.data_type == 'visual-r1-cs':
326
+ item = self.data[idx]
327
+ conversations = item['conversations']
328
+
329
+ # 根据来源处理图像路径
330
+ if item['_source'] == 'llava_cot':
331
+ img_name = item['image']
332
+
333
+ else:
334
+ img_name = item['images']
335
+
336
+
337
+
338
+ mod_path = f'{self.args.data_file}/{img_name}'
339
+ image = Image.open(mod_path).convert('RGB')
340
+ sign = image
341
+
342
+ # 处理文本对话
343
+ text_tokens, text_labels = process_tokens(conversations)
344
+
345
+ return sign, text_tokens, text_labels
346
+ else:
347
+ data_audio = bytes_to_audio(self.data['question_audio'][idx]['bytes'])
348
+ data_answer = self.data['answer'][idx]
349
+ audio = librosa.resample(data_audio['array'],orig_sr= 48000,target_sr= 16000)
350
+ #sign,_ = self.speech_encoder(audio)
351
+ sign = audio
352
+ token = torch.tensor(pipeline.encode(f'\x16Assistant: {data_answer}\x17'))
353
+ #print(idx, f'Assistant: {data_answer}\x17')
354
+ return sign, token
world/encoder/clip_encoder.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+
6
+ from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
7
+
8
+
9
+
10
+ class VisualAdapter(nn.Module):
11
+ """
12
+ 2D Image to Patch Embedding
13
+ """
14
+ def __init__(self, encoder_dim, project_dim, hidden_dim=None):
15
+
16
+ super().__init__()
17
+ self.encoder_dim = encoder_dim
18
+ self.project_dim = project_dim
19
+ self.hidden_dim = hidden_dim
20
+
21
+ if self.hidden_dim==None:
22
+ self.hidden_dim = project_dim*2
23
+
24
+ self.pre_norm = nn.LayerNorm(self.project_dim)
25
+ self.proj = nn.Sequential(
26
+ nn.Linear(self.encoder_dim, self.hidden_dim),
27
+ nn.ReLU(),
28
+ nn.Linear(self.hidden_dim, self.project_dim),
29
+ )
30
+ # self.proj = nn.Sequential(
31
+ # nn.Linear(self.encoder_dim, self.hidden_dim),
32
+ # nn.GELU(),
33
+ # nn.Linear(self.hidden_dim, self.hidden_dim),
34
+ # nn.GELU(),
35
+ # nn.Linear(self.hidden_dim, self.project_dim),
36
+ # )
37
+
38
+
39
+ def forward(self, x):
40
+ x = self.proj(x)
41
+ return x + self.pre_norm(x)
42
+
43
+
44
+
45
+ class ClipEncoder(nn.Module):
46
+
47
+ def __init__(
48
+ self,
49
+ encoder_path,
50
+ project_dim,
51
+ train_mode="adapter",
52
+ device="cuda",) -> None:
53
+ super(ClipEncoder, self).__init__()
54
+
55
+ self.device = device
56
+ self.image_processor = CLIPImageProcessor.from_pretrained(encoder_path)
57
+ self.model = CLIPVisionModel.from_pretrained(encoder_path)
58
+ self.encoder_dim = self.model.config.hidden_size
59
+
60
+ self.adapter = VisualAdapter(self.encoder_dim, project_dim)
61
+ def forward(self, x):
62
+
63
+ x= torch.from_numpy(self.image_processor(x)['pixel_values'][0]).to(self.device,dtype=torch.bfloat16)
64
+
65
+ x = self.model(x.unsqueeze(0), output_hidden_states=True).last_hidden_state
66
+
67
+ x = self.adapter(x)
68
+
69
+ return x
world/encoder/siglip_encoder.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+
6
+ from transformers import AutoModel, SiglipImageProcessor
7
+
8
+ class VisualAdapter(nn.Module):
9
+ """
10
+ 2D Image to Patch Embedding
11
+ """
12
+ def __init__(self, encoder_dim, project_dim, hidden_dim=None):
13
+
14
+ super().__init__()
15
+ self.encoder_dim = encoder_dim
16
+ self.project_dim = project_dim
17
+ self.hidden_dim = hidden_dim
18
+
19
+ if self.hidden_dim==None:
20
+ self.hidden_dim = project_dim*2
21
+
22
+ self.pre_norm = nn.LayerNorm(self.project_dim)
23
+ self.proj = nn.Sequential(
24
+ nn.Linear(self.encoder_dim, self.hidden_dim),
25
+ nn.ReLU(),
26
+ nn.Linear(self.hidden_dim, self.project_dim),
27
+ )
28
+ # self.conv = nn.Conv1d(
29
+ # in_channels=encoder_dim,
30
+ # out_channels=encoder_dim,
31
+ # bias=False,
32
+ # kernel_size=5,
33
+ # stride=4
34
+ # )
35
+
36
+
37
+ def forward(self, x):
38
+ # x = self.conv(x.permute(0,2,1)).permute(0,2,1)
39
+ x = self.proj(x)
40
+ return x + self.pre_norm(x)
41
+
42
+
43
+
44
+ class SiglipEncoder(nn.Module):
45
+
46
+ def __init__(
47
+ self,
48
+ encoder_path,
49
+ project_dim,
50
+ encoder_device='cpu',
51
+ train_mode="adapter"
52
+ ) -> None:
53
+ super(SiglipEncoder, self).__init__()
54
+
55
+ self.device = encoder_device
56
+
57
+ self.model = AutoModel.from_pretrained(encoder_path).vision_model
58
+ self.image_processor = SiglipImageProcessor.from_pretrained(encoder_path)
59
+ self.encoder_dim = 768 #self.model.config.hidden_size
60
+
61
+ self.adapter = VisualAdapter(self.encoder_dim, project_dim)
62
+ def forward(self, x):
63
+
64
+ x= torch.from_numpy(self.image_processor(x)['pixel_values'][0]).to(self.device,dtype=torch.bfloat16)
65
+ x = self.model(x.unsqueeze(0), output_hidden_states=True).last_hidden_state
66
+ x = self.adapter(x)
67
+
68
+ return x
world/encoder/speech_encoder.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn import TransformerEncoder, TransformerEncoderLayer
5
+ import numpy as np
6
+
7
+ from transformers import AutoProcessor, AutoModel
8
+
9
+
10
+
11
+ # class SpeechAdapter(nn.Module):
12
+ # def __init__(self, input_dim, output_dim):
13
+ # super(SpeechAdapter, self).__init__()
14
+ # self.conv = nn.Conv1d(in_channels=input_dim, out_channels=3072, kernel_size=3, stride=2)
15
+ # self.transformer = nn.TransformerEncoderLayer(d_model=3072, nhead=8, dim_feedforward=4096)
16
+ # self.linear = nn.Linear(3072, output_dim)
17
+ # def forward(self, x):
18
+ # # if x.size(1)<5 or x.size(1)>5000:
19
+ # # return False
20
+ # # x shape: (batch_size, seq_len, input_dim)
21
+ # x = x.permute(0, 2, 1)
22
+ # # x shape: (batch_size, input_dim, seq_len)
23
+ # x = self.conv(x)
24
+ # # x shape after conv: (batch_size, input_dim, new_seq_len)
25
+ # x = x.permute(2, 0, 1) # Transformer expects (seq_len, batch_size, input_dim)
26
+ # # x = self.transformer(x, src_key_padding_mask=mask.bool())
27
+ # x = self.transformer(x)
28
+ # x = x.permute(1, 0, 2) # Back to (batch_size, seq_len, input_dim)
29
+ # x = self.linear(x)
30
+ # return x
31
+
32
+ class SpeechAdapter(nn.Module):
33
+ def __init__(self, encoder_dim, project_dim, hidden_dim=None):
34
+ super(SpeechAdapter, self).__init__()
35
+ self.encoder_dim = encoder_dim
36
+ self.project_dim = project_dim
37
+ self.hidden_dim = hidden_dim
38
+
39
+ if self.hidden_dim==None:
40
+ self.hidden_dim = project_dim*2
41
+ self.conv = nn.Conv1d(in_channels=self.encoder_dim , out_channels=self.hidden_dim, kernel_size=3, stride=2, padding=2)
42
+ self.proj = nn.Sequential(
43
+ nn.Linear(self.hidden_dim, self.hidden_dim),
44
+ nn.ReLU(),
45
+ nn.Linear(self.hidden_dim, self.project_dim),
46
+ )
47
+ def forward(self, x):
48
+ # if x.size(1)<5 or x.size(1)>5000:
49
+ # return False
50
+
51
+ # x shape: (batch_size, seq_len, input_dim)
52
+ x = x.permute(0, 2, 1)
53
+ # x shape: (batch_size, input_dim, seq_len)
54
+ x = self.conv(x).permute(0, 2, 1)
55
+ # x shape after conv: (batch_size, input_dim, new_seq_len)
56
+ x = self.proj(x)
57
+ if x.size(1)>1023:
58
+ return False
59
+ return x
60
+
61
+ class SpeechEncoder(nn.Module):
62
+ def __init__(
63
+ self,
64
+ encoder_path,
65
+ project_dim,
66
+ train_mode="adapter",
67
+ device="cuda",
68
+ ):
69
+ assert train_mode in ["adapter", "full"]
70
+ super(SpeechEncoder, self).__init__()
71
+
72
+ self.device = device
73
+
74
+ try:
75
+ self.processor = AutoProcessor.from_pretrained(encoder_path)
76
+ except:
77
+ self.processor = AutoProcessor.from_pretrained("facebook/hubert-large-ls960-ft")
78
+
79
+ self.time_reduction_factor = int(
80
+ self.processor.feature_extractor.sampling_rate / 50
81
+ )
82
+ self.padding_length = 320
83
+
84
+ self.model = AutoModel.from_pretrained(encoder_path)
85
+ self.model.eval()
86
+ self.model_output_dim = self.model.config.hidden_size
87
+ self.project_dim = project_dim
88
+
89
+ self.project_dim = project_dim
90
+ self.adapter = SpeechAdapter(self.model_output_dim, self.project_dim).to(self.device,dtype=torch.bfloat16)
91
+ # self.set_gradient(train_mode)
92
+
93
+
94
+
95
+
96
+ def forward(self, x):
97
+ input_dict = self.processor(
98
+ x, return_tensors="pt", padding=True, sampling_rate=16000
99
+ ).to(self.device,dtype=torch.bfloat16)
100
+
101
+ # encoder only
102
+ x = self.model(**input_dict).last_hidden_state
103
+
104
+ # stf encoder
105
+ # x = self.model(**input_dict, output_hidden_states=True).hidden_states[-1]
106
+
107
+ x= self.adapter(x)#x:(B,T,hidden dim)
108
+ # mask = torch.ones(x.shape[0],x.shape[1]).to(self.device,dtype=torch.bfloat16)
109
+ return x
world/encoder/visual_encoder.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+
6
+ from diffusers import AutoencoderKL
7
+
8
+
9
+ class Patch(nn.Module):
10
+ def __init__(self, Imgsize=64, Patchsize=16) -> None:
11
+ super(Patch, self).__init__()
12
+ self.Patchsize = Patchsize
13
+ self.Imgsize = Imgsize
14
+ def encoder(self, x):
15
+ assert x.size(3)==self.Imgsize
16
+ imgsize = self.Imgsize
17
+ patchsize = self.Patchsize
18
+ x = x.unfold(2, patchsize, patchsize).unfold(3, patchsize, patchsize)
19
+ x = x.contiguous().reshape(x.size(0), x.size(1), int(pow(imgsize/patchsize, 2)), -1)
20
+ x = x.transpose(1, 2)
21
+ x = x.reshape(x.size(0), x.size(1), -1)
22
+ return x
23
+
24
+ def decoder(self, x):
25
+ imgsize = self.Imgsize
26
+ patchsize = self.Patchsize
27
+ x = x.reshape(x.size(0), x.size(1), x.size(2)//patchsize, patchsize)
28
+ x = x.transpose(1, 2)
29
+ x = x.unfold(2,imgsize//patchsize,imgsize//patchsize).unfold(3, patchsize, patchsize)
30
+ x = x.reshape(x.size(0), 4, imgsize, imgsize)
31
+ return x
32
+
33
+ class SD_Auto():
34
+ def __init__(self, path="sdxl", input_dtype=torch.float32) -> None:
35
+ self.autoencoder = AutoencoderKL.from_pretrained(path, subfolder="vae")
36
+ self.autoencoder = self.autoencoder.to('cuda', input_dtype)
37
+
38
+ def encoder(self, x):
39
+ with torch.no_grad():
40
+ x = self.autoencoder.encode(x).latent_dist.sample()
41
+ return x
42
+
43
+ def decoder(self, x):
44
+
45
+ with torch.no_grad():
46
+ x = self.autoencoder.decode(x).sample
47
+ #x = self.autoencoder.decode(x).sample
48
+ return x
49
+
50
+ def kld_loss(mu, logvar):
51
+ KLD = - 0.5 * torch.sum(1 + logvar - mu.pow(2) -
52
+ logvar.exp()) / mu.shape[0]
53
+ return KLD
54
+
55
+
56
+ class VisualAdapter(nn.Module):
57
+ """
58
+ 2D Image to Patch Embedding
59
+ """
60
+ def __init__(self, img_size=512//8, patch_size=16, in_c=4, text_dim=2560, head_size=64):
61
+
62
+ super().__init__()
63
+ self.head_size = head_size
64
+ # self.img_receptance = nn.Linear((patch_size*patch_size*in_c), text_dim, bias=False)
65
+ # self.img_key = nn.Linear((patch_size*patch_size*in_c), text_dim, bias=False)
66
+ # self.img_value = nn.Linear((patch_size*patch_size*in_c), text_dim, bias=False)
67
+ self.linear = nn.Linear((patch_size*patch_size*in_c), text_dim, bias=False)
68
+ self.patch = Patch(Imgsize=img_size, Patchsize=patch_size)
69
+
70
+
71
+ def forward(self, x):
72
+ B, C, H, W = x.shape
73
+ x = self.patch.encoder(x)
74
+ # r = self.img_receptance(x)
75
+ # k = self.img_key(x)
76
+ # v = self.img_value(x)
77
+ # r = r.view(*x.shape[:2], -1, self.head_size).transpose(1, 2)
78
+ # k = k.view(*x.shape[:2], -1, self.head_size).transpose(1, 2)
79
+ # v = v.view(*x.shape[:2], -1, self.head_size).transpose(1, 2)
80
+ # x_img = torch.nn.functional.scaled_dot_product_attention(
81
+ # r, k, v, is_causal=True, scale=1 / self.head_size
82
+ # )
83
+ # x = x_img.transpose(1, 2).reshape(*x.shape[:2], -1)
84
+ return self.linear(x)
85
+
86
+
87
+
88
+ class VisualEncoder(nn.Module):
89
+ def __init__(
90
+ self,
91
+ encoder_path,
92
+ project_dim,
93
+ train_mode="adapter",
94
+ device="cuda",) -> None:
95
+ super(VisualEncoder, self).__init__()
96
+
97
+ self.model = AutoencoderKL.from_pretrained(path, subfolder="vae", allow_pickle=False).to('cuda', input_dtype)
98
+ self.adapter = VisualAdapter(text_dim=llm_dim).to('cuda',dtype=torch.bfloat16)
99
+
100
+ def forward(self, x):
101
+ x = self.model.encode(x.unsqueeze(0)).latent_dist.sample()
102
+ #print(x.view(-1))
103
+ x = self.adapter(x)
104
+ #print(x.view(-1))
105
+
106
+ return x
world/encoder/whisper_encoder.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+
6
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration
7
+
8
+ class SpeechAdapter(nn.Module):
9
+ def __init__(self, encoder_dim, project_dim, hidden_dim=None):
10
+ super(SpeechAdapter, self).__init__()
11
+ self.encoder_dim = encoder_dim
12
+ self.project_dim = project_dim
13
+ self.hidden_dim = hidden_dim
14
+
15
+ if self.hidden_dim==None:
16
+ self.hidden_dim = project_dim*2
17
+ self.conv = nn.Conv1d(in_channels=self.encoder_dim , out_channels=self.hidden_dim, kernel_size=3, stride=2, padding=2)
18
+ self.proj = nn.Sequential(
19
+ nn.Linear(self.hidden_dim, self.hidden_dim),
20
+ nn.ReLU(),
21
+ nn.Linear(self.hidden_dim, self.project_dim),
22
+ )
23
+ def forward(self, x):
24
+ # if x.size(1)<5 or x.size(1)>5000:
25
+ # return False
26
+
27
+ # x shape: (batch_size, seq_len, input_dim)
28
+ x = x.permute(0, 2, 1)
29
+ # x shape: (batch_size, input_dim, seq_len)
30
+ x = self.conv(x).permute(0, 2, 1)
31
+ # x shape after conv: (batch_size, input_dim, new_seq_len)
32
+ x = self.proj(x)
33
+ return x
34
+
35
+
36
+
37
+ class WhisperEncoder(nn.Module):
38
+ def __init__(
39
+ self,
40
+ encoder_path,
41
+ project_dim,
42
+ train_mode="adapter",
43
+ device="cuda",
44
+ ):
45
+ assert train_mode in ["adapter", "full"]
46
+ super(WhisperEncoder, self).__init__()
47
+ self.device = device
48
+ self.processor = WhisperProcessor.from_pretrained(encoder_path)
49
+
50
+ self.model = WhisperForConditionalGeneration.from_pretrained(encoder_path).model.encoder
51
+
52
+ self.model_output_dim = self.model.config.d_model
53
+
54
+ self.project_dim = project_dim
55
+ self.adapter = SpeechAdapter(self.model_output_dim, self.project_dim)
56
+
57
+ def forward(self, x):
58
+ input_dict = self.processor(
59
+ x, return_tensors="pt", sampling_rate=16000, return_attention_mask=True
60
+ ).to(self.device,dtype=torch.bfloat16)
61
+
62
+ chunk = torch.sum(input_dict['attention_mask'], dim=-1)//2+1
63
+
64
+ x = self.model(**input_dict).last_hidden_state
65
+ x = x[:,:chunk,:]
66
+ x= self.adapter(x)#x:(B,T,hidden dim)
67
+
68
+ return x
world/loss.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ class L2Wrap(torch.autograd.Function):
3
+ @staticmethod
4
+ def forward(ctx, loss, y):
5
+ ctx.save_for_backward(y)
6
+ return loss
7
+
8
+ @staticmethod
9
+ def backward(ctx, grad_output):
10
+ y = ctx.saved_tensors[0]
11
+ # to encourage the logits to be close to 0
12
+ factor = 1e-4 / (y.shape[0] * y.shape[1])
13
+ maxx, ids = torch.max(y, -1, keepdim=True)
14
+ gy = torch.zeros_like(y)
15
+ gy.scatter_(-1, ids, maxx * factor)
16
+ return (grad_output, gy)
world/model.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ########################################################################################################
2
+ # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
3
+ ########################################################################################################
4
+ from torch.utils.checkpoint import checkpoint as torch_checkpoint
5
+ from torch.profiler import profile, record_function, ProfilerActivity
6
+ #from adam_mini import Adam_mini
7
+
8
+ import os, math, gc, importlib
9
+ import torch
10
+
11
+ import torch.nn as nn
12
+ from torch.nn import functional as F
13
+ import lightning as pl
14
+ from lightning.pytorch.strategies import DeepSpeedStrategy
15
+ if importlib.util.find_spec('deepspeed'):
16
+ import deepspeed
17
+ from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
18
+
19
+ from .block import Block
20
+ from .loss import L2Wrap
21
+ from .cat import mod_pad_text
22
+
23
+ from rwkv.utils import PIPELINE
24
+ pipeline = PIPELINE('rwkv6', "rwkv_vocab_v20230424")
25
+
26
+ class RWKV(pl.LightningModule):
27
+ def __init__(self, args, modality=None):
28
+ super().__init__()
29
+ self.args = args
30
+ if not hasattr(args, 'dim_att'):
31
+ args.dim_att = args.n_embd
32
+ if not hasattr(args, 'dim_ffn'):
33
+ args.dim_ffn = args.n_embd * 4
34
+
35
+ assert args.n_embd % 32 == 0
36
+ assert args.dim_att % 32 == 0
37
+ assert args.dim_ffn % 32 == 0
38
+
39
+ self.emb = nn.Embedding(args.vocab_size, args.n_embd)
40
+
41
+ self.blocks = nn.ModuleList([Block(args, i) for i in range(args.n_layer)])
42
+
43
+ self.ln_out = nn.LayerNorm(args.n_embd)
44
+ self.head = nn.Linear(args.n_embd, args.vocab_size, bias=False)
45
+
46
+ self.modality = modality
47
+
48
+
49
+ def pad_mod(self, tensor_list, signal_list):
50
+ """
51
+ 对一个包含不同长度张量的列表进行填充,使所有张量的长度相同且为16的整数倍,并生成掩码。
52
+ 参数:
53
+ tensor_list (list of torch.Tensor): 输入的张量列表,每个张量形状为 [seq_len]。
54
+ pad_value (int, optional): 填充值,默认值为 0。
55
+ 返回:
56
+ padded_tensor (torch.Tensor): 填充后的张量,形状为 [batch_size, target_len]。
57
+ mask (torch.Tensor): 填充掩码,1 表示有效数据,0 表示填充部分。
58
+ """
59
+
60
+ modality_list = []
61
+ #max_len = max((token.size(0) + signal.size(1)) for token, signal in zip(tensor_list, modality_list))
62
+ max_len = 0
63
+ for token, signal in zip(tensor_list, signal_list):
64
+
65
+ modality = self.modality(signal)
66
+ if modality is False:
67
+ modality_list.append(False)
68
+ continue
69
+ modality_list.append(modality)
70
+ max_len = max(token.size(0) + modality.size(1), max_len)
71
+
72
+ # 计算目标长度(向上取整到16的整数倍)
73
+ target_len = ((max_len + 15) // 16 * 16)+1
74
+
75
+ if self.args.ctx_len is not None:
76
+ target_len = min(target_len, self.args.ctx_len+1)
77
+
78
+ masks = torch.zeros((len(tensor_list), target_len-1), dtype=torch.int32)
79
+ x = []
80
+ y = []
81
+ s = []
82
+ m = []
83
+ for token, signal, mask in zip(tensor_list, modality_list, masks):
84
+ if signal is False:
85
+ continue
86
+ pad_len = target_len-(token.size(0) + signal.size(1))
87
+
88
+ padded_token = F.pad(token, (0, pad_len), value=0)
89
+
90
+ x_token = padded_token[:-1]
91
+ y_token = F.pad(padded_token, (signal.size(1)-1, 0), value=0)
92
+
93
+ mask[signal.size(1) : -pad_len] = 1
94
+
95
+ s.append(signal)
96
+ x.append(x_token)
97
+ y.append(y_token)
98
+ m.append(mask)
99
+
100
+ y = torch.stack(y, dim=0)
101
+ m = torch.stack(m, dim=0).cuda()
102
+
103
+ return s, x, y, m
104
+
105
+
106
+ def forward(self, idx, signs= None):
107
+ args = self.args
108
+ #B, T = idx.size()
109
+ # assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted."
110
+
111
+ x_list = []
112
+ if signs!=None:
113
+ for token, sign in zip(idx, signs):
114
+ sign_emb = sign#self.adapter(sign)
115
+ x_emb = self.emb(token)
116
+ # #print(sign_emb.shape, x.shape)
117
+ x_list.append(torch.cat([sign_emb.squeeze(0),x_emb], dim=0))
118
+ x = torch.stack(x_list, dim=0)
119
+ else:
120
+ x = self.emb(idx)
121
+
122
+ v_first = torch.empty_like(x)
123
+ for block in self.blocks:
124
+ if args.grad_cp == 1:
125
+ if args.state_tune or args.train_type == 'state' or args.peft !='none':
126
+ x, v_first = torch_checkpoint(block, x, v_first ,use_reentrant=False)
127
+ else:
128
+ x, v_first = deepspeed.checkpointing.checkpoint(block, x, v_first)
129
+ else:
130
+ x, v_first = block(x, v_first)
131
+
132
+ x = self.ln_out(x)
133
+ x = self.head(x)
134
+
135
+ return x
136
+
137
+ def training_step(self, batch, batch_idx):
138
+ args = self.args
139
+ if args.data_type == "jsonl": ########test
140
+ idx, targets, mask = batch
141
+
142
+ mask = mask.view(-1)
143
+ sum_mask = torch.sum(mask).item()
144
+ logits = self(idx)
145
+
146
+ loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1), reduction='none')
147
+ loss = torch.sum(loss * mask) / sum_mask
148
+ return loss
149
+
150
+ if args.data_type in ["visual", "visual-r1-cs"]: ########test
151
+ signs, text_tokens, text_labels = batch
152
+ sign, idx, targets = mod_pad_text(self, signs, text_tokens, text_labels)
153
+
154
+ logits = self(idx,sign)
155
+ loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))
156
+
157
+ return loss
158
+
159
+ signs, tokens = batch
160
+ sign, idx, targets, mask = self.pad_mod(tokens, signs)
161
+
162
+ mask = mask.view(-1)
163
+ sum_mask = torch.sum(mask).item()
164
+ logits = self(idx,sign)
165
+
166
+ loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1), reduction='none')
167
+ loss = torch.sum(loss * mask) / sum_mask
168
+
169
+ return loss
170
+
171
+
172
+ def configure_optimizers(self):
173
+ args = self.args
174
+
175
+ lr_decay = set()
176
+ lr_1x = set()
177
+ lr_2x = set()
178
+ lr_3x = set()
179
+ for n, p in self.named_parameters():
180
+ if not p.requires_grad:
181
+ continue
182
+ if (("_w1" in n) or ("_w2" in n)) and (args.layerwise_lr > 0):
183
+ lr_1x.add(n)
184
+ elif (("time_mix" in n) or ("time_maa" in n)) and (args.layerwise_lr > 0):
185
+ if args.my_pile_stage == 2:
186
+ lr_2x.add(n)
187
+ else:
188
+ lr_1x.add(n)
189
+ elif (("time_decay" in n) or ("time_daaaa" in n)) and (args.layerwise_lr > 0):
190
+ if args.my_pile_stage == 2:
191
+ lr_3x.add(n)
192
+ else:
193
+ lr_2x.add(n)
194
+ elif ("time_faaaa" in n) and (args.layerwise_lr > 0):
195
+ if args.my_pile_stage == 2:
196
+ lr_2x.add(n)
197
+ else:
198
+ lr_1x.add(n)
199
+ elif ("time_first" in n) and (args.layerwise_lr > 0):
200
+ lr_3x.add(n)
201
+ elif (len(p.squeeze().shape) >= 2) and (args.weight_decay > 0):
202
+ lr_decay.add(n)
203
+ else:
204
+ lr_1x.add(n)
205
+
206
+ lr_decay = sorted(list(lr_decay))
207
+ lr_1x = sorted(list(lr_1x))
208
+ lr_2x = sorted(list(lr_2x))
209
+ lr_3x = sorted(list(lr_3x))
210
+
211
+ param_dict = {n: p for n, p in self.named_parameters()}
212
+
213
+ if args.layerwise_lr > 0:
214
+ if args.my_pile_stage == 2:
215
+ optim_groups = [
216
+ {"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0},
217
+ {"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 2e-3 / args.lr_init},
218
+ {"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 3e-3 / args.lr_init},
219
+ ]
220
+ else:
221
+ optim_groups = [
222
+ {"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0},
223
+ {"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0},
224
+ {"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 3.0},
225
+ ]
226
+ else:
227
+ optim_groups = [{"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0}]
228
+
229
+ if args.weight_decay > 0:
230
+ optim_groups += [{"params": [param_dict[n] for n in lr_decay], "weight_decay": args.weight_decay, "my_lr_scale": 1.0}]
231
+
232
+ if self.deepspeed_offload:
233
+ return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=True, amsgrad=False)
234
+ return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=True, amsgrad=False)
235
+ else:
236
+
237
+ if self.deepspeed_offload:
238
+ return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False)
239
+ return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False)
240
+ # return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False)
241
+
242
+ @property
243
+ def deepspeed_offload(self) -> bool:
244
+ strategy = self.trainer.strategy
245
+ if isinstance(strategy, DeepSpeedStrategy):
246
+ cfg = strategy.config["zero_optimization"]
247
+ return cfg.get("offload_optimizer") or cfg.get("offload_param")
248
+ return False
249
+
250
+ def generate_init_weight(self):
251
+ print(
252
+ f"""
253
+ ############################################################################
254
+ #
255
+ # Init model weight (slow for large models)...
256
+ #
257
+ ############################################################################
258
+ """
259
+ )
260
+ m = {}
261
+ for n in self.state_dict():
262
+ p = self.state_dict()[n]
263
+ shape = p.shape
264
+
265
+ gain = 1.0
266
+ scale = 1.0
267
+ if "ln_" in n or ".ln" in n or "time_" in n or "_mask" in n or "pos_emb" in n or '.mask.' in n:
268
+ if 'ln_x.weight' in n:
269
+ layer_scale = (1+int(n.split('.')[1])) / self.args.n_layer
270
+ m[n] = (p * 0.0) + (layer_scale ** 0.7)
271
+ else:
272
+ m[n] = p
273
+ else:
274
+ if n == "emb.weight":
275
+ scale = -1 * self.args.lr_init
276
+ else:
277
+ if shape[0] > shape[1]:
278
+ gain = math.sqrt(shape[0] / shape[1])
279
+
280
+ zero = [".att.output.", ".ffn.value.", ".ffn.receptance.", ".ffnPre.value.", ".ffnPre.receptance.", "head_q.", '.oo.', '.rr.']
281
+
282
+ for kk in zero:
283
+ if kk in n:
284
+ scale = 0
285
+ if n == "head.weight":
286
+ scale = 0.5
287
+ if "head_k." in n:
288
+ scale = 0.1
289
+ if "head_q." in n:
290
+ scale = 0
291
+
292
+ print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {str(scale).ljust(4)} {n}")
293
+
294
+ if self.args.accelerator.upper() == "GPU":
295
+ m[n] = torch.empty((shape[0], shape[1]), device="cuda")
296
+ else:
297
+ m[n] = torch.empty((shape[0], shape[1]))
298
+
299
+ if scale == 0:
300
+ nn.init.zeros_(m[n])
301
+ elif scale < 0:
302
+ nn.init.uniform_(m[n], a=scale, b=-scale)
303
+ else:
304
+ nn.init.orthogonal_(m[n], gain=gain * scale)
305
+
306
+ m[n] = m[n].cpu()
307
+ if os.environ["RWKV_FLOAT_MODE"] == "fp16":
308
+ m[n] = m[n].half()
309
+ elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
310
+ m[n] = m[n].bfloat16()
311
+
312
+ # if n == "emb.weight":
313
+ # print(m[n])
314
+
315
+ gc.collect()
316
+ torch.cuda.empty_cache()
317
+ return m
world/world_encoder.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from world.encoder.speech_encoder import SpeechEncoder
5
+ # from world.encoder.visual_encoder import VisualEncoder
6
+ from world.encoder.whisper_encoder import WhisperEncoder
7
+ from world.encoder.clip_encoder import ClipEncoder
8
+ from world.encoder.siglip_encoder import SiglipEncoder
9
+
10
+
11
+ class WorldEncoder(nn.Module):
12
+ def __init__(self, encoder_type: str, **kwargs):
13
+ super().__init__()
14
+ self.world_encoder = self._build_encoder(encoder_type, kwargs)
15
+
16
+ def _build_encoder(self, encoder_type: str, config: dict):
17
+ encoder_map = {
18
+ "clip": ClipEncoder,
19
+ "whisper": WhisperEncoder,
20
+ # "visual": VisualEncoder,
21
+ "speech": SpeechEncoder,
22
+ "siglip": SiglipEncoder
23
+ }
24
+
25
+ if encoder_type not in encoder_map:
26
+ raise ValueError(f"Unsupported encoder type: {encoder_type}")
27
+
28
+ # 动态过滤有效参数
29
+ encoder_class = encoder_map[encoder_type]
30
+
31
+ return encoder_class(**config)
32
+
33
+ def forward(self, x):
34
+ return self.world_encoder(x)
35
+
36
+ def load_checkpoint(self, state_dict):
37
+ self.world_encoder.load_state_dict(state_dict, strict=False)
world/world_load.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from world.model import RWKV
2
+ from world.world_encoder import WorldEncoder
3
+ def WorldLoading(args):
4
+ config = {
5
+ 'encoder_type': args.encoder_type,
6
+ 'encoder_path': args.encoder_path,
7
+ 'project_dim' : args.n_embd
8
+ }
9
+ modality = WorldEncoder(**config)
10
+
11
+ model = RWKV(args, modality=modality)
12
+ #model = RWKV(args)
13
+ print(model)
14
+
15
+ if 'moda' not in args.train_step:
16
+ for param in model.modality.world_encoder.model.parameters():
17
+ param.requires_grad = False
18
+ if 'adapter' not in args.train_step:
19
+ for param in model.modality.world_encoder.adapter.parameters():
20
+ param.requires_grad = False
21
+ if 'rwkv' not in args.train_step:
22
+ for param in model.emb.parameters():
23
+ param.requires_grad = False
24
+ for param in model.blocks.parameters():
25
+ param.requires_grad = False
26
+ for param in model.ln_out.parameters():
27
+ param.requires_grad = False
28
+ for param in model.head.parameters():
29
+ param.requires_grad = False
30
+ return model