#!/usr/bin/env bash set -euo pipefail KERNEL_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) cd "$KERNEL_DIR" export KERNEL_DIR detect_variant() { python - <<'PY' import os import pathlib root = pathlib.Path(os.environ["KERNEL_DIR"]) build_dir = root / "build" variant = None try: from kernels.utils import build_variant as _build_variant except Exception: _build_variant = None if _build_variant is not None: try: variant = _build_variant() except Exception: variant = None if variant is None: candidates = sorted(build_dir.glob("torch*-rocm64-*") or build_dir.glob("torch*-cu*")) if candidates: variant = candidates[0].name if variant is None: raise SystemExit("Could not determine MegaBlocks build variant. Run build.py first.") print(variant) PY } VARIANT=$(detect_variant) STAGED_DIR="$KERNEL_DIR/build/$VARIANT" find_staged_lib() { local base="$1" local candidates=( "$base/_megablocks_rocm.so" "$base/megablocks/_megablocks_rocm.so" ) for path in "${candidates[@]}"; do if [[ -f "$path" ]]; then echo "$path" return 0 fi done return 1 } STAGED_LIB=$(find_staged_lib "$STAGED_DIR") || true if [[ -z "${STAGED_LIB:-}" ]]; then echo "Staged ROCm extension not found under $STAGED_DIR; rebuilding kernels..." python build.py VARIANT=$(detect_variant) STAGED_DIR="$KERNEL_DIR/build/$VARIANT" STAGED_LIB=$(find_staged_lib "$STAGED_DIR") || true if [[ -z "${STAGED_LIB:-}" ]]; then echo "ERROR: build.py completed but no extension was found under $STAGED_DIR" >&2 exit 1 fi fi export PYTHONPATH="$STAGED_DIR:${PYTHONPATH:-}" echo "Using MegaBlocks build variant: $VARIANT" declare -i GPU_COUNT GPU_COUNT=$(python - <<'PY' import torch print(torch.cuda.device_count() if torch.cuda.is_available() else 0) PY ) if (( GPU_COUNT == 0 )); then echo "ERROR: No HIP/CUDA GPUs detected. Tests require at least one visible accelerator." >&2 exit 1 fi echo "Detected $GPU_COUNT visible GPU(s)." log() { echo echo "==> $1" } run_pytest() { local label="$1" shift log "$label" set -x "$@" { set +x; } 2>/dev/null || true } SINGLE_GPU_ENV=(HIP_VISIBLE_DEVICES=0 CUDA_VISIBLE_DEVICES=0 WORLD_SIZE=1) MULTI2_GPU_ENV=(HIP_VISIBLE_DEVICES=0,1 CUDA_VISIBLE_DEVICES=0,1 WORLD_SIZE=2) MULTI8_GPU_ENV=(HIP_VISIBLE_DEVICES=$(seq -s, 0 7) CUDA_VISIBLE_DEVICES=$(seq -s, 0 7) WORLD_SIZE=8) SINGLE_TESTS=( "test_mb_moe.py" "test_mb_moe_shared_expert.py" "layer_test.py" "test_gg.py" "ops_test.py" ) for test in "${SINGLE_TESTS[@]}"; do run_pytest "Single-GPU pytest ${test}" env "${SINGLE_GPU_ENV[@]}" python -m pytest "tests/${test}" -q done if (( GPU_COUNT >= 2 )); then run_pytest "Distributed layer smoke (2 GPUs)" env "${MULTI2_GPU_ENV[@]}" python -m pytest "tests/parallel_layer_test.py::test_megablocks_moe_mlp_functionality" -q else log "Skipping 2-GPU distributed layer test (requires >=2 GPUs, detected ${GPU_COUNT})." fi run_pytest "Shared expert functionality (world_size=1)" env "${SINGLE_GPU_ENV[@]}" python -m pytest 'tests/test_mb_moe_shared_expert_multi.py::test_shared_expert_distributed_functionality[1]' -q run_pytest "Shared expert weighted sum (world_size=1)" env "${SINGLE_GPU_ENV[@]}" python -m pytest 'tests/test_mb_moe_shared_expert_multi.py::test_shared_expert_distributed_weighted_sum[1]' -q if (( GPU_COUNT >= 8 )); then run_pytest "Shared expert functionality (world_size=8)" env "${MULTI8_GPU_ENV[@]}" python -m pytest 'tests/test_mb_moe_shared_expert_multi.py::test_shared_expert_distributed_functionality[8]' -q run_pytest "Shared expert weighted sum (world_size=8)" env "${MULTI8_GPU_ENV[@]}" python -m pytest 'tests/test_mb_moe_shared_expert_multi.py::test_shared_expert_distributed_weighted_sum[8]' -q else log "Skipping 8-GPU shared expert tests (requires >=8 GPUs, detected ${GPU_COUNT})." fi echo echo "All requested tests completed."