Skip to content

vllm.model_executor.layers.fused_moe.oracle.nvfp4

FLASHINFER_NVFP4_MOE_BACKENDS module-attribute

FLASHINFER_NVFP4_MOE_BACKENDS = [
    FLASHINFER_CUTLASS,
    FLASHINFER_TRTLLM,
    FLASHINFER_CUTEDSL,
]

fi_2_vllm_backend_map module-attribute

logger module-attribute

logger = init_logger(__name__)

NvFp4MoeBackend

Bases: Enum

Source code in vllm/model_executor/layers/fused_moe/oracle/nvfp4.py
class NvFp4MoeBackend(Enum):
    FLASHINFER_TRTLLM = "FLASHINFER_TRTLLM"
    FLASHINFER_CUTLASS = "FLASHINFER_CUTLASS"
    FLASHINFER_CUTEDSL = "FLASHINFER_CUTEDSL"
    VLLM_CUTLASS = "VLLM_CUTLASS"
    MARLIN = "MARLIN"

FLASHINFER_CUTEDSL class-attribute instance-attribute

FLASHINFER_CUTEDSL = 'FLASHINFER_CUTEDSL'

FLASHINFER_CUTLASS class-attribute instance-attribute

FLASHINFER_CUTLASS = 'FLASHINFER_CUTLASS'

FLASHINFER_TRTLLM class-attribute instance-attribute

FLASHINFER_TRTLLM = 'FLASHINFER_TRTLLM'

MARLIN class-attribute instance-attribute

MARLIN = 'MARLIN'

VLLM_CUTLASS class-attribute instance-attribute

VLLM_CUTLASS = 'VLLM_CUTLASS'

backend_2_kernel_cls

backend_2_kernel_cls(
    backend: NvFp4MoeBackend,
) -> type[FusedMoEPermuteExpertsUnpermute]
Source code in vllm/model_executor/layers/fused_moe/oracle/nvfp4.py
def backend_2_kernel_cls(
    backend: NvFp4MoeBackend,
) -> type[mk.FusedMoEPermuteExpertsUnpermute]:
    if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
        raise NotImplementedError

    elif backend == NvFp4MoeBackend.FLASHINFER_CUTLASS:
        from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
            FlashInferExperts,
        )

        return FlashInferExperts

    elif backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL:
        from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import (
            FlashInferCuteDSLExperts,
        )

        return FlashInferCuteDSLExperts

    elif backend == NvFp4MoeBackend.VLLM_CUTLASS:
        from vllm.model_executor.layers.fused_moe.cutlass_moe import (
            CutlassExpertsFp4,
        )

        return CutlassExpertsFp4

    elif backend == NvFp4MoeBackend.MARLIN:
        from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
            MarlinExperts,
        )

        return MarlinExperts
    else:
        raise ValueError(f"Unknown NvFP4 MoE backend: {backend.value}")

convert_to_nvfp4_moe_kernel_format

convert_to_nvfp4_moe_kernel_format(
    nvfp4_backend: NvFp4MoeBackend,
    layer: Module,
    w13: Tensor,
    w13_scale: Tensor,
    w13_scale_2: Tensor,
    a13_scale: Tensor | None,
    w2: Tensor,
    w2_scale: Tensor,
    w2_scale_2: Tensor,
    a2_scale: Tensor | None,
    is_act_and_mul: bool,
) -> tuple[
    Tensor,
    Tensor,
    Tensor,
    Tensor,
    Tensor,
    Tensor,
    Tensor,
    Tensor,
]
Source code in vllm/model_executor/layers/fused_moe/oracle/nvfp4.py
def convert_to_nvfp4_moe_kernel_format(
    nvfp4_backend: NvFp4MoeBackend,
    layer: torch.nn.Module,
    w13: torch.Tensor,
    w13_scale: torch.Tensor,
    w13_scale_2: torch.Tensor,
    a13_scale: torch.Tensor | None,
    w2: torch.Tensor,
    w2_scale: torch.Tensor,
    w2_scale_2: torch.Tensor,
    a2_scale: torch.Tensor | None,
    is_act_and_mul: bool,
) -> tuple[
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor,
]:
    if (
        nvfp4_backend in FLASHINFER_NVFP4_MOE_BACKENDS
        or nvfp4_backend == NvFp4MoeBackend.VLLM_CUTLASS
    ):
        (
            w13,
            w13_scale,
            w13_scale_2,
            a13_scale,
            w2,
            w2_scale,
            w2_scale_2,
            a2_scale,
        ) = prepare_nvfp4_moe_layer_for_fi_or_cutlass(
            backend=nvfp4_backend,
            layer=layer,
            w13=w13,
            w13_scale=w13_scale,
            w13_scale_2=w13_scale_2,
            a13_scale=a13_scale,
            w2=w2,
            w2_scale=w2_scale,
            w2_scale_2=w2_scale_2,
            a2_scale=a2_scale,
            is_act_and_mul=is_act_and_mul,
        )
    elif nvfp4_backend == NvFp4MoeBackend.MARLIN:
        a13_scale = None
        a2_scale = None
        (
            w13,
            w13_scale,
            w13_scale_2,
            w2,
            w2_scale,
            w2_scale_2,
        ) = prepare_nvfp4_moe_layer_for_marlin(
            layer=layer,
            w13=w13,
            w13_scale=w13_scale,
            w13_scale_2=w13_scale_2,
            w2=w2,
            w2_scale=w2_scale,
            w2_scale_2=w2_scale_2,
        )
    else:
        raise ValueError(f"Unknown NvFp4 backend for MoE: {nvfp4_backend}")

    return (
        w13,
        w13_scale,
        w13_scale_2,
        a13_scale,
        w2,
        w2_scale,
        w2_scale_2,
        a2_scale,
    )

is_global_sf_supported_for_nvfp4_backend

is_global_sf_supported_for_nvfp4_backend(
    backend: NvFp4MoeBackend,
) -> bool
Source code in vllm/model_executor/layers/fused_moe/oracle/nvfp4.py
def is_global_sf_supported_for_nvfp4_backend(backend: NvFp4MoeBackend) -> bool:
    # Checks whether `backend` supports quantizing with scaling factors
    # of all experts in Expert Parallel Mode when all experts are not
    # on the same rank.

    return backend in FLASHINFER_NVFP4_MOE_BACKENDS

make_nvfp4_moe_kernel

make_nvfp4_moe_kernel(
    layer: FusedMoE,
    moe_quant_config: FusedMoEQuantConfig,
    moe_config: FusedMoEConfig,
    experts_cls: type[FusedMoEPermuteExpertsUnpermute],
) -> FusedMoEModularKernel
Source code in vllm/model_executor/layers/fused_moe/oracle/nvfp4.py
def make_nvfp4_moe_kernel(
    layer: "FusedMoE",
    moe_quant_config: FusedMoEQuantConfig,
    moe_config: FusedMoEConfig,
    experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute],
) -> mk.FusedMoEModularKernel:
    # Create Prepare/Finalize.
    prepare_finalize = maybe_make_prepare_finalize(
        moe=moe_config,
        quant_config=moe_quant_config,
        routing_tables=None,  # TODO: init routing tables here?
        defer_input_quant=experts_cls.should_pf_defer_input_quant(
            moe_config, moe_quant_config
        ),
        allow_new_interface=True,
    )
    assert prepare_finalize is not None

    logger.info_once("Using %s", prepare_finalize.__class__.__name__)

    # Create Experts.
    if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.Standard:
        experts = experts_cls.make_standard_experts(
            moe_config=moe_config,
            quant_config=moe_quant_config,
        )
    else:
        max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
        assert max_num_tokens_per_rank is not None
        experts = experts_cls.make_batched_experts(
            moe_config=moe_config,
            quant_config=moe_quant_config,
            max_num_tokens=max_num_tokens_per_rank,
            num_dispatchers=prepare_finalize.num_dispatchers(),
        )

    # NOTE(rob): we only want the ModularKernel to control the SharedExpert
    # if we are using all2all (for SBO). Need to make a change somewhere
    # else to prevent double running the Shared Expert.
    # This needs to be refactored.
    kernel = mk.FusedMoEModularKernel(
        prepare_finalize,
        experts,
        shared_experts=(
            getattr(layer, "shared_expert", None)
            if moe_config.moe_parallel_config.use_all2all_kernels
            else None
        ),
        moe_parallel_config=moe_config.moe_parallel_config,
    )

    return kernel

make_nvfp4_moe_quant_config

make_nvfp4_moe_quant_config(
    backend: NvFp4MoeBackend,
    w13_scale: Tensor,
    w2_scale: Tensor,
    w13_scale_2: Tensor,
    w2_scale_2: Tensor,
    a13_scale: Tensor,
    a2_scale: Tensor,
) -> FusedMoEQuantConfig | None
Source code in vllm/model_executor/layers/fused_moe/oracle/nvfp4.py
def make_nvfp4_moe_quant_config(
    backend: NvFp4MoeBackend,
    w13_scale: torch.Tensor,
    w2_scale: torch.Tensor,
    w13_scale_2: torch.Tensor,
    w2_scale_2: torch.Tensor,
    a13_scale: torch.Tensor,
    a2_scale: torch.Tensor,
) -> FusedMoEQuantConfig | None:
    UNSUPPORTED = [NvFp4MoeBackend.FLASHINFER_TRTLLM]
    if backend in UNSUPPORTED:
        return None

    elif backend == NvFp4MoeBackend.MARLIN:
        return nvfp4_w4a16_moe_quant_config(
            g1_alphas=w13_scale_2,
            g2_alphas=w2_scale_2,
            w1_scale=w13_scale,
            w2_scale=w2_scale,
        )

    g1_alphas = a13_scale * w13_scale_2
    g2_alphas = a2_scale * w2_scale_2
    return nvfp4_moe_quant_config(
        g1_alphas=g1_alphas,
        g2_alphas=g2_alphas,
        a1_gscale=(1.0 / a13_scale),
        a2_gscale=(1.0 / a2_scale),
        w1_scale=w13_scale,
        w2_scale=w2_scale,
    )

select_nvfp4_moe_backend

select_nvfp4_moe_backend(
    config: FusedMoEConfig,
    weight_key: QuantKey | None,
    activation_key: QuantKey | None,
) -> tuple[
    NvFp4MoeBackend,
    type[FusedMoEPermuteExpertsUnpermute] | None,
]

Select the primary NvFP4 MoE backend Note: Shape-specific fallbacks may still occur at runtime.

Source code in vllm/model_executor/layers/fused_moe/oracle/nvfp4.py
def select_nvfp4_moe_backend(
    config: FusedMoEConfig,
    weight_key: QuantKey | None,
    activation_key: QuantKey | None,
) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute] | None]:
    """
    Select the primary NvFP4 MoE backend
    Note: Shape-specific fallbacks may still occur at runtime.
    """

    # NOTE(rob): this is kind of a hack. We need to peak into
    # the prepare-finalize selection to determine if we are using
    # the batched or standard expert format.
    use_batched = (
        config.moe_parallel_config.use_deepep_ll_kernels
        or config.moe_parallel_config.use_pplx_kernels
    )
    activation_format = (
        mk.FusedMoEActivationFormat.BatchedExperts
        if use_batched
        else mk.FusedMoEActivationFormat.Standard
    )

    def _make_log_backend(backend: NvFp4MoeBackend):
        return f"Using '{backend.value}' backend for NvFp4 MoE"

    def _make_log_unsupported(backend: NvFp4MoeBackend, reason: str | None) -> str:
        if reason:
            return (
                f"NvFP4 MoE backend '{backend.value}' does not support the "
                f"deployment configuration since {reason}."
            )
        else:
            return (
                f"NvFP4 MoE backend '{backend.value}' does not support the "
                "deployment configuration."
            )

    def _return_or_raise(
        backend: NvFp4MoeBackend,
        config: FusedMoEConfig,
        weight_key: QuantKey | None,
        activation_key: QuantKey | None,
        activation_format: mk.FusedMoEActivationFormat,
    ) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute]]:
        k_cls = backend_2_kernel_cls(backend)
        supported, reason = k_cls.is_supported_config(
            k_cls, config, weight_key, activation_key, activation_format
        )
        if supported:
            logger.info_once(_make_log_backend(backend))
            return backend, k_cls
        raise ValueError(_make_log_unsupported(backend, reason))

    # NOTE: the kernels are selected in the following order.
    AVAILABLE_BACKENDS = [
        NvFp4MoeBackend.FLASHINFER_TRTLLM,
        NvFp4MoeBackend.FLASHINFER_CUTEDSL,
        NvFp4MoeBackend.FLASHINFER_CUTLASS,
        NvFp4MoeBackend.MARLIN,
        # NvFp4MoeBackend.VLLM_CUTLASS,
    ]

    if envs.is_set("VLLM_USE_FLASHINFER_MOE_FP4"):
        if not envs.VLLM_USE_FLASHINFER_MOE_FP4:
            # If the user rejects FlashInfer remove those backends.
            for fi_backend in FLASHINFER_NVFP4_MOE_BACKENDS:
                AVAILABLE_BACKENDS.remove(fi_backend)

        elif envs.is_set("VLLM_FLASHINFER_MOE_BACKEND"):
            # If user is explicit about backend, validate it.
            fi_backend = get_flashinfer_moe_backend()

            if fi_backend == FlashinferMoeBackend.TENSORRT_LLM:
                backend = NvFp4MoeBackend.FLASHINFER_TRTLLM
                supported, reason = is_supported_config_trtllm(
                    config, weight_key, activation_key, activation_format
                )
                if supported:
                    logger.info_once(_make_log_backend(backend))
                    return backend, None
                else:
                    raise ValueError(_make_log_unsupported(backend, reason))
            else:
                backend = fi_2_vllm_backend_map[fi_backend]
                return _return_or_raise(
                    backend, config, weight_key, activation_key, activation_format
                )
        else:
            # If the user is not explicit about the backend, try each.
            for backend in FLASHINFER_NVFP4_MOE_BACKENDS:
                k_cls = backend_2_kernel_cls(backend)
                if k_cls.is_supported_config(
                    k_cls, config, weight_key, activation_key, activation_format
                ):
                    logger.info_once(_make_log_backend(backend))
                    return backend, k_cls

            raise NotImplementedError(
                "Found VLLM_USE_FLASHINFER_MOE_FP4=1, but no "
                "FlashInfer NVFP4 MoE backend supports the configuration."
            )

    if envs.VLLM_TEST_FORCE_FP8_MARLIN:
        backend = NvFp4MoeBackend.MARLIN
        return _return_or_raise(
            backend, config, weight_key, activation_key, activation_format
        )

    # Select kernels in order of backend.
    for backend in AVAILABLE_BACKENDS:
        if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
            k_cls = None  # type: ignore[assignment]
            supported, reason = is_supported_config_trtllm(
                config,
                weight_key,
                activation_key,
                activation_format,
            )
        else:
            k_cls = backend_2_kernel_cls(backend)
            supported, reason = k_cls.is_supported_config(
                k_cls,
                config,
                weight_key,
                activation_key,
                activation_format,
            )

        if supported:
            logger.info_once(_make_log_backend(backend), scope="local")
            return backend, k_cls
        else:
            logger.info_once(_make_log_unsupported(backend, reason), scope="local")

    raise NotImplementedError(
        "No NvFp4 MoE backend supports the deployment configuration."
    )