Update world RWKV CPU
Browse files- infer/__init__.py +0 -0
- infer/rwkv/__init__.py +0 -0
- infer/rwkv/cuda/gemm_fp16_cublas.cpp +75 -0
- infer/rwkv/cuda/operators.cu +246 -0
- infer/rwkv/cuda/rwkv5.cu +88 -0
- infer/rwkv/cuda/rwkv5_op.cpp +34 -0
- infer/rwkv/cuda/rwkv6.cu +87 -0
- infer/rwkv/cuda/rwkv6_op.cpp +34 -0
- infer/rwkv/cuda/rwkv7.cu +77 -0
- infer/rwkv/cuda/rwkv7_op.cpp +26 -0
- infer/rwkv/cuda/wrapper.cpp +141 -0
- infer/rwkv/model.py +469 -0
- infer/rwkv/rwkv_tokenizer.py +103 -0
- infer/rwkv/rwkv_vocab_v20230424.txt +0 -0
- infer/rwkv/utils.py +148 -0
- infer/worldmodel.py +79 -0
- world/__init__.py +0 -0
- world/block.py +34 -0
- world/cat.py +173 -0
- world/dataset.py +354 -0
- world/encoder/clip_encoder.py +69 -0
- world/encoder/siglip_encoder.py +68 -0
- world/encoder/speech_encoder.py +109 -0
- world/encoder/visual_encoder.py +106 -0
- world/encoder/whisper_encoder.py +68 -0
- world/loss.py +16 -0
- world/model.py +317 -0
- world/world_encoder.py +37 -0
- world/world_load.py +30 -0
infer/__init__.py
ADDED
File without changes
|
infer/rwkv/__init__.py
ADDED
File without changes
|
infer/rwkv/cuda/gemm_fp16_cublas.cpp
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <cublas_v2.h>
|
2 |
+
#include <cuda.h>
|
3 |
+
#include <cuda_fp16.h>
|
4 |
+
#include <cuda_runtime.h>
|
5 |
+
#include <torch/extension.h>
|
6 |
+
#include <c10/cuda/CUDAGuard.h>
|
7 |
+
#include <ATen/cuda/CUDAContext.h>
|
8 |
+
|
9 |
+
#define CUBLAS_CHECK(condition) \
|
10 |
+
for (cublasStatus_t _cublas_check_status = (condition); \
|
11 |
+
_cublas_check_status != CUBLAS_STATUS_SUCCESS;) \
|
12 |
+
throw std::runtime_error("cuBLAS error " + \
|
13 |
+
std::to_string(_cublas_check_status) + " at " + \
|
14 |
+
std::to_string(__LINE__));
|
15 |
+
|
16 |
+
#define CUDA_CHECK(condition) \
|
17 |
+
for (cudaError_t _cuda_check_status = (condition); \
|
18 |
+
_cuda_check_status != cudaSuccess;) \
|
19 |
+
throw std::runtime_error( \
|
20 |
+
"CUDA error " + std::string(cudaGetErrorString(_cuda_check_status)) + \
|
21 |
+
" at " + std::to_string(__LINE__));
|
22 |
+
|
23 |
+
/*
|
24 |
+
NOTE: blas gemm is column-major by default, but we need row-major output.
|
25 |
+
The data of row-major, transposed matrix is exactly the same as the
|
26 |
+
column-major, non-transposed matrix, and C = A * B ---> C^T = B^T * A^T
|
27 |
+
*/
|
28 |
+
void gemm_fp16_cublas(torch::Tensor a, torch::Tensor b, torch::Tensor c) {
|
29 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
|
30 |
+
const auto cuda_data_type = CUDA_R_16F;
|
31 |
+
const auto cuda_c_data_type =
|
32 |
+
c.dtype() == torch::kFloat32 ? CUDA_R_32F : CUDA_R_16F;
|
33 |
+
const auto compute_type = CUDA_R_32F;
|
34 |
+
const float sp_alpha = 1.f;
|
35 |
+
// swap a and b, and use CUBLAS_OP_N. see the notes above
|
36 |
+
std::swap(a, b);
|
37 |
+
const cublasOperation_t cublas_trans_a = CUBLAS_OP_N;
|
38 |
+
const cublasOperation_t cublas_trans_b = CUBLAS_OP_N;
|
39 |
+
// m = (B^T).size(0) = B.size(1), and = A.size(1) after swap,
|
40 |
+
// negative axis is used because of the existence of batch matmul.
|
41 |
+
const int m = a.size(-1);
|
42 |
+
const int k = a.size(-2);
|
43 |
+
const int n = b.size(-2);
|
44 |
+
const int cublas_lda = m;
|
45 |
+
const int cublas_ldb = k;
|
46 |
+
const int cublas_ldc = m;
|
47 |
+
cublasHandle_t cublas_handle = at::cuda::getCurrentCUDABlasHandle();
|
48 |
+
|
49 |
+
#if CUDA_VERSION >= 11000
|
50 |
+
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
|
51 |
+
#else
|
52 |
+
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
|
53 |
+
#endif
|
54 |
+
const float sp_beta = 0.f;
|
55 |
+
if (a.sizes().size() == 2 && b.sizes().size() == 2) {
|
56 |
+
CUBLAS_CHECK(cublasGemmEx(
|
57 |
+
cublas_handle, cublas_trans_a, cublas_trans_b, m, n, k, &sp_alpha,
|
58 |
+
a.data_ptr(), cuda_data_type, cublas_lda, b.data_ptr(), cuda_data_type,
|
59 |
+
cublas_ldb, &sp_beta, c.data_ptr(), cuda_c_data_type, cublas_ldc,
|
60 |
+
compute_type, algo));
|
61 |
+
} else {
|
62 |
+
// batch matmul
|
63 |
+
assert(a.sizes().size() == 3 && b.sizes().size() == 3);
|
64 |
+
|
65 |
+
const long long int cublas_stride_a = m * k;
|
66 |
+
const long long int cublas_stride_b = k * n;
|
67 |
+
const long long int cublas_stride_c = m * n;
|
68 |
+
CUBLAS_CHECK(cublasGemmStridedBatchedEx(
|
69 |
+
cublas_handle, cublas_trans_a, cublas_trans_b, m,
|
70 |
+
n, k, &sp_alpha, a.data_ptr(), cuda_data_type, cublas_lda,
|
71 |
+
cublas_stride_a, b.data_ptr(), cuda_data_type, cublas_ldb, cublas_stride_b,
|
72 |
+
&sp_beta, c.data_ptr(), cuda_c_data_type, cublas_ldc, cublas_stride_c,
|
73 |
+
a.size(0), compute_type, algo));
|
74 |
+
}
|
75 |
+
}
|
infer/rwkv/cuda/operators.cu
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <stdio.h>
|
2 |
+
#include <assert.h>
|
3 |
+
#include "ATen/ATen.h"
|
4 |
+
#include <cuda_fp16.h>
|
5 |
+
#define MIN_VALUE (-1e38)
|
6 |
+
typedef at::Half fp16;
|
7 |
+
__half *cast(fp16 *ptr) {
|
8 |
+
return reinterpret_cast<__half *>(ptr);
|
9 |
+
}
|
10 |
+
|
11 |
+
template <typename F>
|
12 |
+
__global__ void kernel_wkv_forward(const int B, const int T, const int C,
|
13 |
+
const float *__restrict__ const _w, const float *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
|
14 |
+
F *__restrict__ const _y, float *__restrict__ const _aa, float *__restrict__ const _bb, float *__restrict__ const _pp) {
|
15 |
+
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
16 |
+
const int _b = idx / C;
|
17 |
+
const int _c = idx % C;
|
18 |
+
const int _offset = _b * T * C + _c;
|
19 |
+
const int _state_offset = _b * C + _c;
|
20 |
+
|
21 |
+
float u = _u[_c];
|
22 |
+
float w = _w[_c];
|
23 |
+
const F *__restrict__ const k = _k + _offset;
|
24 |
+
const F *__restrict__ const v = _v + _offset;
|
25 |
+
F *__restrict__ const y = _y + _offset;
|
26 |
+
|
27 |
+
float aa = _aa[_state_offset];
|
28 |
+
float bb = _bb[_state_offset];
|
29 |
+
float pp = _pp[_state_offset];
|
30 |
+
for (int i = 0; i < T; i++) {
|
31 |
+
const int ii = i * C;
|
32 |
+
const float kk = float(k[ii]);
|
33 |
+
const float vv = float(v[ii]);
|
34 |
+
float ww = u + kk;
|
35 |
+
float p = max(pp, ww);
|
36 |
+
float e1 = exp(pp - p);
|
37 |
+
float e2 = exp(ww - p);
|
38 |
+
y[ii] = F((e1 * aa + e2 * vv) / (e1 * bb + e2));
|
39 |
+
ww = w + pp;
|
40 |
+
p = max(ww, kk);
|
41 |
+
e1 = exp(ww - p);
|
42 |
+
e2 = exp(kk - p);
|
43 |
+
aa = e1 * aa + e2 * vv;
|
44 |
+
bb = e1 * bb + e2;
|
45 |
+
pp = p;
|
46 |
+
}
|
47 |
+
_aa[_state_offset] = aa;
|
48 |
+
_bb[_state_offset] = bb;
|
49 |
+
_pp[_state_offset] = pp;
|
50 |
+
}
|
51 |
+
|
52 |
+
template <typename F>
|
53 |
+
void cuda_wkv_forward(int B, int T, int C, float *w, float *u, F *k, F *v, F *y, float *aa, float *bb, float *pp) {
|
54 |
+
dim3 threadsPerBlock( min(C, 32) );
|
55 |
+
assert(B * C % threadsPerBlock.x == 0);
|
56 |
+
dim3 numBlocks(B * C / threadsPerBlock.x);
|
57 |
+
kernel_wkv_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, aa, bb, pp);
|
58 |
+
}
|
59 |
+
|
60 |
+
template void cuda_wkv_forward<fp16>(
|
61 |
+
int B, int T, int C,
|
62 |
+
float *w, float *u, fp16 *k, fp16 *v, fp16 *y,
|
63 |
+
float *aa, float *bb, float *pp);
|
64 |
+
template void cuda_wkv_forward<float>(
|
65 |
+
int B, int T, int C,
|
66 |
+
float *w, float *u, float *k, float *v, float *y,
|
67 |
+
float *aa, float *bb, float *pp);
|
68 |
+
|
69 |
+
__global__ void kernel_mm_seq_fp32i8(
|
70 |
+
const int B, const int N, const int M,
|
71 |
+
const float *__restrict__ const x, const int x_stride,
|
72 |
+
const uint8_t *__restrict__ const w, const int w_stride,
|
73 |
+
const float *__restrict__ const mx,
|
74 |
+
const float *__restrict__ const rx,
|
75 |
+
const float *__restrict__ const my,
|
76 |
+
const float *__restrict__ const ry,
|
77 |
+
float *__restrict__ const y, const int y_stride) {
|
78 |
+
|
79 |
+
const int i = blockIdx.x * blockDim.x + threadIdx.x;
|
80 |
+
const int k = blockIdx.y * blockDim.y + threadIdx.y;
|
81 |
+
|
82 |
+
if (i < B && k < M) {
|
83 |
+
float y_local = 0;
|
84 |
+
for (int j = 0; j < N; ++j) {
|
85 |
+
y_local += x[i * x_stride + j] * (
|
86 |
+
(float(w[j * w_stride + k]) + 0.5f)
|
87 |
+
* rx[k] * ry[j] + mx[k] + my[j]
|
88 |
+
);
|
89 |
+
}
|
90 |
+
y[i * y_stride + k] = y_local;
|
91 |
+
}
|
92 |
+
}
|
93 |
+
|
94 |
+
template <typename F>
|
95 |
+
void cuda_mm8_seq(int B, int N, int M,
|
96 |
+
F *x, int x_stride,
|
97 |
+
uint8_t *w, int w_stride,
|
98 |
+
F *mx, F *rx,
|
99 |
+
F *my, F *ry,
|
100 |
+
F *y, int y_stride);
|
101 |
+
|
102 |
+
template <>
|
103 |
+
void cuda_mm8_seq<float>(int B, int N, int M,
|
104 |
+
float *x, int x_stride,
|
105 |
+
uint8_t *w, int w_stride,
|
106 |
+
float *mx, float *rx,
|
107 |
+
float *my, float *ry,
|
108 |
+
float *y, int y_stride) {
|
109 |
+
dim3 blockSize(1, 128);
|
110 |
+
dim3 gridSize((B + blockSize.x - 1) / blockSize.x, (M + blockSize.y - 1) / blockSize.y);
|
111 |
+
kernel_mm_seq_fp32i8<<<gridSize, blockSize>>>(
|
112 |
+
B, N, M, x, x_stride, w, w_stride,
|
113 |
+
mx, rx, my, ry, y, y_stride);
|
114 |
+
}
|
115 |
+
|
116 |
+
__global__ void kernel_mm_seq_fp16i8(
|
117 |
+
const int B, const int N, const int M,
|
118 |
+
const __half *__restrict__ const x, const int x_stride,
|
119 |
+
const uint8_t *__restrict__ const w, const int w_stride,
|
120 |
+
const __half *__restrict__ const mx,
|
121 |
+
const __half *__restrict__ const rx,
|
122 |
+
const __half *__restrict__ const my,
|
123 |
+
const __half *__restrict__ const ry,
|
124 |
+
__half *__restrict__ const y, const int y_stride) {
|
125 |
+
|
126 |
+
const int i = blockIdx.x * blockDim.x + threadIdx.x;
|
127 |
+
const int k = blockIdx.y * blockDim.y + threadIdx.y;
|
128 |
+
|
129 |
+
if (i < B && k < M) {
|
130 |
+
float y_local = 0;
|
131 |
+
for (int j = 0; j < N; ++j) {
|
132 |
+
y_local += __half2float(x[i * x_stride + j]) * (
|
133 |
+
(float(w[j * w_stride + k]) + 0.5f)
|
134 |
+
* __half2float(rx[k]) * __half2float(ry[j])
|
135 |
+
+ __half2float(mx[k]) + __half2float(my[j])
|
136 |
+
);
|
137 |
+
}
|
138 |
+
y[i * y_stride + k] = __float2half(y_local);
|
139 |
+
}
|
140 |
+
}
|
141 |
+
|
142 |
+
template <>
|
143 |
+
void cuda_mm8_seq<fp16>(int B, int N, int M,
|
144 |
+
fp16 *x, int x_stride,
|
145 |
+
uint8_t *w, int w_stride,
|
146 |
+
fp16 *mx, fp16 *rx,
|
147 |
+
fp16 *my, fp16 *ry,
|
148 |
+
fp16 *y, int y_stride) {
|
149 |
+
dim3 blockSize(1, 128);
|
150 |
+
dim3 gridSize((B + blockSize.x - 1) / blockSize.x, (M + blockSize.y - 1) / blockSize.y);
|
151 |
+
kernel_mm_seq_fp16i8<<<gridSize, blockSize>>>(
|
152 |
+
B, N, M, cast(x), x_stride, w, w_stride,
|
153 |
+
cast(mx), cast(rx), cast(my), cast(ry), cast(y), y_stride);
|
154 |
+
}
|
155 |
+
|
156 |
+
#define MM8_ONE_JSPLIT 24
|
157 |
+
#define MM8_ONE_TILE 1024
|
158 |
+
|
159 |
+
__global__ void kernel_mm_one_fp32i8(
|
160 |
+
const int N, const int M,
|
161 |
+
const float *__restrict__ const x,
|
162 |
+
const uint8_t *__restrict__ const w, const int w_stride,
|
163 |
+
const float *__restrict__ const mx,
|
164 |
+
const float *__restrict__ const rx,
|
165 |
+
const float *__restrict__ const my,
|
166 |
+
const float *__restrict__ const ry,
|
167 |
+
float *__restrict__ const y) {
|
168 |
+
|
169 |
+
const int k = blockIdx.y * blockDim.y + threadIdx.y;
|
170 |
+
const int j0 = min(N, blockIdx.x * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT));
|
171 |
+
const int j1 = min(N, (blockIdx.x + 1) * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT));
|
172 |
+
|
173 |
+
if (k < M) {
|
174 |
+
float y_local = 0;
|
175 |
+
for (int j = j0; j < j1; ++j) {
|
176 |
+
y_local += x[j] * (
|
177 |
+
(float(w[j * w_stride + k]) + 0.5f)
|
178 |
+
* rx[k] * ry[j] + mx[k] + my[j]
|
179 |
+
);
|
180 |
+
}
|
181 |
+
atomicAdd(&y[k], y_local);
|
182 |
+
}
|
183 |
+
}
|
184 |
+
|
185 |
+
template <typename F>
|
186 |
+
void cuda_mm8_one(int N, int M,
|
187 |
+
F *x,
|
188 |
+
uint8_t *w, int w_stride,
|
189 |
+
F *mx, F *rx,
|
190 |
+
F *my, F *ry,
|
191 |
+
float *y);
|
192 |
+
|
193 |
+
template <>
|
194 |
+
void cuda_mm8_one<float>(int N, int M,
|
195 |
+
float *x,
|
196 |
+
uint8_t *w, int w_stride,
|
197 |
+
float *mx, float *rx,
|
198 |
+
float *my, float *ry,
|
199 |
+
float *y) {
|
200 |
+
dim3 blockSize(1, MM8_ONE_TILE);
|
201 |
+
dim3 gridSize(MM8_ONE_JSPLIT, (M + blockSize.y - 1) / blockSize.y);
|
202 |
+
kernel_mm_one_fp32i8<<<gridSize, blockSize>>>(
|
203 |
+
N, M, x, w, w_stride,
|
204 |
+
mx, rx, my, ry, y);
|
205 |
+
}
|
206 |
+
|
207 |
+
__global__ void kernel_mm_one_fp16i8(
|
208 |
+
const int N, const int M,
|
209 |
+
const __half *__restrict__ const x,
|
210 |
+
const uint8_t *__restrict__ const w, const int w_stride,
|
211 |
+
const __half *__restrict__ const mx,
|
212 |
+
const __half *__restrict__ const rx,
|
213 |
+
const __half *__restrict__ const my,
|
214 |
+
const __half *__restrict__ const ry,
|
215 |
+
float *__restrict__ const y) {
|
216 |
+
|
217 |
+
const int k = blockIdx.y * blockDim.y + threadIdx.y;
|
218 |
+
const int j0 = min(N, blockIdx.x * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT));
|
219 |
+
const int j1 = min(N, (blockIdx.x + 1) * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT));
|
220 |
+
|
221 |
+
if (k < M) {
|
222 |
+
float y_local = 0;
|
223 |
+
for (int j = j0; j < j1; ++j) {
|
224 |
+
y_local += __half2float(x[j]) * (
|
225 |
+
(float(w[j * w_stride + k]) + 0.5f)
|
226 |
+
* __half2float(rx[k]) * __half2float(ry[j])
|
227 |
+
+ __half2float(mx[k]) + __half2float(my[j])
|
228 |
+
);
|
229 |
+
}
|
230 |
+
atomicAdd(&y[k], y_local);
|
231 |
+
}
|
232 |
+
}
|
233 |
+
|
234 |
+
template <>
|
235 |
+
void cuda_mm8_one<fp16>(int N, int M,
|
236 |
+
fp16 *x,
|
237 |
+
uint8_t *w, int w_stride,
|
238 |
+
fp16 *mx, fp16 *rx,
|
239 |
+
fp16 *my, fp16 *ry,
|
240 |
+
float *y) {
|
241 |
+
dim3 blockSize(1, MM8_ONE_TILE);
|
242 |
+
dim3 gridSize(MM8_ONE_JSPLIT, (M + blockSize.y - 1) / blockSize.y);
|
243 |
+
kernel_mm_one_fp16i8<<<gridSize, blockSize>>>(
|
244 |
+
N, M, cast(x), w, w_stride,
|
245 |
+
cast(mx), cast(rx), cast(my), cast(ry), y);
|
246 |
+
}
|
infer/rwkv/cuda/rwkv5.cu
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <stdio.h>
|
2 |
+
#include <assert.h>
|
3 |
+
#include "ATen/ATen.h"
|
4 |
+
typedef at::BFloat16 bf16;
|
5 |
+
typedef at::Half fp16;
|
6 |
+
typedef float fp32;
|
7 |
+
|
8 |
+
template <typename F>
|
9 |
+
__global__ void kernel_forward(const int B, const int T, const int C, const int H, float *__restrict__ _state,
|
10 |
+
const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u,
|
11 |
+
F *__restrict__ const _y)
|
12 |
+
{
|
13 |
+
const int b = blockIdx.x / H;
|
14 |
+
const int h = blockIdx.x % H;
|
15 |
+
const int i = threadIdx.x;
|
16 |
+
_w += h*_N_;
|
17 |
+
_u += h*_N_;
|
18 |
+
_state += h*_N_*_N_ + i*_N_; // wrong if B > 1 !!!
|
19 |
+
|
20 |
+
__shared__ float r[_N_], k[_N_], u[_N_], w[_N_];
|
21 |
+
|
22 |
+
float state[_N_];
|
23 |
+
#pragma unroll
|
24 |
+
for (int j = 0; j < _N_; j++)
|
25 |
+
state[j] = _state[j];
|
26 |
+
|
27 |
+
__syncthreads();
|
28 |
+
u[i] = float(_u[i]);
|
29 |
+
w[i] = _w[i];
|
30 |
+
__syncthreads();
|
31 |
+
|
32 |
+
for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C)
|
33 |
+
{
|
34 |
+
__syncthreads();
|
35 |
+
r[i] = float(_r[t]);
|
36 |
+
k[i] = float(_k[t]);
|
37 |
+
__syncthreads();
|
38 |
+
|
39 |
+
const float v = float(_v[t]);
|
40 |
+
float y = 0;
|
41 |
+
|
42 |
+
#pragma unroll
|
43 |
+
for (int j = 0; j < _N_; j+=4)
|
44 |
+
{
|
45 |
+
const float4& r_ = (float4&)(r[j]);
|
46 |
+
const float4& k_ = (float4&)(k[j]);
|
47 |
+
const float4& w_ = (float4&)(w[j]);
|
48 |
+
const float4& u_ = (float4&)(u[j]);
|
49 |
+
float4& s = (float4&)(state[j]);
|
50 |
+
float4 x;
|
51 |
+
|
52 |
+
x.x = k_.x * v;
|
53 |
+
x.y = k_.y * v;
|
54 |
+
x.z = k_.z * v;
|
55 |
+
x.w = k_.w * v;
|
56 |
+
|
57 |
+
y += r_.x * (u_.x * x.x + s.x);
|
58 |
+
y += r_.y * (u_.y * x.y + s.y);
|
59 |
+
y += r_.z * (u_.z * x.z + s.z);
|
60 |
+
y += r_.w * (u_.w * x.w + s.w);
|
61 |
+
|
62 |
+
s.x = s.x * w_.x + x.x;
|
63 |
+
s.y = s.y * w_.y + x.y;
|
64 |
+
s.z = s.z * w_.z + x.z;
|
65 |
+
s.w = s.w * w_.w + x.w;
|
66 |
+
}
|
67 |
+
_y[t] = F(y);
|
68 |
+
}
|
69 |
+
#pragma unroll
|
70 |
+
for (int j = 0; j < _N_; j++)
|
71 |
+
_state[j] = state[j];
|
72 |
+
}
|
73 |
+
|
74 |
+
void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y)
|
75 |
+
{
|
76 |
+
assert(H*_N_ == C);
|
77 |
+
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
|
78 |
+
}
|
79 |
+
void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y)
|
80 |
+
{
|
81 |
+
assert(H*_N_ == C);
|
82 |
+
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
|
83 |
+
}
|
84 |
+
void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y)
|
85 |
+
{
|
86 |
+
assert(H*_N_ == C);
|
87 |
+
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
|
88 |
+
}
|
infer/rwkv/cuda/rwkv5_op.cpp
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
#include "ATen/ATen.h"
|
3 |
+
#include <c10/cuda/CUDAGuard.h>
|
4 |
+
typedef at::BFloat16 bf16;
|
5 |
+
typedef at::Half fp16;
|
6 |
+
typedef float fp32;
|
7 |
+
|
8 |
+
void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y);
|
9 |
+
void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y);
|
10 |
+
void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y);
|
11 |
+
|
12 |
+
void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
|
13 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
|
14 |
+
cuda_forward_bf16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>());
|
15 |
+
}
|
16 |
+
void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
|
17 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
|
18 |
+
cuda_forward_fp16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp16>(), k.data_ptr<fp16>(), v.data_ptr<fp16>(), w.data_ptr<float>(), u.data_ptr<fp16>(), y.data_ptr<fp16>());
|
19 |
+
}
|
20 |
+
void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
|
21 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
|
22 |
+
cuda_forward_fp32(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp32>(), k.data_ptr<fp32>(), v.data_ptr<fp32>(), w.data_ptr<float>(), u.data_ptr<fp32>(), y.data_ptr<fp32>());
|
23 |
+
}
|
24 |
+
|
25 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
26 |
+
m.def("forward_bf16", &forward_bf16, "rwkv5 forward_bf16");
|
27 |
+
m.def("forward_fp16", &forward_fp16, "rwkv5 forward_fp16");
|
28 |
+
m.def("forward_fp32", &forward_fp32, "rwkv5 forward_fp32");
|
29 |
+
}
|
30 |
+
TORCH_LIBRARY(rwkv5, m) {
|
31 |
+
m.def("forward_bf16", forward_bf16);
|
32 |
+
m.def("forward_fp16", forward_fp16);
|
33 |
+
m.def("forward_fp32", forward_fp32);
|
34 |
+
}
|
infer/rwkv/cuda/rwkv6.cu
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <stdio.h>
|
2 |
+
#include <assert.h>
|
3 |
+
#include "ATen/ATen.h"
|
4 |
+
typedef at::BFloat16 bf16;
|
5 |
+
typedef at::Half fp16;
|
6 |
+
typedef float fp32;
|
7 |
+
|
8 |
+
template <typename F>
|
9 |
+
__global__ void kernel_forward(const int B, const int T, const int C, const int H, float *__restrict__ _state,
|
10 |
+
const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u,
|
11 |
+
F *__restrict__ const _y)
|
12 |
+
{
|
13 |
+
const int b = blockIdx.x / H;
|
14 |
+
const int h = blockIdx.x % H;
|
15 |
+
const int i = threadIdx.x;
|
16 |
+
_u += h*_N_;
|
17 |
+
_state += h*_N_*_N_ + i*_N_; // wrong if B > 1 !!!
|
18 |
+
|
19 |
+
__shared__ float r[_N_], k[_N_], u[_N_], w[_N_];
|
20 |
+
|
21 |
+
float state[_N_];
|
22 |
+
#pragma unroll
|
23 |
+
for (int j = 0; j < _N_; j++)
|
24 |
+
state[j] = _state[j];
|
25 |
+
|
26 |
+
__syncthreads();
|
27 |
+
u[i] = float(_u[i]);
|
28 |
+
__syncthreads();
|
29 |
+
|
30 |
+
for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C)
|
31 |
+
{
|
32 |
+
__syncthreads();
|
33 |
+
w[i] = _w[t];
|
34 |
+
r[i] = float(_r[t]);
|
35 |
+
k[i] = float(_k[t]);
|
36 |
+
__syncthreads();
|
37 |
+
|
38 |
+
const float v = float(_v[t]);
|
39 |
+
float y = 0;
|
40 |
+
|
41 |
+
#pragma unroll
|
42 |
+
for (int j = 0; j < _N_; j+=4)
|
43 |
+
{
|
44 |
+
const float4& r_ = (float4&)(r[j]);
|
45 |
+
const float4& k_ = (float4&)(k[j]);
|
46 |
+
const float4& w_ = (float4&)(w[j]);
|
47 |
+
const float4& u_ = (float4&)(u[j]);
|
48 |
+
float4& s = (float4&)(state[j]);
|
49 |
+
float4 x;
|
50 |
+
|
51 |
+
x.x = k_.x * v;
|
52 |
+
x.y = k_.y * v;
|
53 |
+
x.z = k_.z * v;
|
54 |
+
x.w = k_.w * v;
|
55 |
+
|
56 |
+
y += r_.x * (u_.x * x.x + s.x);
|
57 |
+
y += r_.y * (u_.y * x.y + s.y);
|
58 |
+
y += r_.z * (u_.z * x.z + s.z);
|
59 |
+
y += r_.w * (u_.w * x.w + s.w);
|
60 |
+
|
61 |
+
s.x = s.x * w_.x + x.x;
|
62 |
+
s.y = s.y * w_.y + x.y;
|
63 |
+
s.z = s.z * w_.z + x.z;
|
64 |
+
s.w = s.w * w_.w + x.w;
|
65 |
+
}
|
66 |
+
_y[t] = F(y);
|
67 |
+
}
|
68 |
+
#pragma unroll
|
69 |
+
for (int j = 0; j < _N_; j++)
|
70 |
+
_state[j] = state[j];
|
71 |
+
}
|
72 |
+
|
73 |
+
void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y)
|
74 |
+
{
|
75 |
+
assert(H*_N_ == C);
|
76 |
+
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
|
77 |
+
}
|
78 |
+
void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y)
|
79 |
+
{
|
80 |
+
assert(H*_N_ == C);
|
81 |
+
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
|
82 |
+
}
|
83 |
+
void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y)
|
84 |
+
{
|
85 |
+
assert(H*_N_ == C);
|
86 |
+
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
|
87 |
+
}
|
infer/rwkv/cuda/rwkv6_op.cpp
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
#include "ATen/ATen.h"
|
3 |
+
#include <c10/cuda/CUDAGuard.h>
|
4 |
+
typedef at::BFloat16 bf16;
|
5 |
+
typedef at::Half fp16;
|
6 |
+
typedef float fp32;
|
7 |
+
|
8 |
+
void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y);
|
9 |
+
void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y);
|
10 |
+
void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y);
|
11 |
+
|
12 |
+
void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
|
13 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
|
14 |
+
cuda_forward_bf16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>());
|
15 |
+
}
|
16 |
+
void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
|
17 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
|
18 |
+
cuda_forward_fp16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp16>(), k.data_ptr<fp16>(), v.data_ptr<fp16>(), w.data_ptr<float>(), u.data_ptr<fp16>(), y.data_ptr<fp16>());
|
19 |
+
}
|
20 |
+
void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
|
21 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
|
22 |
+
cuda_forward_fp32(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp32>(), k.data_ptr<fp32>(), v.data_ptr<fp32>(), w.data_ptr<float>(), u.data_ptr<fp32>(), y.data_ptr<fp32>());
|
23 |
+
}
|
24 |
+
|
25 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
26 |
+
m.def("forward_bf16", &forward_bf16, "rwkv6 forward_bf16");
|
27 |
+
m.def("forward_fp16", &forward_fp16, "rwkv6 forward_fp16");
|
28 |
+
m.def("forward_fp32", &forward_fp32, "rwkv6 forward_fp32");
|
29 |
+
}
|
30 |
+
TORCH_LIBRARY(rwkv6, m) {
|
31 |
+
m.def("forward_bf16", forward_bf16);
|
32 |
+
m.def("forward_fp16", forward_fp16);
|
33 |
+
m.def("forward_fp32", forward_fp32);
|
34 |
+
}
|
infer/rwkv/cuda/rwkv7.cu
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <stdio.h>
|
2 |
+
#include <assert.h>
|
3 |
+
#include "ATen/ATen.h"
|
4 |
+
|
5 |
+
typedef at::Half fp16;
|
6 |
+
typedef at::BFloat16 bf16;
|
7 |
+
typedef float fp32;
|
8 |
+
|
9 |
+
template <typename F>
|
10 |
+
__global__ void kernel_forward(const int B, const int T, const int C, const int H,
|
11 |
+
float *__restrict__ _state, const F *__restrict__ const _r, const F *__restrict__ const _w, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _a, const F *__restrict__ const _b,
|
12 |
+
F *__restrict__ const _y)
|
13 |
+
{
|
14 |
+
const int e = blockIdx.x / H;
|
15 |
+
const int h = blockIdx.x % H;
|
16 |
+
const int i = threadIdx.x;
|
17 |
+
_state += h*_N_*_N_ + i*_N_; // wrong if B > 1 !!!
|
18 |
+
|
19 |
+
float state[_N_];
|
20 |
+
#pragma unroll
|
21 |
+
for (int j = 0; j < _N_; j++)
|
22 |
+
state[j] = _state[j];
|
23 |
+
|
24 |
+
__shared__ float r[_N_], k[_N_], w[_N_], a[_N_], b[_N_];
|
25 |
+
|
26 |
+
for (int _t = 0; _t < T; _t++)
|
27 |
+
{
|
28 |
+
const int t = e*T*C + h*_N_ + i + _t * C;
|
29 |
+
__syncthreads();
|
30 |
+
r[i] = float(_r[t]);
|
31 |
+
w[i] = __expf(-__expf(float(_w[t])));
|
32 |
+
k[i] = float(_k[t]);
|
33 |
+
a[i] = float(_a[t]);
|
34 |
+
b[i] = float(_b[t]);
|
35 |
+
__syncthreads();
|
36 |
+
|
37 |
+
float sa = 0;
|
38 |
+
#pragma unroll
|
39 |
+
for (int j = 0; j < _N_; j++)
|
40 |
+
{
|
41 |
+
sa += a[j] * state[j];
|
42 |
+
}
|
43 |
+
|
44 |
+
float vv = float(_v[t]);
|
45 |
+
float y = 0;
|
46 |
+
#pragma unroll
|
47 |
+
for (int j = 0; j < _N_; j++)
|
48 |
+
{
|
49 |
+
float& s = state[j];
|
50 |
+
s = s * w[j] + k[j] * vv + sa * b[j];
|
51 |
+
y += s * r[j];
|
52 |
+
}
|
53 |
+
_y[t] = F(y);
|
54 |
+
}
|
55 |
+
#pragma unroll
|
56 |
+
for (int j = 0; j < _N_; j++)
|
57 |
+
_state[j] = state[j];
|
58 |
+
}
|
59 |
+
|
60 |
+
void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16* w, bf16 *k, bf16 *v, bf16 *a, bf16 *b, bf16 *y)
|
61 |
+
{
|
62 |
+
assert(H*_N_ == C);
|
63 |
+
assert(B == 1); // only for B=1
|
64 |
+
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, w, k, v, a, b, y);
|
65 |
+
}
|
66 |
+
void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16* w, fp16 *k, fp16 *v, fp16 *a, fp16 *b, fp16 *y)
|
67 |
+
{
|
68 |
+
assert(H*_N_ == C);
|
69 |
+
assert(B == 1); // only for B=1
|
70 |
+
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, w, k, v, a, b, y);
|
71 |
+
}
|
72 |
+
void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32* w, fp32 *k, fp32 *v, fp32 *a, fp32 *b, fp32 *y)
|
73 |
+
{
|
74 |
+
assert(H*_N_ == C);
|
75 |
+
assert(B == 1); // only for B=1
|
76 |
+
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, w, k, v, a, b, y);
|
77 |
+
}
|
infer/rwkv/cuda/rwkv7_op.cpp
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
#include "ATen/ATen.h"
|
3 |
+
|
4 |
+
typedef at::Half fp16;
|
5 |
+
typedef at::BFloat16 bf16;
|
6 |
+
typedef float fp32;
|
7 |
+
|
8 |
+
void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *w, bf16 *k, bf16 *v, bf16 *a, bf16 *b, bf16 *y);
|
9 |
+
void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *w, fp16 *k, fp16 *v, fp16 *a, fp16 *b, fp16 *y);
|
10 |
+
void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *w, fp32 *k, fp32 *v, fp32 *a, fp32 *b, fp32 *y);
|
11 |
+
|
12 |
+
void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &y) {
|
13 |
+
cuda_forward_bf16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<bf16>(), w.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), a.data_ptr<bf16>(), b.data_ptr<bf16>(), y.data_ptr<bf16>());
|
14 |
+
}
|
15 |
+
void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &y) {
|
16 |
+
cuda_forward_fp16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp16>(), w.data_ptr<fp16>(), k.data_ptr<fp16>(), v.data_ptr<fp16>(), a.data_ptr<fp16>(), b.data_ptr<fp16>(), y.data_ptr<fp16>());
|
17 |
+
}
|
18 |
+
void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &y) {
|
19 |
+
cuda_forward_fp32(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp32>(), w.data_ptr<fp32>(), k.data_ptr<fp32>(), v.data_ptr<fp32>(), a.data_ptr<fp32>(), b.data_ptr<fp32>(), y.data_ptr<fp32>());
|
20 |
+
}
|
21 |
+
|
22 |
+
TORCH_LIBRARY(wkv7s, m) {
|
23 |
+
m.def("forward_bf16", forward_bf16);
|
24 |
+
m.def("forward_fp16", forward_fp16);
|
25 |
+
m.def("forward_fp32", forward_fp32);
|
26 |
+
}
|
infer/rwkv/cuda/wrapper.cpp
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
#include "ATen/ATen.h"
|
3 |
+
#include <iostream>
|
4 |
+
#include <c10/cuda/CUDAGuard.h>
|
5 |
+
|
6 |
+
typedef at::Half fp16;
|
7 |
+
|
8 |
+
template <typename F>
|
9 |
+
void cuda_wkv_forward(int B, int T, int C,
|
10 |
+
float *w, float *u, F *k, F *v, F *y,
|
11 |
+
float *aa, float *bb, float *pp);
|
12 |
+
template <typename F>
|
13 |
+
void cuda_mm8_seq(int B, int N, int M,
|
14 |
+
F *x, int x_stride,
|
15 |
+
uint8_t *w, int w_stride,
|
16 |
+
F *mx, F *rx,
|
17 |
+
F *my, F *ry,
|
18 |
+
F *y, int y_stride);
|
19 |
+
template <typename F>
|
20 |
+
void cuda_mm8_one(int N, int M,
|
21 |
+
F *x,
|
22 |
+
uint8_t *w, int w_stride,
|
23 |
+
F *mx, F *rx,
|
24 |
+
F *my, F *ry,
|
25 |
+
float *y);
|
26 |
+
|
27 |
+
void wkv_forward(int64_t B, int64_t T, int64_t C,
|
28 |
+
torch::Tensor &w, torch::Tensor &u,
|
29 |
+
torch::Tensor &k, torch::Tensor &v, torch::Tensor &y,
|
30 |
+
torch::Tensor &aa, torch::Tensor &bb, torch::Tensor &pp) {
|
31 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(w));
|
32 |
+
switch (k.scalar_type()) {
|
33 |
+
case c10::ScalarType::Half:
|
34 |
+
cuda_wkv_forward(B, T, C,
|
35 |
+
w.data_ptr<float>(), u.data_ptr<float>(),
|
36 |
+
k.data_ptr<fp16>(), v.data_ptr<fp16>(), y.data_ptr<fp16>(),
|
37 |
+
aa.data_ptr<float>(), bb.data_ptr<float>(), pp.data_ptr<float>());
|
38 |
+
break;
|
39 |
+
case c10::ScalarType::Float:
|
40 |
+
cuda_wkv_forward(B, T, C,
|
41 |
+
w.data_ptr<float>(), u.data_ptr<float>(),
|
42 |
+
k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>(),
|
43 |
+
aa.data_ptr<float>(), bb.data_ptr<float>(), pp.data_ptr<float>());
|
44 |
+
break;
|
45 |
+
default:
|
46 |
+
assert(false && "Only FP16 and FP32 are currently supported");
|
47 |
+
}
|
48 |
+
}
|
49 |
+
|
50 |
+
void mm8_seq(int64_t B, int64_t N, int64_t M,
|
51 |
+
torch::Tensor &x, torch::Tensor &w,
|
52 |
+
torch::Tensor &mx, torch::Tensor &rx,
|
53 |
+
torch::Tensor &my, torch::Tensor &ry,
|
54 |
+
torch::Tensor &y) {
|
55 |
+
assert(x.stride(1) == 1);
|
56 |
+
assert(w.stride(1) == 1);
|
57 |
+
assert(mx.stride(0) == 1 && rx.stride(0) == 1);
|
58 |
+
assert(my.stride(0) == 1 && ry.stride(0) == 1);
|
59 |
+
assert(y.stride(1) == 1);
|
60 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(w));
|
61 |
+
switch (x.scalar_type()) {
|
62 |
+
case c10::ScalarType::Half:
|
63 |
+
cuda_mm8_seq(
|
64 |
+
B, N, M,
|
65 |
+
x.data_ptr<fp16>(), x.stride(0),
|
66 |
+
w.data_ptr<uint8_t>(), w.stride(0),
|
67 |
+
mx.data_ptr<fp16>(), rx.data_ptr<fp16>(),
|
68 |
+
my.data_ptr<fp16>(), ry.data_ptr<fp16>(),
|
69 |
+
y.data_ptr<fp16>(), y.stride(0));
|
70 |
+
break;
|
71 |
+
case c10::ScalarType::Float:
|
72 |
+
cuda_mm8_seq(
|
73 |
+
B, N, M,
|
74 |
+
x.data_ptr<float>(), x.stride(0),
|
75 |
+
w.data_ptr<uint8_t>(), w.stride(0),
|
76 |
+
mx.data_ptr<float>(), rx.data_ptr<float>(),
|
77 |
+
my.data_ptr<float>(), ry.data_ptr<float>(),
|
78 |
+
y.data_ptr<float>(), y.stride(0));
|
79 |
+
break;
|
80 |
+
default:
|
81 |
+
assert(false && "Only FP16 and FP32 are currently supported");
|
82 |
+
}
|
83 |
+
}
|
84 |
+
void mm8_one(int64_t N, int64_t M,
|
85 |
+
torch::Tensor &x, torch::Tensor &w,
|
86 |
+
torch::Tensor &mx, torch::Tensor &rx,
|
87 |
+
torch::Tensor &my, torch::Tensor &ry,
|
88 |
+
torch::Tensor &y) {
|
89 |
+
assert(x.stride(0) == 1);
|
90 |
+
assert(w.stride(1) == 1);
|
91 |
+
assert(mx.stride(0) == 1 && rx.stride(0) == 1);
|
92 |
+
assert(my.stride(0) == 1 && ry.stride(0) == 1);
|
93 |
+
assert(y.stride(0) == 1);
|
94 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(w));
|
95 |
+
switch (x.scalar_type()) {
|
96 |
+
case c10::ScalarType::Half:
|
97 |
+
cuda_mm8_one(
|
98 |
+
N, M,
|
99 |
+
x.data_ptr<fp16>(),
|
100 |
+
w.data_ptr<uint8_t>(), w.stride(0),
|
101 |
+
mx.data_ptr<fp16>(), rx.data_ptr<fp16>(),
|
102 |
+
my.data_ptr<fp16>(), ry.data_ptr<fp16>(),
|
103 |
+
y.data_ptr<float>());
|
104 |
+
break;
|
105 |
+
case c10::ScalarType::Float:
|
106 |
+
cuda_mm8_one(
|
107 |
+
N, M,
|
108 |
+
x.data_ptr<float>(),
|
109 |
+
w.data_ptr<uint8_t>(), w.stride(0),
|
110 |
+
mx.data_ptr<float>(), rx.data_ptr<float>(),
|
111 |
+
my.data_ptr<float>(), ry.data_ptr<float>(),
|
112 |
+
y.data_ptr<float>());
|
113 |
+
break;
|
114 |
+
default:
|
115 |
+
assert(false && "Only FP16 and FP32 are currently supported");
|
116 |
+
}
|
117 |
+
}
|
118 |
+
|
119 |
+
using torch::Tensor;
|
120 |
+
|
121 |
+
#ifndef DISABLE_CUBLAS_GEMM
|
122 |
+
void gemm_fp16_cublas(Tensor a, Tensor b, Tensor c);
|
123 |
+
#endif
|
124 |
+
|
125 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
126 |
+
m.def("wkv_forward", &wkv_forward, "wkv forward");
|
127 |
+
m.def("mm8_seq", &mm8_seq, "mm8 seq");
|
128 |
+
m.def("mm8_one", &mm8_one, "mm8 one");
|
129 |
+
#ifndef DISABLE_CUBLAS_GEMM
|
130 |
+
m.def("gemm_fp16_cublas", &gemm_fp16_cublas, "gemv fp16 cublas");
|
131 |
+
#endif
|
132 |
+
}
|
133 |
+
|
134 |
+
TORCH_LIBRARY(rwkv, m) {
|
135 |
+
m.def("wkv_forward", wkv_forward);
|
136 |
+
m.def("mm8_seq", mm8_seq);
|
137 |
+
m.def("mm8_one", mm8_one);
|
138 |
+
#ifndef DISABLE_CUBLAS_GEMM
|
139 |
+
m.def("gemm_fp16_cublas", gemm_fp16_cublas);
|
140 |
+
#endif
|
141 |
+
}
|
infer/rwkv/model.py
ADDED
@@ -0,0 +1,469 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
########################################################################################################
|
2 |
+
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
3 |
+
########################################################################################################
|
4 |
+
|
5 |
+
from typing import Optional
|
6 |
+
import types, gc, os, time, re, math
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from torch.nn import functional as F
|
10 |
+
torch.backends.cudnn.benchmark = True
|
11 |
+
torch.backends.cudnn.allow_tf32 = True
|
12 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
13 |
+
current_path = os.path.dirname(os.path.abspath(__file__))
|
14 |
+
|
15 |
+
########################################################################################################
|
16 |
+
|
17 |
+
if os.environ.get('RWKV_JIT_ON') != '0':
|
18 |
+
os.environ["RWKV_JIT_ON"] = '1'
|
19 |
+
MyModule = torch.jit.ScriptModule
|
20 |
+
MyFunction = torch.jit.script_method
|
21 |
+
MyStatic = torch.jit.script
|
22 |
+
else:
|
23 |
+
MyModule = torch.nn.Module
|
24 |
+
def __nop(ob):
|
25 |
+
return ob
|
26 |
+
MyFunction = __nop
|
27 |
+
MyStatic = __nop
|
28 |
+
|
29 |
+
if os.environ.get('RWKV_CUDA_ON') == '1':
|
30 |
+
from torch.utils.cpp_extension import load
|
31 |
+
try:
|
32 |
+
load(
|
33 |
+
name=f"wkv_cuda",
|
34 |
+
sources=[f"{current_path}/cuda/wrapper.cpp", f"{current_path}/cuda/operators.cu", f"{current_path}/cuda/gemm_fp16_cublas.cpp"],
|
35 |
+
verbose=True,
|
36 |
+
extra_ldflags=["cublas.lib" if os.name == "nt" else ""],
|
37 |
+
extra_cuda_cflags=["--use_fast_math", "-O3", "--extra-device-vectorization"],
|
38 |
+
is_python_module=False)
|
39 |
+
DISABLE_CUBLAS_GEMM = False
|
40 |
+
except:
|
41 |
+
print("Failed to build cuBLAS matmul, falling back to torch.matmul. Small model with fp16 will overflow.")
|
42 |
+
load(
|
43 |
+
name=f"wkv_cuda",
|
44 |
+
sources=[f"{current_path}/cuda/wrapper.cpp", f"{current_path}/cuda/operators.cu"],
|
45 |
+
verbose=True,
|
46 |
+
extra_cuda_cflags=["--use_fast_math", "-O3", "--extra-device-vectorization"],
|
47 |
+
extra_cflags=["-DDISABLE_CUBLAS_GEMM"],
|
48 |
+
is_python_module=False)
|
49 |
+
DISABLE_CUBLAS_GEMM = True
|
50 |
+
|
51 |
+
@MyStatic
|
52 |
+
def cuda_wkv(T: int, C: int, w, u, k, v, aa, bb, pp):
|
53 |
+
assert 1 * C % min(C, 32) == 0
|
54 |
+
assert k.dtype == v.dtype == torch.float16 or k.dtype == v.dtype == torch.float32
|
55 |
+
assert w.dtype == u.dtype == aa.dtype == bb.dtype == pp.dtype == torch.float32
|
56 |
+
w = w.contiguous()
|
57 |
+
u = u.contiguous()
|
58 |
+
k = k.contiguous()
|
59 |
+
v = v.contiguous()
|
60 |
+
y = torch.empty((T, C), device=w.device, memory_format=torch.contiguous_format, dtype=k.dtype)
|
61 |
+
torch.ops.rwkv.wkv_forward(1, T, C, w, u, k, v, y, aa, bb, pp)
|
62 |
+
return y, aa, bb, pp
|
63 |
+
@MyStatic
|
64 |
+
def cuda_mm8_seq(B: int, N: int, M: int, x, w, mx, rx, my, ry):
|
65 |
+
assert x.dtype == mx.dtype == rx.dtype == my.dtype == ry.dtype
|
66 |
+
assert x.dtype == torch.float32 or x.dtype == torch.float16
|
67 |
+
assert w.dtype == torch.uint8
|
68 |
+
assert x.shape == (B, N)
|
69 |
+
assert w.shape == (N, M)
|
70 |
+
assert rx.shape == mx.shape == (M,)
|
71 |
+
assert ry.shape == my.shape == (N, 1)
|
72 |
+
y = torch.empty((B, M), device=w.device, dtype=x.dtype)
|
73 |
+
torch.ops.rwkv.mm8_seq(B, N, M, x, w, mx, rx, my, ry, y)
|
74 |
+
return y
|
75 |
+
@MyStatic
|
76 |
+
def cuda_mm8_one(N: int, M: int, x, w, mx, rx, my, ry):
|
77 |
+
assert x.dtype == mx.dtype == rx.dtype == my.dtype == ry.dtype
|
78 |
+
assert x.dtype == torch.float32 or x.dtype == torch.float16
|
79 |
+
assert w.dtype == torch.uint8
|
80 |
+
assert x.shape == (N,)
|
81 |
+
assert w.shape == (N, M)
|
82 |
+
assert rx.shape == mx.shape == (M,)
|
83 |
+
assert ry.shape == my.shape == (N, 1)
|
84 |
+
y = torch.zeros((M,), device=w.device, dtype=torch.float32)
|
85 |
+
torch.ops.rwkv.mm8_one(N, M, x, w, mx, rx, my, ry, y)
|
86 |
+
return y.to(dtype=x.dtype)
|
87 |
+
else:
|
88 |
+
os.environ["RWKV_CUDA_ON"] = '0'
|
89 |
+
|
90 |
+
|
91 |
+
@MyStatic
|
92 |
+
def torch_mm8_seq(x, w, mx, rx, my, ry):
|
93 |
+
return x @ ((w.to(dtype=x.dtype) + 0.5) * ry * rx + my + mx)
|
94 |
+
|
95 |
+
@MyStatic
|
96 |
+
def torch_mm8_one(x, w, mx, rx, my, ry):
|
97 |
+
return x @ ((w.to(dtype=x.dtype) + 0.5) * ry * rx + my + mx)
|
98 |
+
|
99 |
+
if os.environ.get('RWKV_CUDA_ON') == '1':
|
100 |
+
@MyStatic
|
101 |
+
def mm8_seq(x, w, mx, rx, my, ry):
|
102 |
+
if w.device.type == 'cuda' and x.dtype == torch.float16:
|
103 |
+
B, N, M = x.shape[0], w.shape[0], w.shape[1]
|
104 |
+
return cuda_mm8_seq(B, N, M, x, w, mx, rx, my, ry)
|
105 |
+
else:
|
106 |
+
return torch_mm8_seq(x, w, mx, rx, my, ry)
|
107 |
+
@MyStatic
|
108 |
+
def mm8_one(x, w, mx, rx, my, ry):
|
109 |
+
if w.device.type == 'cuda':
|
110 |
+
N, M = w.shape[0], w.shape[1]
|
111 |
+
return cuda_mm8_one(N, M, x, w, mx, rx, my, ry)
|
112 |
+
else:
|
113 |
+
return torch_mm8_one(x, w, mx, rx, my, ry)
|
114 |
+
else:
|
115 |
+
@MyStatic
|
116 |
+
def mm8_seq(x, w, mx, rx, my, ry):
|
117 |
+
return torch_mm8_seq(x, w, mx, rx, my, ry)
|
118 |
+
@MyStatic
|
119 |
+
def mm8_one(x, w, mx, rx, my, ry):
|
120 |
+
return torch_mm8_one(x, w, mx, rx, my, ry)
|
121 |
+
|
122 |
+
def mm8(x: torch.Tensor, w: torch.Tensor, mx: torch.Tensor, rx: torch.Tensor, my: torch.Tensor, ry: torch.Tensor):
|
123 |
+
if len(x.shape) == 1:
|
124 |
+
return mm8_one(x, w, mx, rx, my, ry)
|
125 |
+
return mm8_seq(x, w, mx, rx, my, ry)
|
126 |
+
|
127 |
+
def matmul(a, b, mx: Optional[torch.Tensor]=None, rx: Optional[torch.Tensor]=None, my: Optional[torch.Tensor]=None, ry: Optional[torch.Tensor]=None, output_dtype: Optional[torch.dtype]=None) -> torch.Tensor:
|
128 |
+
if output_dtype is None:
|
129 |
+
output_dtype = a.dtype
|
130 |
+
if b.dtype in [torch.float16, torch.bfloat16, torch.float32]:
|
131 |
+
assert a.dtype == b.dtype
|
132 |
+
return matmul_float(a, b, output_dtype=output_dtype)
|
133 |
+
elif b.dtype == torch.uint8:
|
134 |
+
assert mx is not None
|
135 |
+
assert rx is not None
|
136 |
+
assert my is not None
|
137 |
+
assert ry is not None
|
138 |
+
return mm8(a, b, mx, rx, my, ry).to(output_dtype)
|
139 |
+
else:
|
140 |
+
raise ValueError("Unsupported dtype")
|
141 |
+
|
142 |
+
|
143 |
+
if os.environ.get('RWKV_CUDA_ON') == '1' and not DISABLE_CUBLAS_GEMM:
|
144 |
+
def matmul_float(a, b, output_dtype: Optional[torch.dtype]=None):
|
145 |
+
if output_dtype is None:
|
146 |
+
output_dtype = a.dtype
|
147 |
+
if a.dtype == b.dtype == torch.float16 and a.device.type == 'cuda':
|
148 |
+
if len(a.shape) == 1:
|
149 |
+
assert len(b.shape) == 2
|
150 |
+
c = torch.empty((b.shape[-1],), dtype=output_dtype, device=a.device)
|
151 |
+
a = a.unsqueeze(0)
|
152 |
+
else:
|
153 |
+
assert len(a.shape) == len(b.shape)
|
154 |
+
assert len(a.shape) == 2 or len(a.shape) == 3
|
155 |
+
# torch.empty((*a.shape[:-1], b.shape[-1])) doesn't work with jit
|
156 |
+
if len(a.shape) == 2:
|
157 |
+
c = torch.empty((a.shape[0], b.shape[-1]), dtype=output_dtype, device=a.device)
|
158 |
+
else:
|
159 |
+
c = torch.empty((a.shape[0], a.shape[1], b.shape[-1]), dtype=output_dtype, device=a.device)
|
160 |
+
torch.ops.rwkv.gemm_fp16_cublas(a, b, c)
|
161 |
+
return c
|
162 |
+
else:
|
163 |
+
return (a @ b).to(output_dtype)
|
164 |
+
|
165 |
+
else:
|
166 |
+
def matmul_float(a, b, output_dtype: Optional[torch.dtype]=None):
|
167 |
+
return (a @ b).to(output_dtype)
|
168 |
+
|
169 |
+
|
170 |
+
if os.environ.get('RWKV_DML_ON') == '1':
|
171 |
+
import torch_directml
|
172 |
+
print("PyTorch with DirectML Enabled")
|
173 |
+
|
174 |
+
|
175 |
+
print(f'\n### RWKV-7 "Goose" enabled ###\n')
|
176 |
+
|
177 |
+
torch.backends.cudnn.benchmark = True
|
178 |
+
torch.backends.cudnn.allow_tf32 = True
|
179 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
180 |
+
# torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
|
181 |
+
# torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True
|
182 |
+
torch._C._jit_set_autocast_mode(False)
|
183 |
+
|
184 |
+
MyModule = torch.jit.ScriptModule
|
185 |
+
MyFunction = torch.jit.script_method
|
186 |
+
MyStatic = torch.jit.script
|
187 |
+
from typing import List
|
188 |
+
|
189 |
+
DTYPE = None
|
190 |
+
DEVICE = None
|
191 |
+
HEAD_SIZE = 64
|
192 |
+
|
193 |
+
if os.environ.get('RWKV_CUDA_ON') == '1':
|
194 |
+
from torch.utils.cpp_extension import load
|
195 |
+
load(name="wkv7s", sources=[f"{current_path}/cuda/rwkv7_op.cpp", f"{current_path}/cuda/rwkv7.cu"], is_python_module=False,
|
196 |
+
verbose=True, extra_cuda_cflags=["-res-usage", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-D_N_={HEAD_SIZE}"])
|
197 |
+
class WKV_7(torch.autograd.Function):
|
198 |
+
@staticmethod
|
199 |
+
def forward(ctx, state, r, w, k, v, a, b):
|
200 |
+
with torch.no_grad():
|
201 |
+
T, C = r.size()
|
202 |
+
H = C // HEAD_SIZE
|
203 |
+
N = HEAD_SIZE
|
204 |
+
assert HEAD_SIZE == C // H
|
205 |
+
assert all(x.dtype == DTYPE for x in [r,w,k,v,a,b])
|
206 |
+
assert all(x.is_contiguous() for x in [r,w,k,v,a,b])
|
207 |
+
y = torch.empty((T, C), device=DEVICE, dtype=r.dtype, requires_grad=False, memory_format=torch.contiguous_format)
|
208 |
+
|
209 |
+
if DTYPE == torch.float16:
|
210 |
+
torch.ops.wkv7s.forward_fp16(1, T, C, H, state, r, w, k, v, a, b, y)
|
211 |
+
elif DTYPE == torch.bfloat16:
|
212 |
+
torch.ops.wkv7s.forward_bf16(1, T, C, H, state, r, w, k, v, a, b, y)
|
213 |
+
elif DTYPE == torch.float32:
|
214 |
+
torch.ops.wkv7s.forward_fp32(1, T, C, H, state, r, w, k, v, a, b, y)
|
215 |
+
|
216 |
+
return y
|
217 |
+
def RWKV7_OP(state, r, w, k, v, a, b):
|
218 |
+
return WKV_7.apply(state, r, w, k, v, a, b)
|
219 |
+
|
220 |
+
########################################################################################################
|
221 |
+
|
222 |
+
class RWKV(MyModule):
|
223 |
+
def __init__(self, model, strategy):
|
224 |
+
global DTYPE, DEVICE
|
225 |
+
super().__init__()
|
226 |
+
self.eval()
|
227 |
+
args = types.SimpleNamespace()
|
228 |
+
self.args = args
|
229 |
+
# args.MODEL_NAME = model
|
230 |
+
|
231 |
+
# print(f'Loading {model} ({strategy})\n')
|
232 |
+
|
233 |
+
ss = strategy.split(' ')
|
234 |
+
DEVICE = ss[0]
|
235 |
+
if ss[1] == 'fp16':
|
236 |
+
DTYPE = torch.half
|
237 |
+
elif ss[1] == 'fp32':
|
238 |
+
DTYPE = torch.float32
|
239 |
+
elif ss[1] == 'bf16':
|
240 |
+
DTYPE = torch.bfloat16
|
241 |
+
else:
|
242 |
+
assert False, "currently rwkv7 strategy must be: cuda/cpu fp16/fp32/bf16"
|
243 |
+
|
244 |
+
# self.z = torch.load(args.MODEL_NAME + '.pth', map_location=DEVICE)
|
245 |
+
self.z = model
|
246 |
+
|
247 |
+
# for k,v in self.z.items():
|
248 |
+
# print(k, v.shape)
|
249 |
+
z = self.z
|
250 |
+
|
251 |
+
self.n_head, self.head_size = z['blocks.0.att.r_k'].shape
|
252 |
+
args.head_size = self.head_size
|
253 |
+
args.vocab_size, args.n_embd = z['emb.weight'].shape
|
254 |
+
|
255 |
+
args.n_layer = 0
|
256 |
+
keys = list(z.keys())
|
257 |
+
for k in keys:
|
258 |
+
layer_id = int(k.split('.')[1]) if ('blocks.' in k) else 0
|
259 |
+
args.n_layer = max(args.n_layer, layer_id+1)
|
260 |
+
if 'key.weight' in k or 'value.weight' in k or 'receptance.weight' in k or 'output.weight' in k or 'head.weight' in k:
|
261 |
+
z[k] = z[k].t()
|
262 |
+
z[k] = z[k].squeeze().to(dtype=DTYPE)
|
263 |
+
if k.endswith('att.r_k'): z[k] = z[k].flatten()
|
264 |
+
|
265 |
+
self.n_embd = args.n_embd
|
266 |
+
self.n_layer = args.n_layer
|
267 |
+
|
268 |
+
# z['emb.weight'] = F.layer_norm(z['emb.weight'], (args.n_embd,), weight=z['blocks.0.ln0.weight'], bias=z['blocks.0.ln0.bias'])
|
269 |
+
z['blocks.0.att.v0'] = z['blocks.0.att.a0'] # actually ignored
|
270 |
+
z['blocks.0.att.v1'] = z['blocks.0.att.a1'] # actually ignored
|
271 |
+
z['blocks.0.att.v2'] = z['blocks.0.att.a2'] # actually ignored
|
272 |
+
|
273 |
+
def forward(self, idx, state, full_output=False, sign=None):
|
274 |
+
if state == None:
|
275 |
+
state = [None for _ in range(self.args.n_layer * 3)]
|
276 |
+
for i in range(self.args.n_layer): # state: 0=att_x_prev 1=att_kv 2=ffn_x_prev
|
277 |
+
state[i*3+0] = torch.zeros(self.args.n_embd, dtype=DTYPE, requires_grad=False, device=DEVICE)
|
278 |
+
state[i*3+1] = torch.zeros((self.args.n_embd // self.args.head_size, self.args.head_size, self.args.head_size), dtype=torch.float, requires_grad=False, device=DEVICE)
|
279 |
+
state[i*3+2] = torch.zeros(self.args.n_embd, dtype=DTYPE, requires_grad=False, device=DEVICE)
|
280 |
+
|
281 |
+
x = self.z['emb.weight'][idx]
|
282 |
+
if isinstance(sign, torch.Tensor):
|
283 |
+
sign = sign.squeeze(0)
|
284 |
+
# sign = F.layer_norm(sign, (self.args.n_embd,), weight=self.z['blocks.0.ln0.weight'], bias=self.z['blocks.0.ln0.bias'])
|
285 |
+
|
286 |
+
x = torch.cat((sign,x.to(DEVICE)), dim=0)
|
287 |
+
|
288 |
+
x = F.layer_norm(x, (self.args.n_embd,), weight=self.z['blocks.0.ln0.weight'], bias=self.z['blocks.0.ln0.bias'])
|
289 |
+
# if isinstance(sign, torch.Tensor):
|
290 |
+
# print(x)
|
291 |
+
|
292 |
+
if type(idx) is list:
|
293 |
+
if len(idx) > 1:
|
294 |
+
return self.forward_seq(x, state, full_output)
|
295 |
+
else:
|
296 |
+
return self.forward_one(x[0], state)
|
297 |
+
else:
|
298 |
+
return self.forward_one(x, state)
|
299 |
+
|
300 |
+
@MyFunction
|
301 |
+
def forward_one(self, x, state:List[torch.Tensor]):
|
302 |
+
with torch.no_grad():
|
303 |
+
z = self.z
|
304 |
+
#x = z['emb.weight'][idx]
|
305 |
+
|
306 |
+
v_first = torch.empty_like(x)
|
307 |
+
for i in range(self.n_layer):
|
308 |
+
bbb = f'blocks.{i}.'
|
309 |
+
att = f'blocks.{i}.att.'
|
310 |
+
ffn = f'blocks.{i}.ffn.'
|
311 |
+
|
312 |
+
xx = F.layer_norm(x, (self.n_embd,), weight=z[bbb+'ln1.weight'], bias=z[bbb+'ln1.bias'])
|
313 |
+
|
314 |
+
xx, state[i*3+0], state[i*3+1], v_first = RWKV_x070_TMix_one(i, self.n_head, self.head_size, xx, state[i*3+0], v_first, state[i*3+1],
|
315 |
+
z[att+'x_r'], z[att+'x_w'], z[att+'x_k'], z[att+'x_v'], z[att+'x_a'], z[att+'x_g'],
|
316 |
+
z[att+'w0'], z[att+'w1'], z[att+'w2'], z[att+'a0'], z[att+'a1'], z[att+'a2'], z[att+'v0'], z[att+'v1'], z[att+'v2'],
|
317 |
+
z[att+'g1'], z[att+'g2'], z[att+'k_k'], z[att+'k_a'], z[att+'r_k'],
|
318 |
+
z[att+'receptance.weight'], z[att+'key.weight'], z[att+'value.weight'], z[att+'output.weight'],
|
319 |
+
z[att+'ln_x.weight'], z[att+'ln_x.bias'])
|
320 |
+
x = x + xx
|
321 |
+
|
322 |
+
xx = F.layer_norm(x, (self.n_embd,), weight=z[bbb+'ln2.weight'], bias=z[bbb+'ln2.bias'])
|
323 |
+
|
324 |
+
xx, state[i*3+2] = RWKV_x070_CMix_one(xx, state[i*3+2], z[ffn+'x_k'], z[ffn+'key.weight'], z[ffn+'value.weight'])
|
325 |
+
x = x + xx
|
326 |
+
|
327 |
+
# if math.isnan(torch.min(x).item()): print(idx, i)
|
328 |
+
|
329 |
+
x = F.layer_norm(x, (self.n_embd,), weight=z['ln_out.weight'], bias=z['ln_out.bias'])
|
330 |
+
x = x @ z['head.weight']
|
331 |
+
return x, state
|
332 |
+
|
333 |
+
@MyFunction
|
334 |
+
def forward_seq(self, x, state:List[torch.Tensor], full_output:bool=False):
|
335 |
+
with torch.no_grad():
|
336 |
+
z = self.z
|
337 |
+
#x = z['emb.weight'][idx]
|
338 |
+
|
339 |
+
v_first = torch.empty_like(x)
|
340 |
+
for i in range(self.n_layer):
|
341 |
+
bbb = f'blocks.{i}.'
|
342 |
+
att = f'blocks.{i}.att.'
|
343 |
+
ffn = f'blocks.{i}.ffn.'
|
344 |
+
|
345 |
+
xx = F.layer_norm(x, (self.n_embd,), weight=z[bbb+'ln1.weight'], bias=z[bbb+'ln1.bias'])
|
346 |
+
|
347 |
+
xx, state[i*3+0], state[i*3+1], v_first = RWKV_x070_TMix_seq(i, self.n_head, self.head_size, xx, state[i*3+0], v_first, state[i*3+1],
|
348 |
+
z[att+'x_r'], z[att+'x_w'], z[att+'x_k'], z[att+'x_v'], z[att+'x_a'], z[att+'x_g'],
|
349 |
+
z[att+'w0'], z[att+'w1'], z[att+'w2'], z[att+'a0'], z[att+'a1'], z[att+'a2'], z[att+'v0'], z[att+'v1'], z[att+'v2'],
|
350 |
+
z[att+'g1'], z[att+'g2'], z[att+'k_k'], z[att+'k_a'], z[att+'r_k'],
|
351 |
+
z[att+'receptance.weight'], z[att+'key.weight'], z[att+'value.weight'], z[att+'output.weight'],
|
352 |
+
z[att+'ln_x.weight'], z[att+'ln_x.bias'])
|
353 |
+
x = x + xx
|
354 |
+
|
355 |
+
xx = F.layer_norm(x, (self.n_embd,), weight=z[bbb+'ln2.weight'], bias=z[bbb+'ln2.bias'])
|
356 |
+
|
357 |
+
xx, state[i*3+2] = RWKV_x070_CMix_seq(xx, state[i*3+2], z[ffn+'x_k'], z[ffn+'key.weight'], z[ffn+'value.weight'])
|
358 |
+
x = x + xx
|
359 |
+
|
360 |
+
if not full_output: x = x[-1,:]
|
361 |
+
x = F.layer_norm(x, (self.n_embd,), weight=z['ln_out.weight'], bias=z['ln_out.bias'])
|
362 |
+
x = x @ z['head.weight']
|
363 |
+
return x, state
|
364 |
+
|
365 |
+
########################################################################################################
|
366 |
+
|
367 |
+
@MyStatic
|
368 |
+
def RWKV_x070_TMix_one(layer_id: int, H:int, N:int, x, x_prev, v_first, state, x_r, x_w, x_k, x_v, x_a, x_g, w0, w1, w2, a0, a1, a2, v0, v1, v2, g1, g2, k_k, k_a, r_k, R_, K_, V_, O_, ln_w, ln_b):
|
369 |
+
xx = x_prev - x
|
370 |
+
xr, xw, xk, xv, xa, xg = x+xx*x_r, x+xx*x_w, x+xx*x_k, x+xx*x_v, x+xx*x_a, x+xx*x_g
|
371 |
+
|
372 |
+
r = xr @ R_
|
373 |
+
w = torch.tanh(xw @ w1) @ w2
|
374 |
+
k = xk @ K_
|
375 |
+
v = xv @ V_
|
376 |
+
a = torch.sigmoid(a0 + (xa @ a1) @ a2)
|
377 |
+
g = torch.sigmoid(xg @ g1) @ g2
|
378 |
+
|
379 |
+
kk = torch.nn.functional.normalize((k * k_k).view(H,N), dim=-1, p=2.0).view(H*N)
|
380 |
+
k = k * (1 + (a-1) * k_a)
|
381 |
+
if layer_id == 0: v_first = v
|
382 |
+
else: v = v + (v_first - v) * torch.sigmoid(v0 + (xv @ v1) @ v2)
|
383 |
+
w = torch.exp(-0.606531 * torch.sigmoid((w0 + w).float())) # 0.606531 = exp(-0.5)
|
384 |
+
|
385 |
+
vk = v.view(H,N,1) @ k.view(H,1,N)
|
386 |
+
ab = (-kk).view(H,N,1) @ (kk*a).view(H,1,N)
|
387 |
+
state = state * w.view(H,1,N) + state @ ab.float() + vk.float()
|
388 |
+
xx = (state.to(dtype=x.dtype) @ r.view(H,N,1))
|
389 |
+
|
390 |
+
xx = torch.nn.functional.group_norm(xx.view(1,H*N), num_groups=H, weight=ln_w, bias=ln_b, eps = 64e-5).view(H*N)
|
391 |
+
xx = xx + ((r * k * r_k).view(H,N).sum(dim=-1, keepdim=True) * v.view(H,N)).view(H*N)
|
392 |
+
return (xx * g) @ O_, x, state, v_first
|
393 |
+
|
394 |
+
if os.environ.get('RWKV_CUDA_ON') == '1':
|
395 |
+
@MyStatic
|
396 |
+
def RWKV_x070_TMix_seq(layer_id: int, H:int, N:int, x, x_prev, v_first, state, x_r, x_w, x_k, x_v, x_a, x_g, w0, w1, w2, a0, a1, a2, v0, v1, v2, g1, g2, k_k, k_a, r_k, R_, K_, V_, O_, ln_w, ln_b):
|
397 |
+
T = x.shape[0]
|
398 |
+
xx = torch.cat((x_prev.unsqueeze(0), x[:-1,:])) - x
|
399 |
+
xr, xw, xk, xv, xa, xg = x+xx*x_r, x+xx*x_w, x+xx*x_k, x+xx*x_v, x+xx*x_a, x+xx*x_g
|
400 |
+
|
401 |
+
r = xr @ R_
|
402 |
+
w = torch.tanh(xw @ w1) @ w2
|
403 |
+
k = xk @ K_
|
404 |
+
v = xv @ V_
|
405 |
+
a = torch.sigmoid(a0 + (xa @ a1) @ a2)
|
406 |
+
g = torch.sigmoid(xg @ g1) @ g2
|
407 |
+
|
408 |
+
kk = torch.nn.functional.normalize((k * k_k).view(T,H,N), dim=-1, p=2.0).view(T,H*N)
|
409 |
+
k = k * (1 + (a-1) * k_a)
|
410 |
+
if layer_id == 0: v_first = v
|
411 |
+
else: v = v + (v_first - v) * torch.sigmoid(v0 + (xv @ v1) @ v2)
|
412 |
+
|
413 |
+
w = -torch.nn.functional.softplus(-(w0 + w)) - 0.5
|
414 |
+
xx = RWKV7_OP(state, r, w, k, v, -kk, kk*a)
|
415 |
+
|
416 |
+
xx = torch.nn.functional.group_norm(xx.view(T,H*N), num_groups=H, weight=ln_w, bias=ln_b, eps = 64e-5).view(T,H*N)
|
417 |
+
xx = xx + ((r * k * r_k).view(T,H,N).sum(dim=-1, keepdim=True) * v.view(T,H,N)).view(T,H*N)
|
418 |
+
return (xx * g) @ O_, x[-1,:], state, v_first
|
419 |
+
else:
|
420 |
+
@MyStatic
|
421 |
+
def RWKV_x070_TMix_seq(layer_id: int, H:int, N:int, x, x_prev, v_first, state, x_r, x_w, x_k, x_v, x_a, x_g, w0, w1, w2, a0, a1, a2, v0, v1, v2, g1, g2, k_k, k_a, r_k, R_, K_, V_, O_, ln_w, ln_b):
|
422 |
+
T = x.shape[0]
|
423 |
+
xx = torch.cat((x_prev.unsqueeze(0), x[:-1,:])) - x
|
424 |
+
xr, xw, xk, xv, xa, xg = x+xx*x_r, x+xx*x_w, x+xx*x_k, x+xx*x_v, x+xx*x_a, x+xx*x_g
|
425 |
+
|
426 |
+
r = xr @ R_
|
427 |
+
w = torch.tanh(xw @ w1) @ w2
|
428 |
+
k = xk @ K_
|
429 |
+
v = xv @ V_
|
430 |
+
a = torch.sigmoid(a0 + (xa @ a1) @ a2)
|
431 |
+
g = torch.sigmoid(xg @ g1) @ g2
|
432 |
+
|
433 |
+
kk = torch.nn.functional.normalize((k * k_k).view(T,H,N), dim=-1, p=2.0).view(T,H*N)
|
434 |
+
k = k * (1 + (a-1) * k_a)
|
435 |
+
if layer_id == 0: v_first = v
|
436 |
+
else: v = v + (v_first - v) * torch.sigmoid(v0 + (xv @ v1) @ v2)
|
437 |
+
|
438 |
+
w = torch.exp(-0.606531 * torch.sigmoid((w0 + w).float())) # 0.606531 = exp(-0.5)
|
439 |
+
for t in range(T):
|
440 |
+
r_, w_, k_, v_, kk_, a_ = r[t], w[t], k[t], v[t], kk[t], a[t]
|
441 |
+
vk = v_.view(H,N,1) @ k_.view(H,1,N)
|
442 |
+
ab = (-kk_).view(H,N,1) @ (kk_*a_).view(H,1,N)
|
443 |
+
state = state * w_.view(H,1,N) + state @ ab.float() + vk.float()
|
444 |
+
xx[t] = (state.to(dtype=x.dtype) @ r_.view(H,N,1)).view(H*N)
|
445 |
+
|
446 |
+
xx = torch.nn.functional.group_norm(xx.view(T,H*N), num_groups=H, weight=ln_w, bias=ln_b, eps = 64e-5).view(T,H*N)
|
447 |
+
xx = xx + ((r * k * r_k).view(T,H,N).sum(dim=-1, keepdim=True) * v.view(T,H,N)).view(T,H*N)
|
448 |
+
return (xx * g) @ O_, x[-1,:], state, v_first
|
449 |
+
|
450 |
+
########################################################################################################
|
451 |
+
|
452 |
+
@MyStatic
|
453 |
+
def RWKV_x070_CMix_one(x, x_prev, x_k, K_, V_):
|
454 |
+
xx = x_prev - x
|
455 |
+
k = x + xx * x_k
|
456 |
+
k = torch.relu(k @ K_) ** 2
|
457 |
+
return k @ V_, x
|
458 |
+
|
459 |
+
@MyStatic
|
460 |
+
def RWKV_x070_CMix_seq(x, x_prev, x_k, K_, V_):
|
461 |
+
xx = torch.cat((x_prev.unsqueeze(0), x[:-1,:])) - x
|
462 |
+
k = x + xx * x_k
|
463 |
+
k = torch.relu(k @ K_) ** 2
|
464 |
+
return k @ V_, x[-1,:]
|
465 |
+
|
466 |
+
########################################################################################################
|
467 |
+
|
468 |
+
|
469 |
+
|
infer/rwkv/rwkv_tokenizer.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
########################################################################################################
|
2 |
+
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
3 |
+
########################################################################################################
|
4 |
+
|
5 |
+
class TRIE:
|
6 |
+
__slots__ = tuple("ch,to,values,front".split(","))
|
7 |
+
to:list
|
8 |
+
values:set
|
9 |
+
def __init__(self, front=None, ch=None):
|
10 |
+
self.ch = ch
|
11 |
+
self.to = [None for ch in range(256)]
|
12 |
+
self.values = set()
|
13 |
+
self.front = front
|
14 |
+
|
15 |
+
def __repr__(self):
|
16 |
+
fr = self
|
17 |
+
ret = []
|
18 |
+
while(fr!=None):
|
19 |
+
if(fr.ch!=None):
|
20 |
+
ret.append(fr.ch)
|
21 |
+
fr = fr.front
|
22 |
+
return "<TRIE %s %s>"%(ret[::-1], self.values)
|
23 |
+
|
24 |
+
def add(self, key:bytes, idx:int=0, val=None):
|
25 |
+
if(idx == len(key)):
|
26 |
+
if(val is None):
|
27 |
+
val = key
|
28 |
+
self.values.add(val)
|
29 |
+
return self
|
30 |
+
ch = key[idx]
|
31 |
+
if(self.to[ch] is None):
|
32 |
+
self.to[ch] = TRIE(front=self, ch=ch)
|
33 |
+
return self.to[ch].add(key, idx=idx+1, val=val)
|
34 |
+
|
35 |
+
def find_longest(self, key:bytes, idx:int=0):
|
36 |
+
u:TRIE = self
|
37 |
+
ch:int = key[idx]
|
38 |
+
|
39 |
+
while(u.to[ch] is not None):
|
40 |
+
u = u.to[ch]
|
41 |
+
idx += 1
|
42 |
+
if(u.values):
|
43 |
+
ret = idx, u, u.values
|
44 |
+
if(idx==len(key)):
|
45 |
+
break
|
46 |
+
ch = key[idx]
|
47 |
+
return ret
|
48 |
+
|
49 |
+
class TRIE_TOKENIZER():
|
50 |
+
def __init__(self, file_name):
|
51 |
+
self.idx2token = {}
|
52 |
+
sorted = [] # must be already sorted
|
53 |
+
with open(file_name, "r", encoding="utf-8") as f:
|
54 |
+
lines = f.readlines()
|
55 |
+
for l in lines:
|
56 |
+
idx = int(l[:l.index(' ')])
|
57 |
+
x = eval(l[l.index(' '):l.rindex(' ')])
|
58 |
+
x = x.encode("utf-8") if isinstance(x, str) else x
|
59 |
+
assert isinstance(x, bytes)
|
60 |
+
assert len(x) == int(l[l.rindex(' '):])
|
61 |
+
sorted += [x]
|
62 |
+
self.idx2token[idx] = x
|
63 |
+
|
64 |
+
self.token2idx = {}
|
65 |
+
for k,v in self.idx2token.items():
|
66 |
+
self.token2idx[v] = int(k)
|
67 |
+
|
68 |
+
self.root = TRIE()
|
69 |
+
for t, i in self.token2idx.items():
|
70 |
+
_ = self.root.add(t, val=(t, i))
|
71 |
+
|
72 |
+
def encodeBytes(self, src:bytes):
|
73 |
+
idx:int = 0
|
74 |
+
tokens = []
|
75 |
+
while (idx < len(src)):
|
76 |
+
_idx:int = idx
|
77 |
+
idx, _, values = self.root.find_longest(src, idx)
|
78 |
+
assert(idx != _idx)
|
79 |
+
_, token = next(iter(values))
|
80 |
+
tokens.append(token)
|
81 |
+
return tokens
|
82 |
+
|
83 |
+
def decodeBytes(self, tokens):
|
84 |
+
return b''.join(map(lambda i: self.idx2token[i], tokens))
|
85 |
+
|
86 |
+
def encode(self, src):
|
87 |
+
return self.encodeBytes(src.encode("utf-8"))
|
88 |
+
|
89 |
+
def decode(self, tokens):
|
90 |
+
try:
|
91 |
+
return self.decodeBytes(tokens).decode('utf-8')
|
92 |
+
except:
|
93 |
+
return '\ufffd' # bad utf-8
|
94 |
+
|
95 |
+
def printTokens(self, tokens):
|
96 |
+
for i in tokens:
|
97 |
+
s = self.idx2token[i]
|
98 |
+
try:
|
99 |
+
s = s.decode('utf-8')
|
100 |
+
except:
|
101 |
+
pass
|
102 |
+
print(f'{repr(s)}{i}', end=' ')
|
103 |
+
print()
|
infer/rwkv/rwkv_vocab_v20230424.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
infer/rwkv/utils.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
########################################################################################################
|
2 |
+
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
3 |
+
########################################################################################################
|
4 |
+
|
5 |
+
import os, sys
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from torch.nn import functional as F
|
9 |
+
|
10 |
+
class PIPELINE_ARGS():
|
11 |
+
def __init__(self, temperature=1.0, top_p=0.85, top_k=0, alpha_frequency=0.2, alpha_presence=0.2, alpha_decay=0.996, token_ban=[], token_stop=[], chunk_len=256):
|
12 |
+
self.temperature = temperature
|
13 |
+
self.top_p = top_p
|
14 |
+
self.top_k = top_k
|
15 |
+
self.alpha_frequency = alpha_frequency # Frequency Penalty (as in GPT-3)
|
16 |
+
self.alpha_presence = alpha_presence # Presence Penalty (as in GPT-3)
|
17 |
+
self.alpha_decay = alpha_decay # gradually decay the penalty
|
18 |
+
self.token_ban = token_ban # ban the generation of some tokens
|
19 |
+
self.token_stop = token_stop # stop generation whenever you see any token here
|
20 |
+
self.chunk_len = chunk_len # split input into chunks to save VRAM (shorter -> slower)
|
21 |
+
|
22 |
+
class PIPELINE():
|
23 |
+
def __init__(self, model, WORD_NAME):
|
24 |
+
self.model = model
|
25 |
+
if WORD_NAME == 'cl100k_base':
|
26 |
+
import tiktoken
|
27 |
+
self.tokenizer = tiktoken.get_encoding(WORD_NAME)
|
28 |
+
elif WORD_NAME == 'rwkv_vocab_v20230424':
|
29 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
30 |
+
from rwkv_tokenizer import TRIE_TOKENIZER
|
31 |
+
self.tokenizer = TRIE_TOKENIZER(os.path.dirname(os.path.abspath(__file__)) + '/rwkv_vocab_v20230424.txt')
|
32 |
+
else:
|
33 |
+
from tokenizers import Tokenizer
|
34 |
+
self.tokenizer = Tokenizer.from_file(WORD_NAME)
|
35 |
+
|
36 |
+
def refine_context(self, context):
|
37 |
+
context = context.strip().split('\n')
|
38 |
+
for c in range(len(context)):
|
39 |
+
context[c] = context[c].strip().strip('\u3000').strip('\r')
|
40 |
+
context = list(filter(lambda c: c != '', context))
|
41 |
+
context = '\n' + ('\n'.join(context)).strip()
|
42 |
+
if context == '':
|
43 |
+
context = '\n'
|
44 |
+
return context
|
45 |
+
|
46 |
+
def encode(self, x):
|
47 |
+
if 'Tokenizer' in str(type(self.tokenizer)):
|
48 |
+
return self.tokenizer.encode(x).ids
|
49 |
+
else:
|
50 |
+
return self.tokenizer.encode(x)
|
51 |
+
|
52 |
+
def decode(self, x):
|
53 |
+
return self.tokenizer.decode(x)
|
54 |
+
|
55 |
+
def sample_logits(self, logits, temperature=1.0, top_p=0.85, top_k=0):
|
56 |
+
if temperature == 0:
|
57 |
+
temperature = 1.0
|
58 |
+
top_p = 0
|
59 |
+
probs = F.softmax(logits.float(), dim=-1)
|
60 |
+
top_k = int(top_k)
|
61 |
+
# 'privateuseone' is the type of custom devices like `torch_directml.device()`
|
62 |
+
if probs.device.type in ['cpu', 'privateuseone']:
|
63 |
+
probs = probs.cpu().numpy()
|
64 |
+
sorted_ids = np.argsort(probs)
|
65 |
+
sorted_probs = probs[sorted_ids][::-1]
|
66 |
+
cumulative_probs = np.cumsum(sorted_probs)
|
67 |
+
cutoff = float(sorted_probs[np.argmax(cumulative_probs >= top_p)])
|
68 |
+
probs[probs < cutoff] = 0
|
69 |
+
if top_k < len(probs) and top_k > 0:
|
70 |
+
probs[sorted_ids[:-top_k]] = 0
|
71 |
+
if temperature != 1.0:
|
72 |
+
probs = probs ** (1.0 / temperature)
|
73 |
+
probs = probs / np.sum(probs)
|
74 |
+
out = np.random.choice(a=len(probs), p=probs)
|
75 |
+
return int(out)
|
76 |
+
else:
|
77 |
+
sorted_ids = torch.argsort(probs)
|
78 |
+
sorted_probs = probs[sorted_ids]
|
79 |
+
sorted_probs = torch.flip(sorted_probs, dims=(0,))
|
80 |
+
cumulative_probs = torch.cumsum(sorted_probs, dim=-1).cpu().numpy()
|
81 |
+
cutoff = float(sorted_probs[np.argmax(cumulative_probs >= top_p)])
|
82 |
+
probs[probs < cutoff] = 0
|
83 |
+
if top_k < len(probs) and top_k > 0:
|
84 |
+
probs[sorted_ids[:-top_k]] = 0
|
85 |
+
if temperature != 1.0:
|
86 |
+
probs = probs ** (1.0 / temperature)
|
87 |
+
out = torch.multinomial(probs, num_samples=1)[0]
|
88 |
+
return int(out)
|
89 |
+
|
90 |
+
def generate(self, ctx, token_count=100, args=PIPELINE_ARGS(), callback=None, state=None, sign=None):
|
91 |
+
all_tokens = []
|
92 |
+
out_last = 0
|
93 |
+
out_str = ''
|
94 |
+
occurrence = {}
|
95 |
+
for i in range(token_count):
|
96 |
+
|
97 |
+
# forward & adjust prob.
|
98 |
+
tokens = self.encode(ctx) if i == 0 else [token]
|
99 |
+
while len(tokens) > 0:
|
100 |
+
out, state = self.model.forward(tokens[:args.chunk_len], state, sign=sign)
|
101 |
+
tokens = tokens[args.chunk_len:]
|
102 |
+
|
103 |
+
for n in args.token_ban:
|
104 |
+
out[n] = -float('inf')
|
105 |
+
for n in occurrence:
|
106 |
+
out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
|
107 |
+
|
108 |
+
# sampler
|
109 |
+
token = self.sample_logits(out, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k)
|
110 |
+
|
111 |
+
if token in args.token_stop:
|
112 |
+
break
|
113 |
+
all_tokens += [token]
|
114 |
+
for xxx in occurrence:
|
115 |
+
occurrence[xxx] *= args.alpha_decay
|
116 |
+
|
117 |
+
ttt = self.decode([token])
|
118 |
+
|
119 |
+
www = 1
|
120 |
+
if ttt in ' \t0123456789':
|
121 |
+
www = 0
|
122 |
+
# elif ttt in '\r\n,.;?!"\':+-*/=#@$%^&_`~|<>\\()[]{},。;“”:?!()【】':
|
123 |
+
# www = 0.5
|
124 |
+
if token not in occurrence:
|
125 |
+
occurrence[token] = www
|
126 |
+
else:
|
127 |
+
occurrence[token] += www
|
128 |
+
# print(occurrence) # debug
|
129 |
+
|
130 |
+
# output
|
131 |
+
tmp = self.decode(all_tokens[out_last:])
|
132 |
+
if '\ufffd' not in tmp: # is valid utf-8 string?
|
133 |
+
if callback:
|
134 |
+
callback(tmp)
|
135 |
+
out_str += tmp
|
136 |
+
out_last = i + 1
|
137 |
+
sign =None
|
138 |
+
return out_str, state
|
139 |
+
|
140 |
+
def prefill(self, ctx, token_count=100, args=PIPELINE_ARGS(), callback=None, state=None, sign=None):
|
141 |
+
tokens = self.encode(ctx)
|
142 |
+
out, state = self.model.forward(tokens[:args.chunk_len], state, sign=sign,full_output=True)
|
143 |
+
max_indices = torch.argmax(out, dim=-1)
|
144 |
+
token = self.sample_logits(out[19,:], temperature=args.temperature, top_p=args.top_p, top_k=args.top_k)
|
145 |
+
print(token, self.decode([token]),out[19,:])
|
146 |
+
print(state[0].view(-1))
|
147 |
+
print(self.decode(max_indices.tolist()))
|
148 |
+
return self.decode(max_indices.tolist())
|
infer/worldmodel.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
import os, sys, torch, time
|
5 |
+
import numpy as np
|
6 |
+
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
7 |
+
import torch
|
8 |
+
print(torch.__version__)
|
9 |
+
print(torch.version.cuda)
|
10 |
+
|
11 |
+
# set these before import RWKV
|
12 |
+
# os.environ['RWKV_JIT_ON'] = '1'
|
13 |
+
# os.environ["RWKV_CUDA_ON"] = '1' # '1' to compile CUDA kernel (10x faster), requires c++ compiler & cuda libraries
|
14 |
+
from infer.rwkv.model import RWKV # pip install rwkv
|
15 |
+
from infer.rwkv.utils import PIPELINE, PIPELINE_ARGS
|
16 |
+
|
17 |
+
|
18 |
+
from world.world_encoder import WorldEncoder
|
19 |
+
|
20 |
+
class Worldinfer():
|
21 |
+
def __init__(self, model_path, encoder_type, encoder_path, strategy='cpu bf16', args=None):
|
22 |
+
|
23 |
+
ss = strategy.split(' ')
|
24 |
+
DEVICE = ss[0]
|
25 |
+
if ss[1] == 'fp16':
|
26 |
+
self.DTYPE = torch.half
|
27 |
+
elif ss[1] == 'fp32':
|
28 |
+
self.DTYPE = torch.float32
|
29 |
+
elif ss[1] == 'bf16':
|
30 |
+
self.DTYPE = torch.bfloat16
|
31 |
+
else:
|
32 |
+
assert False, "currently rwkv7 strategy must be: cuda/cpu fp16/fp32/bf16"
|
33 |
+
|
34 |
+
self.model_weight = torch.load(model_path + '.pth', map_location=DEVICE)
|
35 |
+
modality_dict = {}
|
36 |
+
for key, value in self.model_weight.items():
|
37 |
+
if 'emb.weight' in key:
|
38 |
+
_, n_embd = value.shape
|
39 |
+
if 'modality' in key:
|
40 |
+
k = key.replace('modality.world_encoder.', '')
|
41 |
+
modality_dict[k] = value
|
42 |
+
model = RWKV(model=self.model_weight, strategy=strategy)
|
43 |
+
self.pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
|
44 |
+
|
45 |
+
if args==None:
|
46 |
+
self.args = PIPELINE_ARGS(temperature = 1.0, top_p = 0.0, top_k=0, # top_k = 0 then ignore
|
47 |
+
alpha_frequency = 0.0,
|
48 |
+
alpha_presence = 0.0,
|
49 |
+
token_ban = [0], # ban the generation of some tokens
|
50 |
+
token_stop = [24], # stop generation whenever you see any token here
|
51 |
+
chunk_len = 256) # split input into chunks to save VRAM (shorter -> slower)
|
52 |
+
else:
|
53 |
+
self.args=args
|
54 |
+
print('RWKV finish!!!')
|
55 |
+
|
56 |
+
config = {
|
57 |
+
'encoder_type': encoder_type,
|
58 |
+
'encoder_path': encoder_path,
|
59 |
+
'project_dim' : n_embd
|
60 |
+
}
|
61 |
+
self.modality = WorldEncoder(**config).to(DEVICE, torch.bfloat16)
|
62 |
+
self.modality.load_checkpoint(modality_dict)
|
63 |
+
|
64 |
+
|
65 |
+
def generate(self, text, modality='none', state=None):
|
66 |
+
if isinstance(modality, str):
|
67 |
+
y=None
|
68 |
+
else:
|
69 |
+
y = self.modality(modality).to(self.DTYPE)
|
70 |
+
result, state = self.pipeline.generate(text, token_count=500, args=self.args, callback=None, state=state, sign=y)
|
71 |
+
return result, state
|
72 |
+
|
73 |
+
# def prefill(self, text, modality='none', state=None):
|
74 |
+
# if isinstance(modality, str):
|
75 |
+
# y=None
|
76 |
+
# else:
|
77 |
+
# y = self.modality(modality).to(self.DTYPE)
|
78 |
+
# result, state = self.pipeline.forward(text, token_count=500, args=self.args, callback=None, state=state, sign=y)
|
79 |
+
# return result, state
|
world/__init__.py
ADDED
File without changes
|
world/block.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from src.infctx_module import *
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
from src.rwkv7.Channel_mix import RWKV_CMix_x070
|
8 |
+
from src.rwkv7.Time_mix import RWKV_Tmix_x070
|
9 |
+
|
10 |
+
class Block(nn.Module):
|
11 |
+
def __init__(self, args, layer_id):
|
12 |
+
super().__init__()
|
13 |
+
self.args = args
|
14 |
+
self.layer_id = layer_id
|
15 |
+
|
16 |
+
self.ln1 = nn.LayerNorm(args.n_embd)
|
17 |
+
self.ln2 = nn.LayerNorm(args.n_embd)
|
18 |
+
|
19 |
+
if self.layer_id == 0:
|
20 |
+
self.ln0 = nn.LayerNorm(args.n_embd)
|
21 |
+
|
22 |
+
self.att = RWKV_Tmix_x070(args, layer_id)
|
23 |
+
self.ffn = RWKV_CMix_x070(args, layer_id)
|
24 |
+
|
25 |
+
|
26 |
+
def forward(self, x, v_first):
|
27 |
+
if self.layer_id == 0:
|
28 |
+
x = self.ln0(x)
|
29 |
+
|
30 |
+
x_attn, v_first = self.att(self.ln1(x), v_first)
|
31 |
+
x = x + x_attn
|
32 |
+
|
33 |
+
x = x + self.ffn(self.ln2(x))
|
34 |
+
return x, v_first
|
world/cat.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
from torch.nn import functional as F
|
4 |
+
def pad_mod(self, tensor_list, signal_list):
|
5 |
+
"""
|
6 |
+
对一个包含不同长度张量的列表进行填充,使所有张量的长度相同且为16的整数倍,并生成掩码。
|
7 |
+
参数:
|
8 |
+
tensor_list (list of torch.Tensor): 输入的张量列表,每个张量形状为 [seq_len]。
|
9 |
+
pad_value (int, optional): 填充值,默认值为 0。
|
10 |
+
返回:
|
11 |
+
padded_tensor (torch.Tensor): 填充后的张量,形状为 [batch_size, target_len]。
|
12 |
+
mask (torch.Tensor): 填充掩码,1 表示有效数据,0 表示填充部分。
|
13 |
+
"""
|
14 |
+
|
15 |
+
modality_list = []
|
16 |
+
#max_len = max((token.size(0) + signal.size(1)) for token, signal in zip(tensor_list, modality_list))
|
17 |
+
max_len = 0
|
18 |
+
for token, signal in zip(tensor_list, signal_list):
|
19 |
+
|
20 |
+
modality = self.modality(signal)
|
21 |
+
if modality is False:
|
22 |
+
modality_list.append(False)
|
23 |
+
continue
|
24 |
+
modality_list.append(modality)
|
25 |
+
max_len = max(token.size(0) + modality.size(1), max_len)
|
26 |
+
|
27 |
+
# 计算目标长度(向上取整到16的整数倍)
|
28 |
+
target_len = ((max_len + 15) // 16 * 16)+1
|
29 |
+
|
30 |
+
if self.args.ctx_len is not None:
|
31 |
+
target_len = min(target_len, self.args.ctx_len+1)
|
32 |
+
|
33 |
+
masks = torch.zeros((len(tensor_list), target_len-1), dtype=torch.int32)
|
34 |
+
x = []
|
35 |
+
y = []
|
36 |
+
s = []
|
37 |
+
m = []
|
38 |
+
for token, signal, mask in zip(tensor_list, modality_list, masks):
|
39 |
+
if signal is False:
|
40 |
+
continue
|
41 |
+
pad_len = target_len-(token.size(0) + signal.size(1))
|
42 |
+
|
43 |
+
padded_token = F.pad(token, (0, pad_len), value=0)
|
44 |
+
|
45 |
+
x_token = padded_token[:-1]
|
46 |
+
y_token = F.pad(padded_token, (signal.size(1)-1, 0), value=0)
|
47 |
+
|
48 |
+
mask[signal.size(1) : -pad_len] = 1
|
49 |
+
|
50 |
+
s.append(signal)
|
51 |
+
x.append(x_token)
|
52 |
+
y.append(y_token)
|
53 |
+
m.append(mask)
|
54 |
+
|
55 |
+
y = torch.stack(y, dim=0)
|
56 |
+
m = torch.stack(m, dim=0).cuda()
|
57 |
+
|
58 |
+
return s, x, y, m
|
59 |
+
|
60 |
+
|
61 |
+
|
62 |
+
def mod_pad_text(self, signal_list, text_inputs, text_labels):
|
63 |
+
"""
|
64 |
+
对一个包含不同长度张量的列表进行填充,使所有张量的长度相同且为16的整数倍,并生成掩码。
|
65 |
+
参数:
|
66 |
+
tensor_list (list of torch.Tensor): 输入的张量列表,每个张量形状为 [seq_len]。
|
67 |
+
pad_value (int, optional): 填充值,默认值为 0。
|
68 |
+
返回:
|
69 |
+
padded_tensor (torch.Tensor): 填充后的张量,形状为 [batch_size, target_len]。
|
70 |
+
mask (torch.Tensor): 填充掩码,1 表示有效数据,0 表示填充部分。
|
71 |
+
"""
|
72 |
+
|
73 |
+
modality_list = []
|
74 |
+
#max_len = max((token.size(0) + signal.size(1)) for token, signal in zip(tensor_list, modality_list))
|
75 |
+
max_len = 0
|
76 |
+
for i, (signal, token, label) in enumerate(zip(signal_list, text_inputs, text_labels)):
|
77 |
+
|
78 |
+
modality = self.modality(signal)
|
79 |
+
modality_list.append(modality)
|
80 |
+
mod_label = torch.full((modality.size(1),), -100, device='cuda')
|
81 |
+
text_labels[i] = torch.cat([mod_label, label])
|
82 |
+
max_len = max(token.size(0) + modality.size(1), max_len)
|
83 |
+
|
84 |
+
# 计算目标长度(向上取整到16的整数倍)
|
85 |
+
target_len = ((max_len + 15) // 16 * 16)+1
|
86 |
+
|
87 |
+
if self.args.ctx_len is not None:
|
88 |
+
target_len = min(target_len, self.args.ctx_len+1)
|
89 |
+
|
90 |
+
|
91 |
+
for i, (signal, token, label) in enumerate(zip(modality_list , text_inputs, text_labels)):
|
92 |
+
pad_len = target_len-(token.size(0) + signal.size(1))
|
93 |
+
|
94 |
+
text_inputs[i] = F.pad(token, (0, pad_len), value=0)[:-1]
|
95 |
+
text_labels[i] = F.pad(label, (0, pad_len), value=-100)[1:]
|
96 |
+
|
97 |
+
|
98 |
+
targets = torch.stack(text_labels, dim=0).cuda()
|
99 |
+
|
100 |
+
return modality_list, text_inputs, targets
|
101 |
+
|
102 |
+
|
103 |
+
|
104 |
+
def cat_tts(self, tensor_list, signal_list):
|
105 |
+
"""
|
106 |
+
对一个包含不同长度张量的列表进行填充,使所有张量的长度相同且为16的整数倍,并生成掩码。
|
107 |
+
参数:
|
108 |
+
tensor_list (list of torch.Tensor): 输入的张量列表,每个张量形状为 [seq_len]。
|
109 |
+
pad_value (int, optional): 填充值,默认值为 0。
|
110 |
+
返回:
|
111 |
+
padded_tensor (torch.Tensor): 填充后的张量,形状为 [batch_size, target_len]。
|
112 |
+
mask (torch.Tensor): 填充掩码,1 表示有效数据,0 表示填充部分。
|
113 |
+
"""
|
114 |
+
|
115 |
+
modality_list = []
|
116 |
+
atokens = []
|
117 |
+
labels_list = [] #多模态拼接标签
|
118 |
+
#max_len = max((token.size(0) + signal.size(1)) for token, signal in zip(tensor_list, modality_list))
|
119 |
+
max_len = 0
|
120 |
+
for token, signal in zip(tensor_list, signal_list):
|
121 |
+
global_tokens, semantic_tokens = self.modality.world_encoder.encoder(signal)
|
122 |
+
# print(global_tokens.squeeze(0).squeeze(0), global_tokens.squeeze(0).squeeze(0)+8192)
|
123 |
+
global_tokens = global_tokens.squeeze(0).squeeze(0)+8194
|
124 |
+
# global_tokens = F.pad(global_tokens.squeeze(0).squeeze(0)+8194, (0, 1), value=8193)
|
125 |
+
semantic_tokens = F.pad(semantic_tokens.squeeze(0), (0, 1), value=8192)
|
126 |
+
audio_token = torch.cat([global_tokens,semantic_tokens])
|
127 |
+
mask_gt = torch.full_like(global_tokens, -100)
|
128 |
+
label = torch.cat([global_tokens-1,semantic_tokens-1])
|
129 |
+
# modality = self.modality.encoder(audio_token)
|
130 |
+
|
131 |
+
# if modality is False:
|
132 |
+
# modality_list.append(False)
|
133 |
+
# continue
|
134 |
+
|
135 |
+
mask_t = torch.full_like(token, -100)
|
136 |
+
label = torch.cat([mask_t,label])
|
137 |
+
atokens.append(audio_token)
|
138 |
+
# modality_list.append(modality)
|
139 |
+
labels_list.append(label)
|
140 |
+
max_len = max(label.size(0), max_len)
|
141 |
+
|
142 |
+
# 计算目标长度(向上取整到16的整数倍)
|
143 |
+
target_len = ((max_len + 15) // 16 * 16)+1
|
144 |
+
|
145 |
+
if self.args.ctx_len is not None:
|
146 |
+
target_len = min(target_len, self.args.ctx_len+1)
|
147 |
+
|
148 |
+
text_token = []
|
149 |
+
labels = []
|
150 |
+
mod_token = []
|
151 |
+
|
152 |
+
for token, atoken, mask in zip(tensor_list, atokens, labels_list):
|
153 |
+
|
154 |
+
pad_len = target_len-(token.size(0) + atoken.size(0))
|
155 |
+
|
156 |
+
padded_atoken = F.pad(atoken, (0, pad_len), value=8192)
|
157 |
+
|
158 |
+
atoken = padded_atoken[:-1]
|
159 |
+
mod = self.modality(atoken)
|
160 |
+
# padded_token = F.pad(signal, (0, 0, 0, pad_len), value=0)
|
161 |
+
|
162 |
+
# pad_mod = padded_token[:,:-1,:]
|
163 |
+
|
164 |
+
pad_mask = F.pad(mask, (0, pad_len), value=-100)[1:]
|
165 |
+
|
166 |
+
mod_token.append(mod)
|
167 |
+
|
168 |
+
labels.append(pad_mask)
|
169 |
+
|
170 |
+
labels = torch.stack(labels, dim=0)
|
171 |
+
|
172 |
+
|
173 |
+
return mod_token, tensor_list, labels
|
world/dataset.py
ADDED
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
########################################################################################################
|
2 |
+
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
3 |
+
########################################################################################################
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
import json
|
7 |
+
import math
|
8 |
+
import random
|
9 |
+
import os
|
10 |
+
import sys
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import lightning as L
|
14 |
+
from torch.utils.data import Dataset
|
15 |
+
from torch.utils.data import DataLoader
|
16 |
+
from lightning_utilities.core.rank_zero import rank_zero_info
|
17 |
+
from infer.rwkv.utils import PIPELINE
|
18 |
+
pipeline = PIPELINE('rwkv', "rwkv_vocab_v20230424")
|
19 |
+
from PIL import Image
|
20 |
+
|
21 |
+
import pandas as pd
|
22 |
+
import librosa
|
23 |
+
import io
|
24 |
+
import soundfile as sf
|
25 |
+
# 读取parquet文件
|
26 |
+
from torchvision import transforms
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
transform = transforms.Compose([
|
31 |
+
transforms.Resize((512, 512)),
|
32 |
+
transforms.ToTensor() # 将图像转换为张量
|
33 |
+
])
|
34 |
+
|
35 |
+
def process_conversation_text(conversations):
|
36 |
+
conversation_text = f"\x16"
|
37 |
+
|
38 |
+
for conv in conversations:
|
39 |
+
role = conv.get('from', '').lower()
|
40 |
+
content = conv.get('value', '')
|
41 |
+
|
42 |
+
if role == 'human':
|
43 |
+
conversation_text += f"User: {content}\x17"
|
44 |
+
elif role in ['assistant', 'gpt']:
|
45 |
+
conversation_text += f"Assistant: {content}\x17"
|
46 |
+
|
47 |
+
return conversation_text
|
48 |
+
|
49 |
+
def process_tokens(conversations):
|
50 |
+
# conversation_text = f"\x16"
|
51 |
+
inputs = []
|
52 |
+
labels = []
|
53 |
+
for conv in conversations:
|
54 |
+
role = conv.get('from', '').lower()
|
55 |
+
content = conv.get('value', '')
|
56 |
+
|
57 |
+
if role in ['human', 'user']:
|
58 |
+
question = f"\x16User: {content}\x17"
|
59 |
+
input = torch.tensor(pipeline.encode(question))
|
60 |
+
label = torch.full_like(input, -100)
|
61 |
+
elif role in ['assistant', 'gpt']:
|
62 |
+
answer = f"\x16Assistant: {content}\x17"
|
63 |
+
input= torch.tensor(pipeline.encode(answer))
|
64 |
+
label = input
|
65 |
+
inputs.append(input)
|
66 |
+
labels.append(label)
|
67 |
+
inputs =torch.cat(inputs)
|
68 |
+
labels =torch.cat(labels)
|
69 |
+
return inputs, labels
|
70 |
+
|
71 |
+
def bytes_to_audio(audio_bytes):
|
72 |
+
with io.BytesIO(audio_bytes) as buf:
|
73 |
+
# 使用 soundfile 读取音频数据
|
74 |
+
audio_array, sr = sf.read(buf)
|
75 |
+
|
76 |
+
# 确保是单声道
|
77 |
+
if len(audio_array.shape) > 1:
|
78 |
+
audio_array = audio_array.mean(axis=1)
|
79 |
+
|
80 |
+
# 确保是 float32 类型
|
81 |
+
audio_array = audio_array.astype(np.float32)
|
82 |
+
|
83 |
+
return {
|
84 |
+
'array': audio_array,
|
85 |
+
'sampling_rate': sr
|
86 |
+
}
|
87 |
+
|
88 |
+
|
89 |
+
|
90 |
+
def get_data_by_l_version(trainer: L.Trainer, args):
|
91 |
+
if L.__version__[0] == '2':
|
92 |
+
train_data = MyDataModule(args)
|
93 |
+
else:
|
94 |
+
raise ValueError(f"Unsupported PyTorch Lightning version: {L.__version__}")
|
95 |
+
return train_data
|
96 |
+
|
97 |
+
class GlobalIndexManager:
|
98 |
+
def __init__(self, rank=0, device_num=1, shuffle=True):
|
99 |
+
self.current_idx = 0
|
100 |
+
self.rank = rank
|
101 |
+
self.device_num = device_num
|
102 |
+
self.shuffle = shuffle
|
103 |
+
|
104 |
+
def get_next_idx(self, idx_t):
|
105 |
+
if self.shuffle:
|
106 |
+
idx = idx_t
|
107 |
+
else:
|
108 |
+
idx = self.current_idx * self.device_num + self.rank
|
109 |
+
self.current_idx += 1
|
110 |
+
return idx
|
111 |
+
|
112 |
+
class MyDataModule(L.LightningDataModule):
|
113 |
+
def __init__(self, args):
|
114 |
+
super().__init__()
|
115 |
+
self.args = args
|
116 |
+
self.train_data = None
|
117 |
+
|
118 |
+
|
119 |
+
def setup(self, stage=None):
|
120 |
+
self.train_data = MyDataset(self.args)
|
121 |
+
self.args.vocab_size = self.train_data.vocab_size
|
122 |
+
self.train_data.real_epoch = self.trainer.current_epoch
|
123 |
+
self.train_data.rank = self.trainer.global_rank
|
124 |
+
self.train_data.world_size = self.trainer.world_size
|
125 |
+
self.train_data.setup(self.trainer.global_rank, self.trainer.world_size,
|
126 |
+
int(self.args.devices), self.args.data_shuffle)
|
127 |
+
|
128 |
+
def train_dataloader(self):
|
129 |
+
# must set shuffle=False, persistent_workers=False (because worker is in another thread)
|
130 |
+
return DataLoader(
|
131 |
+
self.train_data,
|
132 |
+
shuffle=self.args.data_shuffle,
|
133 |
+
pin_memory=True,
|
134 |
+
batch_size=self.args.micro_bsz,
|
135 |
+
num_workers=1,
|
136 |
+
persistent_workers=False,
|
137 |
+
drop_last=True
|
138 |
+
)
|
139 |
+
|
140 |
+
class WorldDataset(Dataset):
|
141 |
+
def __init__(self, args, emb=None):
|
142 |
+
self.args = args
|
143 |
+
self.rank = 0
|
144 |
+
self.real_epoch = 0
|
145 |
+
self.world_size = 0
|
146 |
+
self.index_manager = None
|
147 |
+
self.emb = emb
|
148 |
+
|
149 |
+
if args.data_type =='wav':
|
150 |
+
import jsonlines
|
151 |
+
|
152 |
+
# 打开并读取 JSON 文件
|
153 |
+
#with open(f'{args.data_file}/answer.jsonl', 'r') as file:
|
154 |
+
with jsonlines.open(f'{args.data_file}/answer.jsonl') as file:
|
155 |
+
self.data = list(file)
|
156 |
+
elif args.data_type =='img':
|
157 |
+
import jsonlines
|
158 |
+
|
159 |
+
# 打开并读取 JSON 文件
|
160 |
+
#with open(f'{args.data_file}/answer.jsonl', 'r') as file:
|
161 |
+
with jsonlines.open(f'{args.data_file}/answer.jsonl') as file:
|
162 |
+
self.data = list(file)
|
163 |
+
elif args.data_type=='hf_img':
|
164 |
+
import jsonlines
|
165 |
+
# with open(f'{args.data_file}/chat.json', 'r', encoding='utf-8') as file:
|
166 |
+
# self.data = json.load(file)
|
167 |
+
with jsonlines.open(f'{args.data_file}/chat.jsonl') as file:
|
168 |
+
self.data = list(file)
|
169 |
+
elif args.data_type=='visual':
|
170 |
+
import jsonlines
|
171 |
+
# with open(f'{args.data_file}/chat.json', 'r', encoding='utf-8') as file:
|
172 |
+
# self.data = json.load(file)
|
173 |
+
with jsonlines.open(f'{args.data_file}/chat.jsonl') as file:
|
174 |
+
self.data = list(file)
|
175 |
+
elif args.data_type == 'visual-r1-cs':
|
176 |
+
|
177 |
+
llava_path = os.path.join(args.data_file, 'vision_r1_llava_cot_full.json')
|
178 |
+
mulberry_path = os.path.join(args.data_file, 'vision_r1_mulberry_sft_full.json')
|
179 |
+
|
180 |
+
import json
|
181 |
+
with open(f'{llava_path}', 'r', encoding='utf-8') as file:
|
182 |
+
llava_data = json.load(file)
|
183 |
+
with open(f'{mulberry_path}', 'r', encoding='utf-8') as file:
|
184 |
+
mulberry_data = json.load(file)
|
185 |
+
|
186 |
+
# 合并数据集并添加来源标识
|
187 |
+
for item in llava_data:
|
188 |
+
item['_source'] = 'llava_cot'
|
189 |
+
for item in mulberry_data:
|
190 |
+
item['_source'] = 'mulberry'
|
191 |
+
|
192 |
+
self.data = llava_data + mulberry_data
|
193 |
+
|
194 |
+
elif args.data_type =='hf' or args.data_type =='qa' or args.data_type =='cnqa' or args.data_type =='cnasr' or args.data_type =='tts':
|
195 |
+
from datasets import load_dataset, concatenate_datasets
|
196 |
+
|
197 |
+
def list_subdirectories(base_path):
|
198 |
+
return [
|
199 |
+
name for name in os.listdir(base_path)
|
200 |
+
if os.path.isdir(os.path.join(base_path, name)) and not name.startswith('.')
|
201 |
+
]
|
202 |
+
|
203 |
+
datasets = []
|
204 |
+
files = list_subdirectories(args.data_file)
|
205 |
+
if not files:
|
206 |
+
datasets = load_dataset(args.data_file, split="train")
|
207 |
+
else:
|
208 |
+
for file in files:
|
209 |
+
dataset = load_dataset(f'{args.data_file}/{file}', split="train")
|
210 |
+
datasets.append(dataset)
|
211 |
+
datasets = concatenate_datasets(datasets)
|
212 |
+
self.data = datasets
|
213 |
+
print(len(datasets))
|
214 |
+
|
215 |
+
elif args.data_type == "jsonl":
|
216 |
+
import jsonlines
|
217 |
+
|
218 |
+
with jsonlines.open(args.data_file) as file:
|
219 |
+
self.data = list(file)
|
220 |
+
|
221 |
+
else:
|
222 |
+
self.data = pd.read_parquet(args.data_file)
|
223 |
+
|
224 |
+
|
225 |
+
|
226 |
+
def setup(self, rank, world_size, devices, shuffle):
|
227 |
+
self.rank = rank
|
228 |
+
self.world_size = world_size
|
229 |
+
self.index_manager = GlobalIndexManager(rank=rank, device_num=devices, shuffle=shuffle)
|
230 |
+
|
231 |
+
def __len__(self):
|
232 |
+
return self.args.epoch_steps * self.args.micro_bsz
|
233 |
+
|
234 |
+
|
235 |
+
def __getitem__(self, idx):
|
236 |
+
idx = self.index_manager.get_next_idx(idx_t=idx) if self.index_manager else idx
|
237 |
+
args = self.args
|
238 |
+
if args.data_type =='wav':
|
239 |
+
|
240 |
+
mod_name = self.data[idx]['file_name']
|
241 |
+
data_answer = self.data[idx]['answer']
|
242 |
+
mod_path = f'{args.data_file}/{mod_name}'
|
243 |
+
audio, sample_rate = librosa.load(mod_path, sr=16000) # sr=None 保持原采样率
|
244 |
+
#sign,_ = self.speech_encoder(audio)
|
245 |
+
sign = audio
|
246 |
+
token = torch.tensor(pipeline.encode(f'\x16Assistant: {data_answer}\x17'))
|
247 |
+
elif args.data_type =='hf':
|
248 |
+
sample = self.data[idx]
|
249 |
+
audio = sample['audio']
|
250 |
+
data_answer = sample['text'] #####caption
|
251 |
+
audio = librosa.resample(audio['array'],orig_sr= audio['sampling_rate'],target_sr= 16000) # sr=None 保持原采样率
|
252 |
+
sign = audio
|
253 |
+
token = torch.tensor(pipeline.encode(f'\x16Assistant: {data_answer}\x17'))
|
254 |
+
elif args.data_type =='tts':
|
255 |
+
sample = self.data[idx]
|
256 |
+
audio = sample['audio']
|
257 |
+
data_answer = sample['text'] #####caption
|
258 |
+
audio = librosa.resample(audio['array'],orig_sr= audio['sampling_rate'],target_sr= 16000) # sr=None 保持原采样率
|
259 |
+
sign = audio
|
260 |
+
token = torch.tensor(pipeline.encode(f'User: {data_answer}\x17Assistant:'))
|
261 |
+
elif args.data_type =='qa':
|
262 |
+
sample = self.data[idx]
|
263 |
+
# audio = sample['speech_cosy'][0]
|
264 |
+
# data_answer = sample['answer']
|
265 |
+
|
266 |
+
audio = sample['question_audio']
|
267 |
+
data_answer = sample['answer']
|
268 |
+
sign = librosa.resample(audio['array'],orig_sr= audio['sampling_rate'],target_sr= 16000) # sr=None 保持原采样率
|
269 |
+
|
270 |
+
token = torch.tensor(pipeline.encode(f'\x16Assistant: {data_answer}\x17'))
|
271 |
+
elif args.data_type =='cnqa':
|
272 |
+
sample = self.data[idx]
|
273 |
+
audio = sample['audio']
|
274 |
+
data_answer = sample['answer']
|
275 |
+
sign = librosa.resample(audio['array'],orig_sr= audio['sampling_rate'],target_sr= 16000) # sr=None 保持原采样率
|
276 |
+
token = torch.tensor(pipeline.encode(f'\x16Assistant: {data_answer}\x17'))
|
277 |
+
elif args.data_type =='cnasr':
|
278 |
+
sample = self.data[idx]
|
279 |
+
audio = sample['audio']
|
280 |
+
data_answer = sample['transcript']
|
281 |
+
sign = librosa.resample(audio['array'],orig_sr= audio['sampling_rate'],target_sr= 16000) # sr=None 保持原采样率
|
282 |
+
token = torch.tensor(pipeline.encode(f'\x16Assistant: {data_answer}\x17'))
|
283 |
+
elif args.data_type == "jsonl":
|
284 |
+
ctx_len = args.ctx_len
|
285 |
+
req_len = ctx_len + 1
|
286 |
+
ctx = self.data[idx]['text']
|
287 |
+
token = torch.tensor(pipeline.encode(ctx))
|
288 |
+
token_len = len(token)
|
289 |
+
pad_len = req_len - token_len
|
290 |
+
|
291 |
+
dix = F.pad(token, (0, pad_len), value=0)
|
292 |
+
x = dix[:-1]
|
293 |
+
y = dix[1:]
|
294 |
+
mask = torch.zeros(req_len - 1)
|
295 |
+
mask[:token_len - 1] = 1
|
296 |
+
return x, y, mask
|
297 |
+
elif args.data_type == "img":
|
298 |
+
|
299 |
+
mod_name = self.data[idx]['file_name']
|
300 |
+
data_answer = self.data[idx]['answer']
|
301 |
+
mod_path = f'{args.data_file}/{mod_name}'
|
302 |
+
token = torch.tensor(pipeline.encode(f'\n\nAssistant: {data_answer}\x17'))
|
303 |
+
image = Image.open(mod_path).convert('RGB')
|
304 |
+
sign = transform(image)
|
305 |
+
elif args.data_type == 'visual':
|
306 |
+
|
307 |
+
img_name = self.data[idx]['image']
|
308 |
+
conversation_text = self.data[idx]['conversations']
|
309 |
+
|
310 |
+
mod_path = f'{args.data_file}/images/{img_name}'
|
311 |
+
image = Image.open(mod_path).convert('RGB')
|
312 |
+
sign = image
|
313 |
+
text_tokens, text_labels = process_tokens(conversation_text)
|
314 |
+
return sign, text_tokens, text_labels
|
315 |
+
elif args.data_type== 'hf_img':
|
316 |
+
|
317 |
+
img_name = self.data[idx]['image']
|
318 |
+
conversation_text = self.data[idx]['conversations']
|
319 |
+
conversation_text = process_conversation_text(conversation_text)
|
320 |
+
|
321 |
+
mod_path = f'{args.data_file}/images/{img_name}'
|
322 |
+
token = torch.tensor(pipeline.encode(conversation_text))
|
323 |
+
image = Image.open(mod_path).convert('RGB')
|
324 |
+
sign = image
|
325 |
+
elif args.data_type == 'visual-r1-cs':
|
326 |
+
item = self.data[idx]
|
327 |
+
conversations = item['conversations']
|
328 |
+
|
329 |
+
# 根据来源处理图像路径
|
330 |
+
if item['_source'] == 'llava_cot':
|
331 |
+
img_name = item['image']
|
332 |
+
|
333 |
+
else:
|
334 |
+
img_name = item['images']
|
335 |
+
|
336 |
+
|
337 |
+
|
338 |
+
mod_path = f'{self.args.data_file}/{img_name}'
|
339 |
+
image = Image.open(mod_path).convert('RGB')
|
340 |
+
sign = image
|
341 |
+
|
342 |
+
# 处理文本对话
|
343 |
+
text_tokens, text_labels = process_tokens(conversations)
|
344 |
+
|
345 |
+
return sign, text_tokens, text_labels
|
346 |
+
else:
|
347 |
+
data_audio = bytes_to_audio(self.data['question_audio'][idx]['bytes'])
|
348 |
+
data_answer = self.data['answer'][idx]
|
349 |
+
audio = librosa.resample(data_audio['array'],orig_sr= 48000,target_sr= 16000)
|
350 |
+
#sign,_ = self.speech_encoder(audio)
|
351 |
+
sign = audio
|
352 |
+
token = torch.tensor(pipeline.encode(f'\x16Assistant: {data_answer}\x17'))
|
353 |
+
#print(idx, f'Assistant: {data_answer}\x17')
|
354 |
+
return sign, token
|
world/encoder/clip_encoder.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
class VisualAdapter(nn.Module):
|
11 |
+
"""
|
12 |
+
2D Image to Patch Embedding
|
13 |
+
"""
|
14 |
+
def __init__(self, encoder_dim, project_dim, hidden_dim=None):
|
15 |
+
|
16 |
+
super().__init__()
|
17 |
+
self.encoder_dim = encoder_dim
|
18 |
+
self.project_dim = project_dim
|
19 |
+
self.hidden_dim = hidden_dim
|
20 |
+
|
21 |
+
if self.hidden_dim==None:
|
22 |
+
self.hidden_dim = project_dim*2
|
23 |
+
|
24 |
+
self.pre_norm = nn.LayerNorm(self.project_dim)
|
25 |
+
self.proj = nn.Sequential(
|
26 |
+
nn.Linear(self.encoder_dim, self.hidden_dim),
|
27 |
+
nn.ReLU(),
|
28 |
+
nn.Linear(self.hidden_dim, self.project_dim),
|
29 |
+
)
|
30 |
+
# self.proj = nn.Sequential(
|
31 |
+
# nn.Linear(self.encoder_dim, self.hidden_dim),
|
32 |
+
# nn.GELU(),
|
33 |
+
# nn.Linear(self.hidden_dim, self.hidden_dim),
|
34 |
+
# nn.GELU(),
|
35 |
+
# nn.Linear(self.hidden_dim, self.project_dim),
|
36 |
+
# )
|
37 |
+
|
38 |
+
|
39 |
+
def forward(self, x):
|
40 |
+
x = self.proj(x)
|
41 |
+
return x + self.pre_norm(x)
|
42 |
+
|
43 |
+
|
44 |
+
|
45 |
+
class ClipEncoder(nn.Module):
|
46 |
+
|
47 |
+
def __init__(
|
48 |
+
self,
|
49 |
+
encoder_path,
|
50 |
+
project_dim,
|
51 |
+
train_mode="adapter",
|
52 |
+
device="cuda",) -> None:
|
53 |
+
super(ClipEncoder, self).__init__()
|
54 |
+
|
55 |
+
self.device = device
|
56 |
+
self.image_processor = CLIPImageProcessor.from_pretrained(encoder_path)
|
57 |
+
self.model = CLIPVisionModel.from_pretrained(encoder_path)
|
58 |
+
self.encoder_dim = self.model.config.hidden_size
|
59 |
+
|
60 |
+
self.adapter = VisualAdapter(self.encoder_dim, project_dim)
|
61 |
+
def forward(self, x):
|
62 |
+
|
63 |
+
x= torch.from_numpy(self.image_processor(x)['pixel_values'][0]).to(self.device,dtype=torch.bfloat16)
|
64 |
+
|
65 |
+
x = self.model(x.unsqueeze(0), output_hidden_states=True).last_hidden_state
|
66 |
+
|
67 |
+
x = self.adapter(x)
|
68 |
+
|
69 |
+
return x
|
world/encoder/siglip_encoder.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from transformers import AutoModel, SiglipImageProcessor
|
7 |
+
|
8 |
+
class VisualAdapter(nn.Module):
|
9 |
+
"""
|
10 |
+
2D Image to Patch Embedding
|
11 |
+
"""
|
12 |
+
def __init__(self, encoder_dim, project_dim, hidden_dim=None):
|
13 |
+
|
14 |
+
super().__init__()
|
15 |
+
self.encoder_dim = encoder_dim
|
16 |
+
self.project_dim = project_dim
|
17 |
+
self.hidden_dim = hidden_dim
|
18 |
+
|
19 |
+
if self.hidden_dim==None:
|
20 |
+
self.hidden_dim = project_dim*2
|
21 |
+
|
22 |
+
self.pre_norm = nn.LayerNorm(self.project_dim)
|
23 |
+
self.proj = nn.Sequential(
|
24 |
+
nn.Linear(self.encoder_dim, self.hidden_dim),
|
25 |
+
nn.ReLU(),
|
26 |
+
nn.Linear(self.hidden_dim, self.project_dim),
|
27 |
+
)
|
28 |
+
# self.conv = nn.Conv1d(
|
29 |
+
# in_channels=encoder_dim,
|
30 |
+
# out_channels=encoder_dim,
|
31 |
+
# bias=False,
|
32 |
+
# kernel_size=5,
|
33 |
+
# stride=4
|
34 |
+
# )
|
35 |
+
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
# x = self.conv(x.permute(0,2,1)).permute(0,2,1)
|
39 |
+
x = self.proj(x)
|
40 |
+
return x + self.pre_norm(x)
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
class SiglipEncoder(nn.Module):
|
45 |
+
|
46 |
+
def __init__(
|
47 |
+
self,
|
48 |
+
encoder_path,
|
49 |
+
project_dim,
|
50 |
+
encoder_device='cpu',
|
51 |
+
train_mode="adapter"
|
52 |
+
) -> None:
|
53 |
+
super(SiglipEncoder, self).__init__()
|
54 |
+
|
55 |
+
self.device = encoder_device
|
56 |
+
|
57 |
+
self.model = AutoModel.from_pretrained(encoder_path).vision_model
|
58 |
+
self.image_processor = SiglipImageProcessor.from_pretrained(encoder_path)
|
59 |
+
self.encoder_dim = 768 #self.model.config.hidden_size
|
60 |
+
|
61 |
+
self.adapter = VisualAdapter(self.encoder_dim, project_dim)
|
62 |
+
def forward(self, x):
|
63 |
+
|
64 |
+
x= torch.from_numpy(self.image_processor(x)['pixel_values'][0]).to(self.device,dtype=torch.bfloat16)
|
65 |
+
x = self.model(x.unsqueeze(0), output_hidden_states=True).last_hidden_state
|
66 |
+
x = self.adapter(x)
|
67 |
+
|
68 |
+
return x
|
world/encoder/speech_encoder.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.nn import TransformerEncoder, TransformerEncoderLayer
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from transformers import AutoProcessor, AutoModel
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
# class SpeechAdapter(nn.Module):
|
12 |
+
# def __init__(self, input_dim, output_dim):
|
13 |
+
# super(SpeechAdapter, self).__init__()
|
14 |
+
# self.conv = nn.Conv1d(in_channels=input_dim, out_channels=3072, kernel_size=3, stride=2)
|
15 |
+
# self.transformer = nn.TransformerEncoderLayer(d_model=3072, nhead=8, dim_feedforward=4096)
|
16 |
+
# self.linear = nn.Linear(3072, output_dim)
|
17 |
+
# def forward(self, x):
|
18 |
+
# # if x.size(1)<5 or x.size(1)>5000:
|
19 |
+
# # return False
|
20 |
+
# # x shape: (batch_size, seq_len, input_dim)
|
21 |
+
# x = x.permute(0, 2, 1)
|
22 |
+
# # x shape: (batch_size, input_dim, seq_len)
|
23 |
+
# x = self.conv(x)
|
24 |
+
# # x shape after conv: (batch_size, input_dim, new_seq_len)
|
25 |
+
# x = x.permute(2, 0, 1) # Transformer expects (seq_len, batch_size, input_dim)
|
26 |
+
# # x = self.transformer(x, src_key_padding_mask=mask.bool())
|
27 |
+
# x = self.transformer(x)
|
28 |
+
# x = x.permute(1, 0, 2) # Back to (batch_size, seq_len, input_dim)
|
29 |
+
# x = self.linear(x)
|
30 |
+
# return x
|
31 |
+
|
32 |
+
class SpeechAdapter(nn.Module):
|
33 |
+
def __init__(self, encoder_dim, project_dim, hidden_dim=None):
|
34 |
+
super(SpeechAdapter, self).__init__()
|
35 |
+
self.encoder_dim = encoder_dim
|
36 |
+
self.project_dim = project_dim
|
37 |
+
self.hidden_dim = hidden_dim
|
38 |
+
|
39 |
+
if self.hidden_dim==None:
|
40 |
+
self.hidden_dim = project_dim*2
|
41 |
+
self.conv = nn.Conv1d(in_channels=self.encoder_dim , out_channels=self.hidden_dim, kernel_size=3, stride=2, padding=2)
|
42 |
+
self.proj = nn.Sequential(
|
43 |
+
nn.Linear(self.hidden_dim, self.hidden_dim),
|
44 |
+
nn.ReLU(),
|
45 |
+
nn.Linear(self.hidden_dim, self.project_dim),
|
46 |
+
)
|
47 |
+
def forward(self, x):
|
48 |
+
# if x.size(1)<5 or x.size(1)>5000:
|
49 |
+
# return False
|
50 |
+
|
51 |
+
# x shape: (batch_size, seq_len, input_dim)
|
52 |
+
x = x.permute(0, 2, 1)
|
53 |
+
# x shape: (batch_size, input_dim, seq_len)
|
54 |
+
x = self.conv(x).permute(0, 2, 1)
|
55 |
+
# x shape after conv: (batch_size, input_dim, new_seq_len)
|
56 |
+
x = self.proj(x)
|
57 |
+
if x.size(1)>1023:
|
58 |
+
return False
|
59 |
+
return x
|
60 |
+
|
61 |
+
class SpeechEncoder(nn.Module):
|
62 |
+
def __init__(
|
63 |
+
self,
|
64 |
+
encoder_path,
|
65 |
+
project_dim,
|
66 |
+
train_mode="adapter",
|
67 |
+
device="cuda",
|
68 |
+
):
|
69 |
+
assert train_mode in ["adapter", "full"]
|
70 |
+
super(SpeechEncoder, self).__init__()
|
71 |
+
|
72 |
+
self.device = device
|
73 |
+
|
74 |
+
try:
|
75 |
+
self.processor = AutoProcessor.from_pretrained(encoder_path)
|
76 |
+
except:
|
77 |
+
self.processor = AutoProcessor.from_pretrained("facebook/hubert-large-ls960-ft")
|
78 |
+
|
79 |
+
self.time_reduction_factor = int(
|
80 |
+
self.processor.feature_extractor.sampling_rate / 50
|
81 |
+
)
|
82 |
+
self.padding_length = 320
|
83 |
+
|
84 |
+
self.model = AutoModel.from_pretrained(encoder_path)
|
85 |
+
self.model.eval()
|
86 |
+
self.model_output_dim = self.model.config.hidden_size
|
87 |
+
self.project_dim = project_dim
|
88 |
+
|
89 |
+
self.project_dim = project_dim
|
90 |
+
self.adapter = SpeechAdapter(self.model_output_dim, self.project_dim).to(self.device,dtype=torch.bfloat16)
|
91 |
+
# self.set_gradient(train_mode)
|
92 |
+
|
93 |
+
|
94 |
+
|
95 |
+
|
96 |
+
def forward(self, x):
|
97 |
+
input_dict = self.processor(
|
98 |
+
x, return_tensors="pt", padding=True, sampling_rate=16000
|
99 |
+
).to(self.device,dtype=torch.bfloat16)
|
100 |
+
|
101 |
+
# encoder only
|
102 |
+
x = self.model(**input_dict).last_hidden_state
|
103 |
+
|
104 |
+
# stf encoder
|
105 |
+
# x = self.model(**input_dict, output_hidden_states=True).hidden_states[-1]
|
106 |
+
|
107 |
+
x= self.adapter(x)#x:(B,T,hidden dim)
|
108 |
+
# mask = torch.ones(x.shape[0],x.shape[1]).to(self.device,dtype=torch.bfloat16)
|
109 |
+
return x
|
world/encoder/visual_encoder.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from diffusers import AutoencoderKL
|
7 |
+
|
8 |
+
|
9 |
+
class Patch(nn.Module):
|
10 |
+
def __init__(self, Imgsize=64, Patchsize=16) -> None:
|
11 |
+
super(Patch, self).__init__()
|
12 |
+
self.Patchsize = Patchsize
|
13 |
+
self.Imgsize = Imgsize
|
14 |
+
def encoder(self, x):
|
15 |
+
assert x.size(3)==self.Imgsize
|
16 |
+
imgsize = self.Imgsize
|
17 |
+
patchsize = self.Patchsize
|
18 |
+
x = x.unfold(2, patchsize, patchsize).unfold(3, patchsize, patchsize)
|
19 |
+
x = x.contiguous().reshape(x.size(0), x.size(1), int(pow(imgsize/patchsize, 2)), -1)
|
20 |
+
x = x.transpose(1, 2)
|
21 |
+
x = x.reshape(x.size(0), x.size(1), -1)
|
22 |
+
return x
|
23 |
+
|
24 |
+
def decoder(self, x):
|
25 |
+
imgsize = self.Imgsize
|
26 |
+
patchsize = self.Patchsize
|
27 |
+
x = x.reshape(x.size(0), x.size(1), x.size(2)//patchsize, patchsize)
|
28 |
+
x = x.transpose(1, 2)
|
29 |
+
x = x.unfold(2,imgsize//patchsize,imgsize//patchsize).unfold(3, patchsize, patchsize)
|
30 |
+
x = x.reshape(x.size(0), 4, imgsize, imgsize)
|
31 |
+
return x
|
32 |
+
|
33 |
+
class SD_Auto():
|
34 |
+
def __init__(self, path="sdxl", input_dtype=torch.float32) -> None:
|
35 |
+
self.autoencoder = AutoencoderKL.from_pretrained(path, subfolder="vae")
|
36 |
+
self.autoencoder = self.autoencoder.to('cuda', input_dtype)
|
37 |
+
|
38 |
+
def encoder(self, x):
|
39 |
+
with torch.no_grad():
|
40 |
+
x = self.autoencoder.encode(x).latent_dist.sample()
|
41 |
+
return x
|
42 |
+
|
43 |
+
def decoder(self, x):
|
44 |
+
|
45 |
+
with torch.no_grad():
|
46 |
+
x = self.autoencoder.decode(x).sample
|
47 |
+
#x = self.autoencoder.decode(x).sample
|
48 |
+
return x
|
49 |
+
|
50 |
+
def kld_loss(mu, logvar):
|
51 |
+
KLD = - 0.5 * torch.sum(1 + logvar - mu.pow(2) -
|
52 |
+
logvar.exp()) / mu.shape[0]
|
53 |
+
return KLD
|
54 |
+
|
55 |
+
|
56 |
+
class VisualAdapter(nn.Module):
|
57 |
+
"""
|
58 |
+
2D Image to Patch Embedding
|
59 |
+
"""
|
60 |
+
def __init__(self, img_size=512//8, patch_size=16, in_c=4, text_dim=2560, head_size=64):
|
61 |
+
|
62 |
+
super().__init__()
|
63 |
+
self.head_size = head_size
|
64 |
+
# self.img_receptance = nn.Linear((patch_size*patch_size*in_c), text_dim, bias=False)
|
65 |
+
# self.img_key = nn.Linear((patch_size*patch_size*in_c), text_dim, bias=False)
|
66 |
+
# self.img_value = nn.Linear((patch_size*patch_size*in_c), text_dim, bias=False)
|
67 |
+
self.linear = nn.Linear((patch_size*patch_size*in_c), text_dim, bias=False)
|
68 |
+
self.patch = Patch(Imgsize=img_size, Patchsize=patch_size)
|
69 |
+
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
B, C, H, W = x.shape
|
73 |
+
x = self.patch.encoder(x)
|
74 |
+
# r = self.img_receptance(x)
|
75 |
+
# k = self.img_key(x)
|
76 |
+
# v = self.img_value(x)
|
77 |
+
# r = r.view(*x.shape[:2], -1, self.head_size).transpose(1, 2)
|
78 |
+
# k = k.view(*x.shape[:2], -1, self.head_size).transpose(1, 2)
|
79 |
+
# v = v.view(*x.shape[:2], -1, self.head_size).transpose(1, 2)
|
80 |
+
# x_img = torch.nn.functional.scaled_dot_product_attention(
|
81 |
+
# r, k, v, is_causal=True, scale=1 / self.head_size
|
82 |
+
# )
|
83 |
+
# x = x_img.transpose(1, 2).reshape(*x.shape[:2], -1)
|
84 |
+
return self.linear(x)
|
85 |
+
|
86 |
+
|
87 |
+
|
88 |
+
class VisualEncoder(nn.Module):
|
89 |
+
def __init__(
|
90 |
+
self,
|
91 |
+
encoder_path,
|
92 |
+
project_dim,
|
93 |
+
train_mode="adapter",
|
94 |
+
device="cuda",) -> None:
|
95 |
+
super(VisualEncoder, self).__init__()
|
96 |
+
|
97 |
+
self.model = AutoencoderKL.from_pretrained(path, subfolder="vae", allow_pickle=False).to('cuda', input_dtype)
|
98 |
+
self.adapter = VisualAdapter(text_dim=llm_dim).to('cuda',dtype=torch.bfloat16)
|
99 |
+
|
100 |
+
def forward(self, x):
|
101 |
+
x = self.model.encode(x.unsqueeze(0)).latent_dist.sample()
|
102 |
+
#print(x.view(-1))
|
103 |
+
x = self.adapter(x)
|
104 |
+
#print(x.view(-1))
|
105 |
+
|
106 |
+
return x
|
world/encoder/whisper_encoder.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from transformers import WhisperProcessor, WhisperForConditionalGeneration
|
7 |
+
|
8 |
+
class SpeechAdapter(nn.Module):
|
9 |
+
def __init__(self, encoder_dim, project_dim, hidden_dim=None):
|
10 |
+
super(SpeechAdapter, self).__init__()
|
11 |
+
self.encoder_dim = encoder_dim
|
12 |
+
self.project_dim = project_dim
|
13 |
+
self.hidden_dim = hidden_dim
|
14 |
+
|
15 |
+
if self.hidden_dim==None:
|
16 |
+
self.hidden_dim = project_dim*2
|
17 |
+
self.conv = nn.Conv1d(in_channels=self.encoder_dim , out_channels=self.hidden_dim, kernel_size=3, stride=2, padding=2)
|
18 |
+
self.proj = nn.Sequential(
|
19 |
+
nn.Linear(self.hidden_dim, self.hidden_dim),
|
20 |
+
nn.ReLU(),
|
21 |
+
nn.Linear(self.hidden_dim, self.project_dim),
|
22 |
+
)
|
23 |
+
def forward(self, x):
|
24 |
+
# if x.size(1)<5 or x.size(1)>5000:
|
25 |
+
# return False
|
26 |
+
|
27 |
+
# x shape: (batch_size, seq_len, input_dim)
|
28 |
+
x = x.permute(0, 2, 1)
|
29 |
+
# x shape: (batch_size, input_dim, seq_len)
|
30 |
+
x = self.conv(x).permute(0, 2, 1)
|
31 |
+
# x shape after conv: (batch_size, input_dim, new_seq_len)
|
32 |
+
x = self.proj(x)
|
33 |
+
return x
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
class WhisperEncoder(nn.Module):
|
38 |
+
def __init__(
|
39 |
+
self,
|
40 |
+
encoder_path,
|
41 |
+
project_dim,
|
42 |
+
train_mode="adapter",
|
43 |
+
device="cuda",
|
44 |
+
):
|
45 |
+
assert train_mode in ["adapter", "full"]
|
46 |
+
super(WhisperEncoder, self).__init__()
|
47 |
+
self.device = device
|
48 |
+
self.processor = WhisperProcessor.from_pretrained(encoder_path)
|
49 |
+
|
50 |
+
self.model = WhisperForConditionalGeneration.from_pretrained(encoder_path).model.encoder
|
51 |
+
|
52 |
+
self.model_output_dim = self.model.config.d_model
|
53 |
+
|
54 |
+
self.project_dim = project_dim
|
55 |
+
self.adapter = SpeechAdapter(self.model_output_dim, self.project_dim)
|
56 |
+
|
57 |
+
def forward(self, x):
|
58 |
+
input_dict = self.processor(
|
59 |
+
x, return_tensors="pt", sampling_rate=16000, return_attention_mask=True
|
60 |
+
).to(self.device,dtype=torch.bfloat16)
|
61 |
+
|
62 |
+
chunk = torch.sum(input_dict['attention_mask'], dim=-1)//2+1
|
63 |
+
|
64 |
+
x = self.model(**input_dict).last_hidden_state
|
65 |
+
x = x[:,:chunk,:]
|
66 |
+
x= self.adapter(x)#x:(B,T,hidden dim)
|
67 |
+
|
68 |
+
return x
|
world/loss.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
class L2Wrap(torch.autograd.Function):
|
3 |
+
@staticmethod
|
4 |
+
def forward(ctx, loss, y):
|
5 |
+
ctx.save_for_backward(y)
|
6 |
+
return loss
|
7 |
+
|
8 |
+
@staticmethod
|
9 |
+
def backward(ctx, grad_output):
|
10 |
+
y = ctx.saved_tensors[0]
|
11 |
+
# to encourage the logits to be close to 0
|
12 |
+
factor = 1e-4 / (y.shape[0] * y.shape[1])
|
13 |
+
maxx, ids = torch.max(y, -1, keepdim=True)
|
14 |
+
gy = torch.zeros_like(y)
|
15 |
+
gy.scatter_(-1, ids, maxx * factor)
|
16 |
+
return (grad_output, gy)
|
world/model.py
ADDED
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
########################################################################################################
|
2 |
+
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
3 |
+
########################################################################################################
|
4 |
+
from torch.utils.checkpoint import checkpoint as torch_checkpoint
|
5 |
+
from torch.profiler import profile, record_function, ProfilerActivity
|
6 |
+
#from adam_mini import Adam_mini
|
7 |
+
|
8 |
+
import os, math, gc, importlib
|
9 |
+
import torch
|
10 |
+
|
11 |
+
import torch.nn as nn
|
12 |
+
from torch.nn import functional as F
|
13 |
+
import lightning as pl
|
14 |
+
from lightning.pytorch.strategies import DeepSpeedStrategy
|
15 |
+
if importlib.util.find_spec('deepspeed'):
|
16 |
+
import deepspeed
|
17 |
+
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
|
18 |
+
|
19 |
+
from .block import Block
|
20 |
+
from .loss import L2Wrap
|
21 |
+
from .cat import mod_pad_text
|
22 |
+
|
23 |
+
from rwkv.utils import PIPELINE
|
24 |
+
pipeline = PIPELINE('rwkv6', "rwkv_vocab_v20230424")
|
25 |
+
|
26 |
+
class RWKV(pl.LightningModule):
|
27 |
+
def __init__(self, args, modality=None):
|
28 |
+
super().__init__()
|
29 |
+
self.args = args
|
30 |
+
if not hasattr(args, 'dim_att'):
|
31 |
+
args.dim_att = args.n_embd
|
32 |
+
if not hasattr(args, 'dim_ffn'):
|
33 |
+
args.dim_ffn = args.n_embd * 4
|
34 |
+
|
35 |
+
assert args.n_embd % 32 == 0
|
36 |
+
assert args.dim_att % 32 == 0
|
37 |
+
assert args.dim_ffn % 32 == 0
|
38 |
+
|
39 |
+
self.emb = nn.Embedding(args.vocab_size, args.n_embd)
|
40 |
+
|
41 |
+
self.blocks = nn.ModuleList([Block(args, i) for i in range(args.n_layer)])
|
42 |
+
|
43 |
+
self.ln_out = nn.LayerNorm(args.n_embd)
|
44 |
+
self.head = nn.Linear(args.n_embd, args.vocab_size, bias=False)
|
45 |
+
|
46 |
+
self.modality = modality
|
47 |
+
|
48 |
+
|
49 |
+
def pad_mod(self, tensor_list, signal_list):
|
50 |
+
"""
|
51 |
+
对一个包含不同长度张量的列表进行填充,使所有张量的长度相同且为16的整数倍,并生成掩码。
|
52 |
+
参数:
|
53 |
+
tensor_list (list of torch.Tensor): 输入的张量列表,每个张量形状为 [seq_len]。
|
54 |
+
pad_value (int, optional): 填充值,默认值为 0。
|
55 |
+
返回:
|
56 |
+
padded_tensor (torch.Tensor): 填充后的张量,形状为 [batch_size, target_len]。
|
57 |
+
mask (torch.Tensor): 填充掩码,1 表示有效数据,0 表示填充部分。
|
58 |
+
"""
|
59 |
+
|
60 |
+
modality_list = []
|
61 |
+
#max_len = max((token.size(0) + signal.size(1)) for token, signal in zip(tensor_list, modality_list))
|
62 |
+
max_len = 0
|
63 |
+
for token, signal in zip(tensor_list, signal_list):
|
64 |
+
|
65 |
+
modality = self.modality(signal)
|
66 |
+
if modality is False:
|
67 |
+
modality_list.append(False)
|
68 |
+
continue
|
69 |
+
modality_list.append(modality)
|
70 |
+
max_len = max(token.size(0) + modality.size(1), max_len)
|
71 |
+
|
72 |
+
# 计算目标长度(向上取整到16的整数倍)
|
73 |
+
target_len = ((max_len + 15) // 16 * 16)+1
|
74 |
+
|
75 |
+
if self.args.ctx_len is not None:
|
76 |
+
target_len = min(target_len, self.args.ctx_len+1)
|
77 |
+
|
78 |
+
masks = torch.zeros((len(tensor_list), target_len-1), dtype=torch.int32)
|
79 |
+
x = []
|
80 |
+
y = []
|
81 |
+
s = []
|
82 |
+
m = []
|
83 |
+
for token, signal, mask in zip(tensor_list, modality_list, masks):
|
84 |
+
if signal is False:
|
85 |
+
continue
|
86 |
+
pad_len = target_len-(token.size(0) + signal.size(1))
|
87 |
+
|
88 |
+
padded_token = F.pad(token, (0, pad_len), value=0)
|
89 |
+
|
90 |
+
x_token = padded_token[:-1]
|
91 |
+
y_token = F.pad(padded_token, (signal.size(1)-1, 0), value=0)
|
92 |
+
|
93 |
+
mask[signal.size(1) : -pad_len] = 1
|
94 |
+
|
95 |
+
s.append(signal)
|
96 |
+
x.append(x_token)
|
97 |
+
y.append(y_token)
|
98 |
+
m.append(mask)
|
99 |
+
|
100 |
+
y = torch.stack(y, dim=0)
|
101 |
+
m = torch.stack(m, dim=0).cuda()
|
102 |
+
|
103 |
+
return s, x, y, m
|
104 |
+
|
105 |
+
|
106 |
+
def forward(self, idx, signs= None):
|
107 |
+
args = self.args
|
108 |
+
#B, T = idx.size()
|
109 |
+
# assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted."
|
110 |
+
|
111 |
+
x_list = []
|
112 |
+
if signs!=None:
|
113 |
+
for token, sign in zip(idx, signs):
|
114 |
+
sign_emb = sign#self.adapter(sign)
|
115 |
+
x_emb = self.emb(token)
|
116 |
+
# #print(sign_emb.shape, x.shape)
|
117 |
+
x_list.append(torch.cat([sign_emb.squeeze(0),x_emb], dim=0))
|
118 |
+
x = torch.stack(x_list, dim=0)
|
119 |
+
else:
|
120 |
+
x = self.emb(idx)
|
121 |
+
|
122 |
+
v_first = torch.empty_like(x)
|
123 |
+
for block in self.blocks:
|
124 |
+
if args.grad_cp == 1:
|
125 |
+
if args.state_tune or args.train_type == 'state' or args.peft !='none':
|
126 |
+
x, v_first = torch_checkpoint(block, x, v_first ,use_reentrant=False)
|
127 |
+
else:
|
128 |
+
x, v_first = deepspeed.checkpointing.checkpoint(block, x, v_first)
|
129 |
+
else:
|
130 |
+
x, v_first = block(x, v_first)
|
131 |
+
|
132 |
+
x = self.ln_out(x)
|
133 |
+
x = self.head(x)
|
134 |
+
|
135 |
+
return x
|
136 |
+
|
137 |
+
def training_step(self, batch, batch_idx):
|
138 |
+
args = self.args
|
139 |
+
if args.data_type == "jsonl": ########test
|
140 |
+
idx, targets, mask = batch
|
141 |
+
|
142 |
+
mask = mask.view(-1)
|
143 |
+
sum_mask = torch.sum(mask).item()
|
144 |
+
logits = self(idx)
|
145 |
+
|
146 |
+
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1), reduction='none')
|
147 |
+
loss = torch.sum(loss * mask) / sum_mask
|
148 |
+
return loss
|
149 |
+
|
150 |
+
if args.data_type in ["visual", "visual-r1-cs"]: ########test
|
151 |
+
signs, text_tokens, text_labels = batch
|
152 |
+
sign, idx, targets = mod_pad_text(self, signs, text_tokens, text_labels)
|
153 |
+
|
154 |
+
logits = self(idx,sign)
|
155 |
+
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))
|
156 |
+
|
157 |
+
return loss
|
158 |
+
|
159 |
+
signs, tokens = batch
|
160 |
+
sign, idx, targets, mask = self.pad_mod(tokens, signs)
|
161 |
+
|
162 |
+
mask = mask.view(-1)
|
163 |
+
sum_mask = torch.sum(mask).item()
|
164 |
+
logits = self(idx,sign)
|
165 |
+
|
166 |
+
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1), reduction='none')
|
167 |
+
loss = torch.sum(loss * mask) / sum_mask
|
168 |
+
|
169 |
+
return loss
|
170 |
+
|
171 |
+
|
172 |
+
def configure_optimizers(self):
|
173 |
+
args = self.args
|
174 |
+
|
175 |
+
lr_decay = set()
|
176 |
+
lr_1x = set()
|
177 |
+
lr_2x = set()
|
178 |
+
lr_3x = set()
|
179 |
+
for n, p in self.named_parameters():
|
180 |
+
if not p.requires_grad:
|
181 |
+
continue
|
182 |
+
if (("_w1" in n) or ("_w2" in n)) and (args.layerwise_lr > 0):
|
183 |
+
lr_1x.add(n)
|
184 |
+
elif (("time_mix" in n) or ("time_maa" in n)) and (args.layerwise_lr > 0):
|
185 |
+
if args.my_pile_stage == 2:
|
186 |
+
lr_2x.add(n)
|
187 |
+
else:
|
188 |
+
lr_1x.add(n)
|
189 |
+
elif (("time_decay" in n) or ("time_daaaa" in n)) and (args.layerwise_lr > 0):
|
190 |
+
if args.my_pile_stage == 2:
|
191 |
+
lr_3x.add(n)
|
192 |
+
else:
|
193 |
+
lr_2x.add(n)
|
194 |
+
elif ("time_faaaa" in n) and (args.layerwise_lr > 0):
|
195 |
+
if args.my_pile_stage == 2:
|
196 |
+
lr_2x.add(n)
|
197 |
+
else:
|
198 |
+
lr_1x.add(n)
|
199 |
+
elif ("time_first" in n) and (args.layerwise_lr > 0):
|
200 |
+
lr_3x.add(n)
|
201 |
+
elif (len(p.squeeze().shape) >= 2) and (args.weight_decay > 0):
|
202 |
+
lr_decay.add(n)
|
203 |
+
else:
|
204 |
+
lr_1x.add(n)
|
205 |
+
|
206 |
+
lr_decay = sorted(list(lr_decay))
|
207 |
+
lr_1x = sorted(list(lr_1x))
|
208 |
+
lr_2x = sorted(list(lr_2x))
|
209 |
+
lr_3x = sorted(list(lr_3x))
|
210 |
+
|
211 |
+
param_dict = {n: p for n, p in self.named_parameters()}
|
212 |
+
|
213 |
+
if args.layerwise_lr > 0:
|
214 |
+
if args.my_pile_stage == 2:
|
215 |
+
optim_groups = [
|
216 |
+
{"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0},
|
217 |
+
{"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 2e-3 / args.lr_init},
|
218 |
+
{"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 3e-3 / args.lr_init},
|
219 |
+
]
|
220 |
+
else:
|
221 |
+
optim_groups = [
|
222 |
+
{"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0},
|
223 |
+
{"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0},
|
224 |
+
{"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 3.0},
|
225 |
+
]
|
226 |
+
else:
|
227 |
+
optim_groups = [{"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0}]
|
228 |
+
|
229 |
+
if args.weight_decay > 0:
|
230 |
+
optim_groups += [{"params": [param_dict[n] for n in lr_decay], "weight_decay": args.weight_decay, "my_lr_scale": 1.0}]
|
231 |
+
|
232 |
+
if self.deepspeed_offload:
|
233 |
+
return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=True, amsgrad=False)
|
234 |
+
return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=True, amsgrad=False)
|
235 |
+
else:
|
236 |
+
|
237 |
+
if self.deepspeed_offload:
|
238 |
+
return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False)
|
239 |
+
return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False)
|
240 |
+
# return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False)
|
241 |
+
|
242 |
+
@property
|
243 |
+
def deepspeed_offload(self) -> bool:
|
244 |
+
strategy = self.trainer.strategy
|
245 |
+
if isinstance(strategy, DeepSpeedStrategy):
|
246 |
+
cfg = strategy.config["zero_optimization"]
|
247 |
+
return cfg.get("offload_optimizer") or cfg.get("offload_param")
|
248 |
+
return False
|
249 |
+
|
250 |
+
def generate_init_weight(self):
|
251 |
+
print(
|
252 |
+
f"""
|
253 |
+
############################################################################
|
254 |
+
#
|
255 |
+
# Init model weight (slow for large models)...
|
256 |
+
#
|
257 |
+
############################################################################
|
258 |
+
"""
|
259 |
+
)
|
260 |
+
m = {}
|
261 |
+
for n in self.state_dict():
|
262 |
+
p = self.state_dict()[n]
|
263 |
+
shape = p.shape
|
264 |
+
|
265 |
+
gain = 1.0
|
266 |
+
scale = 1.0
|
267 |
+
if "ln_" in n or ".ln" in n or "time_" in n or "_mask" in n or "pos_emb" in n or '.mask.' in n:
|
268 |
+
if 'ln_x.weight' in n:
|
269 |
+
layer_scale = (1+int(n.split('.')[1])) / self.args.n_layer
|
270 |
+
m[n] = (p * 0.0) + (layer_scale ** 0.7)
|
271 |
+
else:
|
272 |
+
m[n] = p
|
273 |
+
else:
|
274 |
+
if n == "emb.weight":
|
275 |
+
scale = -1 * self.args.lr_init
|
276 |
+
else:
|
277 |
+
if shape[0] > shape[1]:
|
278 |
+
gain = math.sqrt(shape[0] / shape[1])
|
279 |
+
|
280 |
+
zero = [".att.output.", ".ffn.value.", ".ffn.receptance.", ".ffnPre.value.", ".ffnPre.receptance.", "head_q.", '.oo.', '.rr.']
|
281 |
+
|
282 |
+
for kk in zero:
|
283 |
+
if kk in n:
|
284 |
+
scale = 0
|
285 |
+
if n == "head.weight":
|
286 |
+
scale = 0.5
|
287 |
+
if "head_k." in n:
|
288 |
+
scale = 0.1
|
289 |
+
if "head_q." in n:
|
290 |
+
scale = 0
|
291 |
+
|
292 |
+
print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {str(scale).ljust(4)} {n}")
|
293 |
+
|
294 |
+
if self.args.accelerator.upper() == "GPU":
|
295 |
+
m[n] = torch.empty((shape[0], shape[1]), device="cuda")
|
296 |
+
else:
|
297 |
+
m[n] = torch.empty((shape[0], shape[1]))
|
298 |
+
|
299 |
+
if scale == 0:
|
300 |
+
nn.init.zeros_(m[n])
|
301 |
+
elif scale < 0:
|
302 |
+
nn.init.uniform_(m[n], a=scale, b=-scale)
|
303 |
+
else:
|
304 |
+
nn.init.orthogonal_(m[n], gain=gain * scale)
|
305 |
+
|
306 |
+
m[n] = m[n].cpu()
|
307 |
+
if os.environ["RWKV_FLOAT_MODE"] == "fp16":
|
308 |
+
m[n] = m[n].half()
|
309 |
+
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
310 |
+
m[n] = m[n].bfloat16()
|
311 |
+
|
312 |
+
# if n == "emb.weight":
|
313 |
+
# print(m[n])
|
314 |
+
|
315 |
+
gc.collect()
|
316 |
+
torch.cuda.empty_cache()
|
317 |
+
return m
|
world/world_encoder.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from world.encoder.speech_encoder import SpeechEncoder
|
5 |
+
# from world.encoder.visual_encoder import VisualEncoder
|
6 |
+
from world.encoder.whisper_encoder import WhisperEncoder
|
7 |
+
from world.encoder.clip_encoder import ClipEncoder
|
8 |
+
from world.encoder.siglip_encoder import SiglipEncoder
|
9 |
+
|
10 |
+
|
11 |
+
class WorldEncoder(nn.Module):
|
12 |
+
def __init__(self, encoder_type: str, **kwargs):
|
13 |
+
super().__init__()
|
14 |
+
self.world_encoder = self._build_encoder(encoder_type, kwargs)
|
15 |
+
|
16 |
+
def _build_encoder(self, encoder_type: str, config: dict):
|
17 |
+
encoder_map = {
|
18 |
+
"clip": ClipEncoder,
|
19 |
+
"whisper": WhisperEncoder,
|
20 |
+
# "visual": VisualEncoder,
|
21 |
+
"speech": SpeechEncoder,
|
22 |
+
"siglip": SiglipEncoder
|
23 |
+
}
|
24 |
+
|
25 |
+
if encoder_type not in encoder_map:
|
26 |
+
raise ValueError(f"Unsupported encoder type: {encoder_type}")
|
27 |
+
|
28 |
+
# 动态过滤有效参数
|
29 |
+
encoder_class = encoder_map[encoder_type]
|
30 |
+
|
31 |
+
return encoder_class(**config)
|
32 |
+
|
33 |
+
def forward(self, x):
|
34 |
+
return self.world_encoder(x)
|
35 |
+
|
36 |
+
def load_checkpoint(self, state_dict):
|
37 |
+
self.world_encoder.load_state_dict(state_dict, strict=False)
|
world/world_load.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from world.model import RWKV
|
2 |
+
from world.world_encoder import WorldEncoder
|
3 |
+
def WorldLoading(args):
|
4 |
+
config = {
|
5 |
+
'encoder_type': args.encoder_type,
|
6 |
+
'encoder_path': args.encoder_path,
|
7 |
+
'project_dim' : args.n_embd
|
8 |
+
}
|
9 |
+
modality = WorldEncoder(**config)
|
10 |
+
|
11 |
+
model = RWKV(args, modality=modality)
|
12 |
+
#model = RWKV(args)
|
13 |
+
print(model)
|
14 |
+
|
15 |
+
if 'moda' not in args.train_step:
|
16 |
+
for param in model.modality.world_encoder.model.parameters():
|
17 |
+
param.requires_grad = False
|
18 |
+
if 'adapter' not in args.train_step:
|
19 |
+
for param in model.modality.world_encoder.adapter.parameters():
|
20 |
+
param.requires_grad = False
|
21 |
+
if 'rwkv' not in args.train_step:
|
22 |
+
for param in model.emb.parameters():
|
23 |
+
param.requires_grad = False
|
24 |
+
for param in model.blocks.parameters():
|
25 |
+
param.requires_grad = False
|
26 |
+
for param in model.ln_out.parameters():
|
27 |
+
param.requires_grad = False
|
28 |
+
for param in model.head.parameters():
|
29 |
+
param.requires_grad = False
|
30 |
+
return model
|