Colossal-RL: Democratizing Reinforcement Learning with Colossal-AI
Overview
In this section, we introduce how we can run reinforcement learning with Colossal-AI to train your own RL model. We support the main algorithm used to train DeepSeek R1 model, a variant of Proximal Policy Optimization (PPO), that enhances mathematical reasoning abilities while optimizing the memory usage of PPO.
Data
The following example illustrates how training data should be constructed. We accept JSONL format with each line having the following structure:
{
"messages": {
"role": "user",
"content": "Let \\[f(x) = \\left\\{\n\\begin{array}{cl} ax+3, &\\text{ if }x>2, \\\\\nx-5 &\\text{ if } -2 \\le x \\le 2, \\\\\n2x-b &\\text{ if } x <-2.\n\\end{array}\n\\right.\\]Find $a+b$ if the piecewise function is continuous (which means that its graph can be drawn without lifting your pencil from the paper)."
},
"gt_answer": "0"
}
content
: Normally math questionsgt_answer
: Ground truth answers
Training Script
Users are allowed to adjust training parameters to train their own model. The following parameters are suggested to change:
Parameter | Description |
---|---|
-m , --model | Local model path |
-d , --dataset | Local data path |
-s , --system-prompt | System prompt to construct the dataset |
-p , --project | Project name for Wandb |
-g , --num-generations | Number of generations per prompt |
-e , --num-episodes | Number of episodes |
-lr , --learning-rate | Learning rate for GRPO |
-kl , --kl-coeff | KL penalty coefficient for GRPO |
-si , --save-interval | Interval for saving checkpoints |
-sd , --save-dir | Directory for saving checkpoints |
-mnt , --max-new-tokens | Max length for generation |
-mpt , --max-new-tokens | Max length for prompt |
-temp , --temperature | Temperature for sampling |
-topk , --top-k | Top k for sampling |
-topp , --top-p | Top p for sampling |
-ibs , --inference-batch-size | Number of prompts to generate per inference step. It should be divisible by tbs, and the weights on the inference backend will be synced every ibs/tbs training steps of the policy model |
-imbs , --inference-microbatch-size | Effective batch size for the inference backend to run generation. Please select based on memory constraint |
-tbs , --train-batch-size | Number of unique prompts to update policy model per step per dp group. Gradient is accumulated across tbs dp_size unique prompts, equivalently tbs g * dp_size samples |
-tMbs , --train-minibatch-size | Number of unique prompts in each training batch per dp group. The inference backend must generate tMbs g dp_size samples before forwarding. Satisfy tMbs * g >= tmbs |
-tmbs , --train-microbatch-size | Effective batch size per dp group for forwarding and backwarding. Please select based on the available memory |
For other parameters, we suggest keeping them as default to avoid unexpected issues. We are under intensive development and will release more features soon.
Template
A few examples are provided for users to quickly start your training.
Model size | GPU type | Num GPUs | Policy | Max new tokens | Max prompt tokens | Train micro batchsize | Max CUDA memory (Terminal) |
---|---|---|---|---|---|---|---|
3B | H20: 98G | 4 | Producer 2 Consumer 2 Zero2 | 1024 * 4 - 512 | 512 | 2 | ~55G |
3B | H20: 98G | 4 | Producer 2 Consumer 2 Zero2 | 1024 * 4 - 512 | 512 | 4 | ~72G |
3B | H20: 98G | 4 | Producer 2 Consumer 2 Zero2 | 1024 * 8 - 512 | 512 | 2 | ~72G |
3B | H200: 140G | 8 | Producer 4 Consumer 4 Zero2 | 1024 * 8 - 512 | 512 | 2 | ~50G |
3B | H200: 140G | 8 | Producer 4 Consumer 4 Zero2 | 1024 * 8 - 512 | 512 | 4 | ~60G |
3B | H200: 140G | 8 | Producer 4 Consumer 4 Zero2 | 1024 * 8 - 512 | 512 | 8 | ~100G |
7B | H200: 140G | 8 | Producer 4 Consumer 4 Zero2 | 1024 * 8 - 512 | 512 | 2 | ~85G |
7B | H200: 140G | 8 | Producer 4 Consumer 4 Zero2 | 1024 * 8 - 512 | 512 | 4 | ~100G |
7B | H200: 140G | 8 | Producer 4 Consumer 4 Zero2 | 1024 * 8 - 512 | 512 | 8 | ~138G |
14B | H200: 140G | 8 | Producer 4 Consumer 4 Zero2 | 1024 * 8 - 512 | 512 | 2 | ~140G |