Skip to content

Commit

Permalink
Merge pull request #763 from deftdawg/amdgpu
Browse files Browse the repository at this point in the history
AMD/ROCm: Changes required to detect and inference on AMD GPUs
  • Loading branch information
AlexCheema authored Mar 6, 2025
2 parents 017bf93 + f98d9ba commit 854f515
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 9 deletions.
15 changes: 6 additions & 9 deletions exo/topology/device_capabilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,22 +198,19 @@ async def linux_device_capabilities() -> DeviceCapabilities:
flops=CHIP_FLOPS.get(gpu_name, DeviceFlops(fp32=0, fp16=0, int8=0)),
)
elif Device.DEFAULT == "AMD":
# For AMD GPUs, pyrsmi is the way (Official python package for rocm-smi)
from pyrsmi import rocml
import pyamdgpuinfo

rocml.smi_initialize()
gpu_name = rocml.smi_get_device_name(0).upper()
gpu_memory_info = rocml.smi_get_device_memory_total(0)
gpu_raw_info = pyamdgpuinfo.get_gpu(0)
gpu_name = gpu_raw_info.name
gpu_memory_info = gpu_raw_info.memory_info["vram_size"]

if DEBUG >= 2: print(f"AMD device {gpu_name=} {gpu_memory_info=}")

rocml.smi_shutdown()

return DeviceCapabilities(
model="Linux Box ({gpu_name})",
model="Linux Box (" + gpu_name + ")",
chip=gpu_name,
memory=gpu_memory_info // 2**20,
flops=DeviceFlops(fp32=0, fp16=0, int8=0),
flops=CHIP_FLOPS.get(gpu_name, DeviceFlops(fp32=0, fp16=0, int8=0)),
)

else:
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"prometheus-client==0.20.0",
"protobuf==5.28.1",
"psutil==6.0.0",
"pyamdgpuinfo==2.1.6;platform_system=='Linux'",
"pydantic==2.9.2",
"requests==2.32.3",
"rich==13.7.1",
Expand Down

0 comments on commit 854f515

Please sign in to comment.