Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Web] An error occurred during model execution: "TypeError: Cannot read properties of undefined (reading 'apply')". #15719

Closed
xenova opened this issue Apr 27, 2023 · 12 comments
Labels
model:transformer issues related to a transformer model: BERT, GPT2, Hugging Face, Longformer, T5, etc. platform:web issues related to ONNX Runtime web; typically submitted using template

Comments

@xenova
Copy link

xenova commented Apr 27, 2023

Describe the issue

I am unable to use the latest build as the backend for Transformers.js due to the following error:

Uncaught (in promise) TypeError: Cannot read properties of undefined (reading 'apply')
    at Ga (ort.js:22426:361)
    at ort-wasm-simd.wasm:0x8c2802
    at Pa.b.<computed> (ort.js:22436:150)
    at c._OrtRun (ort.js:22452:317)
    at Object.run (ort.js:21905:1)
    at run (ort.js:21229:1)
    at OnnxruntimeWebAssemblySessionHandler.run (ort.js:21395:1)
    at InferenceSession.run (ort.js:351:1)

Unfortunately, it's quite difficult to tell what's going wrong as the minification doesn't help debugging.

To reproduce

  1. Build from source, following instructions here. For convenience, here are my dist files: dist.zip
  2. Load model from here: https://huggingface.co/Xenova/transformers.js/tree/main/quantized/t5-small/seq2seq-lm-with-past
  3. Run with this input:
    let input = {
        attention_mask: new Tensor(
            'int64',
            new BigInt64Array([1n, 1n, 1n, 1n, 1n, 1n, 1n, 1n, 1n, 1n, 1n, 1n]),
            [1, 12]
        ),
        input_ids: new Tensor(
            'int64',
            new BigInt64Array([13959n, 1566n, 12n, 2379n, 10n, 8774n, 6n, 149n, 33n, 25n, 58n, 1n]),
            [1, 12]
        )
    }

Will throw the above error. It occurs for both WASM and WebGPU.

Urgency

Relatively urgent as - if this is a problem with the build itself - it will push back the webgpu release.

ONNX Runtime Installation

Built from Source

ONNX Runtime Version or Commit ID

2efb75b

Execution Provider

WebGL, WASM, Other / Unknown

@xenova xenova added the platform:web issues related to ONNX Runtime web; typically submitted using template label Apr 27, 2023
@github-actions github-actions bot added the model:transformer issues related to a transformer model: BERT, GPT2, Hugging Face, Longformer, T5, etc. label Apr 27, 2023
@xenova
Copy link
Author

xenova commented Apr 27, 2023

Here's a full demo which shows the problem. Just replace the first script link to your ort.js file and place your wasm files in the correct folder.

<script src="http://localhost:8080/dist/ort.js"></script>

<script>

    document.addEventListener('DOMContentLoaded', async () => {

        // Load model
        let url = 'https://huggingface.co/Xenova/bert-base-cased_web/resolve/main/onnx/model_quantized.onnx'
        let model = await fetch(url);
        let buffer = await model.arrayBuffer();
        let array = new Uint8Array(buffer);

        // Create a new session
        let session = await ort.InferenceSession.create(array, {
            executionProviders: ['wasm'],
        })

        // Construct model input
        let input = {
            attention_mask: new ort.Tensor(
                'int64',
                new BigInt64Array([1n, 1n, 1n, 1n, 1n, 1n, 1n, 1n, 1n]),
                [1, 9],
            ),
            input_ids: new ort.Tensor(
                'int64',
                // The goal of life is [MASK].
                new BigInt64Array([101n, 1109n, 2273n, 1104n, 1297n, 1110n, 103n, 119n, 102n]),
                [1, 9],
            ),
            token_type_ids: new ort.Tensor(
                'int64',
                new BigInt64Array([0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n, 0n]),
                [1, 9],
            ),
        }

        // Run the model with the input
        let output = await session.run(input);

        let logits = output.logits;
        let [batchSize, sequenceLength, vocabSize] = logits.dims;

        // Get logits of masked token
        let maskIndex = 6;
        let masked = logits.data.slice(maskIndex * vocabSize, (maskIndex + 1) * vocabSize);

        let maxIndex = 0;
        let maxValue = masked[0];
        for (let i = 1; i < masked.length; ++i) {
            if (masked[i] > maxValue) {
                maxValue = masked[i];
                maxIndex = i;
            }
        }

        // Display output
        // maxIndex === 8115, maxValue === 8.050320625305176
        alert(`maxIndex=${maxIndex}, maxValue=${maxValue}`)
        alert(`data=${logits.data.slice(0, 100)}`)
    });

</script>

produces:

ort-wasm.js:28 Uncaught (in promise) TypeError: Cannot read properties of undefined (reading 'apply')
    at Ga (ort-wasm.js:28:361)
    at ort-wasm-simd.wasm:0x8c2802
    at Pa.b.<computed> (ort-wasm.js:38:150)
    at c._OrtRun (ort-wasm.js:54:317)
    at Object.run (wasm-core-impl.ts:213:28)
    at run (proxy-wrapper.ts:221:17)
    at OnnxruntimeWebAssemblySessionHandler.run (session-handler.ts:82:18)
    at InferenceSession.run (inference-session-impl.js:91:1)
    at HTMLDocument.<anonymous> (webgpu.html:39:28)

@fs-eire
Copy link
Contributor

fs-eire commented Apr 27, 2023

I am aware of this issue. It might be a long story to explain... I will try to update with more info later.

To workaround, disable SIMD and Threads:

ort.env.wasm.simd = false;
ort.env.wasm.numThreads = 1;

and this unblocks you to try WebGPU EP.

@xenova
Copy link
Author

xenova commented Apr 27, 2023

Okay - that made it work for WASM (albeit significantly slower).

However, WebGPU didn't work due to the following issue:

Non-zero status code returned while running Transpose node. Name:'/block.0/layer.0/SelfAttention/Transpose_3' Status Message: Failed to run JSEP kernel

An error occurred during model execution: "Error: failed to call OrtRun(). error code = 1."

I assume it's just an unsupported operation?

@xenova
Copy link
Author

xenova commented Apr 27, 2023

I tried with a different model for image classification, and although it works, it is slower than the WASM implementation :/

WASM: 13807.7 milliseconds
WebGPU: 14620.1 milliseconds

I assume that disabling SIMD has this effect?

It also gives this warning:

ort-wasm.js:49 2023-04-27 21:26:05.920300 [W:onnxruntime:, session_state.cc:1169 VerifyEachNodeIsAssignedToAnEp] Some nodes were not assigned to the preferred execution providers which may or may not have an negative impact on performance. e.g. ORT explicitly assigns shape related ops to CPU to improve perf.

@fs-eire
Copy link
Contributor

fs-eire commented Apr 27, 2023

a few tools that can be used to taking deeper look at it:

  • env.logLevel = 'verbose'; env.debug = true; - This will let onnxruntime-web to output some logs helpful for analysing the execution. including telling which operators are running on webgpu and which are on CPU (fallback). to improve performance caused by fallback we need to improve the operator coverage. I can help to implement the missing ops.
  • env.webgpu.profilingMode = 'default'; - This will output quite a lot of logs into console for each webgpu shaders - by aggregating and analyzing those we can know which shader is slow.

@xenova
Copy link
Author

xenova commented Apr 27, 2023

a few tools that can be used to taking deeper look at it:

  • env.logLevel = 'verbose'; env.debug = true; - This will let onnxruntime-web to output some logs helpful for analysing the execution. including telling which operators are running on webgpu and which are on CPU (fallback). to improve performance caused by fallback we need to improve the operator coverage. I can help to implement the missing ops.
  • env.webgpu.profilingMode = 'default'; - This will output quite a lot of logs into console for each webgpu shaders - by aggregating and analyzing those we can know which shader is slow.

Great, thanks! Would you like me to run it and send the logs here?

@fs-eire
Copy link
Contributor

fs-eire commented Apr 28, 2023

please send out logs - if model failed, or operator coverage issue, or bad performance, create new issues so that I can track them easier.

@xenova
Copy link
Author

xenova commented Apr 28, 2023

here are the (very long) logs. I'm not too sure how to interpret it, but I can see that some of the later operations take like 2 seconds for some reason.

@xenova
Copy link
Author

xenova commented May 3, 2023

Hi. Were you able to find the error?

@fs-eire
Copy link
Contributor

fs-eire commented May 3, 2023

There are 3 issues that being discussed in this issue:

I tracked the later 2 issues in #15796

@xenova
Copy link
Author

xenova commented May 3, 2023

Okay great - thanks! I'll continue the conversation for the last 2 issues there. This issue can be closed when #15780 is merged.

@fs-eire
Copy link
Contributor

fs-eire commented May 4, 2023

The PR is merged.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
model:transformer issues related to a transformer model: BERT, GPT2, Hugging Face, Longformer, T5, etc. platform:web issues related to ONNX Runtime web; typically submitted using template
Projects
None yet
Development

No branches or pull requests

2 participants