@MULTIMODAL_REGISTRY.register_processor(
Qwen3VLMultiModalProcessor,
info=Qwen3_5ProcessingInfo,
dummy_inputs=Qwen3VLDummyInputsBuilder,
)
class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid):
# Qwen3.5 does not support multimodal pruning (EVS).
supports_multimodal_pruning = False
packed_modules_mapping = Qwen3VLForConditionalGeneration.packed_modules_mapping | {
"in_proj_qkvz": ["in_proj_qkv", "in_proj_z"],
"in_proj_ba": ["in_proj_b", "in_proj_a"],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"):
# protocols have not __init__ method, so we need to use nn.Module.__init__
nn.Module.__init__(self)
self.update_packed_mapping(enable_lora=vllm_config.lora_config is not None)
config: Qwen3_5Config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
# Qwen3.5 does not support multimodal pruning (EVS).
self.is_multimodal_pruning_enabled = False
with self._mark_tower_model(vllm_config, {"image", "video"}):
self.visual = Qwen3_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
)
with self._mark_language_model(vllm_config):
self.language_model = Qwen3_5ForCausalLM(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "language_model")
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
)
def update_packed_mapping(self, enable_lora: bool):
# When LoRA is enabled, GDN uses separate in_proj_qkv and in_proj_z
if enable_lora:
base = getattr(
Qwen3_5ForConditionalGeneration, "packed_modules_mapping", {}
)
self.packed_modules_mapping = {k: list(v) for k, v in base.items()}
self.packed_modules_mapping.pop("in_proj_qkvz", None)
self.packed_modules_mapping["in_proj_qkv"] = ["in_proj_qkv"]
def embed_input_ids(
self,
input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
) -> torch.Tensor:
inputs_embeds = self._embed_text_input_ids(
input_ids,
self.language_model.embed_input_ids,
is_multimodal=is_multimodal,
)
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
return inputs_embeds
is_multimodal = _require_is_multimodal(is_multimodal)
inputs_embeds = _merge_multimodal_embeddings(
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
)
return inputs_embeds
def recompute_mrope_positions(self, *args, **kwargs):
raise NotImplementedError(
"Qwen3.5 does not support multimodal pruning (EVS). "
"recompute_mrope_positions should never be called."
)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
) -> torch.Tensor | IntermediateTensors:
"""Run forward pass for Qwen3.5.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
positions: Flattened (concatenated) position ids corresponding to a
batch.
**NOTE**: If mrope is enabled (default setting for Qwen3VL
opensource models), the shape will be `(3, seq_len)`,
otherwise it will be `(seq_len,).
intermediate_tensors: Intermediate tensors from previous pipeline
stages.
inputs_embeds: Pre-computed input embeddings.
**kwargs: Additional keyword arguments including:
- pixel_values: Pixel values to be fed to a model.
`None` if no images are passed.
- image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in
LLM. `None` if no images are passed.
- pixel_values_videos: Pixel values of videos to be fed to a
model. `None` if no videos are passed.
- video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in
LLM. `None` if no videos are passed.
"""
if intermediate_tensors is not None:
inputs_embeds = None
hidden_states = self.language_model.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
return hidden_states
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=["mtp."],
)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
@classmethod
def get_mamba_state_dtype_from_config(
cls,
vllm_config: "VllmConfig",
) -> tuple[torch.dtype, torch.dtype]:
return MambaStateDtypeCalculator.gated_delta_net_state_dtype(
vllm_config.model_config.dtype,
vllm_config.cache_config.mamba_cache_dtype,
vllm_config.cache_config.mamba_ssm_cache_dtype,
)
@classmethod
def get_mamba_state_shape_from_config(
cls, vllm_config: "VllmConfig"
) -> tuple[tuple[int, int], tuple[int, int]]:
parallel_config = vllm_config.parallel_config
hf_config = vllm_config.model_config.hf_text_config
tp_size = parallel_config.tensor_parallel_size
num_spec = (
vllm_config.speculative_config.num_speculative_tokens
if vllm_config.speculative_config
else 0
)
return MambaStateShapeCalculator.gated_delta_net_state_shape(
tp_size,
hf_config.linear_num_key_heads,
hf_config.linear_num_value_heads,
hf_config.linear_key_head_dim,
hf_config.linear_value_head_dim,
hf_config.linear_conv_kernel_dim,
num_spec,
)
@classmethod
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
return MambaStateCopyFuncCalculator.gated_delta_net_state_copy_func()