Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

onnx.Loop #332

Open
Tracked by #215
renxida opened this issue Jan 10, 2024 · 20 comments
Open
Tracked by #215

onnx.Loop #332

renxida opened this issue Jan 10, 2024 · 20 comments
Assignees

Comments

@renxida
Copy link
Contributor

renxida commented Jan 10, 2024

This contains some notes from @renxida on onnx.If

@PhaneeshB is currently working on onnx.Loop

Find the if in retinanet_resnet50_fpn_vaiq_int8 model support
#199

@stellaraccident
Copy link
Contributor

(you'll need to extend the importer for this as it is a "special" op)

@renxida
Copy link
Contributor Author

renxida commented Jan 10, 2024

(you'll need to extend the importer for this as it is a "special" op)

I suspected I'll need more infra for this. Have simple test case onnx, will meet with rob to figure this out.

@renxida
Copy link
Contributor Author

renxida commented Jan 11, 2024

@renxida
Copy link
Contributor Author

renxida commented Jan 11, 2024

Deprioritizing this now.

@renxida renxida changed the title If If and Loop (ops that use compute graph sections as inputs) Jan 19, 2024
@renxida
Copy link
Contributor Author

renxida commented Jan 24, 2024

Back on it.

@renxida
Copy link
Contributor Author

renxida commented Jan 24, 2024

Todo list for now:

  • find the torch aten / prim ops and types for if op, loop op, and graph regions
  • find todo items for extending the importer

@renxida
Copy link
Contributor Author

renxida commented Jan 24, 2024

Torch ops to translate to:

PrimLoopOp

PrimIfOp

    This op (together with prim.If.yield) define a conditional control flow
    construct. It is analogous to `scf.if` for MLIR folks that are familiar
    with that. The main differences from that op are:

    - `!torch.bool` condition value.
    - The "else" region is always present. This is reflective of invariants of
      the TorchScript IR.
    - No special prettiness for the "no yielded values" case. These are
      interesting for modeling mostly-non-SSA programs, but TorchScript IR
      is already in SSA form.

PrimIfYieldOp

Methods / types of interest:

mlir::Region

PrimIfOp::getThenRegion
PrimIfOp::getElseRegion
PrimIfOp::getSuccessorRegions

@renxida
Copy link
Contributor Author

renxida commented Jan 24, 2024

Rob's Guide for Implementing If and Loop Constructs with Regions in ONNX Lowering

TLDR:

  • Regions are lists of related blocks. They don't provide methods or impose constraints, but other methods are built on top of the regions type to e.g. enforce single entry, single exit, or return types / return and yield ops.
  • Our onnx import doesn't provide classes, so we directly interface with onnx ops that are all represented as the same op with an attribute telling us which op it is.

Some concepts

Block: A basic block with a single entry and exit point. Fundamental unit within a region.
Region:

  • What Regions: A collection of connected basic blocks. Analogous to a function body, it can contain any number of blocks and represents a procedural logic unit within an operation.
  • Why Regions: The introduction of regions allows for a more structured and efficient approach to control flow constructs like If and Loop. Regions encapsulate a collection of connected blocks, facilitating common optimizations such as common subexpression elimination (CSE) and dead code elimination (DCE). They don't provide methods or impose constraints, but other methods are built on top of the regions type to e.g. enforce single entry, single exit, or return types / return and yield ops.

How Linalg and SCF use Regions

  • Linalg Operations: In the context of Linalg, regions function similarly to map-reduce components. A Linalg operation is a generic operation that can run both map-like and reduce-like operations using a region
  • Structured Control Flow (SCF) Dialect: The introduction of regions is pivotal in optimizing code that uses the SCF dialect. Regions facilitate reasoning about conditions and enable Dead Code Elimination (DCE) and Common Subexpression Elimination (CSE).

@renxida
Copy link
Contributor Author

renxida commented Jan 24, 2024

Rob gave this reference for manipulating regions:

https://github.com/llvm/llvm-project/blob/main/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp

Inline if case

Rewrite terminators: the region's parent defines the rules about what terminators the region has. E.g. "must have 1 terminator block that ends in a particular op. e.g. functions must end with return"

e.g. in SCF, 2 blocks for If / else. Has to "scf yield" at the end. What's getting returned is given in the yield.

e.g. in while, one of the blocks has to have a conditional, and the other block has to have a yield block.
the conditional block and a yield / body block.

before / after block for for style loops

do while: before block and after block.

scf while: yield op is both the conditional and the before block

@renxida
Copy link
Contributor Author

renxida commented Jan 24, 2024

onnx binder: binder.op

binder.op->getRegions() for a regionrange

RegionRange.size()

@renxida
Copy link
Contributor Author

renxida commented Jan 24, 2024

Move the regions to torch, then translate the ops inside the regions.
yield

@renxida renxida changed the title If and Loop (ops that use compute graph sections as inputs) ONNX Control Flow Ops (onnx.If and onnx.Loop) Jan 24, 2024
@renxida
Copy link
Contributor Author

renxida commented Jan 24, 2024

Running this to test if import:

wget -O onnx_conditional_example.py https://gist.githubusercontent.com/renxida/6e859dbfab286916dd8b99542c0a2332/raw/e38d1558a2bdc9553a14338ca84c8f1ae4bb40b5/onnx_conditional_example.py
python onnx_conditional_example.py
PYTHONPATH=~/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir python -m torch_mlir.tools.import_onnx ./conditional_example.onnx -o conditional_example.mlir

Current failure:

PYTHONPATH=~/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir python -m torch_mlir.tools.import_onnx ./conditional_example.onnx -o conditional_example.mlir
/home/azureuser/miniconda/lib/python3.11/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _torch_pytree._register_pytree_node(
Traceback (most recent call last):
  File "/home/azureuser/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/extras/onnx_importer.py", line 291, in import_node
    input_values.append(self._nv_map[input_name])
                        ~~~~~~~~~~~~^^^^^^^^^^^^
KeyError: 'cond'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/azureuser/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/tools/import_onnx/__main__.py", line 77, in <module>
    _cli_main()
  File "/home/azureuser/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/tools/import_onnx/__main__.py", line 73, in _cli_main
    sys.exit(main(parse_arguments()))
             ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/azureuser/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/tools/import_onnx/__main__.py", line 37, in main
    imp.import_all()
  File "/home/azureuser/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/extras/onnx_importer.py", line 262, in import_all
    self.import_node(node)
  File "/home/azureuser/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/extras/onnx_importer.py", line 293, in import_node
    raise OnnxImportError(
torch_mlir.extras.onnx_importer.OnnxImportError: Non topologically produced ONNX node input 'cond': input: "cond"
output: "then_res"
op_type: "If"
attribute {
  name: "then_branch"
  type: GRAPH
  g {
    node {
      output: "then_out"
      op_type: "Constant"
      attribute {
        name: "value"
        type: TENSOR
        t {
          dims: 5
          data_type: 1
          raw_data: "\000\000\200?\000\000\000@\000\000@@\000\000\200@\000\000\240@"
        }
      }
    }
    name: "then_body"
    output {
      name: "then_out"
      type {
        tensor_type {
          elem_type: 1
          shape {
            dim {
              dim_value: 5
            }
          }
        }
      }
    }
  }
}

@renxida
Copy link
Contributor Author

renxida commented Apr 10, 2024

continuing here: llvm/torch-mlir#3136

@AmosLewis
Copy link
Contributor

AmosLewis commented May 1, 2024

Support onnx.If #2825 llvm/torch-mlir#2825

@renxida renxida changed the title ONNX Control Flow Ops (onnx.If and onnx.Loop) onnx.Loop May 1, 2024
@PhaneeshB
Copy link
Contributor

support - llvm/torch-mlir#3408
Test - nod-ai/SHARK-TestSuite#248

@AmosLewis
Copy link
Contributor

loop failed again for dlrm pytorch model

@AmosLewis AmosLewis reopened this Jun 12, 2024
@PhaneeshB
Copy link
Contributor

PhaneeshB commented Jun 25, 2024

llvm/torch-mlir#3408
EDIT: PR Merged, closing issue

@pdhirajkumarprasad
Copy link

module {
  func.func @tf2onnx(%arg0: !torch.vtensor<[?,?,?,3],ui8>, %arg33: !torch.vtensor<[],si64>, %arg32: !torch.vtensor<[],i1>, %arg31: !torch.vtensor<[],si32>, %arg41: !torch.vtensor<[4],si64>, %arg42: !torch.vtensor<[],i1>, %arg43: !torch.vtensor<[],si32>, %arg44: !torch.vtensor<[3,300,300],f32>, %arg45: !torch.vtensor<[3],si32>) -> (!torch.vtensor<[],si32>, !torch.vtensor<[?,3,300,300],f32>, !torch.vtensor<[?,3],si32>) attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "onnx.quantize", torch.onnx_meta.producer_version = "0.1.0"} {
    %621:3 = torch.operator "onnx.Loop"(%arg33, %arg32, %arg31) : (!torch.vtensor<[],si64>, !torch.vtensor<[],i1>, !torch.vtensor<[],si32>) -> (!torch.vtensor<[],si32>, !torch.vtensor<[?,3,300,300],f32>, !torch.vtensor<[?,3],si32>) {
    ^bb0(%arg1: !torch.vtensor<[],si64>, %arg2: !torch.vtensor<[],i1>, %arg3: !torch.vtensor<[],si32>):
      torch.operator_terminator %arg42, %arg43, %arg44, %arg45 : !torch.vtensor<[],i1>, !torch.vtensor<[],si32>, !torch.vtensor<[3,300,300],f32>, !torch.vtensor<[3],si32>
    }
    return %621#0, %621#1, %621#2  : !torch.vtensor<[],si32>, !torch.vtensor<[?,3,300,300],f32>, !torch.vtensor<[?,3],si32>
  }
}

@pdhirajkumarprasad
Copy link

ir.txt

complete IR

@kumardeepakamd
Copy link

@PhaneeshB is innx.loop not done yet? Should this issue be closed?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants