Skip to content

Commit

Permalink
Added torch compile row to pytorch install table
Browse files Browse the repository at this point in the history
  • Loading branch information
suryasidd committed Oct 9, 2024
1 parent 78250ba commit bf51948
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 0 deletions.
115 changes: 115 additions & 0 deletions _includes/quick-start-module.js
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ var opts = {
pm: 'pip',
language: 'python',
ptbuild: 'stable',
'torch-compile': null
};

var supportedCloudPlatforms = [
Expand All @@ -34,6 +35,7 @@ var package = $(".package > .option");
var language = $(".language > .option");
var cuda = $(".cuda > .option");
var ptbuild = $(".ptbuild > .option");
var torchCompile = $(".torch-compile > .option")

os.on("click", function() {
selectedOption(os, this, "os");
Expand All @@ -50,6 +52,9 @@ cuda.on("click", function() {
ptbuild.on("click", function() {
selectedOption(ptbuild, this, "ptbuild")
});
torchCompile.on("click", function() {
selectedOption(torchCompile, this, "torch-compile")
});

// Pre-select user's operating system
$(function() {
Expand Down Expand Up @@ -168,6 +173,110 @@ function changeAccNoneName(osname) {
}
}

function getIDFromBackend(backend) {
const idTobackendMap = {
inductor: 'inductor',
cgraphs : 'cudagraphs',
onnxrt: 'onnxrt',
openvino: 'openvino',
tensorrt: 'tensorrt',
tvm: 'tvm',
};
return idTobackendMap[backend];
}

function getPmCmd(backend) {
const pmCmd = {
onnxrt: 'onnxruntime',
tvm: 'apache-tvm',
openvino: 'openvino',
tensorrt: 'torch-tensorrt',
};
return pmCmd[backend];
}

function getImportCmd(backend) {
const importCmd = {
onnxrt: 'import onnxruntime',
tvm: 'import tvm',
openvino: 'import openvino.torch',
tensorrt: 'import torch_tensorrt'
}
return importCmd[backend];
}

function getInstallCommand(optionID) {
backend = getIDFromBackend(optionID);
pmCmd = getPmCmd(optionID);
finalCmd = "";
if (opts.pm == "pip") {
finalCmd = `pip3 install ${pmCmd}`;
}
else if (opts.pm == "conda") {
finalCmd = `conda install ${pmCmd}`;
}
return finalCmd;
}

function getTorchCompileUsage(optionId) {
backend = getIDFromBackend(optionId);
importCmd = "<br>" + getImportCmd(optionId) + "<br>";
finalCmd = "";
tcUsage = "# Torch Compile usage: ";
backendCmd = `torch.compile(model, backend="${backend}")`;
libtorchCmd = `# Torch compile ${backend} not supported with Libtorch`;

if (opts.pm == "libtorch") {
return libtorchCmd;
}
if (backend == "openvino") {
if (opts.pm == "source") {
finalCmd += "# Follow instructions at this URL to build openvino from source: https://github.com/openvinotoolkit/openvino/blob/master/docs/dev/build.md" + "<br>" ;
tcUsage += importCmd;
}
else if (opts.pm == "conda") {
tcUsage += importCmd;
}
if (opts.os == "windows" && !tcUsage.includes(importCmd)) {
tcUsage += importCmd;
}
}
else{
tcUsage += importCmd;
}
if (backend == "onnxrt") {
if (opts.pm == "source") {
finalCmd += "# Follow instructions at this URL to build onnxruntime from source: https://onnxruntime.ai/docs/build" + "<br>" ;
}
}
if (backend == "tvm") {
if (opts.pm == "source") {
finalCmd += "# Follow instructions at this URL to build tvm from source: https://tvm.apache.org/docs/install/from_source.html" + "<br>" ;
}
}
if (backend == "tensorrt") {
if (opts.pm == "source") {
finalCmd += "# Follow instructions at this URL to build tensorrt from source: https://pytorch.org/TensorRT/getting_started/installation.html#compiling-from-source" + "<br>" ;
}
}
finalCmd += tcUsage + backendCmd;
return finalCmd
}

function addTorchCompileCommandNote(selectedOptionId) {

if (!selectedOptionId) {
return;
}

$("#command").append(
`<pre> ${getInstallCommand(selectedOptionId)} </pre>`
);
$("#command").append(
`<pre> ${getTorchCompileUsage(selectedOptionId)} </pre>`
);
}

function selectedOption(option, selection, category) {
$(option).removeClass("selected");
$(selection).addClass("selected");
Expand Down Expand Up @@ -208,13 +317,19 @@ function selectedOption(option, selection, category) {
changeVersion(opts.ptbuild);
//make sure unsupported platforms are disabled
disableUnsupportedPlatforms(opts.os);
} else if (category === "torch-compile") {
if (selection.id === previousSelection) {
$(selection).removeClass("selected");
opts[category] = null;
}
}
commandMessage(buildMatcher());
if (category === "os") {
disableUnsupportedPlatforms(opts.os);
display(opts.os, 'installation', 'os');
}
changeAccNoneName(opts.os);
addTorchCompileCommandNote(opts['torch-compile'])
}

function display(selection, id, category) {
Expand Down
28 changes: 28 additions & 0 deletions _includes/quick_start_local.html
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
<div class="col-md-12 title-block">
<div class="option-text">Compute Platform</div>
</div>
<div class="col-md-12 title-block">
<div class="option-text">Torch Compile</div>
</div>
<div class="col-md-12 title-block command-block">
<div class="option-text command-text">Run this Command:</div>
</div>
Expand Down Expand Up @@ -103,6 +106,31 @@
<div class="option-text">CPU</div>
</div>
</div>
<div class="row torch-compile">
<!-- Section Label -->
<div class="col-md-12 title-block mobile-heading">
<div class="option-text">Torch Compile</div>
</div>
<!-- Section Label -->
<div class="col-md-2 option block version" id="inductor">
<div class="option-text">Inductor</div>
</div>
<div class="col-md-2 option block version" id="cgraphs">
<div class="option-text">CUDA Graphs</div>
</div>
<div class="col-md-2 option block version" id="openvino">
<div class="option-text">OpenVINO</div>
</div>
<div class="col-md-2 option block version" id="onnxrt">
<div class="option-text">ONNX Runtime</div>
</div>
<div class="col-md-2 option block version" id="tensorrt">
<div class="option-text">TensorRT</div>
</div>
<div class="col-md-2 option block version" id="tvm">
<div class="option-text">TVM</div>
</div>
</div>
<div class="row">
<div class="col-md-12 title-block command-mobile-heading">
<div class="option-text">Run this Command:</div>
Expand Down

0 comments on commit bf51948

Please sign in to comment.