Skip to content

AMD

If you're using the runpod backend or have set up an SSH fleets with on-prem AMD GPUs, you can use AMD GPUs.

Deployment

You can use any serving framework, such as TGI and vLLM. Here's an example of a service that deploys Llama 3.1 70B in FP16 using TGI and vLLM .

type: service
name: amd-service-tgi

# Using the official TGI's ROCm Docker image
image: ghcr.io/huggingface/text-generation-inference:sha-a379d55-rocm

env:
  - HF_TOKEN
  - MODEL_ID=meta-llama/Meta-Llama-3.1-70B-Instruct
  - TRUST_REMOTE_CODE=true
  - ROCM_USE_FLASH_ATTN_V2_TRITON=true
commands:
  - text-generation-launcher --port 8000
port: 8000
# Register the model
model: meta-llama/Meta-Llama-3.1-70B-Instruct

# Uncomment to leverage spot instances
#spot_policy: auto

resources:
  gpu: MI300X
  disk: 150GB

type: service
name: llama31-service-vllm-amd

# Using RunPod's ROCm Docker image
image: runpod/pytorch:2.4.0-py3.10-rocm6.1.0-ubuntu22.04
# Required environment variables
env:
  - HF_TOKEN
  - MODEL_ID=meta-llama/Meta-Llama-3.1-70B-Instruct
  - MAX_MODEL_LEN=126192
# Commands of the task
commands:
  - export PATH=/opt/conda/envs/py_3.10/bin:$PATH
  - wget https://github.com/ROCm/hipBLAS/archive/refs/tags/rocm-6.1.0.zip
  - unzip rocm-6.1.0.zip
  - cd hipBLAS-rocm-6.1.0
  - python rmake.py
  - cd ..
  - git clone https://github.com/vllm-project/vllm.git
  - cd vllm
  - pip install triton
  - pip uninstall torch -y
  - pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.1
  - pip install /opt/rocm/share/amd_smi
  - pip install --upgrade numba scipy huggingface-hub[cli]
  - pip install "numpy<2"
  - pip install -r requirements-rocm.txt
  - wget -N https://github.com/ROCm/vllm/raw/fa78403/rocm_patch/libamdhip64.so.6 -P /opt/rocm/lib
  - rm -f "$(python3 -c 'import torch; print(torch.__path__[0])')"/lib/libamdhip64.so*
  - export PYTORCH_ROCM_ARCH="gfx90a;gfx942"
  - wget https://dstack-binaries.s3.amazonaws.com/vllm-0.6.0%2Brocm614-cp310-cp310-linux_x86_64.whl
  - pip install vllm-0.6.0+rocm614-cp310-cp310-linux_x86_64.whl
  - vllm serve $MODEL_ID --max-model-len $MAX_MODEL_LEN --port 8000
# Service port
port: 8000
# Register the model
model: meta-llama/Meta-Llama-3.1-70B-Instruct

# Uncomment to leverage spot instances
#spot_policy: auto

resources:
  gpu: MI300X
  disk: 200GB

Note, maximum size of vLLM’s KV cache is 126192, consequently we must set MAX_MODEL_LEN to 126192. Adding /opt/conda/envs/py_3.10/bin to PATH ensures we use the Python 3.10 environment necessary for the pre-built binaries compiled specifically for this version.

To speed up the vLLM-ROCm installation, we use a pre-built binary from S3. You can find the task to build and upload the binary in examples/deployment/vllm/amd/ .

Docker image

If you want to use AMD, specifying image is currently required. This must be an image that includes ROCm drivers.

To request multiple GPUs, specify the quantity after the GPU name, separated by a colon, e.g., MI300X:4.

Fine-tuning

Below is an example of LoRA fine-tuning Llama 3.1 8B using TRL and the mlabonne/guanaco-llama2-1k dataset.

type: task
name: trl-amd-llama31-train

# Using RunPod's ROCm Docker image
image: runpod/pytorch:2.1.2-py3.10-rocm6.1-ubuntu22.04

# Required environment variables
env:
  - HF_TOKEN
# Commands of the task
commands:
  - export PATH=/opt/conda/envs/py_3.10/bin:$PATH
  - git clone https://github.com/ROCm/bitsandbytes
  - cd bitsandbytes
  - git checkout rocm_enabled
  - pip install -r requirements-dev.txt
  - cmake -DBNB_ROCM_ARCH="gfx942" -DCOMPUTE_BACKEND=hip -S  .
  - make
  - pip install .
  - pip install trl
  - pip install peft
  - pip install transformers datasets huggingface-hub scipy
  - cd ..
  - python examples/fine-tuning/trl/amd/train.py

# Uncomment to leverage spot instances
#spot_policy: auto

resources:
  gpu: MI300X
  disk: 150GB

Below is an example of fine-tuning Llama 3.1 8B using Axolotl and the tatsu-lab/alpaca dataset.

type: task
name: axolotl-amd-llama31-train

# Using RunPod's ROCm Docker image
image: runpod/pytorch:2.1.2-py3.10-rocm6.0.2-ubuntu22.04
# Required environment variables
env:
  - HF_TOKEN
# Commands of the task
commands:
  - export PATH=/opt/conda/envs/py_3.10/bin:$PATH
  - pip uninstall torch torchvision torchaudio -y
  - python3 -m pip install --pre torch==2.3.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.0/
  - git clone https://github.com/OpenAccess-AI-Collective/axolotl
  - cd axolotl
  - git checkout d4f6c65
  - pip install -e .
  - cd ..
  - wget https://dstack-binaries.s3.amazonaws.com/flash_attn-2.0.4-cp310-cp310-linux_x86_64.whl
  - pip install flash_attn-2.0.4-cp310-cp310-linux_x86_64.whl
  - wget https://dstack-binaries.s3.amazonaws.com/xformers-0.0.26-cp310-cp310-linux_x86_64.whl
  - pip install xformers-0.0.26-cp310-cp310-linux_x86_64.whl
  - git clone --recurse https://github.com/ROCm/bitsandbytes
  - cd bitsandbytes
  - git checkout rocm_enabled
  - pip install -r requirements-dev.txt
  - cmake -DBNB_ROCM_ARCH="gfx942" -DCOMPUTE_BACKEND=hip -S  .
  - make
  - pip install .
  - cd ..
  - accelerate launch -m axolotl.cli.train axolotl/examples/llama-3/fft-8b.yaml

# Uncomment to leverage spot instances
#spot_policy: auto

resources:
  gpu: MI300X
  disk: 150GB

Note, to support ROCm, we need to checkout to commit d4f6c65. You can find the installation instruction in rocm-blogs .

To speed up installation of flash-attention and xformers, we use pre-built binaries uploaded to S3. You can find the tasks that build and upload the binaries in examples/fine-tuning/axolotl/amd/ .

Running a configuration

Once the configuration is ready, run dstack apply -f <configuration file>, and dstack will automatically provision the cloud resources and run the configuration.

$ HF_TOKEN=...
$ dstack apply -f examples/deployment/vllm/amd/.dstack.yml

Source code

The source-code of this example can be found in examples/deployment/tgi/amd , examples/deployment/vllm/amd , examples/fine-tuning/axolotl/amd and examples/fine-tuning/trl/amd

What's next?

  1. Browse TGI , vLLM , Axolotl , TRL and ROCm Bitsandbytes
  2. Check dev environments, tasks, and services.