RWKV LM

2025-12-10 0 840

RWKV: Parallelizable RNN with Transformer-level LLM Performance (pronounced as \”RwaKuv\” (rʌkuv in IPA), from 4 major params: R W K V)

RWKV website: https://*rwkv*.*com (with 90+ RWKV-related papers)

RWKV twitter: https://*twitter**.com/BlinkDL_AI (lastest news)

RWKV discord: https://dis*c**ord.gg/bDSBUMeFpc (9k+ members)

RWKV-7 \”Goose\” is the strongest linear-time & constant-space (no kv-cache) & attention-free & 100% RNN architecture on this planet at this moment, suitable for LLM and multimodal applications and more (see rwkv.com).

IMPORTANT: Use PreLN LayerNorm (instead of RMSNorm) for RWKV. I think it\’s related to better initial state, because I am not using trainable initial state (found it useless when using LayerNorm).

RWKV-7 is a meta-in-context learner, test-time-training its state on the context via in-context gradient descent at every token.

RWKV is a Linux Foundation AI project, so totally free. RWKV runtime is already in Windows & Office.

You are welcome to ask the RWKV community (such as RWKV discord) for advice on upgrading your attention/ssm models to rwkv7 models 🙂

===

Please use https://*github*.*com/BlinkDL/RWKV-LM/tree/main/RWKV-v7/train_temp as RWKV-7 reference implementation. The default config only requires 1 GPU with 10G VRAM (you can reduce bsz if you have less VRAM), so it\’s easy to test.

Note FLA RWKV-7 is NOT aligned with reference implementation yet, and you will get less performance.

This is because RWKV-7 is the whole model with carefully set stuffs, including different init / wd / lr for each parameter, so it\’s readily scalable and very stable (spike-free).

But the price to pay is there is no good simple \”RWKV-7 layer\” because a pytorch layer can\’t make sure itself is using correct init and hyperparameters.

So if you need to use RWKV-7 for another task, please study train_temp code (only several hundred lines) and change it to suit you.

===

RWKV-7 can do math. See https://git***hub.com/BlinkDL/RWKV-LM/blob/main/Research/rwkv7-g0-7.2b.md for details.

History of RWKV (from v1 to v7): https://wiki.r***wkv.com/advance/architecture.html (note: AI-written. might contain errors)

Gradio Demo 1: https://hu**ggingfac*e.co/spaces/BlinkDL/RWKV-Gradio-1

Gradio Demo 2: https://h*ugging*f*ace.co/spaces/BlinkDL/RWKV-Gradio-2

WebGPU Demo: https://cryscan.**gith*ub.io/web-rwkv-puzzles/#/chat

Latest RWKV weights: https://hug*gi*n*gface.co/BlinkDL

===

RWKV-Runner GUI: https://git*h*u*b.com/josStorer/RWKV-Runner/releases

Ai00 Server: https://g*it*hu*b.com/Ai00-X/ai00_server

RWKV pip pkg: https://p*ypi.o**rg/project/rwkv/

PEFT (Lora etc.): https://gi**t*hub.com/JL-er/RWKV-PEFT

RLHF: https://git*hu*b.*com/OpenMOSE/RWKV-LM-RLHF

400+ RWKV projects: https://gi*thub.c**om/search?o=desc&q=rwkv&s=updated&type=Repositories

Faster RWKV-7 kernels: https://g**ith*ub.com/johanwind/wind_rwkv

===

RWKV-5/6 Eagle/Finch paper: https://arxiv***.org/abs/2404.05892

Chat demo code: https://gi*thub.*c*om/BlinkDL/ChatRWKV/blob/main/API_DEMO_CHAT.py

RWKV-7 demo code: https://g*ithu*b*.com/BlinkDL/RWKV-LM/tree/main/RWKV-v7

https://*g*ithub.*com/BlinkDL/RWKV-LM/blob/main/RWKV-v7/rwkv_v7_demo.py (GPT-like mode)

https://*github**.com/BlinkDL/RWKV-LM/blob/main/RWKV-v7/rwkv_v7_demo_rnn.py (RNN mode)

https://*githu**b.com/BlinkDL/RWKV-LM/blob/main/RWKV-v7/rwkv_v7_demo_fast.py (Both mode, fastest)

RWKV-6 demo code: https://*gith*ub.co*m/BlinkDL/RWKV-LM/blob/main/RWKV-v5/rwkv_v6_demo.py

RWKV-6 demo code: https://g**ithub.*com/BlinkDL/ChatRWKV/blob/main/RWKV_v6_demo.py

HOW TO TRAIN RWKV-7/6/5 on MiniPile (1.5G tokens)

For reference, use python 3.10+, torch 2.5+, cuda 12.5+, latest deepspeed, but keep pytorch-lightning==1.9.5

Train RWKV-7:

# you can use latest torch + latest cuda (not limited to cu121)
pip install torch --upgrade --extra-index-url https://download.p**y*torch.org/whl/cu121
pip install pytorch-lightning==1.9.5 deepspeed wandb ninja --upgrade

# train RWKV-7
cd RWKV-v7/train_temp/ 

# download minipile .bin .idx to train_temp/data first (check demo-training-prepare.sh)
# this will generate the initial weight rwkv-init.pth in out/....../
sh ./demo-training-prepare.sh

# this will load rwkv-init.pth and train the model. you may want to log in to wandb first
sh ./demo-training-run.sh

your out/....../train_log.txt should have losses similar to:
0 4.875856 131.0863 0.00059975 2025-04-24 02:23:42.481256 0
1 4.028621 56.1834 0.00059899 2025-04-24 02:28:16.674463 1
2 3.801625 44.7739 0.00059773 2025-04-24 02:32:51.059568 2
3 3.663070 38.9808 0.00059597 2025-04-24 02:37:25.409892 3
4 3.578974 35.8368 0.00059371 2025-04-24 02:41:59.711315 4
5 3.510906 33.4786 0.00059096 2025-04-24 02:46:33.990839 5
6 3.462345 31.8917 0.00058771 2025-04-24 02:51:08.378331 6
7 3.412196 30.3318 0.00058399 2025-04-24 02:55:42.927474 7
8 3.376724 29.2747 0.00057978 2025-04-24 03:00:17.504665 8
9 3.336911 28.1321 0.00057511 2025-04-24 03:04:52.006063 9
10 3.313411 27.4787 0.00056999 2025-04-24 03:09:27.563336 10
11 3.295895 27.0016 0.00056441 2025-04-24 03:14:01.786079 11

RWKV-7 weight example for 1.5B (L24-D2048, vocab 65536):

name shape comment initialization
emb.weight [65536, 2048] wdecay see code
blocks.0.ln0.weight [2048] for layer 0 1
blocks.0.ln0.bias [2048] for layer 0 0
blocks.*.ln1.weight [2048] 1
blocks.*.ln1.bias [2048] 0
blocks.*.att.x_r [1, 1, 2048] see code
blocks.*.att.x_w [1, 1, 2048] see code
blocks.*.att.x_k [1, 1, 2048] see code
blocks.*.att.x_v [1, 1, 2048] see code
blocks.*.att.x_a [1, 1, 2048] see code
blocks.*.att.x_g [1, 1, 2048] see code
blocks.*.att.w0 [1, 1, 2048] lr 2x see code
blocks.*.att.w1 [2048, 96] 0
blocks.*.att.w2 [96, 2048] see code
blocks.*.att.a0 [1, 1, 2048] 0
blocks.*.att.a1 [2048, 96] 0
blocks.*.att.a2 [96, 2048] see code
blocks.*.att.v0 [1, 1, 2048] for layer 1+ 1
blocks.*.att.v1 [2048, 64] for layer 1+ 0
blocks.*.att.v2 [64, 2048] for layer 1+ see code
blocks.*.att.g1 [2048, 256] 0
blocks.*.att.g2 [256, 2048] see code
blocks.*.att.k_k [1, 1, 2048] 1
blocks.*.att.k_a [1, 1, 2048] 1
blocks.*.att.r_k [32, 64] 0
blocks.*.att.receptance.weight [2048, 2048] wdecay see code
blocks.*.att.key.weight [2048, 2048] wdecay see code
blocks.*.att.value.weight [2048, 2048] wdecay see code
blocks.*.att.output.weight [2048, 2048] wdecay 0
blocks.*.att.ln_x.weight [2048] see code
blocks.*.att.ln_x.bias [2048] 0
blocks.*.ln2.weight [2048] 1
blocks.*.ln2.bias [2048] 0
blocks.*.ffn.x_k [1, 1, 2048] see code
blocks.*.ffn.key.weight [8192, 2048] wdecay see code
blocks.*.ffn.value.weight [2048, 8192] wdecay 0
ln_out.weight [2048] 1
ln_out.bias [2048] 0
head.weight [65536, 2048] wdecay see code

Train RWKV-6: use /RWKV-v5/ and use –my_testing \”x060\” in demo-training-prepare.sh and demo-training-run.sh

Your loss curve should look almost exactly the same as this, with the same ups and downs (if you use the same bsz & config):

You can run your model using https://p*ypi.o**rg/project/rwkv/ (use \”rwkv_vocab_v20230424\” instead of \”20B_tokenizer.json\”)

Use https://g**i*thub.com/BlinkDL/RWKV-LM/blob/main/RWKV-v5/make_data.py to prepare binidx data from jsonl, and compute \”–my_exit_tokens\” and \”–magic_prime\”.

Use https://g**it*hub.com/BlinkDL/RWKV-LM/blob/main/RWKV-v5/compute_magic_prime.py to compute \”–my_exit_tokens\” and \”–magic_prime\” for existing binidx.

Much faster tokenizer of large data: https://gi*thub.c*o*m/cahya-wirawan/json2bin https://gi**thub*.com/cahya-wirawan/rwkv-tokenizer https://github*.co**m/m8than/RWKV-World-Tokenizer-CPP

The \”epoch\” in train.py is \”mini-epoch\” (not real epoch. only for convenience), and 1 mini-epoch = 40320 * ctx_len tokens.

For example, if your binidx has 1498226207 tokens and ctxlen=4096, set \”–my_exit_tokens 1498226207\” (this will override epoch_count), and it will be 1498226207/(40320 * 4096) = 9.07 miniepochs. The trainer will auto-exit after \”–my_exit_tokens\” tokens. Set \”–magic_prime\” to the largest 3n+2 prime smaller than datalen/ctxlen-1 (= 1498226207/4096-1 = 365776), which is \”–magic_prime 365759\” in this case.

simple: prepare SFT jsonl => repeat your SFT data 3 or 4 times in make_data.py. more repetition leads to overfitting.

advanced: repeat your SFT data 3 or 4 times in your jsonl (note make_data.py will shuffle all jsonl items) => add some base data (such as slimpajama) to your jsonl => and only repeat 1 times in make_data.py.

Fix training spikes: see the \”Fixing RWKV-6 Spikes\” part on this page.

Or use RWKV-7 (much better). RWKV-7 is very stable and spike-free (verified for 0.1/0.4/1.5/2.9b):

Simple inference for RWKV-6: https://g**ithub.*com/BlinkDL/ChatRWKV/blob/main/RWKV_v6_demo.py

Simple inference for RWKV-5: https://gith*u*b.c*om/BlinkDL/ChatRWKV/blob/main/RWKV_v5_demo.py

Note: In [state = kv + w * state] everything must be in fp32 because w can be very close to 1. So we can keep state and w in fp32, and convert kv to fp32.

lm_eval: https://gi*thu*b*.com/BlinkDL/ChatRWKV/blob/main/run_lm_eval.py

Tips for small model / small data: When I train RWKV music models, I use deep & narrow (such as L29-D512) dimensions, and apply wd and dropout (such as wd=2 dropout=0.02). Note RWKV-LM dropout is very effective – use 1/4 of your usual value.

HOW TO TRAIN RWKV-7 on Pile (332G tokens)

See https://gi*thub*.*com/BlinkDL/RWKV-LM/blob/main/RWKV-v5/demo-training-prepare-v7-pile.sh and https://g*i*thub.co*m/BlinkDL/RWKV-LM/blob/main/RWKV-v5/demo-training-run-v7-pile.sh

Get these files first:

pile_20B_tokenizer_text_document.bin (664230651068 bytes)

pile_20B_tokenizer_text_document.idx (4212099722 bytes)

HOW TO FINETUNE RWKV-5 MODELS

Use .jsonl format for your data (see https://hug*gi*n*gface.co/BlinkDL/rwkv-5-world for formats).

Use https://g**i*thub.com/BlinkDL/RWKV-LM/blob/main/RWKV-v5/make_data.py to tokenizer it using World tokenizer into binidx, suitable for finetuning World models.

Rename the base checkpoint in your model folder to rwkv-init.pth, and change the training commands to use –n_layer 32 –n_embd 4096 –vocab_size 65536 –lr_init 1e-5 –lr_final 1e-5 for 7B.

0.1B = –n_layer 12 –n_embd 768 // 0.4B = –n_layer 24 –n_embd 1024 // 1.5B = –n_layer 24 –n_embd 2048 // 3B = –n_layer 32 –n_embd 2560 // 7B = –n_layer 32 –n_embd 4096

State-tuning (tuning the initial state. zero inference overhead)

Currently unoptimized implementation, takes same vram as full SFT

--train_type \"states\" --load_partial 1 --lr_init 1 --lr_final 0.01 --warmup_steps 10 (yes, use very high LR)

use rwkv 0.8.26+ to auto-load the trained \”time_state\”

Initializing RWKV 5/6 Models

When you train RWKV from scratch, try my initialization for best performance. Check generate_init_weight() of src/model.py:

emb.weight => nn.init.uniform_(a=-1e-4, b=1e-4)
(Note ln0 of block0 is the layernorm for emb.weight)
head.weight => nn.init.orthogonal_(gain=0.5*sqrt(n_vocab / n_embd))

att.receptance.weight => nn.init.orthogonal_(gain=1)
att.key.weight => nn.init.orthogonal_(gain=0.1)
att.value.weight => nn.init.orthogonal_(gain=1)
att.gate.weight => nn.init.orthogonal_(gain=0.1)
att.output.weight => zero

att.ln_x.weight (groupnorm) => ((1 + layer_id) / total_layers) ** 0.7

ffn.key.weight => nn.init.orthogonal_(gain=1)
ffn.value.weight => zero
ffn.receptance.weight => zero

!!! If you are using positional embedding, maybe it\’s better to remove block.0.ln0 and use default initialization for emb.weight instead of my uniform_(a=-1e-4, b=1e-4) !!!

Fixing RWKV-6 Spikes

  1. upgrade to RWKV-7. It\’s very stable.

  2. when training from scratch, add \”k = k * torch.clamp(w, max=0).exp()\” before \”RUN_CUDA_RWKV6(r, k, v, w, u)\”, and remember to change your inference code too. you will see faster convergence.

  3. use \”–adam_eps 1e-18\”

  4. \”–beta2 0.95\” if you see spikes

  5. in trainer.py do \”lr = lr * (0.01 + 0.99 * trainer.global_step / w_step)\” (originally 0.2 + 0.8), and \”–warmup_steps 20\”

  6. \”–weight_decay 0.1\” leads to better final loss if you are training lots of data. set lr_final to 1/100 of lr_init when doing this.

Introducing RWKV

RWKV is an RNN with Transformer-level LLM performance, which can also be directly trained like a GPT transformer (parallelizable). And it\’s 100% attention-free. You only need the hidden state at position t to compute the state at position t+1. You can use the \”GPT\” mode to quickly compute the hidden state for the \”RNN\” mode.

So it\’s combining the best of RNN and transformer – great performance, fast inference, saves VRAM, fast training, \”infinite\” ctx_len, and free sentence embedding (using the final hidden state).

All latest RWKV weights: https://hug*gi*n*gface.co/BlinkDL

HF-compatible RWKV weights: https://hugg*i*ngfa*ce.co/RWKV

os.environ[\"RWKV_JIT_ON\"] = \'1\'
os.environ[\"RWKV_CUDA_ON\"] = \'0\' # if \'1\' then use CUDA kernel for seq mode (much faster)
from rwkv.model import RWKV                         # pip install rwkv
model = RWKV(model=\'/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-20220903-8040\', strategy=\'cuda fp16\')

out, state = model.forward([187, 510, 1563, 310, 247], None)   # use 20B_tokenizer.json
print(out.detach().cpu().numpy())                   # get logits
out, state = model.forward([187, 510], None)
out, state = model.forward([1563], state)           # RNN has state (use deepcopy if you want to clone it)
out, state = model.forward([310, 247], state)
print(out.detach().cpu().numpy())                   # same result as above

nanoRWKV: https://git***hub.com/BlinkDL/nanoRWKV (does not require custom CUDA kernel to train, works for any GPU/CPU)

Cool Community RWKV Projects:

All (400+) RWKV projects: https://gi*thub.c**om/search?o=desc&q=rwkv&s=updated&type=Repositories

https://*github.*c*om/OpenGVLab/Vision-RWKV Vision RWKV

https://g*it*h*ub.com/feizc/Diffusion-RWKV Diffusion RWKV

https://*github.*co*m/cgisky1980/ai00_rwkv_server Fastest WebGPU inference (nVidia/AMD/Intel)

https://*g**ithub.com/cryscan/web-rwkv backend for ai00_rwkv_server

https://gi**th*ub.com/saharNooby/rwkv.cpp Fast CPU/cuBLAS/CLBlast inference: int4/int8/fp16/fp32

https://gi**t*hub.com/JL-er/RWKV-PEFT lora/pissa/Qlora/Qpissa/state tuning

https://git*hub.c*o*m/RWKV/RWKV-infctx-trainer Infctx trainer

https://git**hub*.com/daquexian/faster-rwkv

mlc-ai/mlc-llm#1275

https://git***hub.com/TheRamU/Fay/blob/main/README_EN.md Digital Assistant with RWKV

https://git*hub.*com*/harrisonvanderbyl/rwkv-cpp-cuda Fast GPU inference with cuda/amd/vulkan

RWKV v6 in 250 lines (with tokenizer too): https://g**ithub.*com/BlinkDL/ChatRWKV/blob/main/RWKV_v6_demo.py

RWKV v5 in 250 lines (with tokenizer too): https://gith*u*b.c*om/BlinkDL/ChatRWKV/blob/main/RWKV_v5_demo.py

RWKV v4 in 150 lines (model, inference, text generation): https://gith*ub**.com/BlinkDL/ChatRWKV/blob/main/RWKV_in_150_lines.py

RWKV v4 preprint https://*arxiv.*org*/abs/2305.13048

RWKV v4 introduction, and in 100 lines of numpy: https://johanwind.git***hub.io/2023/03/23/rwkv_overview.html https://johanwind.g*i*t*hub.io/2023/03/23/rwkv_details.html

RWKV v6 illustrated:

A cool paper (Spiking Neural Network) using RWKV: https://gith*ub*.*com/ridgerchu/SpikeGPT

You are welcome to join the RWKV discord https://dis*c**ord.gg/bDSBUMeFpc to build upon it. We have plenty of potential compute (A100 40Gs) now (thanks to Stability and EleutherAI), so if you have interesting ideas I can run them.

RWKV [loss vs token position] for 10000 ctx4k+ documents in Pile. RWKV 1B5-4k is mostly flat after ctx1500, but 3B-4k and 7B-4k and 14B-4k have some slopes, and they are getting better. This debunks the old view that RNNs cannot model long ctxlens. We can predict that RWKV 100B will be great, and RWKV 1T is probably all you need 🙂

ChatRWKV with RWKV 14B ctx8192:

I believe RNN is a better candidate for fundamental models, because: (1) It\’s more friendly for ASICs (no kv cache). (2) It\’s more friendly for RL. (3) When we write, our brain is more similar to RNN. (4) The universe is like an RNN too (because of locality). Transformers are non-local models.

RWKV-3 1.5B on A40 (tf32) = always 0.015 sec/token, tested using simple pytorch code (no CUDA), GPU utilization 45%, VRAM 7823M

GPT2-XL 1.3B on A40 (tf32) = 0.032 sec/token (for ctxlen 1000), tested using HF, GPU utilization 45% too (interesting), VRAM 9655M

Training speed: (new training code) RWKV-4 14B BF16 ctxlen4096 = 114K tokens/s on 8×8 A100 80G (ZERO2+CP). (old training code) RWKV-4 1.5B BF16 ctxlen1024 = 106K tokens/s on 8xA100 40G.

I am doing image experiments too (For example: https://hug*gi*n*gface.co/BlinkDL/clip-guided-binary-autoencoder) and RWKV will be able to do txt2img diffusion 🙂 My idea: 256×256 rgb image -> 32x32x13bit latents -> apply RWKV to compute transition probability for each of the 32×32 grid -> pretend the grids are independent and \”diffuse\” using these probabilities.

Smooth training – no loss spikes! (lr & bsz change around 15G tokens)

All of the trained models will be open-source. Inference is very fast (only matrix-vector multiplications, no matrix-matrix multiplications) even on CPUs, so you can even run a LLM on your phone.

How it works: RWKV gathers information to a number of channels, which are also decaying with different speeds as you move to the next token. It\’s very simple once you understand it.

RWKV is parallelizable because the time-decay of each channel is data-independent (and trainable). For example, in usual RNN you can adjust the time-decay of a channel from say 0.8 to 0.5 (these are called \”gates\”), while in RWKV you simply move the information from a W-0.8-channel to a W-0.5-channel to achieve the same effect. Moreover, you can fine-tune RWKV into a non-parallelizable RNN (then you can use outputs of later layers of the previous token) if you want extra performance.

Here are some of my TODOs. Let\’s work together 🙂

  • HuggingFace integration (check huggingface/transformers#17230
    ), and optimized CPU & iOS & Android & WASM & WebGL inference. RWKV is a RNN and very friendly for edge devices. Let\’s make it possible to run a LLM on your phone.

  • Test it on bidirectional & MLM tasks, and image & audio & video tokens. I think RWKV can support Encoder-Decoder via this: for each decoder token, use a learned mixture of [decoder previous hidden state] & [encoder final hidden state]. Hence all decoder tokens will have access to the encoder output.

  • Now training RWKV-4a with one single tiny extra attention (just a few extra lines comparing with RWKV-4) to further improve some difficult zeroshot tasks (such as LAMBADA) for smaller models. See https://g***ithub.com/BlinkDL/RWKV-LM/commit/a268cd2e40351ee31c30c5f8a5d1266d35b41829

User feedback:

I\’ve so far toyed around the character-based model on our relatively small pre-training dataset (around 10GB of text), and the results are extremely good – similar ppl to models taking much, much longer to train.

dear god rwkv is fast. i switched to another tab after starting training it from scratch & when i returned it was emitting plausible english & maori words, i left to go microwave some coffee & when i came back it was producing fully grammatically correct sentences.

Tweet from Sepp Hochreiter (thank you!): https://twi*t*te*r.com/HochreiterSepp/status/1524270961314484227

You can find me (BlinkDL) in the EleutherAI Discord too: https://www.*el*euth*er.ai/get-involved/

Quick start

IMPORTANT: Use deepspeed==0.7.0 pytorch-lightning==1.9.5 torch==1.13.1+cu117 and cuda 11.7.1 or 11.7 (note torch2 + deepspeed has weird bugs and hurts model performance)

Use https://*github**.com/BlinkDL/RWKV-LM/tree/main/RWKV-v4neo (latest code, compatible with v4).

Here is a great prompt for testing Q&A of LLMs. Works for any model: (found by minimizing ChatGPT ppls for RWKV 1.5B)

prompt = f\'\\nQ & A\\n\\nQuestion:\\n{qq}\\n\\nDetailed Expert Answer:\\n\' # let the model generate after this

Inference

Run RWKV-4 Pile models: Download models from https://hug*gi*n*gface.co/BlinkDL. Set TOKEN_MODE = \’pile\’ in run.py and run it. It\’s fast even on CPU (the default mode).

Colab for RWKV-4 Pile 1.5B: https://colab.r*e*sea*rch.google.com/drive/1F7tZoPZaWJf1fsCmZ5tjw6sYHiFOYVWM

Run RWKV-4 Pile models in your browser (and onnx version): see this issue #7

RWKV-4 Web Demo: https://josephrocca.g*i*thu*b.io/rwkv-v4-web/demo/ (note: only greedy sampling for now)

For the old RWKV-2: see the release here for a 27M params model on enwik8 with 0.72 BPC(dev). Run run.py in https://github.co***m/BlinkDL/RWKV-LM/tree/main/RWKV-v2-RNN. You can even run it in your browser: https://git*hub*.co*m/BlinkDL/AI-Writer/tree/main/docs/eng https://blinkdl.g*it**hub.io/AI-Writer/eng/ (this is using tf.js WASM single-thread mode).

Training / Fine-tuning

pip install deepspeed==0.7.0 // pip install pytorch-lightning==1.9.5 // torch 1.13.1+cu117

NOTE: add weight decay (0.1 or 0.01) and dropout (0.1 or 0.01) when training on small amt of data. try x=x+dropout(att(x)) x=x+dropout(ffn(x)) x=dropout(x+att(x)) x=dropout(x+ffn(x)) etc.

Training RWKV-4 from scratch: run train.py, which by default is using the enwik8 dataset (unzip https://data.d**e*epai.org/enwik8.zip).

You will be training the \”GPT\” version because it\’s paralleziable and faster to train. RWKV-4 can extrapolate, so training with ctxLen 1024 can work for ctxLen of 2500+. You can fine-tune the model with longer ctxLen and it can quickly adapt to longer ctxLens.

Fine-tuning RWKV-4 Pile models: use \’prepare-data.py\’ in https://github*.*c*om/BlinkDL/RWKV-v2-RNN-Pile/tree/main/RWKV-v3 to tokenize .txt into train.npy data. Then use https://*g*ith*ub.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4neo/train.py to train it.

Read the inference code in src/model.py and try using the final hidden state(.xx .aa .bb) as a faithful sentence embedding for other tasks. Probably you should begin with .xx and .aa/.bb (.aa divided by .bb).

Colab for fine-tuning RWKV-4 Pile models: https://colab.researc*h.go*o*gle.com/github/resloved/RWKV-notebooks/blob/master/RWKV_v4_RNN_Pile_Fine_Tuning.ipynb

Large corpus: Use https://gi*th*ub*.com/Abel2076/json2binidx_tool to convert .jsonl into .bin and .idx

The jsonl format sample (one line for each document):

{\"text\": \"This is the first document.\"}
{\"text\": \"Hello\\nWorld\"}
{\"text\": \"1+1=2\\n1+2=3\\n2+2=4\"}

generated by code like this:

ss = json.dumps({\"text\": text}, ensure_ascii=False)
out.write(ss + \"\\n\")

Infinite ctxlen training (WIP): https://gi**thub.*com/Blealtan/RWKV-LM-LoRA/tree/dev-infctx

How to use RWKV hidden state as text embedding

Consider RWKV 14B. The state has 200 vectors, that is, 5 vectors for each block: fp16 (xx), fp32 (aa), fp32 (bb), fp32 (pp), fp16 (xx).

Do not avg pool because different vectors (xx aa bb pp xx) in the state have very different meanings and ranges. You can probably remove pp.

I suggest firstly collect the mean+stdev statistics of each channel of each vector, and normalize all of them (note: the normalization should be data-indepedent and collected from various texts). Then train a linear classifer.

Towards RWKV-5 (just to record some new ideas)

Lastest Design

RWKV-5 is multi-head and here shows one head. There is also a LayerNorm for each head (hence actually GroupNorm).

$ \\begin{array}{|l|l|l|} \\hline & \\text { RWKV-4 with real-valued } k \\,\\&\\, v \\,\\&\\, u \\,\\&\\, w & \\text { RWKV-5 with matrix-valued } \\mathrm{k}^{\\dagger} \\mathrm{v} \\,\\&\\, \\mathrm{u} \\,\\&\\, \\mathrm{w} \\\\ \\hline \\mathrm{y}_0 & \\mathrm{r}_0 \\frac{\\mathrm{uk}_0 \\mathrm{v}_0}{\\mathrm{uk}_0} & \\mathrm{r}_0\\left(\\mathrm{uk}_0^{\\dagger} \\mathrm{v}_0\\right) \\\\ \\hline \\mathrm{y}_1 & \\mathrm{r}_1 \\frac{\\mathrm{uk}_1 \\mathrm{v}_1+\\mathrm{k}_0 \\mathrm{v}_0}{\\mathrm{uk}_1+\\mathrm{k}_0} & \\mathrm{r}_1\\left(\\mathrm{uk}_1^{\\dagger} \\mathrm{v}_1+\\mathrm{k}_0^{\\dagger} \\mathrm{v}_0\\right) \\\\ \\hline \\mathrm{y}_2 & \\mathrm{r}_2 \\frac{\\mathrm{uk}_2 \\mathrm{v}_2+\\mathrm{k}_1 \\mathrm{v}_1+\\mathrm{wk}_0 \\mathrm{v}_0}{\\mathrm{uk}_2+\\mathrm{k}_1+\\mathrm{wk}_0} & \\mathrm{r}_2\\left(\\mathrm{uk}_2^{\\dagger} \\mathrm{v}_2+\\mathrm{k}_1^{\\dagger} \\mathrm{v}_1+\\mathrm{wk}_0^{\\dagger} \\mathrm{v}_0\\right) \\\\ \\hline \\mathrm{y}_3 & \\mathrm{r}_3 \\frac{\\mathrm{uk}_3 \\mathrm{v}_3+\\mathrm{k}_2 \\mathrm{v}_2+\\mathrm{wk}_1 \\mathrm{v}_1+\\mathrm{w}^2 \\mathrm{k}_0 \\mathrm{v}_0}{\\mathrm{uk}_3+\\mathrm{k}_2+\\mathrm{wk}_1+\\mathrm{w}^2 \\mathrm{k}_0} & \\mathrm{r}_3\\left(\\mathrm{uk}_3^{\\dagger} \\mathrm{v}_3+\\mathrm{k}_2^{\\dagger} \\mathrm{v}_2+\\mathrm{wk}_1^{\\dagger} \\mathrm{v}_1+\\mathrm{w}^2 \\mathrm{k}_0^{\\dagger} \\mathrm{v}_0\\right) \\\\ \\hline \\end{array}$

$\\left[\\begin{array}{ll} \\mathrm{y}_{20} & \\cdots \\mathrm{y}_{2 \\mathrm{c}} \\end{array}\\right]=\\left[\\begin{array}{lll} \\mathrm{r}_{20} & \\cdots & \\mathrm{r}_{2 \\mathrm{c}} \\end{array}\\right]$$\\left(\\left[\\begin{array}{ccc} \\mathrm{u}_{00} & \\cdots & \\mathrm{u}_{0 \\mathrm{c}} \\\\ \\vdots & \\ddots & \\vdots \\\\ \\mathrm{u}_{\\mathrm{c} 0} & \\cdots & \\mathrm{u}_{\\mathrm{cc}} \\end{array}\\right]\\left[\\begin{array}{ccc} \\mathrm{k}_{20} \\mathrm{v}_{20} & \\cdots & \\mathrm{k}_{20} \\mathrm{v}_{2 \\mathrm{c}} \\\\ \\vdots & \\ddots & \\vdots \\\\ \\mathrm{k}_{2 \\mathrm{c}} \\mathrm{v}_{20} & \\cdots & \\mathrm{k}_{2 \\mathrm{c}} \\mathrm{v}_{2 \\mathrm{c}} \\end{array}\\right]+\\left[\\begin{array}{ccc} \\mathrm{k}_{10} \\mathrm{v}_{10} & \\cdots & \\mathrm{k}_{10} \\mathrm{v}_{1 \\mathrm{c}} \\\\ \\vdots & \\ddots & \\vdots \\\\ \\mathrm{k}_{1 \\mathrm{c}} \\mathrm{v}_{10} & \\cdots & \\mathrm{k}_{1 \\mathrm{c}} \\mathrm{v}_{1 \\mathrm{c}} \\end{array}\\right]+\\left[\\begin{array}{ccc} \\mathrm{w}_{00} & \\cdots & \\mathrm{w}_{0 \\mathrm{c}} \\\\ \\vdots & \\ddots & \\vdots \\\\ \\mathrm{w}_{\\mathrm{c} 0} & \\cdots & \\mathrm{w}_{\\mathrm{cc}} \\end{array}\\right]\\left[\\begin{array}{ccc} \\mathrm{k}_{00} \\mathrm{v}_{00} & \\cdots & \\mathrm{k}_{00} \\mathrm{v}_{0 c} \\\\ \\vdots & \\ddots & \\vdots \\\\ \\mathrm{k}_{0 \\mathrm{c}} \\mathrm{v}_{00} & \\cdots & \\mathrm{k}_{0 \\mathrm{c}} \\mathrm{v}_{0 c} \\end{array}\\right] \\right)$

RWKV-6

Dynamic Mix & Dynamic Decay. Example (do this for both TimeMix & ChannelMix):

TIME_MIX_EXTRA_DIM = 32
self.time_mix_k_w1 = nn.Parameter(torch.empty(args.n_embd, TIME_MIX_EXTRA_DIM).uniform_(-0.01, 0.01))
self.time_mix_k_w2 = nn.Parameter(torch.zeros(TIME_MIX_EXTRA_DIM, args.n_embd))
self.time_mix_v_w1 = nn.Parameter(torch.empty(args.n_embd, TIME_MIX_EXTRA_DIM).uniform_(-0.01, 0.01))
self.time_mix_v_w2 = nn.Parameter(torch.zeros(TIME_MIX_EXTRA_DIM, args.n_embd))
self.time_mix_r_w1 = nn.Parameter(torch.empty(args.n_embd, TIME_MIX_EXTRA_DIM).uniform_(-0.01, 0.01))
self.time_mix_r_w2 = nn.Parameter(torch.zeros(TIME_MIX_EXTRA_DIM, args.n_embd))
self.time_mix_g_w1 = nn.Parameter(torch.empty(args.n_embd, TIME_MIX_EXTRA_DIM).uniform_(-0.01, 0.01))
self.time_mix_g_w2 = nn.Parameter(torch.zeros(TIME_MIX_EXTRA_DIM, args.n_embd))
...
time_mix_k = self.time_mix_k.view(1,1,-1) + (x @ self.time_mix_k_w1) @ self.time_mix_k_w2
time_mix_v = self.time_mix_v.view(1,1,-1) + (x @ self.time_mix_v_w1) @ self.time_mix_v_w2
time_mix_r = self.time_mix_r.view(1,1,-1) + (x @ self.time_mix_r_w1) @ self.time_mix_r_w2
time_mix_g = self.time_mix_g.view(1,1,-1) + (x @ self.time_mix_g_w1) @ self.time_mix_g_w2

xx = self.time_shift(x)
xk = x * time_mix_k + xx * (1 - time_mix_k)
xv = x * time_mix_v + xx * (1 - time_mix_v)
xr = x * time_mix_r + xx * (1 - time_mix_r)
xg = x * time_mix_g + xx * (1 - time_mix_g)

RWKV-7

Use parallelized mode to quickly generate the state, then use a finetuned full RNN (the layers of token n can use outputs of all layer of token n-1) for sequential generation.

Some old ideas

  1. Now time decay is like 0.999^T (0.999 is learnable). Change it to something like (0.999^T + 0.1) where 0.1 is learnable too. The 0.1 part will be kept forever. Or, A^T + B^T + C = fast-decay + slow-decay + constant. Can even use different formulas (for example, K^2 instead of e^K for a decay component, or, without normalization).

  2. Use complex-valued decay (so, rotation instead of decay) in some channels.

  3. Inject some trainable and extrapolatable positional encoding?

  4. Aside from 2d rotation, we can try other Lie groups such as 3d rotation ( SO(3) ). Non-abelian RWKV lol.

  5. RWKV might be great on analog devices (search for Analog Matrix-vector multiplication & Photonic Matrix-vector multiplication). The RNN mode is very hardware-friendly (processing-in-memory). Can be a SNN too (https://gith*ub*.*com/ridgerchu/SpikeGPT). I wonder if it can be optimized for quantum computation.

  6. Trainable initial hidden state (xx aa bb pp xx).

  7. Layerwise (or even row/column-wise, elementwise) LR, and test Lion optimizer.

Vision Tasks

  1. I find it\’s good to add a 2d pos encoding:
self.pos_emb_x = nn.Parameter(torch.zeros((1,args.my_pos_emb,args.n_embd)))
self.pos_emb_y = nn.Parameter(torch.zeros((args.my_pos_emb,1,args.n_embd)))
...
x = x + pos_emb_x + pos_emb_y
  1. In a BPE langauge model, it\’s the best to use [tokenShift of 1 token] (you can mix more tokens in a char-level English model). However you can try [tokenShift of N (or N-1) (or N+1) tokens] if the image size is N x N, because that will be like mixing [the token above the current positon (or the token above the to-be-predicted positon)] with [current token]. You can use try different tokenShift styles for \”ATT\” & \”FFN\”, or mixing different tokenShift styles – such as mixing [token A] with [token A-1] and [token A-(N-1)] etc.

Misc

Maybe we can improve memorization by simply repeating the context (I guess 2 times is enough). Example: Reference -> Reference(again) -> Question -> Answer

Idea: Bytes-aware Embedding

The idea is to make sure each token in vocab understand its length and raw UTF-8 bytes.

Let a = max(len(token)) for all token in vocab. Define AA : float[a][d_emb]

Let b = max(len_in_utf8_bytes(token)) for all token in vocab. Define BB : float[b][256][d_emb]

For each token X in vocab, let [x0, x1, …, xn] be its raw UTF-8 bytes. We will add some extra values to its embedding EMB(X):

EMB(X) += AA[len(X)] + BB[0][x0] + BB[1][x1] + … + BB[n][xn] (note: AA BB are learnable weights)

  • We can do this for the final Linear(d_emb, n_vocab) projection too.
  • We can use some small networks to generate AA and BB, for some extra regularization (for example, BB[m][xi] and BB[n][xi] should be related).

Old Idea

I have an idea to improve tokenization. We can hardcode some channels to have meanings. Example:

Channel 0 = \”space\”

Channel 1 = \”capitalize first letter\”

Channel 2 = \”capitalize all letters\”

Therefore:

Embedding of \”abc\”: [0, 0, 0, x0, x1, x2 , ..]

Embedding of \” abc\”: [1, 0, 0, x0, x1, x2, ..]

Embedding of \” Abc\”: [1, 1, 0, x0, x1, x2, ..]

Embedding of \”ABC\”: [0, 0, 1, x0, x1, x2, …]

……

so they will share most of the embedding. And we can rapidly compute the output probability of all variations of \”abc\”.

Note: the above method is assuming that p(\” xyz\”) / p(\”xyz\”) is the same for any \”xyz\”, which can be wrong.

Better: define emb_space emb_capitalize_first emb_capitalize_all to be a function of emb.

Maybe the Best: let \’abc\’ \’ abc\’ etc. to share the last 90% of their embeddings.

At this moment, all our tokenizers spend too many items to represent all variations of \’abc\’ \’ abc\’ \’ Abc\’ etc. Moreover the model cannot discover that these are actually similar if some of these variations are rare in the dataset. The method here can improve this. I plan to test this in a new version of RWKV.

Idea: Better Initial States

Example (single-round Q & A):

  1. Generate the final state of all wiki documents.

  2. For any user Q, find the best wiki document, and use its final state as the initial state.

  3. Train a model to directly generate the optimal initial state for any user Q.

However this can be a bit more tricky for multi-round Q & A 🙂

How it works

RWKV is inspired by Apple\’s AFT (https://a*rxiv**.org/abs/2105.14103).

Moreover it\’s using a number of my tricks, such as:

  • SmallInitEmb: https://github**.c*om/BlinkDL/SmallInitEmb (applicable to all transformers) which helps the embedding quality, and stabilizes Post-LN (which is what I am using).

  • Token-shift: https://g*it*hub.*com/BlinkDL/RWKV-LM#token-shift-time-shift-mixing (applicable to all transformers), especially helpful for char-level models.

  • Head-QK: https://git*hub.c**om/BlinkDL/RWKV-LM#the-head-qk-trick-learning-to-copy-and-avoid-tokens (applicable to all transformers). Note: it\’s helpful, but I disabled it in the Pile model to keep it 100% RNN.

  • Extra R-gate in the FFN (applicable to all transformers). I am also using reluSquared from Primer.

  • Better initilization: I init most of the matrices to ZERO (see RWKV_Init in https://github**.com*/BlinkDL/RWKV-LM/blob/main/RWKV-v2-RNN/src/model.py).

  • You can transfer some parameters from a small model to a large model (note: I sort & smooth them too), for faster and better convergence (see https://www.red**di*t.com/r/MachineLearning/comments/umq908/r_rwkvv2rnn_a_parallelizable_rnn_with/).

  • My CUDA kernel: https://*g*ithub.com*/BlinkDL/RWKV-CUDA to speedup training.

The pseudocode (execution from top to bottom):

The a b c d factors work together to build a time-decay curve: [X, 1, W, W^2, W^3, …].

Write out the formulas for \”token at pos 2\” and \”token at pos 3\” and you will get the idea:

  • a and b: EMAs of kv and k.
  • c and d: these are a and b combined with \”self-attention\”.

kv / k is the memory mechanism. The token with high k can be remembered for a long duration, if W is close to 1 in the channel.

The R-gate is important for performance. k = info strength of this token (to be passed to future tokens). r = whether to apply the info to this token.

RWKV-3 improvements

Use different trainable TimeMix factors for R / K / V in SA and FF layers. Example:

xx = self.time_shift(x)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)

Use preLN instead of postLN (more stable & faster convergence):

if self.layer_id == 0:
	x = self.ln0(x)
x = x + self.att(self.ln1(x))
x = x + self.ffn(self.ln2(x))

Explaining the code for RWKV-3 GPT mode

The GPT mode – overview

The building blocks of RWKV-3 GPT mode are similar to that of a usual preLN GPT.

The only difference is an extra LN after embedding. Note you can absorb this LN into the embedding after finishing the training.

x = self.emb(idx)  # input: idx = token indices
x = self.ln_emb(x) # extra LN after embedding
x = x + self.att_0(self.ln_att_0(x)) # preLN
x = x + self.ffn_0(self.ln_ffn_0(x))
...
x = x + self.att_n(self.ln_att_n(x))
x = x + self.ffn_n(self.ln_ffn_n(x))
x = self.ln_head(x) # final LN before projection
x = self.head(x)    # output: x = logits

It is important to initialize emb to tiny values, such as nn.init.uniform_(a=-1e-4, b=1e-4), to utilize my trick https://github**.c*om/BlinkDL/SmallInitEmb.

For the 1.5B RWKV-3, I use Adam (no wd, no dropout) optimizer on 8 * A100 40G.

batchSz = 32 * 896, ctxLen = 896. I am using tf32 so the batchSz is a bit small.

For the first 15B tokens, LR is fixed at 3e-4, and beta=(0.9, 0.99).

Then I set beta=(0.9, 0.999), and do an exponential decay of LR, reaching 1e-5 at 332B tokens.

The GPT mode – ATT block

The RWKV-3 does not have any attention in the usual sense, but we will call this block ATT anyway.

B, T, C = x.size() # x = (Batch,Time,Channel)

# Mix x with the previous timestep to produce xk, xv, xr
xx = self.time_shift(x) # self.time_shift = nn.ZeroPad2d((0,0,1,-1))
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)

# Use xk, xv, xr to produce k, v, r
k = self.key(xk).transpose(-1, -2)
v = self.value(xv).transpose(-1, -2)
r = self.receptance(xr)
k = torch.clamp(k, max=60) # clamp k to avoid overflow
k = torch.exp(k)
kv = k * v

# Compute the W-curve = [e^(-n * e^time_decay), e^(-(n-1) * e^time_decay), ..., 1, e^(time_first)]
self.time_w = torch.cat([torch.exp(self.time_decay) * self.time_curve.to(x.device), self.time_first], dim=-1)
w = torch.exp(self.time_w)

# Use W to mix kv and k respectively. Add K_EPS to wk to avoid divide-by-zero
if RUN_DEVICE == \'cuda\':
wkv = TimeX.apply(w, kv, B,C,T, 0)
wk = TimeX.apply(w, k, B,C,T, K_EPS)
else:
w = w[:,-T:].unsqueeze(1)
wkv = F.conv1d(nn.ZeroPad2d((T-1, 0, 0, 0))(kv), w, groups=C)
wk = F.conv1d(nn.ZeroPad2d((T-1, <span class

下载源码

通过命令行克隆项目:

git clone https://github.com/BlinkDL/RWKV-LM.git

收藏 (0) 打赏

感谢您的支持,我会继续努力的!

打开微信/支付宝扫一扫,即可进行扫码打赏哦,分享从这里开始,精彩与您同在
点赞 (0)

申明:本文由第三方发布,内容仅代表作者观点,与本网站无关。对本文以及其中全部或者部分内容的真实性、完整性、及时性本站不作任何保证或承诺,请读者仅作参考,并请自行核实相关内容。本网发布或转载文章出于传递更多信息之目的,并不意味着赞同其观点或证实其描述,也不代表本网对其真实性负责。

左子网 编程相关 RWKV LM https://www.zuozi.net/33324.html

UTCS OpenGL Work
上一篇: UTCS OpenGL Work
vc extensions
下一篇: vc extensions
常见问题
  • 1、自动:拍下后,点击(下载)链接即可下载;2、手动:拍下后,联系卖家发放即可或者联系官方找开发者发货。
查看详情
  • 1、源码默认交易周期:手动发货商品为1-3天,并且用户付款金额将会进入平台担保直到交易完成或者3-7天即可发放,如遇纠纷无限期延长收款金额直至纠纷解决或者退款!;
查看详情
  • 1、描述:源码描述(含标题)与实际源码不一致的(例:货不对板); 2、演示:有演示站时,与实际源码小于95%一致的(但描述中有”不保证完全一样、有变化的可能性”类似显著声明的除外); 3、发货:不发货可无理由退款; 4、安装:免费提供安装服务的源码但卖家不履行的; 5、收费:价格虚标,额外收取其他费用的(但描述中有显著声明或双方交易前有商定的除外); 6、其他:如质量方面的硬性常规问题BUG等。 注:经核实符合上述任一,均支持退款,但卖家予以积极解决问题则除外。
查看详情
  • 1、左子会对双方交易的过程及交易商品的快照进行永久存档,以确保交易的真实、有效、安全! 2、左子无法对如“永久包更新”、“永久技术支持”等类似交易之后的商家承诺做担保,请买家自行鉴别; 3、在源码同时有网站演示与图片演示,且站演与图演不一致时,默认按图演作为纠纷评判依据(特别声明或有商定除外); 4、在没有”无任何正当退款依据”的前提下,商品写有”一旦售出,概不支持退款”等类似的声明,视为无效声明; 5、在未拍下前,双方在QQ上所商定的交易内容,亦可成为纠纷评判依据(商定与描述冲突时,商定为准); 6、因聊天记录可作为纠纷评判依据,故双方联系时,只与对方在左子上所留的QQ、手机号沟通,以防对方不承认自我承诺。 7、虽然交易产生纠纷的几率很小,但一定要保留如聊天记录、手机短信等这样的重要信息,以防产生纠纷时便于左子介入快速处理。
查看详情

相关文章

猜你喜欢
发表评论
暂无评论
官方客服团队

为您解决烦忧 - 24小时在线 专业服务