Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ValueError: implementation='triton' is unsupported on this GPU generation. #90

Closed
jxwangxl opened this issue Nov 21, 2024 · 5 comments
Closed
Labels
question Further information is requested

Comments

@jxwangxl
Copy link

jxwangxl commented Nov 21, 2024

$ docker run -it     --volume /mnt/d/home/alphafold3/af_input:/root/af_input     --volume /mnt/d/home/alphafold3/af_output:/root/af_output     --volume /mnt/e/AlphaFold/data_alphafold3/models:/root/models     --volume /mnt/e/AlphaFold/data_alphafold3:/root/public_databases     --gpus all     alphafold3     python run_alphafold.py     --json_path=/root/af_input/fold_input.json     --model_dir=/root/models     --output_dir=/root/af_output
I1120 09:32:09.812681 139963233902592 folding_input.py:1044] Detected /root/af_input/fold_input.json is an AlphaFold 3 JSON since the top-level is not a list.
Running AlphaFold 3. Please note that standard AlphaFold 3 model parameters are
only available under terms of use provided at
https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md.
If you do not agree to these terms and are using AlphaFold 3 derived model
parameters, cancel execution of AlphaFold 3 inference with CTRL-C, and do not
use the model parameters.
I1120 09:32:10.553918 139963233902592 xla_bridge.py:895] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I1120 09:32:10.559973 139963233902592 xla_bridge.py:895] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
Found local devices: [CudaDevice(id=0)]
Building model from scratch...
Processing 1 fold inputs.
Processing fold input 2PV7
Checking we can load the model parameters...
Running data pipeline...
Processing chain A
I1120 09:42:09.817233 139963233902592 pipeline.py:40] Getting protein MSAs for sequence YDFTNCDFEKIKAAYLSTISKDLITYMSGTKSTEFNNTVSCSNRPHCLTEIQSLTFNPTAGCASLAKEMFAMKTKAALAIWCPGYSETQINATQAMKKRRKRKVTTNKCLEQVSQLQGLWRRFNRPLLKQQ
I1120 09:42:09.879023 139958312490560 jackhmmer.py:78] Query sequence: YDFTNCDFEKIKAAYLSTISKDLITYMSGTKSTEFNNTVSCSNRPHCLTEIQSLTFNPTAGCASLAKEMFAMKTKAALAIWCPGYSETQINATQAMKKRRKRKVTTNKCLEQVSQLQGLWRRFNRPLLKQQ
I1120 09:42:09.883871 139958312490560 subprocess_utils.py:68] Launching subprocess "/hmmer/bin/jackhmmer -o /dev/null -A /tmp/tmpdmgibfxi/output.sto --noali --F1 0.0005 --F2 5e-05 --F3 5e-07 --cpu 8 -N 1 -E 0.0001 --incE 0.0001 /tmp/tmpdmgibfxi/query.fasta /root/public_databases/mgy_clusters_2022_05.fa"
I1120 09:42:10.085490 139958304097856 jackhmmer.py:78] Query sequence: YDFTNCDFEKIKAAYLSTISKDLITYMSGTKSTEFNNTVSCSNRPHCLTEIQSLTFNPTAGCASLAKEMFAMKTKAALAIWCPGYSETQINATQAMKKRRKRKVTTNKCLEQVSQLQGLWRRFNRPLLKQQ
I1120 09:42:10.085645 139958320883264 jackhmmer.py:78] Query sequence: YDFTNCDFEKIKAAYLSTISKDLITYMSGTKSTEFNNTVSCSNRPHCLTEIQSLTFNPTAGCASLAKEMFAMKTKAALAIWCPGYSETQINATQAMKKRRKRKVTTNKCLEQVSQLQGLWRRFNRPLLKQQ
I1120 09:42:10.085771 139957683353152 jackhmmer.py:78] Query sequence: YDFTNCDFEKIKAAYLSTISKDLITYMSGTKSTEFNNTVSCSNRPHCLTEIQSLTFNPTAGCASLAKEMFAMKTKAALAIWCPGYSETQINATQAMKKRRKRKVTTNKCLEQVSQLQGLWRRFNRPLLKQQ
I1120 09:42:10.086443 139958304097856 subprocess_utils.py:68] Launching subprocess "/hmmer/bin/jackhmmer -o /dev/null -A /tmp/tmpmpnk0lcn/output.sto --noali --F1 0.0005 --F2 5e-05 --F3 5e-07 --cpu 8 -N 1 -E 0.0001 --incE 0.0001 /tmp/tmpmpnk0lcn/query.fasta /root/public_databases/bfd-first_non_consensus_sequences.fasta"
I1120 09:42:10.087563 139958320883264 subprocess_utils.py:68] Launching subprocess "/hmmer/bin/jackhmmer -o /dev/null -A /tmp/tmp6hnqx_y7/output.sto --noali --F1 0.0005 --F2 5e-05 --F3 5e-07 --cpu 8 -N 1 -E 0.0001 --incE 0.0001 /tmp/tmp6hnqx_y7/query.fasta /root/public_databases/uniref90_2022_05.fa"
I1120 09:42:10.088450 139957683353152 subprocess_utils.py:68] Launching subprocess "/hmmer/bin/jackhmmer -o /dev/null -A /tmp/tmpu3anvw5q/output.sto --noali --F1 0.0005 --F2 5e-05 --F3 5e-07 --cpu 8 -N 1 -E 0.0001 --incE 0.0001 /tmp/tmpu3anvw5q/query.fasta /root/public_databases/uniprot_all_2021_04.fa"
I1120 10:27:47.869025 139958304097856 subprocess_utils.py:97] Finished Jackhmmer in 2737.782 seconds
I1120 12:00:20.748736 139958320883264 subprocess_utils.py:97] Finished Jackhmmer in 8290.661 seconds
I1120 12:58:05.194634 139957683353152 subprocess_utils.py:97] Finished Jackhmmer in 11755.106 seconds
I1121 01:09:48.787940 139958312490560 subprocess_utils.py:97] Finished Jackhmmer in 55658.904 seconds
I1121 01:09:48.792834 139963233902592 pipeline.py:73] Getting protein MSAs took 55658.97 seconds for sequence YDFTNCDFEKIKAAYLSTISKDLITYMSGTKSTEFNNTVSCSNRPHCLTEIQSLTFNPTAGCASLAKEMFAMKTKAALAIWCPGYSETQINATQAMKKRRKRKVTTNKCLEQVSQLQGLWRRFNRPLLKQQ
I1121 01:09:48.792947 139963233902592 pipeline.py:79] Deduplicating MSAs and getting protein templates for sequence YDFTNCDFEKIKAAYLSTISKDLITYMSGTKSTEFNNTVSCSNRPHCLTEIQSLTFNPTAGCASLAKEMFAMKTKAALAIWCPGYSETQINATQAMKKRRKRKVTTNKCLEQVSQLQGLWRRFNRPLLKQQ
I1121 01:09:48.982047 139958312490560 subprocess_utils.py:68] Launching subprocess "/hmmer/bin/hmmbuild --informat stockholm --hand --amino /tmp/tmpzca3drih/output.hmm /tmp/tmpzca3drih/query.msa"
I1121 01:09:49.071169 139958312490560 subprocess_utils.py:97] Finished Hmmbuild in 0.089 seconds
I1121 01:09:49.072391 139958312490560 subprocess_utils.py:68] Launching subprocess "/hmmer/bin/hmmsearch --noali --cpu 8 --F1 0.1 --F2 0.1 --F3 0.1 -E 100 --incE 100 --domE 100 --incdomE 100 -A /tmp/tmpal0uczxe/output.sto /tmp/tmpal0uczxe/query.hmm /root/public_databases/pdb_seqres_2022_09_28.fasta"
I1121 01:10:10.507845 139958312490560 subprocess_utils.py:97] Finished Hmmsearch in 21.435 seconds
I1121 01:32:03.485082 139963233902592 pipeline.py:108] Deduplicating MSAs and getting protein templates took 1334.69 seconds for sequence YDFTNCDFEKIKAAYLSTISKDLITYMSGTKSTEFNNTVSCSNRPHCLTEIQSLTFNPTAGCASLAKEMFAMKTKAALAIWCPGYSETQINATQAMKKRRKRKVTTNKCLEQVSQLQGLWRRFNRPLLKQQ
I1121 01:32:03.485309 139963233902592 pipeline.py:115] Filtering protein templates for sequence YDFTNCDFEKIKAAYLSTISKDLITYMSGTKSTEFNNTVSCSNRPHCLTEIQSLTFNPTAGCASLAKEMFAMKTKAALAIWCPGYSETQINATQAMKKRRKRKVTTNKCLEQVSQLQGLWRRFNRPLLKQQ
I1121 01:32:03.494862 139963233902592 pipeline.py:124] Filtering protein templates took 0.01 seconds for sequence YDFTNCDFEKIKAAYLSTISKDLITYMSGTKSTEFNNTVSCSNRPHCLTEIQSLTFNPTAGCASLAKEMFAMKTKAALAIWCPGYSETQINATQAMKKRRKRKVTTNKCLEQVSQLQGLWRRFNRPLLKQQ
Processing chain A took 56993.81 seconds
Output directory: /root/af_output/2pv7
Writing model input JSON to /root/af_output/2pv7
Predicting 3D structure for 2PV7 for seed(s) (1,)...
Featurising data for seeds (1,)...
Featurising 2PV7 with rng_seed 1.
I1121 01:32:11.204071 139963233902592 pipeline.py:160] processing 2PV7, random_seed=1
Featurising 2PV7 with rng_seed 1 took 2.34 seconds.
Featurising data for seeds (1,) took  9.89 seconds.
Running model inference for seed 1...
Traceback (most recent call last):
  File "/app/alphafold/run_alphafold.py", line 678, in <module>
    app.run(main)
  File "/alphafold3_venv/lib/python3.11/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/alphafold3_venv/lib/python3.11/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
             ^^^^^^^^^^
  File "/app/alphafold/run_alphafold.py", line 663, in main
    process_fold_input(
  File "/app/alphafold/run_alphafold.py", line 542, in process_fold_input
    all_inference_results = predict_structure(
                            ^^^^^^^^^^^^^^^^^^
  File "/app/alphafold/run_alphafold.py", line 375, in predict_structure
    result = model_runner.run_inference(example, rng_key)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/app/alphafold/run_alphafold.py", line 313, in run_inference
    result = self._model(rng_key, featurised_example)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/alphafold3_venv/lib/python3.11/site-packages/haiku/_src/transform.py", line 183, in apply_fn
    out, state = f.apply(params, None, *args, **kwargs)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/alphafold3_venv/lib/python3.11/site-packages/haiku/_src/transform.py", line 456, in apply_fn
    out = f(*args, **kwargs)
  ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/app/alphafold/run_alphafold.py", line 294, in forward_fn
    result = self._model_class(self._model_config)(batch)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/alphafold3_venv/lib/python3.11/site-packages/haiku/_src/module.py", line 464, in wrapped
    out = f(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/alphafold3_venv/lib/python3.11/site-packages/haiku/_src/module.py", line 305, in run_interceptors
    return bound_method(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/alphafold3_venv/lib/python3.11/site-packages/alphafold3/model/diffusion/model.py", line 279, in __call__
    embeddings, _ = hk.fori_loop(0, num_iter, recycle_body, (embeddings, key))
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/alphafold3_venv/lib/python3.11/site-packages/haiku/_src/stateful.py", line 697, in fori_loop
    state, val = jax.lax.fori_loop(lower, upper, pure_body_fun, init_val)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/alphafold3_venv/lib/python3.11/site-packages/haiku/_src/stateful.py", line 677, in pure_body_fun
    val = body_fun(i, val)
          ^^^^^^^^^^^^^^^^
  File "/alphafold3_venv/lib/python3.11/site-packages/alphafold3/model/diffusion/model.py", line 252, in recycle_body
    embeddings = embedding_module(
                 ^^^^^^^^^^^^^^^^^
  File "/alphafold3_venv/lib/python3.11/site-packages/haiku/_src/module.py", line 464, in wrapped
    out = f(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/alphafold3_venv/lib/python3.11/site-packages/haiku/_src/module.py", line 305, in run_interceptors
    return bound_method(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/alphafold3_venv/lib/python3.11/site-packages/alphafold3/model/diffusion/model.py", line 734, in __call__
    pair_activations, key = self._embed_template_pair(
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/alphafold3_venv/lib/python3.11/site-packages/alphafold3/model/diffusion/model.py", line 644, in _embed_template_pair
    template_act = template_fn(
                   ^^^^^^^^^^^^
  File "/alphafold3_venv/lib/python3.11/site-packages/haiku/_src/module.py", line 464, in wrapped
    out = f(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/alphafold3_venv/lib/python3.11/site-packages/haiku/_src/module.py", line 305, in run_interceptors
    return bound_method(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/alphafold3_venv/lib/python3.11/site-packages/alphafold3/model/diffusion/template_modules.py", line 194, in __call__
    summed_template_embeddings, _ = hk.scan(
                                    ^^^^^^^^
  File "/alphafold3_venv/lib/python3.11/site-packages/haiku/_src/stateful.py", line 643, in scan
    (carry, state), ys = jax.lax.scan(
                         ^^^^^^^^^^^^^
  File "/alphafold3_venv/lib/python3.11/site-packages/haiku/_src/stateful.py", line 626, in stateful_fun
    carry, out = f(carry, x)
                 ^^^^^^^^^^^
  File "/alphafold3_venv/lib/python3.11/site-packages/alphafold3/model/diffusion/template_modules.py", line 182, in scan_fn
    embedding = template_embedder(
                ^^^^^^^^^^^^^^^^^^
  File "/alphafold3_venv/lib/python3.11/site-packages/haiku/_src/module.py", line 464, in wrapped
    out = f(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/alphafold3_venv/lib/python3.11/site-packages/haiku/_src/module.py", line 305, in run_interceptors
    return bound_method(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/alphafold3_venv/lib/python3.11/site-packages/alphafold3/model/diffusion/template_modules.py", line 348, in __call__
    act = template_stack(act)
          ^^^^^^^^^^^^^^^^^^^
  File "/alphafold3_venv/lib/python3.11/site-packages/haiku/_src/layer_stack.py", line 403, in wrapped
    ret = mod(x=args, **kwargs)[0]
          ^^^^^^^^^^^^^^^^^^^^^
  File "/alphafold3_venv/lib/python3.11/site-packages/haiku/_src/layer_stack.py", line 211, in __call__
    carry, (zs, states) = jax.lax.scan(
                          ^^^^^^^^^^^^^
  File "/alphafold3_venv/lib/python3.11/site-packages/haiku/_src/layer_stack.py", line 195, in layer
    (out_x, z), state = apply_fn(
                        ^^^^^^^^^
  File "/alphafold3_venv/lib/python3.11/site-packages/haiku/_src/transform.py", line 456, in apply_fn
    out = f(*args, **kwargs)
  ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/alphafold3_venv/lib/python3.11/site-packages/haiku/_src/layer_stack.py", line 249, in _call_wrapped
    ret = self._f(*x, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^
  File "/alphafold3_venv/lib/python3.11/site-packages/alphafold3/model/diffusion/template_modules.py", line 341, in template_iteration_fn
    return modules.PairFormerIteration(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/alphafold3_venv/lib/python3.11/site-packages/haiku/_src/module.py", line 464, in wrapped
    out = f(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/alphafold3_venv/lib/python3.11/site-packages/haiku/_src/module.py", line 305, in run_interceptors
    return bound_method(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/alphafold3_venv/lib/python3.11/site-packages/alphafold3/model/diffusion/modules.py", line 481, in __call__
    act += GridSelfAttention(
           ^^^^^^^^^^^^^^^^^^
  File "/alphafold3_venv/lib/python3.11/site-packages/haiku/_src/module.py", line 464, in wrapped
    out = f(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/alphafold3_venv/lib/python3.11/site-packages/haiku/_src/module.py", line 305, in run_interceptors
    return bound_method(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/alphafold3_venv/lib/python3.11/site-packages/alphafold3/model/diffusion/modules.py", line 228, in __call__
    act = mapping.inference_subbatch(
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/alphafold3_venv/lib/python3.11/site-packages/alphafold3/model/components/mapping.py", line 295, in inference_subbatch
    output = sharded_module(*batched_args)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/alphafold3_venv/lib/python3.11/site-packages/alphafold3/model/components/mapping.py", line 148, in mapped_fn
    remainder_shape_dtype = hk.eval_shape(
                            ^^^^^^^^^^^^^^
  File "/alphafold3_venv/lib/python3.11/site-packages/haiku/_src/stateful.py", line 933, in eval_shape
    out_shape = jax.eval_shape(stateless_fun, internal_state(), *args, **kwargs)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/alphafold3_venv/lib/python3.11/site-packages/haiku/_src/stateful.py", line 929, in stateless_fun
    out = fun(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^
  File "/alphafold3_venv/lib/python3.11/site-packages/alphafold3/model/components/mapping.py", line 146, in apply_fun_to_slice
    return fun(*input_slice, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/alphafold3_venv/lib/python3.11/site-packages/alphafold3/model/components/mapping.py", line 279, in run_module
    res = module(*args)
          ^^^^^^^^^^^^^
  File "/alphafold3_venv/lib/python3.11/site-packages/alphafold3/model/diffusion/modules.py", line 170, in _attention
    weighted_avg = attention.dot_product_attention(
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/alphafold3_venv/lib/python3.11/site-packages/alphafold3/jax/attention/attention.py", line 127, in dot_product_attention
    raise ValueError(
ValueError: implementation='triton' is unsupported on this GPU generation.
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
@Augustin-Zidek Augustin-Zidek added the question Further information is requested label Nov 21, 2024
@Augustin-Zidek
Copy link
Collaborator

Augustin-Zidek commented Nov 21, 2024

This error means that the GPU you are running on is too old and it doesn't support the Triton implementation of flash attention. What GPU are you running on?

You can try running with --flash_attention_implementation=xla, but we are seeing bad numerics on GPUs with GPU capability 7.x, see #59.

The safest option would be to upgrade the GPU.

@jxwangxl
Copy link
Author

This error means that the GPU you are running on is too old and it doesn't support the Triton implementation of flash attention. What GPU are you running on?

You can try running with --flash_attention_implementation=xla, but we are seeing bad numerics on GPUs with GPU capability 7.x, see #59.

The safest option would be to upgrade the GPU.

Thanks, I have tried "--flash_attention_implementation=xla" and it works.

@Augustin-Zidek
Copy link
Collaborator

Great to hear, glad that this fixed the issue!

@FilipeMaia
Copy link

FilipeMaia commented Jan 16, 2025

As run_alphafold.py already does a compatibility check due to #59 to guard against missing --xla_disable_hlo_passes=custom-kernel-fusion-rewriter on 7.x cards, wouldn't it make sense to also check for --flash_attention_implementation=xla and fail early otherwise? Or potentially even set it and warn the user?

@Augustin-Zidek
Copy link
Collaborator

Good point, @FilipeMaia. Done in f407412.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants