diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 0edc5c684..e6219500d 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -457,6 +457,9 @@ struct Args { /// The embedding dimension to use for the model. #[clap(long, env)] embedding_dim: Option, + + #[clap(long, env)] + disable_sgmv: bool, } #[derive(Debug)] @@ -500,6 +503,7 @@ fn shard_manager( shutdown: Arc, _shutdown_sender: mpsc::Sender<()>, embedding_dim: Option, + disable_sgmv: bool, ) { // Enter shard-manager tracing span let _span = tracing::span!(tracing::Level::INFO, "shard-manager", rank = rank).entered(); @@ -640,6 +644,10 @@ fn shard_manager( envs.push(("FLASH_INFER".into(), "1".into())); } + if disable_sgmv { + envs.push(("DISABLE_SGMV".into(), "1".into())) + } + // Safetensors load fast envs.push(("SAFETENSORS_FAST_GPU".into(), "1".into())); @@ -1088,6 +1096,7 @@ fn spawn_shards( let merge_adapter_weights = args.merge_adapter_weights; let backend = args.backend; let embedding_dim = args.embedding_dim; + let disable_sgmv = args.disable_sgmv; thread::spawn(move || { shard_manager( model_id, @@ -1123,6 +1132,7 @@ fn spawn_shards( shutdown, shutdown_sender, embedding_dim, + disable_sgmv, ) }); }