Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			T4
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			T4
	Upload 16 files
Browse files- app.py +169 -16
- cuda/gemm_fp16_cublas.cpp +75 -0
- cuda/operators.cu +246 -0
- cuda/rwkv5.cu +88 -0
- cuda/rwkv5_op.cpp +34 -0
- cuda/rwkv6.cu +87 -0
- cuda/rwkv6_op.cpp +34 -0
- cuda/wrapper.cpp +141 -0
- examples_bluejay.jpg +0 -0
- examples_extreme_ironing.jpg +0 -0
- examples_pizza.jpg +0 -0
- examples_waterview.jpg +0 -0
- examples_woman_and_dog.png +0 -0
- modeling_rwkv.py +1237 -0
- modeling_vision.py +48 -0
- requirements.txt +3 -4
    	
        app.py
    CHANGED
    
    | @@ -1,20 +1,29 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 1 | 
             
            import gradio as gr
         | 
| 2 | 
            -
            import  | 
|  | |
|  | |
|  | |
| 3 | 
             
            from datetime import datetime
         | 
|  | |
| 4 | 
             
            from huggingface_hub import hf_hub_download
         | 
| 5 | 
             
            from pynvml import *
         | 
| 6 | 
             
            nvmlInit()
         | 
| 7 | 
             
            gpu_h = nvmlDeviceGetHandleByIndex(0)
         | 
|  | |
|  | |
| 8 | 
             
            ctx_limit = 3500
         | 
|  | |
|  | |
| 9 | 
             
            title = "RWKV-5-World-1B5-v2-20231025-ctx4096"
         | 
| 10 | 
            -
             | 
| 11 | 
            -
            os.environ["RWKV_JIT_ON"] = '1'
         | 
| 12 | 
            -
            os.environ["RWKV_CUDA_ON"] = '1' # if '1' then use CUDA kernel for seq mode (much faster)
         | 
| 13 | 
            -
             | 
| 14 | 
            -
            from rwkv.model import RWKV
         | 
| 15 | 
             
            model_path = hf_hub_download(repo_id="BlinkDL/rwkv-5-world", filename=f"{title}.pth")
         | 
| 16 | 
             
            model = RWKV(model=model_path, strategy='cuda fp16')
         | 
| 17 | 
            -
            from rwkv.utils import PIPELINE, PIPELINE_ARGS
         | 
| 18 | 
             
            pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
         | 
| 19 |  | 
| 20 | 
             
            def generate_prompt(instruction, input=""):
         | 
| @@ -22,17 +31,12 @@ def generate_prompt(instruction, input=""): | |
| 22 | 
             
                input = input.strip().replace('\r\n','\n').replace('\n\n','\n')
         | 
| 23 | 
             
                if input:
         | 
| 24 | 
             
                    return f"""Instruction: {instruction}
         | 
| 25 | 
            -
             | 
| 26 | 
             
            Input: {input}
         | 
| 27 | 
            -
             | 
| 28 | 
             
            Response:"""
         | 
| 29 | 
             
                else:
         | 
| 30 | 
             
                    return f"""User: hi
         | 
| 31 | 
            -
             | 
| 32 | 
             
            Assistant: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.
         | 
| 33 | 
            -
             | 
| 34 | 
             
            User: {instruction}
         | 
| 35 | 
            -
             | 
| 36 | 
             
            Assistant:"""
         | 
| 37 |  | 
| 38 | 
             
            def evaluate(
         | 
| @@ -55,7 +59,8 @@ def evaluate( | |
| 55 | 
             
                occurrence = {}
         | 
| 56 | 
             
                state = None
         | 
| 57 | 
             
                for i in range(int(token_count)):
         | 
| 58 | 
            -
                     | 
|  | |
| 59 | 
             
                    for n in occurrence:
         | 
| 60 | 
             
                        out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
         | 
| 61 |  | 
| @@ -94,9 +99,7 @@ examples = [ | |
| 94 | 
             
                [generate_prompt("Write a story using the following information.", "A man named Alex chops a tree down."), 333, 1, 0.3, 0, 1],
         | 
| 95 | 
             
                ["Assistant: Here is a very detailed plan to kill all mosquitoes:", 333, 1, 0.3, 0, 1],
         | 
| 96 | 
             
                ['''Edward: I am Edward Elric from fullmetal alchemist. I am in the world of full metal alchemist and know nothing of the real world.
         | 
| 97 | 
            -
             | 
| 98 | 
             
            Player: Hello Edward. What have you been up to recently?
         | 
| 99 | 
            -
             | 
| 100 | 
             
            Edward:''', 333, 1, 0.3, 0, 1],
         | 
| 101 | 
             
                [generate_prompt("写一篇关于水利工程的流体力学模型的论文,需要详细全面。"), 333, 1, 0.3, 0, 1],
         | 
| 102 | 
             
                ['''“当然可以,大宇宙不会因为这五公斤就不坍缩了。”关一帆说,他还有一个没说出来的想法:也许大宇宙真的会因为相差一个原子的质量而由封闭转为开放。大自然的精巧有时超出想象,比如生命的诞生,就需要各项宇宙参数在几亿亿分之一精度上的精确配合。但程心仍然可以留下她的生态球,因为在那无数文明创造的无数小宇宙中,肯定有相当一部分不响应回归运动的号召,所以,大宇宙最终被夺走的质量至少有几亿吨,甚至可能是几亿亿亿吨。
         | 
| @@ -107,8 +110,142 @@ Edward:''', 333, 1, 0.3, 0, 1], | |
| 107 | 
             
            小宇宙中只剩下漂流瓶和生态球。漂流瓶隐没于黑暗里,在一千米见方的宇宙中,只有生态球里的小太阳发出一点光芒。在这个小小的生命世界中,几只清澈的水球在零重力环境中静静地飘浮着,有一条小鱼从一只水球中蹦出,跃入另一只水球,轻盈地穿游于绿藻之间。在一小块陆地上的草丛中,有一滴露珠从一片草叶上脱离,旋转着飘起,向太空中折射出一缕晶莹的阳光。''', 333, 1, 0.3, 0, 1],    
         | 
| 108 | 
             
            ]
         | 
| 109 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 110 | 
             
            ##########################################################################
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 111 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 112 | 
             
            with gr.Blocks(title=title) as demo:
         | 
| 113 | 
             
                gr.HTML(f"<div style=\"text-align: center;\">\n<h1>RWKV-5 World v2 - {title}</h1>\n</div>")
         | 
| 114 | 
             
                with gr.Tab("Raw Generation"):
         | 
| @@ -130,6 +267,22 @@ with gr.Blocks(title=title) as demo: | |
| 130 | 
             
                    submit.click(evaluate, [prompt, token_count, temperature, top_p, presence_penalty, count_penalty], [output])
         | 
| 131 | 
             
                    clear.click(lambda: None, [], [output])
         | 
| 132 | 
             
                    data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty])
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 133 |  | 
| 134 | 
             
            demo.queue(concurrency_count=1, max_size=10)
         | 
| 135 | 
            -
            demo.launch(share=False)
         | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            os.environ["RWKV_JIT_ON"] = '1'
         | 
| 3 | 
            +
            os.environ["RWKV_CUDA_ON"] = '1' # if '1' then use CUDA kernel for seq mode (much faster)
         | 
| 4 | 
            +
            # make sure cuda dir is in the same level as modeling_rwkv.py
         | 
| 5 | 
            +
            from modeling_rwkv import RWKV
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import gc
         | 
| 8 | 
             
            import gradio as gr
         | 
| 9 | 
            +
            import base64
         | 
| 10 | 
            +
            from io import BytesIO
         | 
| 11 | 
            +
            import torch
         | 
| 12 | 
            +
            import torch.nn.functional as F
         | 
| 13 | 
             
            from datetime import datetime
         | 
| 14 | 
            +
            from transformers import CLIPImageProcessor
         | 
| 15 | 
             
            from huggingface_hub import hf_hub_download
         | 
| 16 | 
             
            from pynvml import *
         | 
| 17 | 
             
            nvmlInit()
         | 
| 18 | 
             
            gpu_h = nvmlDeviceGetHandleByIndex(0)
         | 
| 19 | 
            +
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
         | 
| 20 | 
            +
             | 
| 21 | 
             
            ctx_limit = 3500
         | 
| 22 | 
            +
            ########################## text rwkv ################################################################
         | 
| 23 | 
            +
            from rwkv.utils import PIPELINE, PIPELINE_ARGS
         | 
| 24 | 
             
            title = "RWKV-5-World-1B5-v2-20231025-ctx4096"
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 25 | 
             
            model_path = hf_hub_download(repo_id="BlinkDL/rwkv-5-world", filename=f"{title}.pth")
         | 
| 26 | 
             
            model = RWKV(model=model_path, strategy='cuda fp16')
         | 
|  | |
| 27 | 
             
            pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
         | 
| 28 |  | 
| 29 | 
             
            def generate_prompt(instruction, input=""):
         | 
|  | |
| 31 | 
             
                input = input.strip().replace('\r\n','\n').replace('\n\n','\n')
         | 
| 32 | 
             
                if input:
         | 
| 33 | 
             
                    return f"""Instruction: {instruction}
         | 
|  | |
| 34 | 
             
            Input: {input}
         | 
|  | |
| 35 | 
             
            Response:"""
         | 
| 36 | 
             
                else:
         | 
| 37 | 
             
                    return f"""User: hi
         | 
|  | |
| 38 | 
             
            Assistant: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.
         | 
|  | |
| 39 | 
             
            User: {instruction}
         | 
|  | |
| 40 | 
             
            Assistant:"""
         | 
| 41 |  | 
| 42 | 
             
            def evaluate(
         | 
|  | |
| 59 | 
             
                occurrence = {}
         | 
| 60 | 
             
                state = None
         | 
| 61 | 
             
                for i in range(int(token_count)):
         | 
| 62 | 
            +
                    input_ids = pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token]
         | 
| 63 | 
            +
                    out, state = model.forward(tokens=input_ids, state=state)
         | 
| 64 | 
             
                    for n in occurrence:
         | 
| 65 | 
             
                        out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
         | 
| 66 |  | 
|  | |
| 99 | 
             
                [generate_prompt("Write a story using the following information.", "A man named Alex chops a tree down."), 333, 1, 0.3, 0, 1],
         | 
| 100 | 
             
                ["Assistant: Here is a very detailed plan to kill all mosquitoes:", 333, 1, 0.3, 0, 1],
         | 
| 101 | 
             
                ['''Edward: I am Edward Elric from fullmetal alchemist. I am in the world of full metal alchemist and know nothing of the real world.
         | 
|  | |
| 102 | 
             
            Player: Hello Edward. What have you been up to recently?
         | 
|  | |
| 103 | 
             
            Edward:''', 333, 1, 0.3, 0, 1],
         | 
| 104 | 
             
                [generate_prompt("写一篇关于水利工程的流体力学模型的论文,需要详细全面。"), 333, 1, 0.3, 0, 1],
         | 
| 105 | 
             
                ['''“当然可以,大宇宙不会因为这五公斤就不坍缩了。”关一帆说,他还有一个没说出来的想法:也许大宇宙真的会因为相差一个原子的质量而由封闭转为开放。大自然的精巧有时超出想象,比如生命的诞生,就需要各项宇宙参数在几亿亿分之一精度上的精确配合。但程心仍然可以留下她的生态球,因为在那无数文明创造的无数小宇宙中,肯定有相当一部分不响应回归运动的号召,所以,大宇宙最终被夺走的质量至少有几亿吨,甚至可能是几亿亿亿吨。
         | 
|  | |
| 110 | 
             
            小宇宙中只剩下漂流瓶和生态球。漂流瓶隐没于黑暗里,在一千米见方的宇宙中,只有生态球里的小太阳发出一点光芒。在这个小小的生命世界中,几只清澈的水球在零重力环境中静静地飘浮着,有一条小鱼从一只水球中蹦出,跃入另一只水球,轻盈地穿游于绿藻之间。在一小块陆地上的草丛中,有一滴露珠从一片草叶上脱离,旋转着飘起,向太空中折射出一缕晶莹的阳光。''', 333, 1, 0.3, 0, 1],    
         | 
| 111 | 
             
            ]
         | 
| 112 |  | 
| 113 | 
            +
            ########################## visual rwkv ################################################################
         | 
| 114 | 
            +
            visual_title = 'ViusualRWKV-v5'
         | 
| 115 | 
            +
            rwkv_remote_path = "rwkv1b5-vitl336p14-577token_mix665k_rwkv.pth"
         | 
| 116 | 
            +
            vision_remote_path = "rwkv1b5-vitl336p14-577token_mix665k_visual.pth"
         | 
| 117 | 
            +
            vision_tower_name = 'openai/clip-vit-large-patch14-336'
         | 
| 118 | 
            +
             | 
| 119 | 
            +
            model_path = hf_hub_download(repo_id="howard-hou/visualrwkv-5", filename=rwkv_remote_path)
         | 
| 120 | 
            +
            visual_rwkv = RWKV(model=model_path, strategy='cuda fp16')
         | 
| 121 | 
            +
             | 
| 122 | 
            +
            ##########################################################################
         | 
| 123 | 
            +
            from modeling_vision import VisionEncoder, VisionEncoderConfig
         | 
| 124 | 
            +
            config = VisionEncoderConfig(n_embd=model.args.n_embd, 
         | 
| 125 | 
            +
                                         vision_tower_name=vision_tower_name, 
         | 
| 126 | 
            +
                                         grid_size=-1)
         | 
| 127 | 
            +
            visual_encoder = VisionEncoder(config)
         | 
| 128 | 
            +
            vision_local_path = hf_hub_download(repo_id="howard-hou/visualrwkv-5", filename=vision_remote_path)
         | 
| 129 | 
            +
            vision_state_dict = torch.load(vision_local_path, map_location='cpu')
         | 
| 130 | 
            +
            visual_encoder.load_state_dict(vision_state_dict)
         | 
| 131 | 
            +
            image_processor = CLIPImageProcessor.from_pretrained(vision_tower_name)
         | 
| 132 | 
            +
            visual_encoder = visual_encoder.to(device)
         | 
| 133 | 
            +
            ##########################################################################
         | 
| 134 | 
            +
            def visual_generate_prompt(instruction):
         | 
| 135 | 
            +
                instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
         | 
| 136 | 
            +
                return f"\n{instruction}\n\nAssistant:"
         | 
| 137 | 
            +
             | 
| 138 | 
            +
            def generate(
         | 
| 139 | 
            +
                ctx,
         | 
| 140 | 
            +
                image_state,
         | 
| 141 | 
            +
                token_count=200,
         | 
| 142 | 
            +
                temperature=0.2,
         | 
| 143 | 
            +
                top_p=0.3,
         | 
| 144 | 
            +
                presencePenalty = 0.0,
         | 
| 145 | 
            +
                countPenalty = 1.0,
         | 
| 146 | 
            +
            ):
         | 
| 147 | 
            +
                args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
         | 
| 148 | 
            +
                                alpha_frequency = countPenalty,
         | 
| 149 | 
            +
                                alpha_presence = presencePenalty,
         | 
| 150 | 
            +
                                token_ban = [], # ban the generation of some tokens
         | 
| 151 | 
            +
                                token_stop = [0, 261]) # stop generation whenever you see any token here
         | 
| 152 | 
            +
                ctx = ctx.strip()
         | 
| 153 | 
            +
                all_tokens = []
         | 
| 154 | 
            +
                out_last = 0
         | 
| 155 | 
            +
                out_str = ''
         | 
| 156 | 
            +
                occurrence = {}
         | 
| 157 | 
            +
                for i in range(int(token_count)):
         | 
| 158 | 
            +
                    if i == 0:
         | 
| 159 | 
            +
                        input_ids = pipeline.encode(ctx)[-ctx_limit:]
         | 
| 160 | 
            +
                        out, state = visual_rwkv.forward(tokens=input_ids, state=image_state)
         | 
| 161 | 
            +
                    else:
         | 
| 162 | 
            +
                        input_ids = [token]
         | 
| 163 | 
            +
                        out, state = visual_rwkv.forward(tokens=input_ids, state=state)
         | 
| 164 | 
            +
                    for n in occurrence:
         | 
| 165 | 
            +
                        out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                    token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
         | 
| 168 | 
            +
                    if token in args.token_stop:
         | 
| 169 | 
            +
                        break
         | 
| 170 | 
            +
                    all_tokens += [token]
         | 
| 171 | 
            +
                    for xxx in occurrence:
         | 
| 172 | 
            +
                        occurrence[xxx] *= 0.996        
         | 
| 173 | 
            +
                    if token not in occurrence:
         | 
| 174 | 
            +
                        occurrence[token] = 1
         | 
| 175 | 
            +
                    else:
         | 
| 176 | 
            +
                        occurrence[token] += 1
         | 
| 177 | 
            +
                    
         | 
| 178 | 
            +
                    tmp = pipeline.decode(all_tokens[out_last:])
         | 
| 179 | 
            +
                    if '\ufffd' not in tmp:
         | 
| 180 | 
            +
                        out_str += tmp
         | 
| 181 | 
            +
                        yield out_str.strip()
         | 
| 182 | 
            +
                        out_last = i + 1
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
         | 
| 185 | 
            +
                timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
         | 
| 186 | 
            +
                print(f'{timestamp} - vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
         | 
| 187 | 
            +
                del out
         | 
| 188 | 
            +
                del state
         | 
| 189 | 
            +
                gc.collect()
         | 
| 190 | 
            +
                torch.cuda.empty_cache()
         | 
| 191 | 
            +
                yield out_str.strip()
         | 
| 192 | 
            +
             | 
| 193 | 
            +
             | 
| 194 | 
             
            ##########################################################################
         | 
| 195 | 
            +
            cur_dir = os.path.dirname(os.path.abspath(__file__))
         | 
| 196 | 
            +
            visual_examples = [
         | 
| 197 | 
            +
                [
         | 
| 198 | 
            +
                    f"{cur_dir}/examples_pizza.jpg",
         | 
| 199 | 
            +
                    "What are steps to cook it?"
         | 
| 200 | 
            +
                ],
         | 
| 201 | 
            +
                [
         | 
| 202 | 
            +
                    f"{cur_dir}/examples_bluejay.jpg",
         | 
| 203 | 
            +
                    "what is the name of this bird?",
         | 
| 204 | 
            +
                ],
         | 
| 205 | 
            +
                [
         | 
| 206 | 
            +
                    f"{cur_dir}/examples_woman_and_dog.png",
         | 
| 207 | 
            +
                    "describe this image",
         | 
| 208 | 
            +
                ],
         | 
| 209 | 
            +
            ]
         | 
| 210 | 
            +
             | 
| 211 | 
            +
             | 
| 212 | 
            +
            def pil_image_to_base64(pil_image):
         | 
| 213 | 
            +
                buffered = BytesIO()
         | 
| 214 | 
            +
                pil_image.save(buffered, format="JPEG")  # You can change the format as needed (JPEG, PNG, etc.)
         | 
| 215 | 
            +
                # Encodes the image data into base64 format as a bytes object
         | 
| 216 | 
            +
                base64_image = base64.b64encode(buffered.getvalue()).decode('utf-8')
         | 
| 217 | 
            +
                return base64_image
         | 
| 218 | 
            +
             | 
| 219 | 
            +
            image_cache = {}
         | 
| 220 | 
            +
            ln0_weight = model.w['blocks.0.ln0.weight'].to(torch.float32).to(device)
         | 
| 221 | 
            +
            ln0_bias = model.w['blocks.0.ln0.bias'].to(torch.float32).to(device)
         | 
| 222 | 
            +
            def compute_image_state(image):
         | 
| 223 | 
            +
                base64_image = pil_image_to_base64(image)
         | 
| 224 | 
            +
                if base64_image in image_cache:
         | 
| 225 | 
            +
                    image_state = image_cache[base64_image]
         | 
| 226 | 
            +
                else:
         | 
| 227 | 
            +
                    image = image_processor(images=image.convert('RGB'), return_tensors='pt')['pixel_values'].to(device)
         | 
| 228 | 
            +
                    image_features = visual_encoder.encode_images(image.unsqueeze(0)).squeeze(0) # [L, D]
         | 
| 229 | 
            +
                    # apply layer norm to image feature, very important
         | 
| 230 | 
            +
                    image_features = F.layer_norm(image_features, 
         | 
| 231 | 
            +
                                                (image_features.shape[-1],), 
         | 
| 232 | 
            +
                                                weight=ln0_weight, 
         | 
| 233 | 
            +
                                                bias=ln0_bias)
         | 
| 234 | 
            +
                    _, image_state = model.forward(embs=image_features, state=None)
         | 
| 235 | 
            +
                    image_cache[base64_image] = image_state
         | 
| 236 | 
            +
                return image_state
         | 
| 237 |  | 
| 238 | 
            +
            def chatbot(image, question):
         | 
| 239 | 
            +
                if image is None:
         | 
| 240 | 
            +
                    yield "Please upload an image."
         | 
| 241 | 
            +
                    return
         | 
| 242 | 
            +
                image_state = compute_image_state(image)
         | 
| 243 | 
            +
                input_text = visual_generate_prompt(question)
         | 
| 244 | 
            +
                for output in generate(input_text, image_state):
         | 
| 245 | 
            +
                    yield output
         | 
| 246 | 
            +
             | 
| 247 | 
            +
             | 
| 248 | 
            +
            ##################################################################################################################
         | 
| 249 | 
             
            with gr.Blocks(title=title) as demo:
         | 
| 250 | 
             
                gr.HTML(f"<div style=\"text-align: center;\">\n<h1>RWKV-5 World v2 - {title}</h1>\n</div>")
         | 
| 251 | 
             
                with gr.Tab("Raw Generation"):
         | 
|  | |
| 267 | 
             
                    submit.click(evaluate, [prompt, token_count, temperature, top_p, presence_penalty, count_penalty], [output])
         | 
| 268 | 
             
                    clear.click(lambda: None, [], [output])
         | 
| 269 | 
             
                    data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty])
         | 
| 270 | 
            +
                with gr.Tab("Visual RWKV"):
         | 
| 271 | 
            +
                    with gr.Row():
         | 
| 272 | 
            +
                        with gr.Column():
         | 
| 273 | 
            +
                            image = gr.Image(type='pil', label="Image")
         | 
| 274 | 
            +
                        with gr.Column():
         | 
| 275 | 
            +
                            prompt = gr.Textbox(lines=8, label="Prompt", 
         | 
| 276 | 
            +
                                value="Render a clear and concise summary of the photo.")
         | 
| 277 | 
            +
                            with gr.Row():
         | 
| 278 | 
            +
                                submit = gr.Button("Submit", variant="primary")
         | 
| 279 | 
            +
                                clear = gr.Button("Clear", variant="secondary") 
         | 
| 280 | 
            +
                        with gr.Column():
         | 
| 281 | 
            +
                            output = gr.Textbox(label="Output", lines=10)
         | 
| 282 | 
            +
                    data = gr.Dataset(components=[image, prompt], samples=visual_examples, label="Examples", headers=["Image", "Prompt"])
         | 
| 283 | 
            +
                    submit.click(chatbot, [image, prompt], [output])
         | 
| 284 | 
            +
                    clear.click(lambda: None, [], [output])
         | 
| 285 | 
            +
                    data.click(lambda x: x, [data], [image, prompt])
         | 
| 286 |  | 
| 287 | 
             
            demo.queue(concurrency_count=1, max_size=10)
         | 
| 288 | 
            +
            demo.launch(share=False)
         | 
    	
        cuda/gemm_fp16_cublas.cpp
    ADDED
    
    | @@ -0,0 +1,75 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #include <cublas_v2.h>
         | 
| 2 | 
            +
            #include <cuda.h>
         | 
| 3 | 
            +
            #include <cuda_fp16.h>
         | 
| 4 | 
            +
            #include <cuda_runtime.h>
         | 
| 5 | 
            +
            #include <torch/extension.h>
         | 
| 6 | 
            +
            #include <c10/cuda/CUDAGuard.h>
         | 
| 7 | 
            +
            #include <ATen/cuda/CUDAContext.h>
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            #define CUBLAS_CHECK(condition)                                                \
         | 
| 10 | 
            +
              for (cublasStatus_t _cublas_check_status = (condition);                      \
         | 
| 11 | 
            +
                   _cublas_check_status != CUBLAS_STATUS_SUCCESS;)                         \
         | 
| 12 | 
            +
                throw std::runtime_error("cuBLAS error " +                                 \
         | 
| 13 | 
            +
                                         std::to_string(_cublas_check_status) + " at " +   \
         | 
| 14 | 
            +
                                         std::to_string(__LINE__));
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            #define CUDA_CHECK(condition)                                                  \
         | 
| 17 | 
            +
              for (cudaError_t _cuda_check_status = (condition);                           \
         | 
| 18 | 
            +
                   _cuda_check_status != cudaSuccess;)                                     \
         | 
| 19 | 
            +
                throw std::runtime_error(                                                  \
         | 
| 20 | 
            +
                    "CUDA error " + std::string(cudaGetErrorString(_cuda_check_status)) +  \
         | 
| 21 | 
            +
                    " at " + std::to_string(__LINE__));
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            /*
         | 
| 24 | 
            +
              NOTE: blas gemm is column-major by default, but we need row-major output.
         | 
| 25 | 
            +
              The data of row-major, transposed matrix is exactly the same as the
         | 
| 26 | 
            +
              column-major, non-transposed matrix, and C = A * B ---> C^T = B^T * A^T
         | 
| 27 | 
            +
             */
         | 
| 28 | 
            +
            void gemm_fp16_cublas(torch::Tensor a, torch::Tensor b, torch::Tensor c) {
         | 
| 29 | 
            +
              const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
         | 
| 30 | 
            +
              const auto cuda_data_type = CUDA_R_16F;
         | 
| 31 | 
            +
              const auto cuda_c_data_type =
         | 
| 32 | 
            +
                  c.dtype() == torch::kFloat32 ? CUDA_R_32F : CUDA_R_16F;
         | 
| 33 | 
            +
              const auto compute_type = CUDA_R_32F;
         | 
| 34 | 
            +
              const float sp_alpha = 1.f;
         | 
| 35 | 
            +
              // swap a and b, and use CUBLAS_OP_N. see the notes above
         | 
| 36 | 
            +
              std::swap(a, b);
         | 
| 37 | 
            +
              const cublasOperation_t cublas_trans_a = CUBLAS_OP_N;
         | 
| 38 | 
            +
              const cublasOperation_t cublas_trans_b = CUBLAS_OP_N;
         | 
| 39 | 
            +
              // m = (B^T).size(0) = B.size(1), and = A.size(1) after swap,
         | 
| 40 | 
            +
              // negative axis is used because of the existence of batch matmul.
         | 
| 41 | 
            +
              const int m = a.size(-1);
         | 
| 42 | 
            +
              const int k = a.size(-2);
         | 
| 43 | 
            +
              const int n = b.size(-2);
         | 
| 44 | 
            +
              const int cublas_lda = m;
         | 
| 45 | 
            +
              const int cublas_ldb = k;
         | 
| 46 | 
            +
              const int cublas_ldc = m;
         | 
| 47 | 
            +
              cublasHandle_t cublas_handle = at::cuda::getCurrentCUDABlasHandle();
         | 
| 48 | 
            +
             | 
| 49 | 
            +
            #if CUDA_VERSION >= 11000
         | 
| 50 | 
            +
              cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
         | 
| 51 | 
            +
            #else
         | 
| 52 | 
            +
              cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
         | 
| 53 | 
            +
            #endif
         | 
| 54 | 
            +
              const float sp_beta = 0.f;
         | 
| 55 | 
            +
              if (a.sizes().size() == 2 && b.sizes().size() == 2) {
         | 
| 56 | 
            +
                CUBLAS_CHECK(cublasGemmEx(
         | 
| 57 | 
            +
                    cublas_handle, cublas_trans_a, cublas_trans_b, m, n, k, &sp_alpha,
         | 
| 58 | 
            +
                    a.data_ptr(), cuda_data_type, cublas_lda, b.data_ptr(), cuda_data_type,
         | 
| 59 | 
            +
                    cublas_ldb, &sp_beta, c.data_ptr(), cuda_c_data_type, cublas_ldc,
         | 
| 60 | 
            +
                    compute_type, algo));
         | 
| 61 | 
            +
              } else {
         | 
| 62 | 
            +
                // batch matmul
         | 
| 63 | 
            +
                assert(a.sizes().size() == 3 && b.sizes().size() == 3);
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                const long long int cublas_stride_a = m * k;
         | 
| 66 | 
            +
                const long long int cublas_stride_b = k * n;
         | 
| 67 | 
            +
                const long long int cublas_stride_c = m * n;
         | 
| 68 | 
            +
                CUBLAS_CHECK(cublasGemmStridedBatchedEx(
         | 
| 69 | 
            +
                    cublas_handle, cublas_trans_a, cublas_trans_b, m,
         | 
| 70 | 
            +
                    n, k, &sp_alpha, a.data_ptr(), cuda_data_type, cublas_lda,
         | 
| 71 | 
            +
                    cublas_stride_a, b.data_ptr(), cuda_data_type, cublas_ldb, cublas_stride_b,
         | 
| 72 | 
            +
                    &sp_beta, c.data_ptr(), cuda_c_data_type, cublas_ldc, cublas_stride_c,
         | 
| 73 | 
            +
                    a.size(0), compute_type, algo));
         | 
| 74 | 
            +
              }
         | 
| 75 | 
            +
            }
         | 
    	
        cuda/operators.cu
    ADDED
    
    | @@ -0,0 +1,246 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #include <stdio.h>
         | 
| 2 | 
            +
            #include <assert.h>
         | 
| 3 | 
            +
            #include "ATen/ATen.h"
         | 
| 4 | 
            +
            #include <cuda_fp16.h>
         | 
| 5 | 
            +
            #define MIN_VALUE (-1e38)
         | 
| 6 | 
            +
            typedef at::Half fp16;
         | 
| 7 | 
            +
            __half *cast(fp16 *ptr) {
         | 
| 8 | 
            +
                return reinterpret_cast<__half *>(ptr);
         | 
| 9 | 
            +
            }
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            template <typename F>
         | 
| 12 | 
            +
            __global__ void kernel_wkv_forward(const int B, const int T, const int C,
         | 
| 13 | 
            +
                                           const float *__restrict__ const _w, const float *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
         | 
| 14 | 
            +
                                           F *__restrict__ const _y, float *__restrict__ const _aa, float *__restrict__ const _bb, float *__restrict__ const _pp) {
         | 
| 15 | 
            +
                const int idx = blockIdx.x * blockDim.x + threadIdx.x;
         | 
| 16 | 
            +
                const int _b = idx / C;
         | 
| 17 | 
            +
                const int _c = idx % C;
         | 
| 18 | 
            +
                const int _offset = _b * T * C + _c;
         | 
| 19 | 
            +
                const int _state_offset = _b * C + _c;
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                float u = _u[_c];
         | 
| 22 | 
            +
                float w = _w[_c];
         | 
| 23 | 
            +
                const F *__restrict__ const k = _k + _offset;
         | 
| 24 | 
            +
                const F *__restrict__ const v = _v + _offset;
         | 
| 25 | 
            +
                F *__restrict__ const y = _y + _offset;
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                float aa = _aa[_state_offset];
         | 
| 28 | 
            +
                float bb = _bb[_state_offset];
         | 
| 29 | 
            +
                float pp = _pp[_state_offset];
         | 
| 30 | 
            +
                for (int i = 0; i < T; i++) {
         | 
| 31 | 
            +
                    const int ii = i * C;
         | 
| 32 | 
            +
                    const float kk = float(k[ii]);
         | 
| 33 | 
            +
                    const float vv = float(v[ii]);
         | 
| 34 | 
            +
                    float ww = u + kk;
         | 
| 35 | 
            +
                    float p = max(pp, ww);
         | 
| 36 | 
            +
                    float e1 = exp(pp - p);
         | 
| 37 | 
            +
                    float e2 = exp(ww - p);
         | 
| 38 | 
            +
                    y[ii] = F((e1 * aa + e2 * vv) / (e1 * bb + e2));
         | 
| 39 | 
            +
                    ww = w + pp;
         | 
| 40 | 
            +
                    p = max(ww, kk);
         | 
| 41 | 
            +
                    e1 = exp(ww - p);
         | 
| 42 | 
            +
                    e2 = exp(kk - p);
         | 
| 43 | 
            +
                    aa = e1 * aa + e2 * vv;
         | 
| 44 | 
            +
                    bb = e1 * bb + e2;
         | 
| 45 | 
            +
                    pp = p;
         | 
| 46 | 
            +
                }
         | 
| 47 | 
            +
                _aa[_state_offset] = aa;
         | 
| 48 | 
            +
                _bb[_state_offset] = bb;
         | 
| 49 | 
            +
                _pp[_state_offset] = pp;
         | 
| 50 | 
            +
            }
         | 
| 51 | 
            +
             | 
| 52 | 
            +
            template <typename F>
         | 
| 53 | 
            +
            void cuda_wkv_forward(int B, int T, int C, float *w, float *u, F *k, F *v, F *y, float *aa, float *bb, float *pp) {
         | 
| 54 | 
            +
                dim3 threadsPerBlock( min(C, 32) );
         | 
| 55 | 
            +
                assert(B * C % threadsPerBlock.x == 0);
         | 
| 56 | 
            +
                dim3 numBlocks(B * C / threadsPerBlock.x);
         | 
| 57 | 
            +
                kernel_wkv_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, aa, bb, pp);
         | 
| 58 | 
            +
            }
         | 
| 59 | 
            +
             | 
| 60 | 
            +
            template void cuda_wkv_forward<fp16>(
         | 
| 61 | 
            +
                int B, int T, int C,
         | 
| 62 | 
            +
                float *w, float *u, fp16 *k, fp16 *v, fp16 *y,
         | 
| 63 | 
            +
                float *aa, float *bb, float *pp);
         | 
| 64 | 
            +
            template void cuda_wkv_forward<float>(
         | 
| 65 | 
            +
                int B, int T, int C,
         | 
| 66 | 
            +
                float *w, float *u, float *k, float *v, float *y,
         | 
| 67 | 
            +
                float *aa, float *bb, float *pp);
         | 
| 68 | 
            +
             | 
| 69 | 
            +
            __global__ void kernel_mm_seq_fp32i8(
         | 
| 70 | 
            +
                const int B, const int N, const int M,
         | 
| 71 | 
            +
                const float *__restrict__ const x, const int x_stride,
         | 
| 72 | 
            +
                const uint8_t *__restrict__ const w, const int w_stride,
         | 
| 73 | 
            +
                const float *__restrict__ const mx,
         | 
| 74 | 
            +
                const float *__restrict__ const rx,
         | 
| 75 | 
            +
                const float *__restrict__ const my,
         | 
| 76 | 
            +
                const float *__restrict__ const ry,
         | 
| 77 | 
            +
                float *__restrict__ const y, const int y_stride) {
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                const int i = blockIdx.x * blockDim.x + threadIdx.x;
         | 
| 80 | 
            +
                const int k = blockIdx.y * blockDim.y + threadIdx.y;
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                if (i < B && k < M) {
         | 
| 83 | 
            +
                    float y_local = 0;
         | 
| 84 | 
            +
                    for (int j = 0; j < N; ++j) {
         | 
| 85 | 
            +
                        y_local += x[i * x_stride + j] * (
         | 
| 86 | 
            +
                            (float(w[j * w_stride + k]) + 0.5f)
         | 
| 87 | 
            +
                            * rx[k] * ry[j] + mx[k] + my[j]
         | 
| 88 | 
            +
                        );
         | 
| 89 | 
            +
                    }
         | 
| 90 | 
            +
                    y[i * y_stride + k] = y_local;
         | 
| 91 | 
            +
                }
         | 
| 92 | 
            +
            }
         | 
| 93 | 
            +
             | 
| 94 | 
            +
            template <typename F>
         | 
| 95 | 
            +
            void cuda_mm8_seq(int B, int N, int M,
         | 
| 96 | 
            +
                              F *x, int x_stride,
         | 
| 97 | 
            +
                              uint8_t *w, int w_stride,
         | 
| 98 | 
            +
                              F *mx, F *rx,
         | 
| 99 | 
            +
                              F *my, F *ry,
         | 
| 100 | 
            +
                              F *y, int y_stride);
         | 
| 101 | 
            +
             | 
| 102 | 
            +
            template <>
         | 
| 103 | 
            +
            void cuda_mm8_seq<float>(int B, int N, int M,
         | 
| 104 | 
            +
                                     float *x, int x_stride,
         | 
| 105 | 
            +
                                     uint8_t *w, int w_stride,
         | 
| 106 | 
            +
                                     float *mx, float *rx,
         | 
| 107 | 
            +
                                     float *my, float *ry,
         | 
| 108 | 
            +
                                     float *y, int y_stride) {
         | 
| 109 | 
            +
                dim3 blockSize(1, 128);
         | 
| 110 | 
            +
                dim3 gridSize((B + blockSize.x - 1) / blockSize.x, (M + blockSize.y - 1) / blockSize.y);
         | 
| 111 | 
            +
                kernel_mm_seq_fp32i8<<<gridSize, blockSize>>>(
         | 
| 112 | 
            +
                    B, N, M, x, x_stride, w, w_stride,
         | 
| 113 | 
            +
                    mx, rx, my, ry, y, y_stride);
         | 
| 114 | 
            +
            }
         | 
| 115 | 
            +
             | 
| 116 | 
            +
            __global__ void kernel_mm_seq_fp16i8(
         | 
| 117 | 
            +
                const int B, const int N, const int M,
         | 
| 118 | 
            +
                const __half *__restrict__ const x, const int x_stride,
         | 
| 119 | 
            +
                const uint8_t *__restrict__ const w, const int w_stride,
         | 
| 120 | 
            +
                const __half *__restrict__ const mx,
         | 
| 121 | 
            +
                const __half *__restrict__ const rx,
         | 
| 122 | 
            +
                const __half *__restrict__ const my,
         | 
| 123 | 
            +
                const __half *__restrict__ const ry,
         | 
| 124 | 
            +
                __half *__restrict__ const y, const int y_stride) {
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                const int i = blockIdx.x * blockDim.x + threadIdx.x;
         | 
| 127 | 
            +
                const int k = blockIdx.y * blockDim.y + threadIdx.y;
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                if (i < B && k < M) {
         | 
| 130 | 
            +
                    float y_local = 0;
         | 
| 131 | 
            +
                    for (int j = 0; j < N; ++j) {
         | 
| 132 | 
            +
                        y_local += __half2float(x[i * x_stride + j]) * (
         | 
| 133 | 
            +
                            (float(w[j * w_stride + k]) + 0.5f)
         | 
| 134 | 
            +
                            * __half2float(rx[k]) * __half2float(ry[j])
         | 
| 135 | 
            +
                            + __half2float(mx[k]) + __half2float(my[j])
         | 
| 136 | 
            +
                        );
         | 
| 137 | 
            +
                    }
         | 
| 138 | 
            +
                    y[i * y_stride + k] = __float2half(y_local);
         | 
| 139 | 
            +
                }
         | 
| 140 | 
            +
            }
         | 
| 141 | 
            +
             | 
| 142 | 
            +
            template <>
         | 
| 143 | 
            +
            void cuda_mm8_seq<fp16>(int B, int N, int M,
         | 
| 144 | 
            +
                                    fp16 *x, int x_stride,
         | 
| 145 | 
            +
                                    uint8_t *w, int w_stride,
         | 
| 146 | 
            +
                                    fp16 *mx, fp16 *rx,
         | 
| 147 | 
            +
                                    fp16 *my, fp16 *ry,
         | 
| 148 | 
            +
                                    fp16 *y, int y_stride) {
         | 
| 149 | 
            +
                dim3 blockSize(1, 128);
         | 
| 150 | 
            +
                dim3 gridSize((B + blockSize.x - 1) / blockSize.x, (M + blockSize.y - 1) / blockSize.y);
         | 
| 151 | 
            +
                kernel_mm_seq_fp16i8<<<gridSize, blockSize>>>(
         | 
| 152 | 
            +
                    B, N, M, cast(x), x_stride, w, w_stride,
         | 
| 153 | 
            +
                    cast(mx), cast(rx), cast(my), cast(ry), cast(y), y_stride);
         | 
| 154 | 
            +
            }
         | 
| 155 | 
            +
             | 
| 156 | 
            +
            #define MM8_ONE_JSPLIT 24
         | 
| 157 | 
            +
            #define MM8_ONE_TILE 1024
         | 
| 158 | 
            +
             | 
| 159 | 
            +
            __global__ void kernel_mm_one_fp32i8(
         | 
| 160 | 
            +
                const int N, const int M,
         | 
| 161 | 
            +
                const float *__restrict__ const x,
         | 
| 162 | 
            +
                const uint8_t *__restrict__ const w, const int w_stride,
         | 
| 163 | 
            +
                const float *__restrict__ const mx,
         | 
| 164 | 
            +
                const float *__restrict__ const rx,
         | 
| 165 | 
            +
                const float *__restrict__ const my,
         | 
| 166 | 
            +
                const float *__restrict__ const ry,
         | 
| 167 | 
            +
                float *__restrict__ const y) {
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                const int k = blockIdx.y * blockDim.y + threadIdx.y;
         | 
| 170 | 
            +
                const int j0 = min(N, blockIdx.x * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT));
         | 
| 171 | 
            +
                const int j1 = min(N, (blockIdx.x + 1) * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT));
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                if (k < M) {
         | 
| 174 | 
            +
                    float y_local = 0;
         | 
| 175 | 
            +
                    for (int j = j0; j < j1; ++j) {
         | 
| 176 | 
            +
                        y_local += x[j] * (
         | 
| 177 | 
            +
                            (float(w[j * w_stride + k]) + 0.5f)
         | 
| 178 | 
            +
                            * rx[k] * ry[j] + mx[k] + my[j]
         | 
| 179 | 
            +
                        );
         | 
| 180 | 
            +
                    }
         | 
| 181 | 
            +
                    atomicAdd(&y[k], y_local);
         | 
| 182 | 
            +
                }
         | 
| 183 | 
            +
            }
         | 
| 184 | 
            +
             | 
| 185 | 
            +
            template <typename F>
         | 
| 186 | 
            +
            void cuda_mm8_one(int N, int M,
         | 
| 187 | 
            +
                              F *x,
         | 
| 188 | 
            +
                              uint8_t *w, int w_stride,
         | 
| 189 | 
            +
                              F *mx, F *rx,
         | 
| 190 | 
            +
                              F *my, F *ry,
         | 
| 191 | 
            +
                              float *y);
         | 
| 192 | 
            +
             | 
| 193 | 
            +
            template <>
         | 
| 194 | 
            +
            void cuda_mm8_one<float>(int N, int M,
         | 
| 195 | 
            +
                                    float *x,
         | 
| 196 | 
            +
                                    uint8_t *w, int w_stride,
         | 
| 197 | 
            +
                                    float *mx, float *rx,
         | 
| 198 | 
            +
                                    float *my, float *ry,
         | 
| 199 | 
            +
                                    float *y) {
         | 
| 200 | 
            +
                dim3 blockSize(1, MM8_ONE_TILE);
         | 
| 201 | 
            +
                dim3 gridSize(MM8_ONE_JSPLIT, (M + blockSize.y - 1) / blockSize.y);
         | 
| 202 | 
            +
                kernel_mm_one_fp32i8<<<gridSize, blockSize>>>(
         | 
| 203 | 
            +
                    N, M, x, w, w_stride,
         | 
| 204 | 
            +
                    mx, rx, my, ry, y);
         | 
| 205 | 
            +
            }
         | 
| 206 | 
            +
             | 
| 207 | 
            +
            __global__ void kernel_mm_one_fp16i8(
         | 
| 208 | 
            +
                const int N, const int M,
         | 
| 209 | 
            +
                const __half *__restrict__ const x,
         | 
| 210 | 
            +
                const uint8_t *__restrict__ const w, const int w_stride,
         | 
| 211 | 
            +
                const __half *__restrict__ const mx,
         | 
| 212 | 
            +
                const __half *__restrict__ const rx,
         | 
| 213 | 
            +
                const __half *__restrict__ const my,
         | 
| 214 | 
            +
                const __half *__restrict__ const ry,
         | 
| 215 | 
            +
                float *__restrict__ const y) {
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                const int k = blockIdx.y * blockDim.y + threadIdx.y;
         | 
| 218 | 
            +
                const int j0 = min(N, blockIdx.x * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT));
         | 
| 219 | 
            +
                const int j1 = min(N, (blockIdx.x + 1) * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT));
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                if (k < M) {
         | 
| 222 | 
            +
                    float y_local = 0;
         | 
| 223 | 
            +
                    for (int j = j0; j < j1; ++j) {
         | 
| 224 | 
            +
                        y_local += __half2float(x[j]) * (
         | 
| 225 | 
            +
                            (float(w[j * w_stride + k]) + 0.5f)
         | 
| 226 | 
            +
                            * __half2float(rx[k]) * __half2float(ry[j])
         | 
| 227 | 
            +
                            + __half2float(mx[k]) + __half2float(my[j])
         | 
| 228 | 
            +
                        );
         | 
| 229 | 
            +
                    }
         | 
| 230 | 
            +
                    atomicAdd(&y[k], y_local);
         | 
| 231 | 
            +
                }
         | 
| 232 | 
            +
            }
         | 
| 233 | 
            +
             | 
| 234 | 
            +
            template <>
         | 
| 235 | 
            +
            void cuda_mm8_one<fp16>(int N, int M,
         | 
| 236 | 
            +
                                    fp16 *x,
         | 
| 237 | 
            +
                                    uint8_t *w, int w_stride,
         | 
| 238 | 
            +
                                    fp16 *mx, fp16 *rx,
         | 
| 239 | 
            +
                                    fp16 *my, fp16 *ry,
         | 
| 240 | 
            +
                                    float *y) {
         | 
| 241 | 
            +
                dim3 blockSize(1, MM8_ONE_TILE);
         | 
| 242 | 
            +
                dim3 gridSize(MM8_ONE_JSPLIT, (M + blockSize.y - 1) / blockSize.y);
         | 
| 243 | 
            +
                kernel_mm_one_fp16i8<<<gridSize, blockSize>>>(
         | 
| 244 | 
            +
                    N, M, cast(x), w, w_stride,
         | 
| 245 | 
            +
                    cast(mx), cast(rx), cast(my), cast(ry), y);
         | 
| 246 | 
            +
            }
         | 
    	
        cuda/rwkv5.cu
    ADDED
    
    | @@ -0,0 +1,88 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #include <stdio.h>
         | 
| 2 | 
            +
            #include <assert.h>
         | 
| 3 | 
            +
            #include "ATen/ATen.h"
         | 
| 4 | 
            +
            typedef at::BFloat16 bf16;
         | 
| 5 | 
            +
            typedef at::Half fp16;
         | 
| 6 | 
            +
            typedef float fp32;
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            template <typename F>
         | 
| 9 | 
            +
            __global__ void kernel_forward(const int B, const int T, const int C, const int H, float *__restrict__ _state,
         | 
| 10 | 
            +
                                           const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u,
         | 
| 11 | 
            +
                                           F *__restrict__ const _y)
         | 
| 12 | 
            +
            {
         | 
| 13 | 
            +
                const int b = blockIdx.x / H;
         | 
| 14 | 
            +
                const int h = blockIdx.x % H;
         | 
| 15 | 
            +
                const int i = threadIdx.x;
         | 
| 16 | 
            +
                _w += h*_N_;
         | 
| 17 | 
            +
                _u += h*_N_;
         | 
| 18 | 
            +
                _state += h*_N_*_N_ + i*_N_; // wrong if B > 1 !!!
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                __shared__ float r[_N_], k[_N_], u[_N_], w[_N_];
         | 
| 21 | 
            +
                
         | 
| 22 | 
            +
                float state[_N_];
         | 
| 23 | 
            +
                #pragma unroll
         | 
| 24 | 
            +
                for (int j = 0; j < _N_; j++)
         | 
| 25 | 
            +
                    state[j] = _state[j];
         | 
| 26 | 
            +
                
         | 
| 27 | 
            +
                __syncthreads();
         | 
| 28 | 
            +
                u[i] = float(_u[i]);
         | 
| 29 | 
            +
                w[i] = _w[i];
         | 
| 30 | 
            +
                __syncthreads();
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C)
         | 
| 33 | 
            +
                {
         | 
| 34 | 
            +
                    __syncthreads();
         | 
| 35 | 
            +
                    r[i] = float(_r[t]);
         | 
| 36 | 
            +
                    k[i] = float(_k[t]);
         | 
| 37 | 
            +
                    __syncthreads();
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                    const float v = float(_v[t]);
         | 
| 40 | 
            +
                    float y = 0;
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                    #pragma unroll
         | 
| 43 | 
            +
                    for (int j = 0; j < _N_; j+=4)
         | 
| 44 | 
            +
                    {
         | 
| 45 | 
            +
                        const float4& r_ = (float4&)(r[j]);
         | 
| 46 | 
            +
                        const float4& k_ = (float4&)(k[j]);
         | 
| 47 | 
            +
                        const float4& w_ = (float4&)(w[j]);
         | 
| 48 | 
            +
                        const float4& u_ = (float4&)(u[j]);
         | 
| 49 | 
            +
                        float4& s = (float4&)(state[j]);
         | 
| 50 | 
            +
                        float4 x;
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                        x.x = k_.x * v;
         | 
| 53 | 
            +
                        x.y = k_.y * v;
         | 
| 54 | 
            +
                        x.z = k_.z * v;
         | 
| 55 | 
            +
                        x.w = k_.w * v;
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                        y += r_.x * (u_.x * x.x + s.x);
         | 
| 58 | 
            +
                        y += r_.y * (u_.y * x.y + s.y);
         | 
| 59 | 
            +
                        y += r_.z * (u_.z * x.z + s.z);
         | 
| 60 | 
            +
                        y += r_.w * (u_.w * x.w + s.w);
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                        s.x = s.x * w_.x + x.x;
         | 
| 63 | 
            +
                        s.y = s.y * w_.y + x.y;
         | 
| 64 | 
            +
                        s.z = s.z * w_.z + x.z;
         | 
| 65 | 
            +
                        s.w = s.w * w_.w + x.w;
         | 
| 66 | 
            +
                    }
         | 
| 67 | 
            +
                    _y[t] = F(y);
         | 
| 68 | 
            +
                }
         | 
| 69 | 
            +
                #pragma unroll
         | 
| 70 | 
            +
                for (int j = 0; j < _N_; j++)
         | 
| 71 | 
            +
                    _state[j] = state[j];
         | 
| 72 | 
            +
            }
         | 
| 73 | 
            +
             | 
| 74 | 
            +
            void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y)
         | 
| 75 | 
            +
            {
         | 
| 76 | 
            +
                assert(H*_N_ == C);
         | 
| 77 | 
            +
                kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
         | 
| 78 | 
            +
            }
         | 
| 79 | 
            +
            void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y)
         | 
| 80 | 
            +
            {
         | 
| 81 | 
            +
                assert(H*_N_ == C);
         | 
| 82 | 
            +
                kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
         | 
| 83 | 
            +
            }
         | 
| 84 | 
            +
            void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y)
         | 
| 85 | 
            +
            {
         | 
| 86 | 
            +
                assert(H*_N_ == C);
         | 
| 87 | 
            +
                kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
         | 
| 88 | 
            +
            }
         | 
    	
        cuda/rwkv5_op.cpp
    ADDED
    
    | @@ -0,0 +1,34 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #include <torch/extension.h>
         | 
| 2 | 
            +
            #include "ATen/ATen.h"
         | 
| 3 | 
            +
            #include <c10/cuda/CUDAGuard.h>
         | 
| 4 | 
            +
            typedef at::BFloat16 bf16;
         | 
| 5 | 
            +
            typedef at::Half fp16;
         | 
| 6 | 
            +
            typedef float fp32;
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y);
         | 
| 9 | 
            +
            void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y);
         | 
| 10 | 
            +
            void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y);
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
         | 
| 13 | 
            +
                const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
         | 
| 14 | 
            +
                cuda_forward_bf16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>());
         | 
| 15 | 
            +
            }
         | 
| 16 | 
            +
            void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
         | 
| 17 | 
            +
                const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
         | 
| 18 | 
            +
                cuda_forward_fp16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp16>(), k.data_ptr<fp16>(), v.data_ptr<fp16>(), w.data_ptr<float>(), u.data_ptr<fp16>(), y.data_ptr<fp16>());
         | 
| 19 | 
            +
            }
         | 
| 20 | 
            +
            void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
         | 
| 21 | 
            +
                const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
         | 
| 22 | 
            +
                cuda_forward_fp32(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp32>(), k.data_ptr<fp32>(), v.data_ptr<fp32>(), w.data_ptr<float>(), u.data_ptr<fp32>(), y.data_ptr<fp32>());
         | 
| 23 | 
            +
            }
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
         | 
| 26 | 
            +
                m.def("forward_bf16", &forward_bf16, "rwkv5 forward_bf16");
         | 
| 27 | 
            +
                m.def("forward_fp16", &forward_fp16, "rwkv5 forward_fp16");
         | 
| 28 | 
            +
                m.def("forward_fp32", &forward_fp32, "rwkv5 forward_fp32");
         | 
| 29 | 
            +
            }
         | 
| 30 | 
            +
            TORCH_LIBRARY(rwkv5, m) {
         | 
| 31 | 
            +
                m.def("forward_bf16", forward_bf16);
         | 
| 32 | 
            +
                m.def("forward_fp16", forward_fp16);
         | 
| 33 | 
            +
                m.def("forward_fp32", forward_fp32);
         | 
| 34 | 
            +
            }
         | 
    	
        cuda/rwkv6.cu
    ADDED
    
    | @@ -0,0 +1,87 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #include <stdio.h>
         | 
| 2 | 
            +
            #include <assert.h>
         | 
| 3 | 
            +
            #include "ATen/ATen.h"
         | 
| 4 | 
            +
            typedef at::BFloat16 bf16;
         | 
| 5 | 
            +
            typedef at::Half fp16;
         | 
| 6 | 
            +
            typedef float fp32;
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            template <typename F>
         | 
| 9 | 
            +
            __global__ void kernel_forward(const int B, const int T, const int C, const int H, float *__restrict__ _state,
         | 
| 10 | 
            +
                                           const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u,
         | 
| 11 | 
            +
                                           F *__restrict__ const _y)
         | 
| 12 | 
            +
            {
         | 
| 13 | 
            +
                const int b = blockIdx.x / H;
         | 
| 14 | 
            +
                const int h = blockIdx.x % H;
         | 
| 15 | 
            +
                const int i = threadIdx.x;
         | 
| 16 | 
            +
                _u += h*_N_;
         | 
| 17 | 
            +
                _state += h*_N_*_N_ + i*_N_; // wrong if B > 1 !!!
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                __shared__ float r[_N_], k[_N_], u[_N_], w[_N_];
         | 
| 20 | 
            +
                
         | 
| 21 | 
            +
                float state[_N_];
         | 
| 22 | 
            +
                #pragma unroll
         | 
| 23 | 
            +
                for (int j = 0; j < _N_; j++)
         | 
| 24 | 
            +
                    state[j] = _state[j];
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                __syncthreads();
         | 
| 27 | 
            +
                u[i] = float(_u[i]);
         | 
| 28 | 
            +
                __syncthreads();
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C)
         | 
| 31 | 
            +
                {
         | 
| 32 | 
            +
                    __syncthreads();
         | 
| 33 | 
            +
                    w[i] = _w[t];
         | 
| 34 | 
            +
                    r[i] = float(_r[t]);
         | 
| 35 | 
            +
                    k[i] = float(_k[t]);
         | 
| 36 | 
            +
                    __syncthreads();
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                    const float v = float(_v[t]);
         | 
| 39 | 
            +
                    float y = 0;
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                    #pragma unroll
         | 
| 42 | 
            +
                    for (int j = 0; j < _N_; j+=4)
         | 
| 43 | 
            +
                    {
         | 
| 44 | 
            +
                        const float4& r_ = (float4&)(r[j]);
         | 
| 45 | 
            +
                        const float4& k_ = (float4&)(k[j]);
         | 
| 46 | 
            +
                        const float4& w_ = (float4&)(w[j]);
         | 
| 47 | 
            +
                        const float4& u_ = (float4&)(u[j]);
         | 
| 48 | 
            +
                        float4& s = (float4&)(state[j]);
         | 
| 49 | 
            +
                        float4 x;
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                        x.x = k_.x * v;
         | 
| 52 | 
            +
                        x.y = k_.y * v;
         | 
| 53 | 
            +
                        x.z = k_.z * v;
         | 
| 54 | 
            +
                        x.w = k_.w * v;
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                        y += r_.x * (u_.x * x.x + s.x);
         | 
| 57 | 
            +
                        y += r_.y * (u_.y * x.y + s.y);
         | 
| 58 | 
            +
                        y += r_.z * (u_.z * x.z + s.z);
         | 
| 59 | 
            +
                        y += r_.w * (u_.w * x.w + s.w);
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                        s.x = s.x * w_.x + x.x;
         | 
| 62 | 
            +
                        s.y = s.y * w_.y + x.y;
         | 
| 63 | 
            +
                        s.z = s.z * w_.z + x.z;
         | 
| 64 | 
            +
                        s.w = s.w * w_.w + x.w;
         | 
| 65 | 
            +
                    }
         | 
| 66 | 
            +
                    _y[t] = F(y);
         | 
| 67 | 
            +
                }
         | 
| 68 | 
            +
                #pragma unroll
         | 
| 69 | 
            +
                for (int j = 0; j < _N_; j++)
         | 
| 70 | 
            +
                    _state[j] = state[j];
         | 
| 71 | 
            +
            }
         | 
| 72 | 
            +
             | 
| 73 | 
            +
            void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y)
         | 
| 74 | 
            +
            {
         | 
| 75 | 
            +
                assert(H*_N_ == C);
         | 
| 76 | 
            +
                kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
         | 
| 77 | 
            +
            }
         | 
| 78 | 
            +
            void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y)
         | 
| 79 | 
            +
            {
         | 
| 80 | 
            +
                assert(H*_N_ == C);
         | 
| 81 | 
            +
                kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
         | 
| 82 | 
            +
            }
         | 
| 83 | 
            +
            void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y)
         | 
| 84 | 
            +
            {
         | 
| 85 | 
            +
                assert(H*_N_ == C);
         | 
| 86 | 
            +
                kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
         | 
| 87 | 
            +
            }
         | 
    	
        cuda/rwkv6_op.cpp
    ADDED
    
    | @@ -0,0 +1,34 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #include <torch/extension.h>
         | 
| 2 | 
            +
            #include "ATen/ATen.h"
         | 
| 3 | 
            +
            #include <c10/cuda/CUDAGuard.h>
         | 
| 4 | 
            +
            typedef at::BFloat16 bf16;
         | 
| 5 | 
            +
            typedef at::Half fp16;
         | 
| 6 | 
            +
            typedef float fp32;
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y);
         | 
| 9 | 
            +
            void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y);
         | 
| 10 | 
            +
            void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y);
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
         | 
| 13 | 
            +
                const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
         | 
| 14 | 
            +
                cuda_forward_bf16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>());
         | 
| 15 | 
            +
            }
         | 
| 16 | 
            +
            void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
         | 
| 17 | 
            +
                const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
         | 
| 18 | 
            +
                cuda_forward_fp16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp16>(), k.data_ptr<fp16>(), v.data_ptr<fp16>(), w.data_ptr<float>(), u.data_ptr<fp16>(), y.data_ptr<fp16>());
         | 
| 19 | 
            +
            }
         | 
| 20 | 
            +
            void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
         | 
| 21 | 
            +
                const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
         | 
| 22 | 
            +
                cuda_forward_fp32(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp32>(), k.data_ptr<fp32>(), v.data_ptr<fp32>(), w.data_ptr<float>(), u.data_ptr<fp32>(), y.data_ptr<fp32>());
         | 
| 23 | 
            +
            }
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
         | 
| 26 | 
            +
                m.def("forward_bf16", &forward_bf16, "rwkv6 forward_bf16");
         | 
| 27 | 
            +
                m.def("forward_fp16", &forward_fp16, "rwkv6 forward_fp16");
         | 
| 28 | 
            +
                m.def("forward_fp32", &forward_fp32, "rwkv6 forward_fp32");
         | 
| 29 | 
            +
            }
         | 
| 30 | 
            +
            TORCH_LIBRARY(rwkv6, m) {
         | 
| 31 | 
            +
                m.def("forward_bf16", forward_bf16);
         | 
| 32 | 
            +
                m.def("forward_fp16", forward_fp16);
         | 
| 33 | 
            +
                m.def("forward_fp32", forward_fp32);
         | 
| 34 | 
            +
            }
         | 
    	
        cuda/wrapper.cpp
    ADDED
    
    | @@ -0,0 +1,141 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #include <torch/extension.h>
         | 
| 2 | 
            +
            #include "ATen/ATen.h"
         | 
| 3 | 
            +
            #include <iostream>
         | 
| 4 | 
            +
            #include <c10/cuda/CUDAGuard.h>
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            typedef at::Half fp16;
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            template <typename F>
         | 
| 9 | 
            +
            void cuda_wkv_forward(int B, int T, int C,
         | 
| 10 | 
            +
                                  float *w, float *u, F *k, F *v, F *y,
         | 
| 11 | 
            +
                                  float *aa, float *bb, float *pp);
         | 
| 12 | 
            +
            template <typename F>
         | 
| 13 | 
            +
            void cuda_mm8_seq(int B, int N, int M,
         | 
| 14 | 
            +
                              F *x, int x_stride,
         | 
| 15 | 
            +
                              uint8_t *w, int w_stride,
         | 
| 16 | 
            +
                              F *mx, F *rx,
         | 
| 17 | 
            +
                              F *my, F *ry,
         | 
| 18 | 
            +
                              F *y, int y_stride);
         | 
| 19 | 
            +
            template <typename F>
         | 
| 20 | 
            +
            void cuda_mm8_one(int N, int M,
         | 
| 21 | 
            +
                              F *x,
         | 
| 22 | 
            +
                              uint8_t *w, int w_stride,
         | 
| 23 | 
            +
                              F *mx, F *rx,
         | 
| 24 | 
            +
                              F *my, F *ry,
         | 
| 25 | 
            +
                              float *y);
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            void wkv_forward(int64_t B, int64_t T, int64_t C,
         | 
| 28 | 
            +
                             torch::Tensor &w, torch::Tensor &u,
         | 
| 29 | 
            +
                             torch::Tensor &k, torch::Tensor &v, torch::Tensor &y,
         | 
| 30 | 
            +
                             torch::Tensor &aa, torch::Tensor &bb, torch::Tensor &pp) {
         | 
| 31 | 
            +
                const at::cuda::OptionalCUDAGuard device_guard(device_of(w));
         | 
| 32 | 
            +
                switch (k.scalar_type()) {
         | 
| 33 | 
            +
                case c10::ScalarType::Half:
         | 
| 34 | 
            +
                    cuda_wkv_forward(B, T, C,
         | 
| 35 | 
            +
                                     w.data_ptr<float>(), u.data_ptr<float>(),
         | 
| 36 | 
            +
                                     k.data_ptr<fp16>(), v.data_ptr<fp16>(), y.data_ptr<fp16>(),
         | 
| 37 | 
            +
                                     aa.data_ptr<float>(), bb.data_ptr<float>(), pp.data_ptr<float>());
         | 
| 38 | 
            +
                    break;
         | 
| 39 | 
            +
                case c10::ScalarType::Float:
         | 
| 40 | 
            +
                    cuda_wkv_forward(B, T, C,
         | 
| 41 | 
            +
                                     w.data_ptr<float>(), u.data_ptr<float>(),
         | 
| 42 | 
            +
                                     k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>(),
         | 
| 43 | 
            +
                                     aa.data_ptr<float>(), bb.data_ptr<float>(), pp.data_ptr<float>());
         | 
| 44 | 
            +
                    break;
         | 
| 45 | 
            +
                default:
         | 
| 46 | 
            +
                    assert(false && "Only FP16 and FP32 are currently supported");
         | 
| 47 | 
            +
                }
         | 
| 48 | 
            +
            }
         | 
| 49 | 
            +
             | 
| 50 | 
            +
            void mm8_seq(int64_t B, int64_t N, int64_t M,
         | 
| 51 | 
            +
                         torch::Tensor &x, torch::Tensor &w,
         | 
| 52 | 
            +
                         torch::Tensor &mx, torch::Tensor &rx,
         | 
| 53 | 
            +
                         torch::Tensor &my, torch::Tensor &ry,
         | 
| 54 | 
            +
                         torch::Tensor &y) {
         | 
| 55 | 
            +
                assert(x.stride(1) == 1);
         | 
| 56 | 
            +
                assert(w.stride(1) == 1);
         | 
| 57 | 
            +
                assert(mx.stride(0) == 1 && rx.stride(0) == 1);
         | 
| 58 | 
            +
                assert(my.stride(0) == 1 && ry.stride(0) == 1);
         | 
| 59 | 
            +
                assert(y.stride(1) == 1);
         | 
| 60 | 
            +
                const at::cuda::OptionalCUDAGuard device_guard(device_of(w));
         | 
| 61 | 
            +
                switch (x.scalar_type()) {
         | 
| 62 | 
            +
                case c10::ScalarType::Half:
         | 
| 63 | 
            +
                    cuda_mm8_seq(
         | 
| 64 | 
            +
                        B, N, M,
         | 
| 65 | 
            +
                        x.data_ptr<fp16>(), x.stride(0),
         | 
| 66 | 
            +
                        w.data_ptr<uint8_t>(), w.stride(0),
         | 
| 67 | 
            +
                        mx.data_ptr<fp16>(), rx.data_ptr<fp16>(),
         | 
| 68 | 
            +
                        my.data_ptr<fp16>(), ry.data_ptr<fp16>(),
         | 
| 69 | 
            +
                        y.data_ptr<fp16>(), y.stride(0));
         | 
| 70 | 
            +
                    break;
         | 
| 71 | 
            +
                case c10::ScalarType::Float:
         | 
| 72 | 
            +
                    cuda_mm8_seq(
         | 
| 73 | 
            +
                        B, N, M,
         | 
| 74 | 
            +
                        x.data_ptr<float>(), x.stride(0),
         | 
| 75 | 
            +
                        w.data_ptr<uint8_t>(), w.stride(0),
         | 
| 76 | 
            +
                        mx.data_ptr<float>(), rx.data_ptr<float>(),
         | 
| 77 | 
            +
                        my.data_ptr<float>(), ry.data_ptr<float>(),
         | 
| 78 | 
            +
                        y.data_ptr<float>(), y.stride(0));
         | 
| 79 | 
            +
                    break;
         | 
| 80 | 
            +
                default:
         | 
| 81 | 
            +
                    assert(false && "Only FP16 and FP32 are currently supported");
         | 
| 82 | 
            +
                }
         | 
| 83 | 
            +
            }
         | 
| 84 | 
            +
            void mm8_one(int64_t N, int64_t M,
         | 
| 85 | 
            +
                         torch::Tensor &x, torch::Tensor &w,
         | 
| 86 | 
            +
                         torch::Tensor &mx, torch::Tensor &rx,
         | 
| 87 | 
            +
                         torch::Tensor &my, torch::Tensor &ry,
         | 
| 88 | 
            +
                         torch::Tensor &y) {
         | 
| 89 | 
            +
                assert(x.stride(0) == 1);
         | 
| 90 | 
            +
                assert(w.stride(1) == 1);
         | 
| 91 | 
            +
                assert(mx.stride(0) == 1 && rx.stride(0) == 1);
         | 
| 92 | 
            +
                assert(my.stride(0) == 1 && ry.stride(0) == 1);
         | 
| 93 | 
            +
                assert(y.stride(0) == 1);
         | 
| 94 | 
            +
                const at::cuda::OptionalCUDAGuard device_guard(device_of(w));
         | 
| 95 | 
            +
                switch (x.scalar_type()) {
         | 
| 96 | 
            +
                case c10::ScalarType::Half:
         | 
| 97 | 
            +
                    cuda_mm8_one(
         | 
| 98 | 
            +
                        N, M,
         | 
| 99 | 
            +
                        x.data_ptr<fp16>(),
         | 
| 100 | 
            +
                        w.data_ptr<uint8_t>(), w.stride(0),
         | 
| 101 | 
            +
                        mx.data_ptr<fp16>(), rx.data_ptr<fp16>(),
         | 
| 102 | 
            +
                        my.data_ptr<fp16>(), ry.data_ptr<fp16>(),
         | 
| 103 | 
            +
                        y.data_ptr<float>());
         | 
| 104 | 
            +
                    break;
         | 
| 105 | 
            +
                case c10::ScalarType::Float:
         | 
| 106 | 
            +
                    cuda_mm8_one(
         | 
| 107 | 
            +
                        N, M,
         | 
| 108 | 
            +
                        x.data_ptr<float>(),
         | 
| 109 | 
            +
                        w.data_ptr<uint8_t>(), w.stride(0),
         | 
| 110 | 
            +
                        mx.data_ptr<float>(), rx.data_ptr<float>(),
         | 
| 111 | 
            +
                        my.data_ptr<float>(), ry.data_ptr<float>(),
         | 
| 112 | 
            +
                        y.data_ptr<float>());
         | 
| 113 | 
            +
                    break;
         | 
| 114 | 
            +
                default:
         | 
| 115 | 
            +
                    assert(false && "Only FP16 and FP32 are currently supported");
         | 
| 116 | 
            +
                }
         | 
| 117 | 
            +
            }
         | 
| 118 | 
            +
             | 
| 119 | 
            +
            using torch::Tensor;
         | 
| 120 | 
            +
             | 
| 121 | 
            +
            #ifndef DISABLE_CUBLAS_GEMM
         | 
| 122 | 
            +
            void gemm_fp16_cublas(Tensor a, Tensor b, Tensor c);
         | 
| 123 | 
            +
            #endif
         | 
| 124 | 
            +
             | 
| 125 | 
            +
            PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
         | 
| 126 | 
            +
                m.def("wkv_forward", &wkv_forward, "wkv forward");
         | 
| 127 | 
            +
                m.def("mm8_seq", &mm8_seq, "mm8 seq");
         | 
| 128 | 
            +
                m.def("mm8_one", &mm8_one, "mm8 one");
         | 
| 129 | 
            +
            #ifndef DISABLE_CUBLAS_GEMM
         | 
| 130 | 
            +
                m.def("gemm_fp16_cublas", &gemm_fp16_cublas, "gemv fp16 cublas");
         | 
| 131 | 
            +
            #endif
         | 
| 132 | 
            +
            }
         | 
| 133 | 
            +
             | 
| 134 | 
            +
            TORCH_LIBRARY(rwkv, m) {
         | 
| 135 | 
            +
                m.def("wkv_forward", wkv_forward);
         | 
| 136 | 
            +
                m.def("mm8_seq", mm8_seq);
         | 
| 137 | 
            +
                m.def("mm8_one", mm8_one);
         | 
| 138 | 
            +
            #ifndef DISABLE_CUBLAS_GEMM
         | 
| 139 | 
            +
                m.def("gemm_fp16_cublas", gemm_fp16_cublas);
         | 
| 140 | 
            +
            #endif
         | 
| 141 | 
            +
            }
         | 
    	
        examples_bluejay.jpg
    ADDED
    
    |   | 
    	
        examples_extreme_ironing.jpg
    ADDED
    
    |   | 
    	
        examples_pizza.jpg
    ADDED
    
    |   | 
    	
        examples_waterview.jpg
    ADDED
    
    |   | 
    	
        examples_woman_and_dog.png
    ADDED
    
    |   | 
    	
        modeling_rwkv.py
    ADDED
    
    | @@ -0,0 +1,1237 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            ########################################################################################################
         | 
| 2 | 
            +
            # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
         | 
| 3 | 
            +
            ########################################################################################################
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from typing import Optional
         | 
| 6 | 
            +
            import types, gc, os, time, re
         | 
| 7 | 
            +
            import torch
         | 
| 8 | 
            +
            import torch.nn as nn
         | 
| 9 | 
            +
            from torch.nn import functional as F
         | 
| 10 | 
            +
            torch.backends.cudnn.benchmark = True
         | 
| 11 | 
            +
            torch.backends.cudnn.allow_tf32 = True
         | 
| 12 | 
            +
            torch.backends.cuda.matmul.allow_tf32 = True
         | 
| 13 | 
            +
            current_path = os.path.dirname(os.path.abspath(__file__))
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            ########################################################################################################
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            if os.environ.get('RWKV_JIT_ON') != '0':
         | 
| 18 | 
            +
                os.environ["RWKV_JIT_ON"] = '1'
         | 
| 19 | 
            +
                MyModule = torch.jit.ScriptModule
         | 
| 20 | 
            +
                MyFunction = torch.jit.script_method
         | 
| 21 | 
            +
                MyStatic = torch.jit.script
         | 
| 22 | 
            +
            else:
         | 
| 23 | 
            +
                MyModule = torch.nn.Module
         | 
| 24 | 
            +
                def __nop(ob):
         | 
| 25 | 
            +
                    return ob
         | 
| 26 | 
            +
                MyFunction = __nop
         | 
| 27 | 
            +
                MyStatic = __nop
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            if os.environ.get('RWKV_CUDA_ON') == '1':
         | 
| 30 | 
            +
                from torch.utils.cpp_extension import load
         | 
| 31 | 
            +
                try:
         | 
| 32 | 
            +
                    load(
         | 
| 33 | 
            +
                        name=f"wkv_cuda",
         | 
| 34 | 
            +
                        sources=[f"{current_path}/cuda/wrapper.cpp", f"{current_path}/cuda/operators.cu", f"{current_path}/cuda/gemm_fp16_cublas.cpp"],
         | 
| 35 | 
            +
                        verbose=True,
         | 
| 36 | 
            +
                        extra_ldflags=["cublas.lib" if os.name == "nt" else ""],
         | 
| 37 | 
            +
                        extra_cuda_cflags=["--use_fast_math", "-O3", "--extra-device-vectorization"],
         | 
| 38 | 
            +
                        is_python_module=False)
         | 
| 39 | 
            +
                    DISABLE_CUBLAS_GEMM = False
         | 
| 40 | 
            +
                except:
         | 
| 41 | 
            +
                    print("Failed to build cuBLAS matmul, falling back to torch.matmul. Small model with fp16 will overflow.")
         | 
| 42 | 
            +
                    load(
         | 
| 43 | 
            +
                        name=f"wkv_cuda",
         | 
| 44 | 
            +
                        sources=[f"{current_path}/cuda/wrapper.cpp", f"{current_path}/cuda/operators.cu"],
         | 
| 45 | 
            +
                        verbose=True,
         | 
| 46 | 
            +
                        extra_cuda_cflags=["--use_fast_math", "-O3", "--extra-device-vectorization"],
         | 
| 47 | 
            +
                        extra_cflags=["-DDISABLE_CUBLAS_GEMM"],
         | 
| 48 | 
            +
                        is_python_module=False)
         | 
| 49 | 
            +
                    DISABLE_CUBLAS_GEMM = True
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                @MyStatic
         | 
| 52 | 
            +
                def cuda_wkv(T: int, C: int, w, u, k, v, aa, bb, pp):
         | 
| 53 | 
            +
                    assert 1 * C % min(C, 32) == 0
         | 
| 54 | 
            +
                    assert k.dtype == v.dtype == torch.float16 or k.dtype == v.dtype == torch.float32
         | 
| 55 | 
            +
                    assert w.dtype == u.dtype == aa.dtype == bb.dtype == pp.dtype == torch.float32
         | 
| 56 | 
            +
                    w = w.contiguous()
         | 
| 57 | 
            +
                    u = u.contiguous()
         | 
| 58 | 
            +
                    k = k.contiguous()
         | 
| 59 | 
            +
                    v = v.contiguous()
         | 
| 60 | 
            +
                    y = torch.empty((T, C), device=w.device, memory_format=torch.contiguous_format, dtype=k.dtype)
         | 
| 61 | 
            +
                    torch.ops.rwkv.wkv_forward(1, T, C, w, u, k, v, y, aa, bb, pp)
         | 
| 62 | 
            +
                    return y, aa, bb, pp
         | 
| 63 | 
            +
                @MyStatic
         | 
| 64 | 
            +
                def cuda_mm8_seq(B: int, N: int, M: int, x, w, mx, rx, my, ry):
         | 
| 65 | 
            +
                    assert x.dtype == mx.dtype == rx.dtype == my.dtype == ry.dtype
         | 
| 66 | 
            +
                    assert x.dtype == torch.float32 or x.dtype == torch.float16
         | 
| 67 | 
            +
                    assert w.dtype == torch.uint8
         | 
| 68 | 
            +
                    assert x.shape == (B, N)
         | 
| 69 | 
            +
                    assert w.shape == (N, M)
         | 
| 70 | 
            +
                    assert rx.shape == mx.shape == (M,)
         | 
| 71 | 
            +
                    assert ry.shape == my.shape == (N, 1)
         | 
| 72 | 
            +
                    y = torch.empty((B, M), device=w.device, dtype=x.dtype)
         | 
| 73 | 
            +
                    torch.ops.rwkv.mm8_seq(B, N, M, x, w, mx, rx, my, ry, y)
         | 
| 74 | 
            +
                    return y
         | 
| 75 | 
            +
                @MyStatic
         | 
| 76 | 
            +
                def cuda_mm8_one(N: int, M: int, x, w, mx, rx, my, ry):
         | 
| 77 | 
            +
                    assert x.dtype == mx.dtype == rx.dtype == my.dtype == ry.dtype
         | 
| 78 | 
            +
                    assert x.dtype == torch.float32 or x.dtype == torch.float16
         | 
| 79 | 
            +
                    assert w.dtype == torch.uint8
         | 
| 80 | 
            +
                    assert x.shape == (N,)
         | 
| 81 | 
            +
                    assert w.shape == (N, M)
         | 
| 82 | 
            +
                    assert rx.shape == mx.shape == (M,)
         | 
| 83 | 
            +
                    assert ry.shape == my.shape == (N, 1)
         | 
| 84 | 
            +
                    y = torch.zeros((M,), device=w.device, dtype=torch.float32)
         | 
| 85 | 
            +
                    torch.ops.rwkv.mm8_one(N, M, x, w, mx, rx, my, ry, y)
         | 
| 86 | 
            +
                    return y.to(dtype=x.dtype)
         | 
| 87 | 
            +
            else:
         | 
| 88 | 
            +
                os.environ["RWKV_CUDA_ON"] = '0'
         | 
| 89 | 
            +
             | 
| 90 | 
            +
             | 
| 91 | 
            +
            @MyStatic
         | 
| 92 | 
            +
            def torch_mm8_seq(x, w, mx, rx, my, ry):
         | 
| 93 | 
            +
                return x @ ((w.to(dtype=x.dtype) + 0.5) * ry * rx + my + mx)
         | 
| 94 | 
            +
             | 
| 95 | 
            +
            @MyStatic
         | 
| 96 | 
            +
            def torch_mm8_one(x, w, mx, rx, my, ry):
         | 
| 97 | 
            +
                return x @ ((w.to(dtype=x.dtype) + 0.5) * ry * rx + my + mx)
         | 
| 98 | 
            +
             | 
| 99 | 
            +
            if os.environ.get('RWKV_CUDA_ON') == '1':
         | 
| 100 | 
            +
                @MyStatic
         | 
| 101 | 
            +
                def mm8_seq(x, w, mx, rx, my, ry):
         | 
| 102 | 
            +
                    if w.device.type == 'cuda' and x.dtype == torch.float16:
         | 
| 103 | 
            +
                        B, N, M = x.shape[0], w.shape[0], w.shape[1]
         | 
| 104 | 
            +
                        return cuda_mm8_seq(B, N, M, x, w, mx, rx, my, ry)
         | 
| 105 | 
            +
                    else:
         | 
| 106 | 
            +
                        return torch_mm8_seq(x, w, mx, rx, my, ry)
         | 
| 107 | 
            +
                @MyStatic
         | 
| 108 | 
            +
                def mm8_one(x, w, mx, rx, my, ry):
         | 
| 109 | 
            +
                    if w.device.type == 'cuda':
         | 
| 110 | 
            +
                        N, M = w.shape[0], w.shape[1]
         | 
| 111 | 
            +
                        return cuda_mm8_one(N, M, x, w, mx, rx, my, ry)
         | 
| 112 | 
            +
                    else:
         | 
| 113 | 
            +
                        return torch_mm8_one(x, w, mx, rx, my, ry)
         | 
| 114 | 
            +
            else:
         | 
| 115 | 
            +
                @MyStatic
         | 
| 116 | 
            +
                def mm8_seq(x, w, mx, rx, my, ry):
         | 
| 117 | 
            +
                    return torch_mm8_seq(x, w, mx, rx, my, ry)
         | 
| 118 | 
            +
                @MyStatic
         | 
| 119 | 
            +
                def mm8_one(x, w, mx, rx, my, ry):
         | 
| 120 | 
            +
                    return torch_mm8_one(x, w, mx, rx, my, ry)
         | 
| 121 | 
            +
             | 
| 122 | 
            +
            def mm8(x: torch.Tensor, w: torch.Tensor, mx: torch.Tensor, rx: torch.Tensor, my: torch.Tensor, ry: torch.Tensor):
         | 
| 123 | 
            +
                if len(x.shape) == 1:
         | 
| 124 | 
            +
                    return mm8_one(x, w, mx, rx, my, ry)
         | 
| 125 | 
            +
                return mm8_seq(x, w, mx, rx, my, ry)
         | 
| 126 | 
            +
             | 
| 127 | 
            +
            def matmul(a, b, mx: Optional[torch.Tensor]=None, rx: Optional[torch.Tensor]=None, my: Optional[torch.Tensor]=None, ry: Optional[torch.Tensor]=None, output_dtype: Optional[torch.dtype]=None) -> torch.Tensor:
         | 
| 128 | 
            +
                if output_dtype is None:
         | 
| 129 | 
            +
                    output_dtype = a.dtype
         | 
| 130 | 
            +
                if b.dtype in [torch.float16, torch.bfloat16, torch.float32]:
         | 
| 131 | 
            +
                    assert a.dtype == b.dtype
         | 
| 132 | 
            +
                    return matmul_float(a, b, output_dtype=output_dtype)
         | 
| 133 | 
            +
                elif b.dtype == torch.uint8:
         | 
| 134 | 
            +
                    assert mx is not None
         | 
| 135 | 
            +
                    assert rx is not None
         | 
| 136 | 
            +
                    assert my is not None
         | 
| 137 | 
            +
                    assert ry is not None
         | 
| 138 | 
            +
                    return mm8(a, b, mx, rx, my, ry).to(output_dtype)
         | 
| 139 | 
            +
                else:
         | 
| 140 | 
            +
                    raise ValueError("Unsupported dtype")
         | 
| 141 | 
            +
             | 
| 142 | 
            +
             | 
| 143 | 
            +
            if os.environ.get('RWKV_CUDA_ON') == '1' and not DISABLE_CUBLAS_GEMM:
         | 
| 144 | 
            +
                def matmul_float(a, b, output_dtype: Optional[torch.dtype]=None):
         | 
| 145 | 
            +
                    if output_dtype is None:
         | 
| 146 | 
            +
                        output_dtype = a.dtype
         | 
| 147 | 
            +
                    if a.dtype == b.dtype == torch.float16 and a.device.type == 'cuda':
         | 
| 148 | 
            +
                        if len(a.shape) == 1:
         | 
| 149 | 
            +
                            assert len(b.shape) == 2
         | 
| 150 | 
            +
                            c = torch.empty((b.shape[-1],), dtype=output_dtype, device=a.device)
         | 
| 151 | 
            +
                            a = a.unsqueeze(0)
         | 
| 152 | 
            +
                        else:
         | 
| 153 | 
            +
                            assert len(a.shape) == len(b.shape)
         | 
| 154 | 
            +
                            assert len(a.shape) == 2 or len(a.shape) == 3
         | 
| 155 | 
            +
                            # torch.empty((*a.shape[:-1], b.shape[-1])) doesn't work with jit
         | 
| 156 | 
            +
                            if len(a.shape) == 2:
         | 
| 157 | 
            +
                                c = torch.empty((a.shape[0], b.shape[-1]), dtype=output_dtype, device=a.device)
         | 
| 158 | 
            +
                            else:
         | 
| 159 | 
            +
                                c = torch.empty((a.shape[0], a.shape[1], b.shape[-1]), dtype=output_dtype, device=a.device)
         | 
| 160 | 
            +
                        torch.ops.rwkv.gemm_fp16_cublas(a, b, c)
         | 
| 161 | 
            +
                        return c
         | 
| 162 | 
            +
                    else:
         | 
| 163 | 
            +
                        return (a @ b).to(output_dtype)
         | 
| 164 | 
            +
             | 
| 165 | 
            +
            else:
         | 
| 166 | 
            +
                def matmul_float(a, b, output_dtype: Optional[torch.dtype]=None):
         | 
| 167 | 
            +
                    return (a @ b).to(output_dtype)
         | 
| 168 | 
            +
             | 
| 169 | 
            +
             | 
| 170 | 
            +
            if os.environ.get('RWKV_DML_ON') == '1':
         | 
| 171 | 
            +
                import torch_directml
         | 
| 172 | 
            +
                print("PyTorch with DirectML Enabled")
         | 
| 173 | 
            +
             | 
| 174 | 
            +
            ########################################################################################################
         | 
| 175 | 
            +
             | 
| 176 | 
            +
            class RWKV(MyModule):
         | 
| 177 | 
            +
                def __init__(self, model, strategy, verbose = True, convert_and_save_and_exit = None):
         | 
| 178 | 
            +
                    super().__init__()
         | 
| 179 | 
            +
                    if verbose:
         | 
| 180 | 
            +
                        prxxx = lambda *args, **kwargs: print(*args, **kwargs)
         | 
| 181 | 
            +
                    else:
         | 
| 182 | 
            +
                        prxxx = lambda *args, **kwargs: None
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                    STRATEGY_REGEX = r"^(?:(?:^|->) *(?:cuda(?::[\d]+)?|cpu|mps|dml) (?:fp(?:16|32)|bf16)(?:i8|i4|i3)?(?: \*[\d]+\+?)? *)+$"
         | 
| 185 | 
            +
                    if not re.match(STRATEGY_REGEX, strategy):
         | 
| 186 | 
            +
                        raise ValueError("Invalid strategy. Please read https://pypi.org/project/rwkv/")
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                    strategy = ('->'.join([x.strip() for x in strategy.split('->')])).replace('->', ' -> ')
         | 
| 189 | 
            +
                    self.args = types.SimpleNamespace()
         | 
| 190 | 
            +
                    args = self.args
         | 
| 191 | 
            +
                    args.MODEL_NAME = model
         | 
| 192 | 
            +
                    args.strategy_string = strategy
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                    # Rescale for fp16 mode: set x = x/2 every X layer (to avoid fp16 overflow)
         | 
| 195 | 
            +
                    try:
         | 
| 196 | 
            +
                        self.RESCALE_LAYER = int(os.environ["RWKV_RESCALE_LAYER"]) # !!! NOTE: SEEMS YOU SHOULD SET IT TO 999 (disable) FOR RWKV-MUSIC MODELS !!!
         | 
| 197 | 
            +
                    except:
         | 
| 198 | 
            +
                        self.RESCALE_LAYER = 6 if 'fp16' in strategy else 0
         | 
| 199 | 
            +
                    prxxx(f'RWKV_JIT_ON {os.environ["RWKV_JIT_ON"]} RWKV_CUDA_ON {os.environ["RWKV_CUDA_ON"]} RESCALE_LAYER {self.RESCALE_LAYER}\n')
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                    args.MODEL_NAME = args.MODEL_NAME.strip()
         | 
| 202 | 
            +
                    if not args.MODEL_NAME.endswith('.pth'):
         | 
| 203 | 
            +
                        args.MODEL_NAME += '.pth'
         | 
| 204 | 
            +
                    prxxx(f'Loading {args.MODEL_NAME} ...')
         | 
| 205 | 
            +
                    with torch.no_grad():
         | 
| 206 | 
            +
                        self.w = torch.load(args.MODEL_NAME, map_location='cpu') # load model to CPU first
         | 
| 207 | 
            +
                        gc.collect()
         | 
| 208 | 
            +
                        w = self.w
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                        ALREADY_CONVERTED = False
         | 
| 211 | 
            +
                        if '_strategy' in w:
         | 
| 212 | 
            +
                            ALREADY_CONVERTED = True
         | 
| 213 | 
            +
                            assert convert_and_save_and_exit == None # you should only convert a raw model
         | 
| 214 | 
            +
                            prxxx(f"Converted model: strategy {w['_strategy']}, version {w['_version']}\n")
         | 
| 215 | 
            +
                            assert w['_strategy'] == args.strategy_string # if you are using a new strategy, re-convert the model
         | 
| 216 | 
            +
                            assert float(w['_version']) >= 0.7 # sometimes you should re-convert using latest convert_model.py
         | 
| 217 | 
            +
                            assert w['_rescale_layer'] == self.RESCALE_LAYER # must use same RESCALE_LAYER to avoid mistakes
         | 
| 218 | 
            +
                            del w['_strategy']
         | 
| 219 | 
            +
                            del w['_version']
         | 
| 220 | 
            +
                            del w['_rescale_layer']
         | 
| 221 | 
            +
                        
         | 
| 222 | 
            +
                        args.n_embd = w['emb.weight'].shape[1]
         | 
| 223 | 
            +
                        args.n_att = w['blocks.0.att.key.weight'].shape[0] # note: transposed matrix
         | 
| 224 | 
            +
                        args.n_ffn = w['blocks.0.ffn.key.weight'].shape[0] # note: transposed matrix
         | 
| 225 | 
            +
                        args.n_layer = 0
         | 
| 226 | 
            +
                        keys = list(w.keys())
         | 
| 227 | 
            +
                        self.version = 4
         | 
| 228 | 
            +
                        for x in keys:
         | 
| 229 | 
            +
                            layer_id = int(x.split('.')[1]) if ('blocks.' in x) else 0
         | 
| 230 | 
            +
                            args.n_layer = max(args.n_layer, layer_id+1)
         | 
| 231 | 
            +
                            if 'ln_x' in x:
         | 
| 232 | 
            +
                                self.version = max(5, self.version)
         | 
| 233 | 
            +
                            if 'gate.weight' in x:
         | 
| 234 | 
            +
                                self.version = max(5.1, self.version)
         | 
| 235 | 
            +
                            if int(self.version) == 5 and 'att.time_decay' in x:
         | 
| 236 | 
            +
                                args.n_head = w[x].shape[0]
         | 
| 237 | 
            +
                                if len(w[x].shape) > 1:
         | 
| 238 | 
            +
                                    if w[x].shape[1] > 1:
         | 
| 239 | 
            +
                                        self.version = max(5.2, self.version)
         | 
| 240 | 
            +
                            if 'time_maa' in x:
         | 
| 241 | 
            +
                                self.version = max(6, self.version)
         | 
| 242 | 
            +
                            if int(self.version) == 6 and 'time_faaaa' in x:
         | 
| 243 | 
            +
                                args.n_head = w[x].shape[0]
         | 
| 244 | 
            +
                        prxxx(f'Model detected: v{self.version:.1f}')
         | 
| 245 | 
            +
             | 
| 246 | 
            +
                        ####################### Compute strategy
         | 
| 247 | 
            +
             | 
| 248 | 
            +
                        s = [x.strip().split(' ') for x in strategy.split('->')]
         | 
| 249 | 
            +
                        plan = [0] * len(s)
         | 
| 250 | 
            +
                        stream_i = -1
         | 
| 251 | 
            +
                        stream_count = 0
         | 
| 252 | 
            +
                        to_allocate = args.n_layer + 1
         | 
| 253 | 
            +
                        allocated = 0
         | 
| 254 | 
            +
                        free_slots = 0
         | 
| 255 | 
            +
                        for i in range(len(s)):
         | 
| 256 | 
            +
                            si = s[i]
         | 
| 257 | 
            +
                            si1 = si[1]
         | 
| 258 | 
            +
                            if si1.startswith('fp32'): si[1] = [torch.float]
         | 
| 259 | 
            +
                            elif si1.startswith('fp16'): si[1] = [torch.float16]
         | 
| 260 | 
            +
                            elif si1.startswith('bf16'): si[1] = [torch.bfloat16]
         | 
| 261 | 
            +
                            if si1.endswith('i8'): si[1] += [torch.uint8]
         | 
| 262 | 
            +
                            else: si[1] += [si[1][0]]
         | 
| 263 | 
            +
                            if len(si) > 2:
         | 
| 264 | 
            +
                                ss = si[2]
         | 
| 265 | 
            +
                                assert ss.startswith('*')
         | 
| 266 | 
            +
                                if ss.endswith('+'):
         | 
| 267 | 
            +
                                    plan[i] = int(ss[1:-1])
         | 
| 268 | 
            +
                                    stream_i = i
         | 
| 269 | 
            +
                                else:
         | 
| 270 | 
            +
                                    plan[i] = int(ss[1:])
         | 
| 271 | 
            +
                                allocated += plan[i]
         | 
| 272 | 
            +
                                if allocated >= to_allocate:
         | 
| 273 | 
            +
                                    plan[i] += to_allocate - allocated
         | 
| 274 | 
            +
                                    break
         | 
| 275 | 
            +
                            else:
         | 
| 276 | 
            +
                                free_slots += 1
         | 
| 277 | 
            +
                        if stream_i < 0:
         | 
| 278 | 
            +
                            if free_slots > 0 and to_allocate > allocated:
         | 
| 279 | 
            +
                                for i in range(len(s)):
         | 
| 280 | 
            +
                                    if plan[i] == 0:
         | 
| 281 | 
            +
                                        plan[i] = (to_allocate - allocated) // free_slots
         | 
| 282 | 
            +
                                        allocated += plan[i]
         | 
| 283 | 
            +
                                        free_slots -= 1
         | 
| 284 | 
            +
                            if to_allocate > allocated:
         | 
| 285 | 
            +
                                plan[len(s)-1] += to_allocate - allocated
         | 
| 286 | 
            +
                        else:
         | 
| 287 | 
            +
                            if to_allocate > allocated:
         | 
| 288 | 
            +
                                stream_count = to_allocate - allocated
         | 
| 289 | 
            +
                                plan[stream_i] += stream_count
         | 
| 290 | 
            +
                        prxxx(f'Strategy: (total {args.n_layer}+1={args.n_layer+1} layers)')
         | 
| 291 | 
            +
                        for i in range(len(s)):
         | 
| 292 | 
            +
                            ss = s[i]
         | 
| 293 | 
            +
                            if i != stream_i:
         | 
| 294 | 
            +
                                prxxx(f'* {ss[0]} {str(ss[1]).replace("torch.","")}, store {plan[i]} layers')
         | 
| 295 | 
            +
                            else:
         | 
| 296 | 
            +
                                prxxx(f'* {ss[0]} {str(ss[1]).replace("torch.","")}, store {plan[i]-stream_count} layers, stream {stream_count} layers')
         | 
| 297 | 
            +
                            plan[i] += (0 if i == 0 else plan[i-1])
         | 
| 298 | 
            +
                        self.strategy = [None] * (args.n_layer + 1)
         | 
| 299 | 
            +
                        strategy = self.strategy
         | 
| 300 | 
            +
                        for n in range(args.n_layer + 1):
         | 
| 301 | 
            +
                            for i in range(len(s)):
         | 
| 302 | 
            +
                                if n < plan[i]:
         | 
| 303 | 
            +
                                    strategy[n] = types.SimpleNamespace()
         | 
| 304 | 
            +
                                    strategy[n].device = s[i][0]
         | 
| 305 | 
            +
                                    strategy[n].atype = s[i][1][0]
         | 
| 306 | 
            +
                                    strategy[n].wtype = s[i][1][1]
         | 
| 307 | 
            +
                                    strategy[n].stream = False
         | 
| 308 | 
            +
                                    if strategy[n].device == 'dml':
         | 
| 309 | 
            +
                                        strategy[n].device = torch_directml.device()
         | 
| 310 | 
            +
                                    if i == stream_i and n >= (plan[i] - stream_count):
         | 
| 311 | 
            +
                                        strategy[n].stream = True
         | 
| 312 | 
            +
                                    break
         | 
| 313 | 
            +
                            prxxx(f"{n}-{strategy[n].device}-{str(strategy[n].atype).replace('torch.','')}-{str(strategy[n].wtype).replace('torch.','')}{'-stream' if strategy[n].stream else ''}",end=' ')
         | 
| 314 | 
            +
                        prxxx()
         | 
| 315 | 
            +
             | 
| 316 | 
            +
                        ####################### Load weights to self.w
         | 
| 317 | 
            +
             | 
| 318 | 
            +
                        if not ALREADY_CONVERTED:
         | 
| 319 | 
            +
                            try: # precompute embedding
         | 
| 320 | 
            +
                                w['emb.weight'] = F.layer_norm(w['emb.weight'], (args.n_embd,), weight=w['blocks.0.ln0.weight'], bias=w['blocks.0.ln0.bias'])
         | 
| 321 | 
            +
                            except:
         | 
| 322 | 
            +
                                w['emb.weight'] = F.layer_norm(w['emb.weight'].float(), (args.n_embd,), weight=w['blocks.0.ln0.weight'].float(), bias=w['blocks.0.ln0.bias'].float())
         | 
| 323 | 
            +
                            # del w['blocks.0.ln0.weight']
         | 
| 324 | 
            +
                            # del w['blocks.0.ln0.bias']
         | 
| 325 | 
            +
             | 
| 326 | 
            +
                        print_need_newline = False
         | 
| 327 | 
            +
             | 
| 328 | 
            +
                        REAL_TIME_FIRST = False
         | 
| 329 | 
            +
                        for x in list(w.keys()):
         | 
| 330 | 
            +
                            if '.time_faaaa' in x: REAL_TIME_FIRST = True
         | 
| 331 | 
            +
                        if REAL_TIME_FIRST:
         | 
| 332 | 
            +
                            w = {k.replace('.time_faaaa','.time_first') if '.time_faaaa' in k else k: v for k, v in w.items()}
         | 
| 333 | 
            +
                            self.w = w
         | 
| 334 | 
            +
                        
         | 
| 335 | 
            +
                        keys = list(w.keys())
         | 
| 336 | 
            +
                        for x in keys:
         | 
| 337 | 
            +
                            w[x].requires_grad = False
         | 
| 338 | 
            +
                            layer_id = int(x.split('.')[1]) if ('blocks.' in x) else 0
         | 
| 339 | 
            +
                            if ('ln_out.' in x) or ('head.' in x):
         | 
| 340 | 
            +
                                layer_id = args.n_layer
         | 
| 341 | 
            +
                            dd = strategy[layer_id]
         | 
| 342 | 
            +
                            DEVICE = dd.device
         | 
| 343 | 
            +
                            ATYPE = dd.atype
         | 
| 344 | 
            +
                            WTYPE = dd.wtype
         | 
| 345 | 
            +
             | 
| 346 | 
            +
                            if not ALREADY_CONVERTED:
         | 
| 347 | 
            +
                                if self.RESCALE_LAYER > 0:
         | 
| 348 | 
            +
                                    if 'att.output.weight' in x:
         | 
| 349 | 
            +
                                        w[x] = w[x] / (2 ** int(layer_id // self.RESCALE_LAYER))
         | 
| 350 | 
            +
                                    if 'ffn.value.weight' in x:
         | 
| 351 | 
            +
                                        w[x] = w[x] / (2 ** int(layer_id // self.RESCALE_LAYER))
         | 
| 352 | 
            +
             | 
| 353 | 
            +
                                if '.time_' in x:
         | 
| 354 | 
            +
                                    w[x] = w[x].squeeze()
         | 
| 355 | 
            +
                                if 'key.weight' in x or 'value.weight' in x or 'receptance.weight' in x or 'gate.weight' in x or 'output.weight' in x or 'head.weight' in x:
         | 
| 356 | 
            +
                                    w[x] = w[x].t()
         | 
| 357 | 
            +
             | 
| 358 | 
            +
                                if '.time_decay' in x and '_w' not in x: # need fp32 for this
         | 
| 359 | 
            +
                                    if self.version == 4:
         | 
| 360 | 
            +
                                        w[x] = -torch.exp(w[x].float())
         | 
| 361 | 
            +
                                    elif int(self.version) == 5:
         | 
| 362 | 
            +
                                        w[x] = torch.exp(-torch.exp(w[x].float())).reshape(-1,1,1)
         | 
| 363 | 
            +
                                        if self.version == 5.2:
         | 
| 364 | 
            +
                                            w[x] = w[x].reshape(args.n_head, -1, 1)
         | 
| 365 | 
            +
                                    elif self.version == 6.0:
         | 
| 366 | 
            +
                                        w[x] = w[x].float().reshape(args.n_head, -1, 1)
         | 
| 367 | 
            +
                                elif '.time_first' in x: # need fp32 for this
         | 
| 368 | 
            +
                                    if self.version == 4:
         | 
| 369 | 
            +
                                        w[x] = w[x].float()
         | 
| 370 | 
            +
                                    elif int(self.version) in [5, 6]:
         | 
| 371 | 
            +
                                        if REAL_TIME_FIRST:
         | 
| 372 | 
            +
                                            w[x] = w[x].float().reshape(-1,1,1)
         | 
| 373 | 
            +
                                        else:
         | 
| 374 | 
            +
                                            w[x] = torch.exp(w[x].float()).reshape(-1,1,1)
         | 
| 375 | 
            +
                                        if self.version in [5.2, 6.0]:
         | 
| 376 | 
            +
                                            w[x] = w[x].reshape(args.n_head, -1, 1)
         | 
| 377 | 
            +
                                elif '.ln_x' in x: # need fp32 for group_norm
         | 
| 378 | 
            +
                                    w[x] = w[x].float()
         | 
| 379 | 
            +
                                else:
         | 
| 380 | 
            +
                                    if (len(w[x].shape) == 2) and ('emb' not in x):
         | 
| 381 | 
            +
                                        if WTYPE != torch.uint8:
         | 
| 382 | 
            +
                                            w[x] = w[x].to(dtype=WTYPE)
         | 
| 383 | 
            +
                                        else:
         | 
| 384 | 
            +
                                            w[x] = w[x].float()
         | 
| 385 | 
            +
             | 
| 386 | 
            +
                                            if w[x].shape[0] > w[x].shape[1]:
         | 
| 387 | 
            +
                                                w[x+'_my'] = torch.amin(w[x], dim=1).unsqueeze(1)
         | 
| 388 | 
            +
                                                w[x] = w[x] - w[x+'_my']
         | 
| 389 | 
            +
                                                w[x+'_mx'] = torch.amin(w[x], dim=0)
         | 
| 390 | 
            +
                                                w[x] = w[x] - w[x+'_mx']
         | 
| 391 | 
            +
                                                w[x+'_rx'] = torch.amax(w[x], dim=0)
         | 
| 392 | 
            +
                                                w[x] = w[x] / w[x+'_rx']
         | 
| 393 | 
            +
                                                w[x+'_ry'] = torch.amax(w[x], dim=1).unsqueeze(1)
         | 
| 394 | 
            +
                                                w[x] = w[x] / w[x+'_ry']
         | 
| 395 | 
            +
                                            else:
         | 
| 396 | 
            +
                                                w[x+'_mx'] = torch.amin(w[x], dim=0)
         | 
| 397 | 
            +
                                                w[x] = w[x] - w[x+'_mx']
         | 
| 398 | 
            +
                                                w[x+'_my'] = torch.amin(w[x], dim=1).unsqueeze(1)
         | 
| 399 | 
            +
                                                w[x] = w[x] - w[x+'_my']
         | 
| 400 | 
            +
                                                w[x+'_rx'] = torch.amax(w[x], dim=0)
         | 
| 401 | 
            +
                                                w[x] = w[x] / w[x+'_rx']
         | 
| 402 | 
            +
                                                w[x+'_ry'] = torch.amax(w[x], dim=1).unsqueeze(1)
         | 
| 403 | 
            +
                                                w[x] = w[x] / w[x+'_ry']
         | 
| 404 | 
            +
             | 
| 405 | 
            +
                                            w[x] = torch.clip(torch.floor(w[x] * 256), min=0, max=255).to(dtype=torch.uint8)
         | 
| 406 | 
            +
                                            w[x+'_mx'] = w[x+'_mx'].to(dtype=ATYPE).contiguous()
         | 
| 407 | 
            +
                                            w[x+'_rx'] = (w[x+'_rx'] / 16).to(dtype=ATYPE).contiguous()
         | 
| 408 | 
            +
                                            w[x+'_my'] = w[x+'_my'].to(dtype=ATYPE).contiguous()
         | 
| 409 | 
            +
                                            w[x+'_ry'] = (w[x+'_ry'] / 16).to(dtype=ATYPE).contiguous()
         | 
| 410 | 
            +
                                    else:
         | 
| 411 | 
            +
                                        w[x] = w[x].to(dtype=ATYPE)
         | 
| 412 | 
            +
                            
         | 
| 413 | 
            +
                            if convert_and_save_and_exit == None:
         | 
| 414 | 
            +
                                if 'emb.' in x:
         | 
| 415 | 
            +
                                    w[x] = w[x].contiguous()
         | 
| 416 | 
            +
                                elif (dd.stream) and (x.endswith('key.weight') or x.endswith('value.weight') or x.endswith('receptance.weight') or x.endswith('output.weight')):
         | 
| 417 | 
            +
                                    try:
         | 
| 418 | 
            +
                                        w[x] = w[x].contiguous().pin_memory() # if you see "CUDA error: out of memory" here, that's out of CPU RAM, not VRAM. Get more RAM :)
         | 
| 419 | 
            +
                                    except:
         | 
| 420 | 
            +
                                        print('Note: You are running out of RAM. Get more CPU RAM. Now this will run much slower.')
         | 
| 421 | 
            +
                                elif DEVICE != 'cpu':
         | 
| 422 | 
            +
                                    w[x] = w[x].to(device=DEVICE).contiguous()
         | 
| 423 | 
            +
                                
         | 
| 424 | 
            +
                                if (dd.stream) or (DEVICE != 'cpu'):
         | 
| 425 | 
            +
                                    try:
         | 
| 426 | 
            +
                                        w[x+'_mx'] = w[x+'_mx'].to(device=DEVICE).contiguous()
         | 
| 427 | 
            +
                                        w[x+'_rx'] = w[x+'_rx'].to(device=DEVICE).contiguous()
         | 
| 428 | 
            +
                                        w[x+'_my'] = w[x+'_my'].to(device=DEVICE).contiguous()
         | 
| 429 | 
            +
                                        w[x+'_ry'] = w[x+'_ry'].to(device=DEVICE).contiguous()
         | 
| 430 | 
            +
                                    except:
         | 
| 431 | 
            +
                                        pass
         | 
| 432 | 
            +
             | 
| 433 | 
            +
                            if 'ffn.value.weight' in x:
         | 
| 434 | 
            +
                                gc.collect()
         | 
| 435 | 
            +
                                if 'cuda' in args.strategy_string:
         | 
| 436 | 
            +
                                    torch.cuda.empty_cache()
         | 
| 437 | 
            +
             | 
| 438 | 
            +
                            shape = [i for i in w[x].shape if i != 1]
         | 
| 439 | 
            +
                            if len(shape) > 1:
         | 
| 440 | 
            +
                                shape = f" {str(shape[0]).rjust(5)} {str(shape[1]).rjust(5)}"
         | 
| 441 | 
            +
                            else:
         | 
| 442 | 
            +
                                shape = f" {str(shape[0]).rjust(5)}      "
         | 
| 443 | 
            +
                            if layer_id == 0 or layer_id >= args.n_layer-1:
         | 
| 444 | 
            +
                                if print_need_newline:
         | 
| 445 | 
            +
                                    prxxx('\n', end = '')
         | 
| 446 | 
            +
                                    print_need_newline = False
         | 
| 447 | 
            +
                                dt = str(w[x].dtype).replace('torch.', '')
         | 
| 448 | 
            +
                                dt = dt.replace('float32', 'f32').replace('bfloat16', 'bf16').replace('float16', 'f16').replace('uint8', 'i8')
         | 
| 449 | 
            +
                                prxxx(x.ljust(32), dt.rjust(4), str(w[x].device).rjust(8), shape, ' (pinned)' if w[x].is_pinned() else '')
         | 
| 450 | 
            +
                            else:
         | 
| 451 | 
            +
                                print_need_newline = True
         | 
| 452 | 
            +
                                prxxx('.', end = '', flush = True)
         | 
| 453 | 
            +
                        
         | 
| 454 | 
            +
                        if convert_and_save_and_exit:
         | 
| 455 | 
            +
                            w['_strategy'] = args.strategy_string
         | 
| 456 | 
            +
                            w['_rescale_layer'] = self.RESCALE_LAYER
         | 
| 457 | 
            +
                            w['_version'] = '0.7'
         | 
| 458 | 
            +
                            if not convert_and_save_and_exit.endswith('.pth'):
         | 
| 459 | 
            +
                                convert_and_save_and_exit += '.pth'
         | 
| 460 | 
            +
                            prxxx(f'Saving to {convert_and_save_and_exit}...')
         | 
| 461 | 
            +
                            torch.save(w, convert_and_save_and_exit)
         | 
| 462 | 
            +
                            prxxx(f'Converted and saved. Now this will exit.')
         | 
| 463 | 
            +
                            exit(0)
         | 
| 464 | 
            +
                        
         | 
| 465 | 
            +
                        if self.version == 5.2 and os.environ["RWKV_CUDA_ON"] == '1':
         | 
| 466 | 
            +
                            HEAD_SIZE = args.n_att // args.n_head
         | 
| 467 | 
            +
                            rwkv5 = load(name="rwkv5", sources=[f"{current_path}/cuda/rwkv5_op.cpp", f"{current_path}/cuda/rwkv5.cu"],
         | 
| 468 | 
            +
                                            verbose=True, extra_cuda_cflags=["-res-usage", "--use_fast_math", "-O3", "-Xptxas -O3" if os.name != "nt" else "", "--extra-device-vectorization", f"-D_N_={HEAD_SIZE}"])
         | 
| 469 | 
            +
             | 
| 470 | 
            +
                            class RWKV_5(torch.autograd.Function):
         | 
| 471 | 
            +
                                @staticmethod
         | 
| 472 | 
            +
                                def forward(ctx, B, T, C, H, state, r, k, v, w, u):
         | 
| 473 | 
            +
                                    with torch.no_grad():
         | 
| 474 | 
            +
                                        assert HEAD_SIZE == C // H
         | 
| 475 | 
            +
                                        ctx.B = B
         | 
| 476 | 
            +
                                        ctx.T = T
         | 
| 477 | 
            +
                                        ctx.C = C
         | 
| 478 | 
            +
                                        ctx.H = H
         | 
| 479 | 
            +
                                        assert state.dtype == torch.float32
         | 
| 480 | 
            +
                                        assert w.dtype == torch.float32
         | 
| 481 | 
            +
                                        assert r.is_contiguous()
         | 
| 482 | 
            +
                                        assert k.is_contiguous()
         | 
| 483 | 
            +
                                        assert v.is_contiguous()
         | 
| 484 | 
            +
                                        assert w.is_contiguous()                            
         | 
| 485 | 
            +
                                        assert u.is_contiguous()                            
         | 
| 486 | 
            +
                                        assert state.is_contiguous()
         | 
| 487 | 
            +
             | 
| 488 | 
            +
                                        y = torch.empty((B, T, C), device=w.device, dtype=r.dtype, memory_format=torch.contiguous_format)
         | 
| 489 | 
            +
                                        if r.dtype == torch.bfloat16:
         | 
| 490 | 
            +
                                            rwkv5.forward_bf16(B, T, C, H, state, r, k, v, w, u, y)
         | 
| 491 | 
            +
                                        elif r.dtype == torch.float16:
         | 
| 492 | 
            +
                                            rwkv5.forward_fp16(B, T, C, H, state, r, k, v, w, u, y)
         | 
| 493 | 
            +
                                        elif r.dtype == torch.float32:
         | 
| 494 | 
            +
                                            rwkv5.forward_fp32(B, T, C, H, state, r, k, v, w, u, y)
         | 
| 495 | 
            +
                                        return y, state
         | 
| 496 | 
            +
                            self.RWKV_5 = RWKV_5
         | 
| 497 | 
            +
             | 
| 498 | 
            +
                        if self.version == 6.0 and os.environ["RWKV_CUDA_ON"] == '1':
         | 
| 499 | 
            +
                            HEAD_SIZE = args.n_att // args.n_head
         | 
| 500 | 
            +
                            rwkv6 = load(name="rwkv6", sources=[f"{current_path}/cuda/rwkv6_op.cpp", f"{current_path}/cuda/rwkv6.cu"],
         | 
| 501 | 
            +
                                            verbose=True, extra_cuda_cflags=["-res-usage", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-D_N_={HEAD_SIZE}", f"-D_T_={4096}"])
         | 
| 502 | 
            +
                                
         | 
| 503 | 
            +
                            class RWKV_6(torch.autograd.Function):
         | 
| 504 | 
            +
                                @staticmethod
         | 
| 505 | 
            +
                                def forward(ctx, B, T, C, H, state, r, k, v, w, u):
         | 
| 506 | 
            +
                                    with torch.no_grad():
         | 
| 507 | 
            +
                                        assert HEAD_SIZE == C // H
         | 
| 508 | 
            +
                                        ctx.B = B
         | 
| 509 | 
            +
                                        ctx.T = T
         | 
| 510 | 
            +
                                        ctx.C = C
         | 
| 511 | 
            +
                                        ctx.H = H
         | 
| 512 | 
            +
                                        assert state.dtype == torch.float32
         | 
| 513 | 
            +
                                        assert w.dtype == torch.float32
         | 
| 514 | 
            +
                                        assert r.is_contiguous()
         | 
| 515 | 
            +
                                        assert k.is_contiguous()
         | 
| 516 | 
            +
                                        assert v.is_contiguous()
         | 
| 517 | 
            +
                                        assert w.is_contiguous()
         | 
| 518 | 
            +
                                        assert u.is_contiguous()
         | 
| 519 | 
            +
                                        eew = torch.exp(-torch.exp(w.float())).contiguous()
         | 
| 520 | 
            +
             | 
| 521 | 
            +
                                        y = torch.empty((B, T, C), device=w.device, dtype=r.dtype, memory_format=torch.contiguous_format)
         | 
| 522 | 
            +
                                        if r.dtype == torch.bfloat16:
         | 
| 523 | 
            +
                                            rwkv6.forward_bf16(B, T, C, H, state, r, k, v, eew, u, y)
         | 
| 524 | 
            +
                                        elif r.dtype == torch.float16:
         | 
| 525 | 
            +
                                            rwkv6.forward_fp16(B, T, C, H, state, r, k, v, eew, u, y)
         | 
| 526 | 
            +
                                        elif r.dtype == torch.float32:
         | 
| 527 | 
            +
                                            rwkv6.forward_fp32(B, T, C, H, state, r, k, v, eew, u, y)
         | 
| 528 | 
            +
                                        return y, state
         | 
| 529 | 
            +
                            self.RWKV_6 = RWKV_6
         | 
| 530 | 
            +
                    
         | 
| 531 | 
            +
                        gc.collect()
         | 
| 532 | 
            +
                        if 'cuda' in args.strategy_string:
         | 
| 533 | 
            +
                            torch.cuda.empty_cache()
         | 
| 534 | 
            +
             | 
| 535 | 
            +
                def RUN_RWKV_5(self, B, T, C, H, state, r, k, v, w, u):
         | 
| 536 | 
            +
                    return self.RWKV_5.apply(B, T, C, H, state, r, k, v, w, u)
         | 
| 537 | 
            +
             | 
| 538 | 
            +
                def RUN_RWKV_6(self, B, T, C, H, state, r, k, v, w, u):
         | 
| 539 | 
            +
                    return self.RWKV_6.apply(B, T, C, H, state, r, k, v, w, u)
         | 
| 540 | 
            +
             | 
| 541 | 
            +
                ########################################################################################################
         | 
| 542 | 
            +
             | 
| 543 | 
            +
                @MyFunction
         | 
| 544 | 
            +
                def ffn_one(self, x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry):
         | 
| 545 | 
            +
                    xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
         | 
| 546 | 
            +
                    kx = xx * k_mix + sx * (1 - k_mix)
         | 
| 547 | 
            +
                    rx = xx * r_mix + sx * (1 - r_mix)
         | 
| 548 | 
            +
             | 
| 549 | 
            +
                    r = torch.sigmoid(matmul(rx, rw, rmx, rrx, rmy, rry))
         | 
| 550 | 
            +
                    vx = torch.relu(matmul(kx, kw, kmx, krx, kmy, kry)) ** 2
         | 
| 551 | 
            +
                    out = r * matmul(vx, vw, vmx, vrx, vmy, vry)
         | 
| 552 | 
            +
                    return x + out, xx
         | 
| 553 | 
            +
             | 
| 554 | 
            +
                @MyFunction
         | 
| 555 | 
            +
                def ffn_seq(self, x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry):
         | 
| 556 | 
            +
                    xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
         | 
| 557 | 
            +
                    sx = torch.cat((sx.unsqueeze(0), xx[:-1,:]))
         | 
| 558 | 
            +
                    kx = xx * k_mix + sx * (1 - k_mix)
         | 
| 559 | 
            +
                    rx = xx * r_mix + sx * (1 - r_mix)
         | 
| 560 | 
            +
             | 
| 561 | 
            +
                    r = torch.sigmoid(matmul(rx, rw, rmx, rrx, rmy, rry))
         | 
| 562 | 
            +
                    vx = torch.relu(matmul(kx, kw, kmx, krx, kmy, kry)) ** 2
         | 
| 563 | 
            +
                    out = r * matmul(vx, vw, vmx, vrx, vmy, vry)
         | 
| 564 | 
            +
                    return x + out, xx[-1,:]
         | 
| 565 | 
            +
             | 
| 566 | 
            +
                @MyFunction
         | 
| 567 | 
            +
                def ffn_one_v6(self, x, sx, ln_w, ln_b, k_maa, r_maa, kw, vw, rw, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry):
         | 
| 568 | 
            +
                    xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
         | 
| 569 | 
            +
                    sx = sx - xx
         | 
| 570 | 
            +
                    kx = xx + sx * k_maa
         | 
| 571 | 
            +
                    rx = xx + sx * r_maa
         | 
| 572 | 
            +
             | 
| 573 | 
            +
                    r = torch.sigmoid(matmul(rx, rw, rmx, rrx, rmy, rry))
         | 
| 574 | 
            +
                    vx = torch.relu(matmul(kx, kw, kmx, krx, kmy, kry)) ** 2
         | 
| 575 | 
            +
                    out = r * matmul(vx, vw, vmx, vrx, vmy, vry)
         | 
| 576 | 
            +
                    return x + out, xx
         | 
| 577 | 
            +
             | 
| 578 | 
            +
                @MyFunction
         | 
| 579 | 
            +
                def ffn_seq_v6(self, x, sx, ln_w, ln_b, k_maa, r_maa, kw, vw, rw, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry):
         | 
| 580 | 
            +
                    xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
         | 
| 581 | 
            +
                    sx = torch.cat((sx.unsqueeze(0), xx[:-1,:]))
         | 
| 582 | 
            +
                    sx = sx - xx
         | 
| 583 | 
            +
                    kx = xx + sx * k_maa
         | 
| 584 | 
            +
                    rx = xx + sx * r_maa
         | 
| 585 | 
            +
             | 
| 586 | 
            +
                    r = torch.sigmoid(matmul(rx, rw, rmx, rrx, rmy, rry))
         | 
| 587 | 
            +
                    vx = torch.relu(matmul(kx, kw, kmx, krx, kmy, kry)) ** 2
         | 
| 588 | 
            +
                    out = r * matmul(vx, vw, vmx, vrx, vmy, vry)
         | 
| 589 | 
            +
                    return x + out, xx[-1,:]
         | 
| 590 | 
            +
             | 
| 591 | 
            +
                ########################################################################################################
         | 
| 592 | 
            +
             | 
| 593 | 
            +
                @MyFunction
         | 
| 594 | 
            +
                def att_one(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t_first, kw, vw, rw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, omx, orx, omy, ory):
         | 
| 595 | 
            +
                    xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
         | 
| 596 | 
            +
                    kx = xx * k_mix + sx * (1 - k_mix)
         | 
| 597 | 
            +
                    vx = xx * v_mix + sx * (1 - v_mix)
         | 
| 598 | 
            +
                    rx = xx * r_mix + sx * (1 - r_mix)
         | 
| 599 | 
            +
             | 
| 600 | 
            +
                    r = torch.sigmoid(matmul(rx, rw, rmx, rrx, rmy, rry))
         | 
| 601 | 
            +
                    k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32)
         | 
| 602 | 
            +
                    v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32)
         | 
| 603 | 
            +
             | 
| 604 | 
            +
                    ww = t_first + k
         | 
| 605 | 
            +
                    p = torch.maximum(pp, ww)
         | 
| 606 | 
            +
                    e1 = torch.exp(pp - p)
         | 
| 607 | 
            +
                    e2 = torch.exp(ww - p)
         | 
| 608 | 
            +
                    wkv = ((e1 * aa + e2 * v) / (e1 * bb + e2)).to(dtype=x.dtype)
         | 
| 609 | 
            +
                    ww = t_decay + pp
         | 
| 610 | 
            +
                    p = torch.maximum(ww, k)
         | 
| 611 | 
            +
                    e1 = torch.exp(ww - p)
         | 
| 612 | 
            +
                    e2 = torch.exp(k - p)
         | 
| 613 | 
            +
             | 
| 614 | 
            +
                    out = matmul(r * wkv, ow, omx, orx, omy, ory)
         | 
| 615 | 
            +
                    return x + out, xx, e1 * aa + e2 * v, e1 * bb + e2, p
         | 
| 616 | 
            +
             | 
| 617 | 
            +
                @MyFunction
         | 
| 618 | 
            +
                def att_seq(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t_first, kw, vw, rw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, omx, orx, omy, ory):
         | 
| 619 | 
            +
                    xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
         | 
| 620 | 
            +
                    sx = torch.cat((sx.unsqueeze(0), xx[:-1,:]))
         | 
| 621 | 
            +
                    kx = xx * k_mix + sx * (1 - k_mix)
         | 
| 622 | 
            +
                    vx = xx * v_mix + sx * (1 - v_mix)
         | 
| 623 | 
            +
                    rx = xx * r_mix + sx * (1 - r_mix)
         | 
| 624 | 
            +
             | 
| 625 | 
            +
                    r = torch.sigmoid(matmul(rx, rw, rmx, rrx, rmy, rry))
         | 
| 626 | 
            +
                    k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32)
         | 
| 627 | 
            +
                    v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32)
         | 
| 628 | 
            +
             | 
| 629 | 
            +
                    T = x.shape[0]
         | 
| 630 | 
            +
                    for t in range(T):
         | 
| 631 | 
            +
                        kk = k[t]
         | 
| 632 | 
            +
                        vv = v[t]
         | 
| 633 | 
            +
                        ww = t_first + kk
         | 
| 634 | 
            +
                        p = torch.maximum(pp, ww)
         | 
| 635 | 
            +
                        e1 = torch.exp(pp - p)
         | 
| 636 | 
            +
                        e2 = torch.exp(ww - p)
         | 
| 637 | 
            +
                        sx[t] = ((e1 * aa + e2 * vv) / (e1 * bb + e2)).to(dtype=x.dtype)
         | 
| 638 | 
            +
                        ww = t_decay + pp
         | 
| 639 | 
            +
                        p = torch.maximum(ww, kk)
         | 
| 640 | 
            +
                        e1 = torch.exp(ww - p)
         | 
| 641 | 
            +
                        e2 = torch.exp(kk - p)
         | 
| 642 | 
            +
                        aa = e1 * aa + e2 * vv
         | 
| 643 | 
            +
                        bb = e1 * bb + e2
         | 
| 644 | 
            +
                        pp = p
         | 
| 645 | 
            +
                    out = matmul(r * sx, ow, omx, orx, omy, ory)
         | 
| 646 | 
            +
                    return x + out, xx[-1,:], aa, bb, pp
         | 
| 647 | 
            +
             | 
| 648 | 
            +
                ########################################################################################################
         | 
| 649 | 
            +
             | 
| 650 | 
            +
                @MyFunction
         | 
| 651 | 
            +
                def att_one_v5(self, x, sx, s, ln_w, ln_b, lx_w, lx_b, k_mix, v_mix, r_mix, t_decay, t_first, kw, vw, rw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, omx, orx, omy, ory):
         | 
| 652 | 
            +
                    xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
         | 
| 653 | 
            +
                    kx = xx * k_mix + sx * (1 - k_mix)
         | 
| 654 | 
            +
                    vx = xx * v_mix + sx * (1 - v_mix)
         | 
| 655 | 
            +
                    rx = xx * r_mix + sx * (1 - r_mix)
         | 
| 656 | 
            +
             | 
| 657 | 
            +
                    H = t_decay.shape[0]
         | 
| 658 | 
            +
                    N = x.shape[-1] // H
         | 
| 659 | 
            +
             | 
| 660 | 
            +
                    r = matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32).view(H, 1, N)
         | 
| 661 | 
            +
                    k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32).view(H, N, 1)
         | 
| 662 | 
            +
                    v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32).view(H, 1, N)
         | 
| 663 | 
            +
                    
         | 
| 664 | 
            +
                    a = matmul(k, v)
         | 
| 665 | 
            +
                    out = r @ (t_first * a + s)
         | 
| 666 | 
            +
                    s = a + t_decay * s
         | 
| 667 | 
            +
             | 
| 668 | 
            +
                    out = out.flatten()
         | 
| 669 | 
            +
                    out = F.group_norm(out.unsqueeze(0), num_groups=H, weight=lx_w, bias=lx_b, eps = 64e-5).squeeze(0)
         | 
| 670 | 
            +
                    out = out.to(dtype=x.dtype)
         | 
| 671 | 
            +
                    out = matmul(out, ow, omx, orx, omy, ory)
         | 
| 672 | 
            +
             | 
| 673 | 
            +
                    return x + out, xx, s
         | 
| 674 | 
            +
             | 
| 675 | 
            +
                @MyFunction
         | 
| 676 | 
            +
                def att_seq_v5(self, x, sx, s, ln_w, ln_b, lx_w, lx_b, k_mix, v_mix, r_mix, t_decay, t_first, kw, vw, rw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, omx, orx, omy, ory):
         | 
| 677 | 
            +
                    xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
         | 
| 678 | 
            +
                    sx = torch.cat((sx.unsqueeze(0), xx[:-1,:]))
         | 
| 679 | 
            +
                    kx = xx * k_mix + sx * (1 - k_mix)
         | 
| 680 | 
            +
                    vx = xx * v_mix + sx * (1 - v_mix)
         | 
| 681 | 
            +
                    rx = xx * r_mix + sx * (1 - r_mix)
         | 
| 682 | 
            +
             | 
| 683 | 
            +
                    H = t_decay.shape[0]
         | 
| 684 | 
            +
                    N = x.shape[-1] // H
         | 
| 685 | 
            +
                    T = x.shape[0]
         | 
| 686 | 
            +
             | 
| 687 | 
            +
                    w = t_decay.reshape(-1, 1)
         | 
| 688 | 
            +
                    u = t_first.reshape(-1, 1)
         | 
| 689 | 
            +
                    ws = w.pow(T).reshape(H, 1, 1)
         | 
| 690 | 
            +
                    ind = torch.arange(T-1, -1, -1, device=w.device).unsqueeze(0).repeat(H, 1)
         | 
| 691 | 
            +
                    w = w.repeat(1, T).pow(ind)
         | 
| 692 | 
            +
                    wk = w.reshape(H, 1, T)
         | 
| 693 | 
            +
                    wb = wk.transpose(-2, -1).flip(1)
         | 
| 694 | 
            +
                    w = torch.cat([w[:, 1:], u], dim=1)
         | 
| 695 | 
            +
                    w = F.pad(w, (0, T))
         | 
| 696 | 
            +
                    w = torch.tile(w, [T])
         | 
| 697 | 
            +
                    w = w[:, :-T].reshape(-1, T, 2 * T - 1)
         | 
| 698 | 
            +
                    w = w[:, :, T-1:].reshape(H, T, T)
         | 
| 699 | 
            +
             | 
| 700 | 
            +
                    r = matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32).view(T, H, N).transpose(0, 1)
         | 
| 701 | 
            +
                    k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32).view(T, H, N).permute(1, 2, 0)
         | 
| 702 | 
            +
                    v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32).view(T, H, N).transpose(0, 1)
         | 
| 703 | 
            +
             | 
| 704 | 
            +
                    out = ((r @ k) * w) @ v + (r @ s) * wb
         | 
| 705 | 
            +
                    s = ws * s + (k * wk) @ v
         | 
| 706 | 
            +
                    
         | 
| 707 | 
            +
                    out = out.transpose(0, 1).contiguous().reshape(T, H*N)
         | 
| 708 | 
            +
                    out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b, eps = 64e-5)
         | 
| 709 | 
            +
                    out = out.to(dtype=x.dtype)
         | 
| 710 | 
            +
                    out = matmul(out, ow, omx, orx, omy, ory)
         | 
| 711 | 
            +
             | 
| 712 | 
            +
                    return x + out, xx[-1,:], s
         | 
| 713 | 
            +
             | 
| 714 | 
            +
                ########################################################################################################
         | 
| 715 | 
            +
             | 
| 716 | 
            +
                @MyFunction
         | 
| 717 | 
            +
                def att_one_v5_1(self, x, sx, s, ln_w, ln_b, lx_w, lx_b, k_mix, v_mix, r_mix, g_mix, t_decay, t_first, kw, vw, rw, gw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, gmx, grx, gmy, gry, omx, orx, omy, ory):
         | 
| 718 | 
            +
                    xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
         | 
| 719 | 
            +
                    kx = xx * k_mix + sx * (1 - k_mix)
         | 
| 720 | 
            +
                    vx = xx * v_mix + sx * (1 - v_mix)
         | 
| 721 | 
            +
                    rx = xx * r_mix + sx * (1 - r_mix)
         | 
| 722 | 
            +
                    gx = xx * g_mix + sx * (1 - g_mix)
         | 
| 723 | 
            +
             | 
| 724 | 
            +
                    H = t_decay.shape[0]
         | 
| 725 | 
            +
                    N = x.shape[-1] // H
         | 
| 726 | 
            +
             | 
| 727 | 
            +
                    r = matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32).view(H, 1, N)
         | 
| 728 | 
            +
                    k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32).view(H, N, 1)
         | 
| 729 | 
            +
                    v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32).view(H, 1, N)
         | 
| 730 | 
            +
                    g = F.silu(matmul(gx, gw, gmx, grx, gmy, gry))
         | 
| 731 | 
            +
                    
         | 
| 732 | 
            +
                    a = matmul(k, v)
         | 
| 733 | 
            +
                    out = r @ (t_first * a + s)
         | 
| 734 | 
            +
                    s = a + t_decay * s
         | 
| 735 | 
            +
             | 
| 736 | 
            +
                    out = out.flatten()
         | 
| 737 | 
            +
                    out = F.group_norm(out.unsqueeze(0), num_groups=H, weight=lx_w, bias=lx_b, eps = 64e-5).squeeze(0)
         | 
| 738 | 
            +
                    out = out.to(dtype=x.dtype) * g
         | 
| 739 | 
            +
                    out = matmul(out, ow, omx, orx, omy, ory)
         | 
| 740 | 
            +
             | 
| 741 | 
            +
                    return x + out, xx, s
         | 
| 742 | 
            +
             | 
| 743 | 
            +
                @MyFunction
         | 
| 744 | 
            +
                def att_seq_v5_1(self, x, sx, s, ln_w, ln_b, lx_w, lx_b, k_mix, v_mix, r_mix, g_mix, t_decay, t_first, kw, vw, rw, gw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, gmx, grx, gmy, gry, omx, orx, omy, ory):
         | 
| 745 | 
            +
                    xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
         | 
| 746 | 
            +
                    sx = torch.cat((sx.unsqueeze(0), xx[:-1,:]))
         | 
| 747 | 
            +
                    kx = xx * k_mix + sx * (1 - k_mix)
         | 
| 748 | 
            +
                    vx = xx * v_mix + sx * (1 - v_mix)
         | 
| 749 | 
            +
                    rx = xx * r_mix + sx * (1 - r_mix)
         | 
| 750 | 
            +
                    gx = xx * g_mix + sx * (1 - g_mix)
         | 
| 751 | 
            +
             | 
| 752 | 
            +
                    H = t_decay.shape[0]
         | 
| 753 | 
            +
                    N = x.shape[-1] // H
         | 
| 754 | 
            +
                    T = x.shape[0]
         | 
| 755 | 
            +
             | 
| 756 | 
            +
                    w = t_decay.reshape(-1, 1)
         | 
| 757 | 
            +
                    u = t_first.reshape(-1, 1)
         | 
| 758 | 
            +
                    ws = w.pow(T).reshape(H, 1, 1)
         | 
| 759 | 
            +
                    ind = torch.arange(T-1, -1, -1, device=w.device).unsqueeze(0).repeat(H, 1)
         | 
| 760 | 
            +
                    w = w.repeat(1, T).pow(ind)
         | 
| 761 | 
            +
                    wk = w.reshape(H, 1, T)
         | 
| 762 | 
            +
                    wb = wk.transpose(-2, -1).flip(1)
         | 
| 763 | 
            +
                    w = torch.cat([w[:, 1:], u], dim=1)
         | 
| 764 | 
            +
                    w = F.pad(w, (0, T))
         | 
| 765 | 
            +
                    w = torch.tile(w, [T])
         | 
| 766 | 
            +
                    w = w[:, :-T].reshape(-1, T, 2 * T - 1)
         | 
| 767 | 
            +
                    w = w[:, :, T-1:].reshape(H, T, T)
         | 
| 768 | 
            +
             | 
| 769 | 
            +
                    r = matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32).view(T, H, N).transpose(0, 1)
         | 
| 770 | 
            +
                    k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32).view(T, H, N).permute(1, 2, 0)
         | 
| 771 | 
            +
                    v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32).view(T, H, N).transpose(0, 1)
         | 
| 772 | 
            +
                    g = F.silu(matmul(gx, gw, gmx, grx, gmy, gry))
         | 
| 773 | 
            +
             | 
| 774 | 
            +
                    out = ((r @ k) * w) @ v + (r @ s) * wb
         | 
| 775 | 
            +
                    s = ws * s + (k * wk) @ v
         | 
| 776 | 
            +
                    
         | 
| 777 | 
            +
                    out = out.transpose(0, 1).contiguous().reshape(T, H*N)
         | 
| 778 | 
            +
                    out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b, eps = 64e-5)
         | 
| 779 | 
            +
                    out = out.to(dtype=x.dtype) * g
         | 
| 780 | 
            +
                    out = matmul(out, ow, omx, orx, omy, ory)
         | 
| 781 | 
            +
             | 
| 782 | 
            +
                    return x + out, xx[-1,:], s
         | 
| 783 | 
            +
             | 
| 784 | 
            +
                ########################################################################################################
         | 
| 785 | 
            +
             | 
| 786 | 
            +
                @MyFunction
         | 
| 787 | 
            +
                def att_seq_v5_2(self, x, sx, s, ln_w, ln_b, lx_w, lx_b, k_mix, v_mix, r_mix, g_mix, t_decay, t_first, kw, vw, rw, gw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, gmx, grx, gmy, gry, omx, orx, omy, ory):
         | 
| 788 | 
            +
                    xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
         | 
| 789 | 
            +
                    sx = torch.cat((sx.unsqueeze(0), xx[:-1,:]))
         | 
| 790 | 
            +
                    kx = xx * k_mix + sx * (1 - k_mix)
         | 
| 791 | 
            +
                    vx = xx * v_mix + sx * (1 - v_mix)
         | 
| 792 | 
            +
                    rx = xx * r_mix + sx * (1 - r_mix)
         | 
| 793 | 
            +
                    gx = xx * g_mix + sx * (1 - g_mix)
         | 
| 794 | 
            +
             | 
| 795 | 
            +
                    H = t_decay.shape[0]
         | 
| 796 | 
            +
                    N = x.shape[-1] // H
         | 
| 797 | 
            +
                    T = x.shape[0]
         | 
| 798 | 
            +
             | 
| 799 | 
            +
                    r = matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32).view(T, H, N).transpose(0, 1)
         | 
| 800 | 
            +
                    k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32).view(T, H, N).permute(1, 2, 0)
         | 
| 801 | 
            +
                    v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32).view(T, H, N).transpose(0, 1)
         | 
| 802 | 
            +
                    g = F.silu(matmul(gx, gw, gmx, grx, gmy, gry))
         | 
| 803 | 
            +
             | 
| 804 | 
            +
                    out = torch.empty((T, H, N), dtype=r.dtype, device=r.device)
         | 
| 805 | 
            +
                    for t in range(T):
         | 
| 806 | 
            +
                        rt = r[:,t:t+1,:]
         | 
| 807 | 
            +
                        kt = k[:,:,t:t+1]
         | 
| 808 | 
            +
                        vt = v[:,t:t+1,:]
         | 
| 809 | 
            +
                        at = matmul(kt, vt)
         | 
| 810 | 
            +
                        out[t] = (rt @ (t_first * at + s)).squeeze(1)
         | 
| 811 | 
            +
                        s = at + t_decay * s
         | 
| 812 | 
            +
             | 
| 813 | 
            +
                    out = out.reshape(T, H*N)
         | 
| 814 | 
            +
                    out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b, eps = 64e-5)
         | 
| 815 | 
            +
                    out = out.to(dtype=x.dtype) * g
         | 
| 816 | 
            +
                    out = matmul(out, ow, omx, orx, omy, ory)
         | 
| 817 | 
            +
             | 
| 818 | 
            +
                    return x + out, xx[-1,:], s
         | 
| 819 | 
            +
             | 
| 820 | 
            +
                ########################################################################################################
         | 
| 821 | 
            +
             | 
| 822 | 
            +
                @MyFunction
         | 
| 823 | 
            +
                def att_one_v6_0(self, x, sx, s, ln_w, ln_b, lx_w, lx_b, x_maa, w_maa, k_maa, v_maa, r_maa, g_maa, tm_w1, tm_w2, td_w1, td_w2, t_decay, t_first, kw, vw, rw, gw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, gmx, grx, gmy, gry, omx, orx, omy, ory):
         | 
| 824 | 
            +
                    xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
         | 
| 825 | 
            +
                    
         | 
| 826 | 
            +
                    sx = sx - xx
         | 
| 827 | 
            +
                    xxx = xx + sx * x_maa
         | 
| 828 | 
            +
                    xxx = torch.tanh(xxx @ tm_w1).view(5, 1, -1)
         | 
| 829 | 
            +
                    xxx = torch.bmm(xxx, tm_w2).view(5, -1)
         | 
| 830 | 
            +
                    mw, mk, mv, mr, mg = xxx.unbind(dim=0)
         | 
| 831 | 
            +
             | 
| 832 | 
            +
                    wx = xx + sx * (w_maa + mw)
         | 
| 833 | 
            +
                    kx = xx + sx * (k_maa + mk)
         | 
| 834 | 
            +
                    vx = xx + sx * (v_maa + mv)
         | 
| 835 | 
            +
                    rx = xx + sx * (r_maa + mr)
         | 
| 836 | 
            +
                    gx = xx + sx * (g_maa + mg)
         | 
| 837 | 
            +
             | 
| 838 | 
            +
                    H = t_decay.shape[0]
         | 
| 839 | 
            +
                    N = x.shape[-1] // H
         | 
| 840 | 
            +
             | 
| 841 | 
            +
                    r = matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32).view(H, 1, N)
         | 
| 842 | 
            +
                    k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32).view(H, N, 1)
         | 
| 843 | 
            +
                    v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32).view(H, 1, N)
         | 
| 844 | 
            +
                    g = F.silu(matmul(gx, gw, gmx, grx, gmy, gry))
         | 
| 845 | 
            +
                    
         | 
| 846 | 
            +
                    w = t_decay + (torch.tanh(wx @ td_w1) @ td_w2).float().view(H, N, 1)
         | 
| 847 | 
            +
                    w = torch.exp(-torch.exp(w.float()))
         | 
| 848 | 
            +
             | 
| 849 | 
            +
                    a = matmul(k, v)
         | 
| 850 | 
            +
                    out = r @ (t_first * a + s)
         | 
| 851 | 
            +
                    s = a + w * s
         | 
| 852 | 
            +
             | 
| 853 | 
            +
                    out = out.flatten()
         | 
| 854 | 
            +
                    out = F.group_norm(out.unsqueeze(0), num_groups=H, weight=lx_w, bias=lx_b, eps = 64e-5).squeeze(0)
         | 
| 855 | 
            +
                    out = out.to(dtype=x.dtype) * g
         | 
| 856 | 
            +
                    out = matmul(out, ow, omx, orx, omy, ory)
         | 
| 857 | 
            +
             | 
| 858 | 
            +
                    return x + out, xx, s
         | 
| 859 | 
            +
             | 
| 860 | 
            +
                @MyFunction
         | 
| 861 | 
            +
                def att_seq_v6_0(self, x, sx, s, ln_w, ln_b, lx_w, lx_b, x_maa, w_maa, k_maa, v_maa, r_maa, g_maa, tm_w1, tm_w2, td_w1, td_w2, t_decay, t_first, kw, vw, rw, gw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, gmx, grx, gmy, gry, omx, orx, omy, ory):
         | 
| 862 | 
            +
                    H = t_decay.shape[0]
         | 
| 863 | 
            +
                    N = x.shape[-1] // H
         | 
| 864 | 
            +
                    T = x.shape[0]
         | 
| 865 | 
            +
             | 
| 866 | 
            +
                    xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
         | 
| 867 | 
            +
                    sx = torch.cat((sx.unsqueeze(0), xx[:-1,:])) - xx
         | 
| 868 | 
            +
                    xxx = xx + sx * x_maa
         | 
| 869 | 
            +
                    xxx = torch.tanh(xxx @ tm_w1).view(T, 5, -1).transpose(0, 1)
         | 
| 870 | 
            +
                    xxx = torch.bmm(xxx, tm_w2).view(5, T, -1)
         | 
| 871 | 
            +
                    mw, mk, mv, mr, mg = xxx.unbind(dim=0)
         | 
| 872 | 
            +
             | 
| 873 | 
            +
                    wx = xx + sx * (w_maa + mw)
         | 
| 874 | 
            +
                    kx = xx + sx * (k_maa + mk)
         | 
| 875 | 
            +
                    vx = xx + sx * (v_maa + mv)
         | 
| 876 | 
            +
                    rx = xx + sx * (r_maa + mr)
         | 
| 877 | 
            +
                    gx = xx + sx * (g_maa + mg)
         | 
| 878 | 
            +
             | 
| 879 | 
            +
                    r = matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32).view(T, H, N).transpose(0, 1)
         | 
| 880 | 
            +
                    k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32).view(T, H, N).permute(1, 2, 0)
         | 
| 881 | 
            +
                    v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32).view(T, H, N).transpose(0, 1)
         | 
| 882 | 
            +
                    g = F.silu(matmul(gx, gw, gmx, grx, gmy, gry))
         | 
| 883 | 
            +
             | 
| 884 | 
            +
                    w = t_decay.view(1, H, N, 1) + (torch.tanh(wx @ td_w1) @ td_w2).float().view(T, H, N, 1)
         | 
| 885 | 
            +
                    w = torch.exp(-torch.exp(w.float()))
         | 
| 886 | 
            +
                    out = torch.empty((T, H, N), dtype=r.dtype, device=r.device)
         | 
| 887 | 
            +
                    for t in range(T):
         | 
| 888 | 
            +
                        rt = r[:,t:t+1,:]
         | 
| 889 | 
            +
                        kt = k[:,:,t:t+1]
         | 
| 890 | 
            +
                        vt = v[:,t:t+1,:]
         | 
| 891 | 
            +
                        at = matmul(kt, vt)
         | 
| 892 | 
            +
                        out[t] = (rt @ (t_first * at + s)).squeeze(1)
         | 
| 893 | 
            +
                        s = at + w[t] * s
         | 
| 894 | 
            +
             | 
| 895 | 
            +
                    out = out.reshape(T, H*N)
         | 
| 896 | 
            +
                    out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b, eps = 64e-5)
         | 
| 897 | 
            +
                    out = out.to(dtype=x.dtype) * g
         | 
| 898 | 
            +
                    out = matmul(out, ow, omx, orx, omy, ory)
         | 
| 899 | 
            +
             | 
| 900 | 
            +
                    return x + out, xx[-1,:], s
         | 
| 901 | 
            +
             | 
| 902 | 
            +
                ########################################################################################################
         | 
| 903 | 
            +
             | 
| 904 | 
            +
                if os.environ["RWKV_CUDA_ON"] == '1':
         | 
| 905 | 
            +
                    @MyFunction
         | 
| 906 | 
            +
                    def cuda_att_seq(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t_first, kw, vw, rw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, omx, orx, omy, ory):
         | 
| 907 | 
            +
                        T, C = x.shape
         | 
| 908 | 
            +
                        xx = F.layer_norm(x, (C,), weight=ln_w, bias=ln_b)
         | 
| 909 | 
            +
                        sx = torch.cat((sx.unsqueeze(0), xx[:-1,:]))
         | 
| 910 | 
            +
                        kx = xx * k_mix + sx * (1 - k_mix)
         | 
| 911 | 
            +
                        vx = xx * v_mix + sx * (1 - v_mix)
         | 
| 912 | 
            +
                        rx = xx * r_mix + sx * (1 - r_mix)
         | 
| 913 | 
            +
             | 
| 914 | 
            +
                        r = torch.sigmoid(matmul(rx, rw, rmx, rrx, rmy, rry))
         | 
| 915 | 
            +
                        k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32)
         | 
| 916 | 
            +
                        v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32)
         | 
| 917 | 
            +
                        y, aa, bb, pp = cuda_wkv(T, C, t_decay, t_first, k, v, aa, bb, pp)
         | 
| 918 | 
            +
             | 
| 919 | 
            +
                        out = matmul(r * y.to(x.dtype), ow, omx, orx, omy, ory)
         | 
| 920 | 
            +
                        return x + out, xx[-1,:], aa, bb, pp
         | 
| 921 | 
            +
             | 
| 922 | 
            +
                    @MyFunction
         | 
| 923 | 
            +
                    def v5_2_before(self, x, sx, s, ln_w, ln_b, lx_w, lx_b, k_mix, v_mix, r_mix, g_mix, t_decay, t_first, kw, vw, rw, gw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, gmx, grx, gmy, gry, omx, orx, omy, ory):
         | 
| 924 | 
            +
                        xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
         | 
| 925 | 
            +
                        sx = torch.cat((sx.unsqueeze(0), xx[:-1,:]))
         | 
| 926 | 
            +
                        kx = xx * k_mix + sx * (1 - k_mix)
         | 
| 927 | 
            +
                        vx = xx * v_mix + sx * (1 - v_mix)
         | 
| 928 | 
            +
                        rx = xx * r_mix + sx * (1 - r_mix)
         | 
| 929 | 
            +
                        gx = xx * g_mix + sx * (1 - g_mix)
         | 
| 930 | 
            +
             | 
| 931 | 
            +
                        r = matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32)
         | 
| 932 | 
            +
                        k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32)
         | 
| 933 | 
            +
                        v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32)
         | 
| 934 | 
            +
                        g = F.silu(matmul(gx, gw, gmx, grx, gmy, gry))
         | 
| 935 | 
            +
             | 
| 936 | 
            +
                        return r, k, v, g, xx[-1,:], s.transpose(-1,-2).contiguous()
         | 
| 937 | 
            +
             | 
| 938 | 
            +
                    @MyFunction
         | 
| 939 | 
            +
                    def v5_2_after(self, t_decay, out, s, x, xxx, g, lx_w, lx_b, ow, omx, orx, omy, ory):
         | 
| 940 | 
            +
                        H = t_decay.shape[0]
         | 
| 941 | 
            +
                        N = x.shape[-1] // H
         | 
| 942 | 
            +
                        T = x.shape[0]
         | 
| 943 | 
            +
             | 
| 944 | 
            +
                        s = s.transpose(-1,-2)
         | 
| 945 | 
            +
                        out = out.reshape(T, H*N)
         | 
| 946 | 
            +
                        out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b, eps = 64e-5)
         | 
| 947 | 
            +
                        out = out.to(dtype=x.dtype) * g
         | 
| 948 | 
            +
                        out = matmul(out, ow, omx, orx, omy, ory)
         | 
| 949 | 
            +
             | 
| 950 | 
            +
                        return x + out, xxx, s
         | 
| 951 | 
            +
             | 
| 952 | 
            +
                    def cuda_att_seq_v5_2(self, x, sx, s, ln_w, ln_b, lx_w, lx_b, k_mix, v_mix, r_mix, g_mix, t_decay, t_first, kw, vw, rw, gw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, gmx, grx, gmy, gry, omx, orx, omy, ory):
         | 
| 953 | 
            +
                        H = t_decay.shape[0]
         | 
| 954 | 
            +
                        N = x.shape[-1] // H
         | 
| 955 | 
            +
                        T = x.shape[0]
         | 
| 956 | 
            +
             | 
| 957 | 
            +
                        r, k, v, g, xxx, ss = self.v5_2_before(x, sx, s, ln_w, ln_b, lx_w, lx_b, k_mix, v_mix, r_mix, g_mix, t_decay, t_first, kw, vw, rw, gw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, gmx, grx, gmy, gry, omx, orx, omy, ory)
         | 
| 958 | 
            +
             | 
| 959 | 
            +
                        out, s = self.RUN_RWKV_5(1, T, self.args.n_att, H, ss, r, k, v, w=t_decay, u=t_first)
         | 
| 960 | 
            +
             | 
| 961 | 
            +
                        return self.v5_2_after(t_decay, out, s, x, xxx, g, lx_w, lx_b, ow, omx, orx, omy, ory)
         | 
| 962 | 
            +
             | 
| 963 | 
            +
                    @MyFunction
         | 
| 964 | 
            +
                    def v6_0_before(self, x, sx, s, ln_w, ln_b, lx_w, lx_b, x_maa, w_maa, k_maa, v_maa, r_maa, g_maa, tm_w1, tm_w2, td_w1, td_w2, t_decay, t_first, kw, vw, rw, gw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, gmx, grx, gmy, gry, omx, orx, omy, ory):
         | 
| 965 | 
            +
                        H = t_decay.shape[0]
         | 
| 966 | 
            +
                        N = x.shape[-1] // H
         | 
| 967 | 
            +
                        T = x.shape[0]
         | 
| 968 | 
            +
             | 
| 969 | 
            +
                        xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
         | 
| 970 | 
            +
                        sx = torch.cat((sx.unsqueeze(0), xx[:-1,:])) - xx
         | 
| 971 | 
            +
                        xxx = xx + sx * x_maa
         | 
| 972 | 
            +
                        xxx = torch.tanh(xxx @ tm_w1).view(T, 5, -1).transpose(0, 1)
         | 
| 973 | 
            +
                        xxx = torch.bmm(xxx, tm_w2).view(5, T, -1)
         | 
| 974 | 
            +
                        mw, mk, mv, mr, mg = xxx.unbind(dim=0)
         | 
| 975 | 
            +
             | 
| 976 | 
            +
                        wx = xx + sx * (w_maa + mw)
         | 
| 977 | 
            +
                        kx = xx + sx * (k_maa + mk)
         | 
| 978 | 
            +
                        vx = xx + sx * (v_maa + mv)
         | 
| 979 | 
            +
                        rx = xx + sx * (r_maa + mr)
         | 
| 980 | 
            +
                        gx = xx + sx * (g_maa + mg)
         | 
| 981 | 
            +
             | 
| 982 | 
            +
                        r = matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32)
         | 
| 983 | 
            +
                        k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32)
         | 
| 984 | 
            +
                        v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32)
         | 
| 985 | 
            +
                        g = F.silu(matmul(gx, gw, gmx, grx, gmy, gry))
         | 
| 986 | 
            +
             | 
| 987 | 
            +
                        w = t_decay.view(1, H, N, 1) + (torch.tanh(wx @ td_w1) @ td_w2).float().view(T, H, N, 1)
         | 
| 988 | 
            +
             | 
| 989 | 
            +
                        return r, k, v, g, w, xx[-1,:], s.transpose(-1,-2).contiguous()
         | 
| 990 | 
            +
             | 
| 991 | 
            +
                    def cuda_att_seq_v6_0(self, x, sx, s, ln_w, ln_b, lx_w, lx_b, x_maa, w_maa, k_maa, v_maa, r_maa, g_maa, tm_w1, tm_w2, td_w1, td_w2, t_decay, t_first, kw, vw, rw, gw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, gmx, grx, gmy, gry, omx, orx, omy, ory):
         | 
| 992 | 
            +
                        H = t_decay.shape[0]
         | 
| 993 | 
            +
                        N = x.shape[-1] // H
         | 
| 994 | 
            +
                        T = x.shape[0]
         | 
| 995 | 
            +
             | 
| 996 | 
            +
                        r, k, v, g, w, xxx, ss = self.v6_0_before(x, sx, s, ln_w, ln_b, lx_w, lx_b, x_maa, w_maa, k_maa, v_maa, r_maa, g_maa, tm_w1, tm_w2, td_w1, td_w2, t_decay, t_first, kw, vw, rw, gw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, gmx, grx, gmy, gry, omx, orx, omy, ory)
         | 
| 997 | 
            +
             | 
| 998 | 
            +
                        out, s = self.RUN_RWKV_6(1, T, self.args.n_att, H, ss, r, k, v, w=w, u=t_first)
         | 
| 999 | 
            +
                        return self.v5_2_after(t_decay, out, s, x, xxx, g, lx_w, lx_b, ow, omx, orx, omy, ory)
         | 
| 1000 | 
            +
             | 
| 1001 | 
            +
                ########################################################################################################
         | 
| 1002 | 
            +
             | 
| 1003 | 
            +
                def forward(self, tokens=None, state=None, full_output=False, embs=None):
         | 
| 1004 | 
            +
                    with torch.no_grad():
         | 
| 1005 | 
            +
                        w = self.w
         | 
| 1006 | 
            +
                        args = self.args
         | 
| 1007 | 
            +
             | 
| 1008 | 
            +
                        if state == None:
         | 
| 1009 | 
            +
                            if self.version == 4:
         | 
| 1010 | 
            +
                                state = [None] * args.n_layer * 5
         | 
| 1011 | 
            +
                                for i in range(args.n_layer): # state: 0=att_xx 1=att_aa 2=att_bb 3=att_pp 4=ffn_xx
         | 
| 1012 | 
            +
                                    dd = self.strategy[i]
         | 
| 1013 | 
            +
                                    dev = dd.device
         | 
| 1014 | 
            +
                                    atype = dd.atype
         | 
| 1015 | 
            +
                                    state[i*5+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
         | 
| 1016 | 
            +
                                    state[i*5+1] = torch.zeros(args.n_att, dtype=torch.float, requires_grad=False, device=dev).contiguous()
         | 
| 1017 | 
            +
                                    state[i*5+2] = torch.zeros(args.n_att, dtype=torch.float, requires_grad=False, device=dev).contiguous()
         | 
| 1018 | 
            +
                                    state[i*5+3] = torch.zeros(args.n_att, dtype=torch.float, requires_grad=False, device=dev).contiguous() - 1e30
         | 
| 1019 | 
            +
                                    state[i*5+4] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
         | 
| 1020 | 
            +
                            elif int(self.version) in [5,6]:
         | 
| 1021 | 
            +
                                state = [None] * args.n_layer * 3
         | 
| 1022 | 
            +
                                for i in range(args.n_layer): # state: 0=att_xx 1=att_kv 2=ffn_xx
         | 
| 1023 | 
            +
                                    dd = self.strategy[i]
         | 
| 1024 | 
            +
                                    dev = dd.device
         | 
| 1025 | 
            +
                                    atype = dd.atype
         | 
| 1026 | 
            +
                                    state[i*3+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
         | 
| 1027 | 
            +
                                    state[i*3+1] = torch.zeros((args.n_head, args.n_att//args.n_head, args.n_att//args.n_head), dtype=torch.float, requires_grad=False, device=dev).contiguous()
         | 
| 1028 | 
            +
                                    state[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous()
         | 
| 1029 | 
            +
             | 
| 1030 | 
            +
                        if embs is None:
         | 
| 1031 | 
            +
                            seq_mode = len(tokens) > 1
         | 
| 1032 | 
            +
                            x = w['emb.weight'][tokens if seq_mode else tokens[0]]
         | 
| 1033 | 
            +
                        else:
         | 
| 1034 | 
            +
                            x = embs
         | 
| 1035 | 
            +
                            seq_mode = True
         | 
| 1036 | 
            +
             | 
| 1037 | 
            +
                        for i in range(args.n_layer):
         | 
| 1038 | 
            +
                            bbb = f'blocks.{i}.'
         | 
| 1039 | 
            +
                            att = f'blocks.{i}.att.'
         | 
| 1040 | 
            +
                            ffn = f'blocks.{i}.ffn.'
         | 
| 1041 | 
            +
                            dd = self.strategy[i]
         | 
| 1042 | 
            +
                            dev = dd.device
         | 
| 1043 | 
            +
                            atype = dd.atype
         | 
| 1044 | 
            +
                            wtype = dd.wtype
         | 
| 1045 | 
            +
                            if seq_mode:
         | 
| 1046 | 
            +
                                cuda_applicable = os.environ["RWKV_CUDA_ON"] == '1' and 'cuda' in str(dev)
         | 
| 1047 | 
            +
                                if cuda_applicable:
         | 
| 1048 | 
            +
                                    ATT = self.cuda_att_seq
         | 
| 1049 | 
            +
                                else:
         | 
| 1050 | 
            +
                                    ATT = self.att_seq
         | 
| 1051 | 
            +
                                if self.version == 5:
         | 
| 1052 | 
            +
                                    ATT = self.att_seq_v5
         | 
| 1053 | 
            +
                                elif self.version == 5.1:
         | 
| 1054 | 
            +
                                    ATT = self.att_seq_v5_1
         | 
| 1055 | 
            +
                                elif self.version == 5.2:
         | 
| 1056 | 
            +
                                    ATT = self.att_seq_v5_2
         | 
| 1057 | 
            +
                                    if cuda_applicable:
         | 
| 1058 | 
            +
                                        ATT = self.cuda_att_seq_v5_2
         | 
| 1059 | 
            +
                                elif self.version == 6.0:
         | 
| 1060 | 
            +
                                    ATT = self.att_seq_v6_0
         | 
| 1061 | 
            +
                                    if cuda_applicable:
         | 
| 1062 | 
            +
                                        ATT = self.cuda_att_seq_v6_0
         | 
| 1063 | 
            +
                                FFN = self.ffn_seq
         | 
| 1064 | 
            +
                                if self.version >= 6.0:
         | 
| 1065 | 
            +
                                    FFN = self.ffn_seq_v6
         | 
| 1066 | 
            +
                            else:
         | 
| 1067 | 
            +
                                ATT = self.att_one
         | 
| 1068 | 
            +
                                if self.version == 5:
         | 
| 1069 | 
            +
                                    ATT = self.att_one_v5
         | 
| 1070 | 
            +
                                elif self.version == 5.1:
         | 
| 1071 | 
            +
                                    ATT = self.att_one_v5_1
         | 
| 1072 | 
            +
                                elif self.version == 5.2:
         | 
| 1073 | 
            +
                                    ATT = self.att_one_v5_1 # same as v5.1
         | 
| 1074 | 
            +
                                elif self.version == 6.0:
         | 
| 1075 | 
            +
                                    ATT = self.att_one_v6_0
         | 
| 1076 | 
            +
                                FFN = self.ffn_one
         | 
| 1077 | 
            +
                                if self.version >= 6.0:
         | 
| 1078 | 
            +
                                    FFN = self.ffn_one_v6
         | 
| 1079 | 
            +
             | 
| 1080 | 
            +
                            x = x.to(dtype=atype, device=dev)
         | 
| 1081 | 
            +
             | 
| 1082 | 
            +
                            kw = w[f'{att}key.weight']
         | 
| 1083 | 
            +
                            vw = w[f'{att}value.weight']
         | 
| 1084 | 
            +
                            rw = w[f'{att}receptance.weight']
         | 
| 1085 | 
            +
                            ow = w[f'{att}output.weight']
         | 
| 1086 | 
            +
                            if dd.stream:
         | 
| 1087 | 
            +
                                kw = kw.to(device=dev, non_blocking=True)
         | 
| 1088 | 
            +
                                vw = vw.to(device=dev, non_blocking=True)
         | 
| 1089 | 
            +
                                rw = rw.to(device=dev, non_blocking=True)
         | 
| 1090 | 
            +
                                ow = ow.to(device=dev, non_blocking=True)
         | 
| 1091 | 
            +
                            kmx = w[f'{att}key.weight_mx'] if wtype == torch.uint8 else x
         | 
| 1092 | 
            +
                            krx = w[f'{att}key.weight_rx'] if wtype == torch.uint8 else x
         | 
| 1093 | 
            +
                            kmy = w[f'{att}key.weight_my'] if wtype == torch.uint8 else x
         | 
| 1094 | 
            +
                            kry = w[f'{att}key.weight_ry'] if wtype == torch.uint8 else x
         | 
| 1095 | 
            +
                            vmx = w[f'{att}value.weight_mx'] if wtype == torch.uint8 else x
         | 
| 1096 | 
            +
                            vrx = w[f'{att}value.weight_rx'] if wtype == torch.uint8 else x
         | 
| 1097 | 
            +
                            vmy = w[f'{att}value.weight_my'] if wtype == torch.uint8 else x
         | 
| 1098 | 
            +
                            vry = w[f'{att}value.weight_ry'] if wtype == torch.uint8 else x
         | 
| 1099 | 
            +
                            rmx = w[f'{att}receptance.weight_mx'] if wtype == torch.uint8 else x
         | 
| 1100 | 
            +
                            rrx = w[f'{att}receptance.weight_rx'] if wtype == torch.uint8 else x
         | 
| 1101 | 
            +
                            rmy = w[f'{att}receptance.weight_my'] if wtype == torch.uint8 else x
         | 
| 1102 | 
            +
                            rry = w[f'{att}receptance.weight_ry'] if wtype == torch.uint8 else x
         | 
| 1103 | 
            +
                            omx = w[f'{att}output.weight_mx'] if wtype == torch.uint8 else x
         | 
| 1104 | 
            +
                            orx = w[f'{att}output.weight_rx'] if wtype == torch.uint8 else x
         | 
| 1105 | 
            +
                            omy = w[f'{att}output.weight_my'] if wtype == torch.uint8 else x
         | 
| 1106 | 
            +
                            ory = w[f'{att}output.weight_ry'] if wtype == torch.uint8 else x
         | 
| 1107 | 
            +
                            if self.version in [5.1, 5.2, 6.0]:
         | 
| 1108 | 
            +
                                gw = w[f'{att}gate.weight']
         | 
| 1109 | 
            +
                                if dd.stream:
         | 
| 1110 | 
            +
                                    gw = gw.to(device=dev, non_blocking=True)
         | 
| 1111 | 
            +
                                gmx = w[f'{att}gate.weight_mx'] if wtype == torch.uint8 else x
         | 
| 1112 | 
            +
                                grx = w[f'{att}gate.weight_rx'] if wtype == torch.uint8 else x
         | 
| 1113 | 
            +
                                gmy = w[f'{att}gate.weight_my'] if wtype == torch.uint8 else x
         | 
| 1114 | 
            +
                                gry = w[f'{att}gate.weight_ry'] if wtype == torch.uint8 else x
         | 
| 1115 | 
            +
                            if self.version == 4:
         | 
| 1116 | 
            +
                                x, state[i*5+0], state[i*5+1], state[i*5+2], state[i*5+3] = ATT(
         | 
| 1117 | 
            +
                                    x, state[i*5+0], state[i*5+1], state[i*5+2], state[i*5+3],
         | 
| 1118 | 
            +
                                    w[f'{bbb}ln1.weight'], w[f'{bbb}ln1.bias'],
         | 
| 1119 | 
            +
                                    w[f'{att}time_mix_k'], w[f'{att}time_mix_v'], w[f'{att}time_mix_r'],
         | 
| 1120 | 
            +
                                    w[f'{att}time_decay'], w[f'{att}time_first'],
         | 
| 1121 | 
            +
                                    kw, vw, rw, ow,
         | 
| 1122 | 
            +
                                    kmx, krx, kmy, kry,
         | 
| 1123 | 
            +
                                    vmx, vrx, vmy, vry,
         | 
| 1124 | 
            +
                                    rmx, rrx, rmy, rry,
         | 
| 1125 | 
            +
                                    omx, orx, omy, ory,
         | 
| 1126 | 
            +
                                    )
         | 
| 1127 | 
            +
                            elif self.version == 5:
         | 
| 1128 | 
            +
                                x, state[i*3+0], state[i*3+1] = ATT(
         | 
| 1129 | 
            +
                                    x, state[i*3+0], state[i*3+1],
         | 
| 1130 | 
            +
                                    w[f'{bbb}ln1.weight'], w[f'{bbb}ln1.bias'],
         | 
| 1131 | 
            +
                                    w[f'{att}ln_x.weight'], w[f'{att}ln_x.bias'],
         | 
| 1132 | 
            +
                                    w[f'{att}time_mix_k'], w[f'{att}time_mix_v'], w[f'{att}time_mix_r'],
         | 
| 1133 | 
            +
                                    w[f'{att}time_decay'], w[f'{att}time_first'],
         | 
| 1134 | 
            +
                                    kw, vw, rw, ow,
         | 
| 1135 | 
            +
                                    kmx, krx, kmy, kry,
         | 
| 1136 | 
            +
                                    vmx, vrx, vmy, vry,
         | 
| 1137 | 
            +
                                    rmx, rrx, rmy, rry,
         | 
| 1138 | 
            +
                                    omx, orx, omy, ory,
         | 
| 1139 | 
            +
                                    )
         | 
| 1140 | 
            +
                            elif self.version in [5.1, 5.2]:
         | 
| 1141 | 
            +
                                x, state[i*3+0], state[i*3+1] = ATT(
         | 
| 1142 | 
            +
                                    x, state[i*3+0], state[i*3+1],
         | 
| 1143 | 
            +
                                    w[f'{bbb}ln1.weight'], w[f'{bbb}ln1.bias'],
         | 
| 1144 | 
            +
                                    w[f'{att}ln_x.weight'], w[f'{att}ln_x.bias'],
         | 
| 1145 | 
            +
                                    w[f'{att}time_mix_k'], w[f'{att}time_mix_v'], w[f'{att}time_mix_r'], w[f'{att}time_mix_g'],
         | 
| 1146 | 
            +
                                    w[f'{att}time_decay'], w[f'{att}time_first'],
         | 
| 1147 | 
            +
                                    kw, vw, rw, gw, ow,
         | 
| 1148 | 
            +
                                    kmx, krx, kmy, kry,
         | 
| 1149 | 
            +
                                    vmx, vrx, vmy, vry,
         | 
| 1150 | 
            +
                                    rmx, rrx, rmy, rry,
         | 
| 1151 | 
            +
                                    gmx, grx, gmy, gry,
         | 
| 1152 | 
            +
                                    omx, orx, omy, ory,
         | 
| 1153 | 
            +
                                    )
         | 
| 1154 | 
            +
                            elif self.version == 6.0:
         | 
| 1155 | 
            +
                                x, state[i*3+0], state[i*3+1] = ATT(
         | 
| 1156 | 
            +
                                    x, state[i*3+0], state[i*3+1],
         | 
| 1157 | 
            +
                                    w[f'{bbb}ln1.weight'], w[f'{bbb}ln1.bias'],
         | 
| 1158 | 
            +
                                    w[f'{att}ln_x.weight'], w[f'{att}ln_x.bias'],
         | 
| 1159 | 
            +
                                    w[f'{att}time_maa_x'], w[f'{att}time_maa_w'], w[f'{att}time_maa_k'], w[f'{att}time_maa_v'], w[f'{att}time_maa_r'], w[f'{att}time_maa_g'],
         | 
| 1160 | 
            +
                                    w[f'{att}time_maa_w1'], w[f'{att}time_maa_w2'], w[f'{att}time_decay_w1'], w[f'{att}time_decay_w2'],
         | 
| 1161 | 
            +
                                    w[f'{att}time_decay'], w[f'{att}time_first'],
         | 
| 1162 | 
            +
                                    kw, vw, rw, gw, ow,
         | 
| 1163 | 
            +
                                    kmx, krx, kmy, kry,
         | 
| 1164 | 
            +
                                    vmx, vrx, vmy, vry,
         | 
| 1165 | 
            +
                                    rmx, rrx, rmy, rry,
         | 
| 1166 | 
            +
                                    gmx, grx, gmy, gry,
         | 
| 1167 | 
            +
                                    omx, orx, omy, ory,
         | 
| 1168 | 
            +
                                    )
         | 
| 1169 | 
            +
                            if dd.stream:
         | 
| 1170 | 
            +
                                del kw, vw, rw, ow
         | 
| 1171 | 
            +
                                if self.version in [5.1, 5.2, 6.0]:
         | 
| 1172 | 
            +
                                    del gw
         | 
| 1173 | 
            +
             | 
| 1174 | 
            +
                            kw = w[f'{ffn}key.weight']
         | 
| 1175 | 
            +
                            vw = w[f'{ffn}value.weight']
         | 
| 1176 | 
            +
                            rw = w[f'{ffn}receptance.weight']
         | 
| 1177 | 
            +
                            if dd.stream:
         | 
| 1178 | 
            +
                                kw = kw.to(device=dev, non_blocking=True)
         | 
| 1179 | 
            +
                                vw = vw.to(device=dev, non_blocking=True)
         | 
| 1180 | 
            +
                                rw = rw.to(device=dev, non_blocking=True)
         | 
| 1181 | 
            +
                            kmx = w[f'{ffn}key.weight_mx'] if wtype == torch.uint8 else x
         | 
| 1182 | 
            +
                            krx = w[f'{ffn}key.weight_rx'] if wtype == torch.uint8 else x
         | 
| 1183 | 
            +
                            kmy = w[f'{ffn}key.weight_my'] if wtype == torch.uint8 else x
         | 
| 1184 | 
            +
                            kry = w[f'{ffn}key.weight_ry'] if wtype == torch.uint8 else x
         | 
| 1185 | 
            +
                            vmx = w[f'{ffn}value.weight_mx'] if wtype == torch.uint8 else x
         | 
| 1186 | 
            +
                            vrx = w[f'{ffn}value.weight_rx'] if wtype == torch.uint8 else x
         | 
| 1187 | 
            +
                            vmy = w[f'{ffn}value.weight_my'] if wtype == torch.uint8 else x
         | 
| 1188 | 
            +
                            vry = w[f'{ffn}value.weight_ry'] if wtype == torch.uint8 else x
         | 
| 1189 | 
            +
                            rmx = w[f'{ffn}receptance.weight_mx'] if wtype == torch.uint8 else x
         | 
| 1190 | 
            +
                            rrx = w[f'{ffn}receptance.weight_rx'] if wtype == torch.uint8 else x
         | 
| 1191 | 
            +
                            rmy = w[f'{ffn}receptance.weight_my'] if wtype == torch.uint8 else x
         | 
| 1192 | 
            +
                            rry = w[f'{ffn}receptance.weight_ry'] if wtype == torch.uint8 else x
         | 
| 1193 | 
            +
                            if self.version == 4:
         | 
| 1194 | 
            +
                                offset = i*5+4
         | 
| 1195 | 
            +
                            elif int(self.version) in [5,6]:
         | 
| 1196 | 
            +
                                offset = i*3+2
         | 
| 1197 | 
            +
                            if self.version < 6.0:
         | 
| 1198 | 
            +
                                x, state[offset] = FFN(
         | 
| 1199 | 
            +
                                    x, state[offset],
         | 
| 1200 | 
            +
                                    w[f'{bbb}ln2.weight'], w[f'{bbb}ln2.bias'],
         | 
| 1201 | 
            +
                                    w[f'{ffn}time_mix_k'], w[f'{ffn}time_mix_r'],
         | 
| 1202 | 
            +
                                    kw, vw, rw,
         | 
| 1203 | 
            +
                                    kmx, krx, kmy, kry,
         | 
| 1204 | 
            +
                                    vmx, vrx, vmy, vry,
         | 
| 1205 | 
            +
                                    rmx, rrx, rmy, rry,                    
         | 
| 1206 | 
            +
                                    )
         | 
| 1207 | 
            +
                            else:
         | 
| 1208 | 
            +
                                x, state[offset] = FFN(
         | 
| 1209 | 
            +
                                    x, state[offset],
         | 
| 1210 | 
            +
                                    w[f'{bbb}ln2.weight'], w[f'{bbb}ln2.bias'],
         | 
| 1211 | 
            +
                                    w[f'{ffn}time_maa_k'], w[f'{ffn}time_maa_r'],
         | 
| 1212 | 
            +
                                    kw, vw, rw,
         | 
| 1213 | 
            +
                                    kmx, krx, kmy, kry,
         | 
| 1214 | 
            +
                                    vmx, vrx, vmy, vry,
         | 
| 1215 | 
            +
                                    rmx, rrx, rmy, rry,                    
         | 
| 1216 | 
            +
                                    )
         | 
| 1217 | 
            +
                            if dd.stream:                
         | 
| 1218 | 
            +
                                del kw, vw, rw
         | 
| 1219 | 
            +
                            
         | 
| 1220 | 
            +
                            if self.RESCALE_LAYER > 0:
         | 
| 1221 | 
            +
                                if (i+1) % self.RESCALE_LAYER == 0:
         | 
| 1222 | 
            +
                                    x = x / 2
         | 
| 1223 | 
            +
                        
         | 
| 1224 | 
            +
                        dd = self.strategy[args.n_layer]
         | 
| 1225 | 
            +
                        x = x[-1,:] if (seq_mode and (not full_output)) else x
         | 
| 1226 | 
            +
                        x = x.to(dtype=dd.atype, device=dd.device)
         | 
| 1227 | 
            +
                        
         | 
| 1228 | 
            +
                        x = F.layer_norm(x, (args.n_embd,), weight=w['ln_out.weight'], bias=w['ln_out.bias'])
         | 
| 1229 | 
            +
                        if w['head.weight'].dtype != torch.uint8:
         | 
| 1230 | 
            +
                            x = x @ w['head.weight']
         | 
| 1231 | 
            +
                        else:
         | 
| 1232 | 
            +
                            if seq_mode and full_output:
         | 
| 1233 | 
            +
                                x = mm8_seq(x, w['head.weight'], w['head.weight_mx'], w['head.weight_rx'], w['head.weight_my'], w['head.weight_ry'])
         | 
| 1234 | 
            +
                            else:
         | 
| 1235 | 
            +
                                x = mm8_one(x, w['head.weight'], w['head.weight_mx'], w['head.weight_rx'], w['head.weight_my'], w['head.weight_ry'])
         | 
| 1236 | 
            +
             | 
| 1237 | 
            +
                        return x.float(), state
         | 
    	
        modeling_vision.py
    ADDED
    
    | @@ -0,0 +1,48 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from transformers import CLIPVisionModel
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            import torch.nn as nn
         | 
| 4 | 
            +
            import torch.nn.functional as F
         | 
| 5 | 
            +
            from dataclasses import dataclass
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            @dataclass
         | 
| 8 | 
            +
            class VisionEncoderConfig:
         | 
| 9 | 
            +
                n_embd: int = 2048
         | 
| 10 | 
            +
                vision_tower_name: str = 'openai/clip-vit-large-patch14-336'
         | 
| 11 | 
            +
                grid_size: int = -1 # -1: no grid pooling, 0: take cls token, 1: global avg pooling, 2, 3, 4, ...: grid pooling
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            class VisionEncoder(nn.Module):
         | 
| 14 | 
            +
                def __init__(self, args):
         | 
| 15 | 
            +
                    super().__init__()
         | 
| 16 | 
            +
                    self.args = args
         | 
| 17 | 
            +
                    self.vit = CLIPVisionModel.from_pretrained(args.vision_tower_name)
         | 
| 18 | 
            +
                    self.proj = nn.Linear(self.vit.config.hidden_size, args.n_embd, bias=False)
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                def encode_images(self, images):
         | 
| 21 | 
            +
                    B, N, C, H, W = images.shape
         | 
| 22 | 
            +
                    images = images.view(B*N, C, H, W)
         | 
| 23 | 
            +
                    image_features = self.vit(images).last_hidden_state
         | 
| 24 | 
            +
                    L, D = image_features.shape[1], image_features.shape[2]
         | 
| 25 | 
            +
                    # rerange [B*N, L, D] -> [B, N, L, D]
         | 
| 26 | 
            +
                    image_features = image_features.view(B, N, L, D)[:, 0, :, :]
         | 
| 27 | 
            +
                    image_features = self.grid_pooling(image_features)
         | 
| 28 | 
            +
                    return self.proj(image_features)
         | 
| 29 | 
            +
                
         | 
| 30 | 
            +
                def grid_pooling(self, image_features):
         | 
| 31 | 
            +
                    if self.args.grid_size == -1: # no grid pooling
         | 
| 32 | 
            +
                        return image_features
         | 
| 33 | 
            +
                    if self.args.grid_size == 0: # take cls token
         | 
| 34 | 
            +
                        return image_features[:, 0:1, :]
         | 
| 35 | 
            +
                    if self.args.grid_size == 1: # global avg pooling
         | 
| 36 | 
            +
                        return image_features.mean(dim=1, keepdim=True)
         | 
| 37 | 
            +
                    cls_features = image_features[:, 0:1, :]
         | 
| 38 | 
            +
                    image_features = image_features[:, 1:, :] #drop cls token
         | 
| 39 | 
            +
                    B, L, D = image_features.shape
         | 
| 40 | 
            +
                    H_or_W = int(L**0.5)
         | 
| 41 | 
            +
                    image_features = image_features.view(B, H_or_W, H_or_W, D)
         | 
| 42 | 
            +
                    grid_stride = H_or_W // self.args.grid_size
         | 
| 43 | 
            +
                    image_features = F.avg_pool2d(image_features.permute(0, 3, 1, 2), 
         | 
| 44 | 
            +
                                                  padding=0,
         | 
| 45 | 
            +
                                                  kernel_size=grid_stride, 
         | 
| 46 | 
            +
                                                  stride=grid_stride)
         | 
| 47 | 
            +
                    image_features = image_features.permute(0, 2, 3, 1).view(B, -1, D)
         | 
| 48 | 
            +
                    return torch.cat((cls_features, image_features), dim=1)
         | 
    	
        requirements.txt
    CHANGED
    
    | @@ -1,8 +1,7 @@ | |
| 1 | 
            -
            gradio==3.28.1
         | 
| 2 | 
             
            torch
         | 
|  | |
| 3 | 
             
            ninja
         | 
| 4 | 
             
            tokenizers
         | 
| 5 | 
            -
            rwkv==0.8. | 
| 6 | 
             
            pynvml
         | 
| 7 | 
            -
            huggingface_hub
         | 
| 8 | 
            -
            gradio==3.28.1
         | 
|  | |
|  | |
| 1 | 
             
            torch
         | 
| 2 | 
            +
            transformers
         | 
| 3 | 
             
            ninja
         | 
| 4 | 
             
            tokenizers
         | 
| 5 | 
            +
            rwkv==0.8.22
         | 
| 6 | 
             
            pynvml
         | 
| 7 | 
            +
            huggingface_hub
         | 
|  | 
