Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Missed optimization: a[mask] = b -> a = torch.where(mask, b, a) #4248

Open
pscollins opened this issue Nov 24, 2022 · 5 comments
Open

Missed optimization: a[mask] = b -> a = torch.where(mask, b, a) #4248

pscollins opened this issue Nov 24, 2022 · 5 comments
Assignees
Labels
dynamism Dynamic Shape Features nostale Do not consider for staleness

Comments

@pscollins
Copy link

🐛 Bug

I'm trying to get an existing model running under pytorch/XLA that uses the construct a[mask] = b frequently, which seems to be a bottleneck. I'm guessing that this is (as far as I can imagine, unnecessarily) becomes something like "create a slice with a dynamic shape, then write to it" and triggers recompilation. I'm running against my CPU currently.

To Reproduce

I tried to demonstrate the issue with a microbenchmark:

import timeit

import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met


DEST_SHAPE = (100000,)
THRESHOLD = .5

ITERS = 1000

def test_store_getitem(device):
    dest = torch.rand(DEST_SHAPE, device=device)
    mask = torch.rand(DEST_SHAPE, device=device) > THRESHOLD

    dest[mask] = 1.0
    if device != 'cpu':
        xm.mark_step()

def test_store_where(device):
    dest = torch.rand(DEST_SHAPE, device=device)
    mask = torch.rand(DEST_SHAPE, device=device) > THRESHOLD

    dest = torch.where(mask, 1.0, dest)
    if device != 'cpu':
        xm.mark_step()


def main():
    xla_device = xm.xla_device()
    cpu_device = 'cpu'

    scope = dict(globals(), **locals())
    print('CPU: getitem')
    print(timeit.timeit('test_store_getitem(cpu_device)', number=ITERS, globals=scope))
    print('CPU: where')
    print(timeit.timeit('test_store_where(cpu_device)', number=ITERS, globals=scope))

    print('XLA/CPU: getitem')
    print(timeit.timeit('test_store_getitem(xla_device)', number=ITERS, globals=scope))
    print(met.metrics_report())
    print('XLA/CPU: where')
    print(timeit.timeit('test_store_where(xla_device)', number=ITERS, globals=scope))
    print(met.metrics_report())


if __name__ == '__main__':
    main()

Output (times only):

CPU: getitem
1.335198831744492
CPU: where
1.4740506513044238
XLA/CPU: getitem                                                                                                                                                                                                   
84.45449290703982 
XLA/CPU: where                                                                                                                                                                                                     
1.3696353798732162 

Grabbing some snippets from the two metrics output, the first call reports:

Counter: CachedCompile
  Value: 505                                                                                                                                                                                                       
Counter: CreateXlaTensor                                                                                                                                                                                           
  Value: 11000                                                                                                                                                                                                     
Counter: DestroyXlaTensor
  Value: 11000               
Counter: DeviceDataCacheMiss
  Value: 1         
Counter: OpByOpCompileCacheMiss                                                                          
  Value: 12                                                                                                                                                                                                        
Counter: UncachedCompile                                                                                                                                                                                           
  Value: 495                                                                                                                                                                                                       
Metric: CompileTime                                                                                                                                                                                                
  TotalSamples: 496                                                                                                                                                                                                
  Accumulator: 01m11s847ms893.248us                                                                                                                                                                                
  ValueRate: 840ms091.405us / second
  Rate: 5.88149 / second                                                                                 
  Percentiles: 1%=136ms016.889us; 5%=138ms843.705us; 10%=139ms704.928us; 20%=140ms990.267us; 50%=142ms292.030us; 80%=145ms245.633us; 90%=149ms023.811us; 95%=151ms380.062us; 99%=156ms496.857us

and the seconds reports:

Counter: CachedCompile
  Value: 1504                                                                                                                                                                                                      
Counter: CreateXlaTensor                                                                                                                                                                                           
  Value: 19000                                                                                                                                                                                                     
Counter: DestroyXlaTensor
  Value: 19000               
Counter: DeviceDataCacheMiss
  Value: 1         
Counter: OpByOpCompileCacheMiss                                                                          
  Value: 12                                                                                                                                                                                                        
Counter: UncachedCompile                                                                                                                                                                                           
  Value: 496                                                                                                                                                                                                       
Metric: CompileTime                                                                                                                                                                                                
  TotalSamples: 497                                                                                                                                                                                                
  Accumulator: 01m11s980ms339.237us                                                                                                                                                                                
  ValueRate: 840ms942.688us / second
  Rate: 5.88123 / second                                                                                 
  Percentiles: 1%=136ms856.889us; 5%=138ms782.989us; 10%=139ms595.450us; 20%=140ms977.095us; 50%=142ms271.090us; 80%=145ms245.633us; 90%=149ms023.811us; 95%=151ms380.062us; 99%=156ms496.857us

Expected behavior

Ideally I think that the getitem and where variants should have roughly equal performance.

Environment

  • Reproducible on XLA backend [CPU/TPU]: CPU
  • torch_xla version: torch_xla/version.py shows:
# Autogenerated file, do not edit!
__version__ = '1.14'
__xla_gitrev__ = 'f790bc8ac411a8e6903b89adf7610b812996537b'
__torch_gitrev__ = '7b0d577c226fae78f377b26feab4122c4203ad59'
@JackCaoG JackCaoG self-assigned this Nov 29, 2022
@JackCaoG
Copy link
Collaborator

I think there are a couple issue in the getitem approach

  1. dest[mask] = 1.0 actually going to trigger an early execution of the mask before the mark_step.
  2. dest[mask] = 1.0 was lowered into a aten::index_put where it takes the indexes that it wants to update. The issue here is that indexes size is data dependent
> /test/test_where.py(19)test_store_getitem()
-> dest[mask] = 1.0
(Pdb) print(torch_xla._XLAC._get_xla_tensors_text([dest]))
IR {
  %0 = s64[] xla::device_data(), location=test_store_getitem@test_where.py:15, device=CPU:0
  %1 = s64[] prim::Constant(), location=test_store_getitem@test_where.py:15, value=214013
  %2 = s64[] aten::mul(%1, %0), location=test_store_getitem@test_where.py:15
  %3 = s64[] prim::Constant(), location=test_store_getitem@test_where.py:15, value=2531011
  %4 = s64[] aten::add(%3, %2), location=test_store_getitem@test_where.py:15
  %5 = f32[] prim::Constant(), location=test_store_getitem@test_where.py:15, value=1
  %6 = f32[] prim::Constant(), location=test_store_getitem@test_where.py:15, value=0
  %7 = f32[100000]{0} aten::uniform(%6, %5, %4), location=test_store_getitem@test_where.py:15, ROOT=0
}

(Pdb) n
> /test/test_where.py(20)test_store_getitem()
-> if device != 'cpu':
(Pdb) print(torch_xla._XLAC._get_xla_tensors_text([dest]))
IR {
  %0 = f32[] prim::Constant(), location=test_store_getitem@test_where.py:19, value=1
  %1 = s64[49896,1]{1,0} xla::device_data(), location=test_store_getitem@test_where.py:19, device=CPU:0
  %2 = s64[49896,1]{1,0} aten::view(%1), location=test_store_getitem@test_where.py:19, output_size=(49896, 1)
  %3 = s64[49896]{0} aten::view(%2), location=test_store_getitem@test_where.py:19, output_size=(49896)
  %4 = s64[] xla::device_data(), location=test_store_getitem@test_where.py:19, device=CPU:0
  %5 = s64[49896]{0} aten::expand(%4), location=test_store_getitem@test_where.py:19, size=(49896)
  %6 = s64[49896]{0} aten::add(%3, %5), location=test_store_getitem@test_where.py:19
  %7 = s64[] prim::Constant(), location=test_store_getitem@test_where.py:19, value=0
  %8 = pred[49896]{0} aten::lt(%3, %7), location=test_store_getitem@test_where.py:19
  %9 = s64[49896]{0} aten::where(%8, %6, %3), location=test_store_getitem@test_where.py:19
  %10 = s64[49896,1]{1,0} aten::stack(%9), location=test_store_getitem@test_where.py:19, dim=1
  %11 = s64[] xla::device_data(), location=test_store_getitem@test_where.py:15, device=CPU:0
  %12 = s64[] prim::Constant(), location=test_store_getitem@test_where.py:15, value=214013
  %13 = s64[] aten::mul(%12, %11), location=test_store_getitem@test_where.py:15
  %14 = s64[] prim::Constant(), location=test_store_getitem@test_where.py:15, value=2531011
  %15 = s64[] aten::add(%14, %13), location=test_store_getitem@test_where.py:15
  %16 = f32[] prim::Constant(), location=test_store_getitem@test_where.py:15, value=1
  %17 = f32[] prim::Constant(), location=test_store_getitem@test_where.py:15, value=0
  %18 = f32[100000]{0} aten::uniform(%17, %16, %15), location=test_store_getitem@test_where.py:15
  %19 = f32[100000]{0} aten::index_put(%18, %10, %0), location=test_store_getitem@test_where.py:19, start_dim=0, accumulate=0
  %20 = f32[100000]{0} aten::permute(%19), location=test_store_getitem@test_where.py:19, dims=(0), ROOT=0
}

(Pdb) mask.size()
torch.Size([100000])

in the example above, %1 = s64[49896,1] can be any other size which caused constant recompilation. I don't think there is much pytorch/xla team can do here. dest[mask] = 1.0 will trigger the execution and converted the tensor to an index of true, this logic lives in upstream pytorch. We can maybe open a FR for pytorch to change their way of dispatching the [].operator.

@pscollins
Copy link
Author

Thanks for the reply. I guess pytorch doesn't provide any way for you to intercept calls at the Python torch.Tensor level?

Is this the sort of thing that should be fixed by the planned dynamic shapes support in #3884 ?

@JackCaoG
Copy link
Collaborator

JackCaoG commented Dec 3, 2022

I don't think dynamic shape will help out of the box here since there is a materialization of the index tensor to query the true index. If we figure out which op does the

op(mask[100000]) -> index[49896]

I think we can make that op supports dynamic shape then we should be good. @ezyang in case you know which op/code does that.

@ezyang
Copy link
Collaborator

ezyang commented Dec 5, 2022

In fact, dynamo would help you here, as it wouldn't immediately force dynamic shape calculation when you do the mask, and then later you can figure out the mask is only used in a setter context and optimize it into a where calculation. However, I don't think we do this optimization yet.

@JackCaoG JackCaoG added the nostale Do not consider for staleness label Dec 5, 2022
@JackCaoG
Copy link
Collaborator

JackCaoG commented Dec 5, 2022

sg, I guess we can revisit this issue later.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
dynamism Dynamic Shape Features nostale Do not consider for staleness
Projects
None yet
Development

No branches or pull requests

3 participants