Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error Running Basic Test Script with v2 Tag (Works With v1) #35

Closed
mdr223 opened this issue May 25, 2022 · 2 comments
Closed

Error Running Basic Test Script with v2 Tag (Works With v1) #35

mdr223 opened this issue May 25, 2022 · 2 comments

Comments

@mdr223
Copy link

mdr223 commented May 25, 2022

Hi there,

I recently tried upgrading my S4 setup / environment to be on the v2 tag but ran into the following issue when running the basic test script:

(base) ray@test-python:~/state-spaces$ python -m train wandb=null pipeline=mnist model=s4
CONFIG
├── train
│   └── seed: 0                                                                                                                                                                                        
│       interval: epoch                                                                                                                                                                                
│       monitor: val/accuracy                                                                                                                                                                          
│       mode: max                                                                                                                                                                                      
│       ema: 0.0                                                                                                                                                                                       
│       test: false                                                                                                                                                                                    
│       debug: false                                                                                                                                                                                   
│       ignore_warnings: false                                                                                                                                                                         
│       state:                                                                                                                                                                                         
│         mode: null                                                                                                                                                                                   
│         chunk_len: null                                                                                                                                                                              
│         overlap_len: null                                                                                                                                                                            
│         n_context: 0                                                                                                                                                                                 
│         n_context_eval: 0                                                                                                                                                                            
│       sweep: null                                                                                                                                                                                    
│       group: null                                                                                                                                                                                    
│       benchmark_step: false                                                                                                                                                                          
│       benchmark_step_k: 1                                                                                                                                                                            
│       benchmark_step_T: 1                                                                                                                                                                            
│       checkpoint_path: null                                                                                                                                                                          
│       visualizer: filters                                                                                                                                                                            
│       disable_dataset: false                                                                                                                                                                         
│                                                                                                                                                                                                      
├── wandb
│   └── None                                                                                                                                                                                           
├── trainer
│   └── gpus: 1                                                                                                                                                                                        
│       accumulate_grad_batches: 1                                                                                                                                                                     
│       max_epochs: 200                                                                                                                                                                                
│       gradient_clip_val: 0.0                                                                                                                                                                         
│       log_every_n_steps: 10                                                                                                                                                                          
│       limit_train_batches: 1.0                                                                                                                                                                       
│       limit_val_batches: 1.0                                                                                                                                                                         
│       weights_summary: top                                                                                                                                                                           
│       progress_bar_refresh_rate: 1                                                                                                                                                                   
│       track_grad_norm: -1                                                                                                                                                                            
│       resume_from_checkpoint: null                                                                                                                                                                   
│                                                                                                                                                                                                      
├── loader
│   └── batch_size: 50                                                                                                                                                                                 
│       num_workers: 4                                                                                                                                                                                 
│       pin_memory: true                                                                                                                                                                               
│       drop_last: true                                                                                                                                                                                
│       train_resolution: 1                                                                                                                                                                            
│       eval_resolutions:                                                                                                                                                                              
│       - 1                                                                                                                                                                                            
│                                                                                                                                                                                                      
├── dataset
│   └── _name_: mnist                                                                                                                                                                                  
│       permute: true                                                                                                                                                                                  
│       val_split: 0.1                                                                                                                                                                                 
│       seed: 42                                                                                                                                                                                       
│                                                                                                                                                                                                      
├── task
│   └── _name_: base                                                                                                                                                                                   
│       loss: cross_entropy                                                                                                                                                                            
│       metrics:                                                                                                                                                                                       
│       - accuracy                                                                                                                                                                                     
│       torchmetrics: null                                                                                                                                                                             
│                                                                                                                                                                                                      
├── optimizer
│   └── _name_: adamw                                                                                                                                                                                  
│       lr: 0.001                                                                                                                                                                                      
│       weight_decay: 0.0                                                                                                                                                                              
│                                                                                                                                                                                                      
├── scheduler
│   └── _name_: plateau                                                                                                                                                                                
│       mode: max                                                                                                                                                                                      
│       factor: 0.2                                                                                                                                                                                    
│       patience: 20                                                                                                                                                                                   
│       min_lr: 0.0                                                                                                                                                                                    
│                                                                                                                                                                                                      
├── encoder
│   └── linear                                                                                                                                                                                         
├── decoder
│   └── _name_: sequence                                                                                                                                                                               
│       mode: pool                                                                                                                                                                                     
│                                                                                                                                                                                                      
├── model
│   └── layer:                                                                                                                                                                                         
│         _name_: s4                                                                                                                                                                                   
│         d_state: 64                                                                                                                                                                                  
│         channels: 1                                                                                                                                                                                  
│         bidirectional: false                                                                                                                                                                         
│         activation: gelu                                                                                                                                                                             
│         postact: null                                                                                                                                                                                
│         hyper_act: null                                                                                                                                                                              
│         dropout: 0.0                                                                                                                                                                                 
│         measure: legs                                                                                                                                                                                
│         rank: 1                                                                                                                                                                                      
│         dt_min: 0.001                                                                                                                                                                                
│         dt_max: 0.1                                                                                                                                                                                  
│         trainable:                                                                                                                                                                                   
│           dt: true                                                                                                                                                                                   
│           A: true                                                                                                                                                                                    
│           P: true                                                                                                                                                                                    
│           B: true                                                                                                                                                                                    
│         lr: 0.001                                                                                                                                                                                    
│         length_correction: true                                                                                                                                                                      
│         tie_state: true                                                                                                                                                                              
│         hurwitz: true                                                                                                                                                                                
│         resample: false                                                                                                                                                                              
│         deterministic: false                                                                                                                                                                         
│         l_max: 784                                                                                                                                                                                   
│         verbose: false                                                                                                                                                                               
│       _name_: model                                                                                                                                                                                  
│       prenorm: false                                                                                                                                                                                 
│       transposed: true                                                                                                                                                                               
│       n_layers: 4                                                                                                                                                                                    
│       d_model: 256                                                                                                                                                                                   
│       residual: R                                                                                                                                                                                    
│       pool:                                                                                                                                                                                          
│         _name_: sample                                                                                                                                                                               
│         pool: 1                                                                                                                                                                                      
│         expand: 1                                                                                                                                                                                    
│       norm: layer                                                                                                                                                                                    
│       dropout: 0.0                                                                                                                                                                                   
│                                                                                                                                                                                                      
└── callbacks
    └── learning_rate_monitor:                                                                                                                                                                         
          logging_interval: epoch                                                                                                                                                                      
        timer:                                                                                                                                                                                         
          step: true                                                                                                                                                                                   
          inter_step: false                                                                                                                                                                            
          epoch: true                                                                                                                                                                                  
          val: true                                                                                                                                                                                    
        params:                                                                                                                                                                                        
          total: true                                                                                                                                                                                  
          trainable: true                                                                                                                                                                              
          fixed: true                                                                                                                                                                                  
        model_checkpoint:                                                                                                                                                                              
          monitor: val/accuracy                                                                                                                                                                        
          mode: max                                                                                                                                                                                    
          save_top_k: 1                                                                                                                                                                                
          save_last: true                                                                                                                                                                              
          dirpath: checkpoints/                                                                                                                                                                        
          filename: val/accuracy                                                                                                                                                                       
          auto_insert_metric_name: false                                                                                                                                                               
          verbose: true                                                                                                                                                                                
                                                                                                                                                                                                       
Global seed set to 0
[2022-05-25 13:40:50,814][__main__][INFO] - Instantiating callback <src.callbacks.timer.Timer>
[2022-05-25 13:40:50,815][__main__][INFO] - Instantiating callback <src.callbacks.params.ParamsLog>
[2022-05-25 13:40:50,816][__main__][INFO] - Instantiating callback <pytorch_lightning.callbacks.ModelCheckpoint>
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1.0)` was configured so 100% of the batches per epoch will be used..
`Trainer(limit_val_batches=1.0)` was configured so 100% of the batches will be used..
[2022-05-25 13:40:50,848][torch.distributed.nn.jit.instantiator][INFO] - Created a temporary directory at /tmp/tmpm51hqe7x
[2022-05-25 13:40:50,849][torch.distributed.nn.jit.instantiator][INFO] - Writing /tmp/tmpm51hqe7x/_remote_module_non_sriptable.py
Error executing job with overrides: ['wandb=null', 'pipeline=mnist', 'model=s4']
Traceback (most recent call last):
  File "/home/ray/state-spaces/train.py", line 553, in main
    train(config)
  File "/home/ray/state-spaces/train.py", line 498, in train
    trainer.fit(model)
  File "/home/ray/anaconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 768, in fit
    self._call_and_handle_interrupt(
  File "/home/ray/anaconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 721, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/ray/anaconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 809, in _fit_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
  File "/home/ray/anaconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1172, in _run
    self._call_setup_hook()  # allow user to setup lightning_module in accelerator environment
  File "/home/ray/anaconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1492, in _call_setup_hook
    self._call_lightning_module_hook("setup", stage=fn)
  File "/home/ray/anaconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1593, in _call_lightning_module_hook
    output = fn(*args, **kwargs)
  File "/home/ray/state-spaces/train.py", line 74, in setup
    self.model = utils.instantiate(registry.model, self.hparams.model)
  File "/home/ray/state-spaces/src/utils/config.py", line 99, in instantiate
    return obj()
  File "/home/ray/state-spaces/src/models/sequence/model.py", line 69, in __init__
    block = SequenceResidualBlock(d, l+1, prenorm=prenorm, dropout=dropout, layer=layer, residual=residual, norm=norm, pool=pool)
  File "/home/ray/state-spaces/src/models/sequence/block.py", line 36, in __init__
    self.layer = utils.instantiate(registry.layer, layer, d_input)
  File "/home/ray/state-spaces/src/utils/config.py", line 99, in instantiate
    return obj()
  File "/home/ray/state-spaces/src/models/sequence/ss/s4.py", line 86, in __init__
    self.kernel = HippoSSKernel(self.h, N=self.n, L=l_max, channels=channels, verbose=verbose, **kernel_args)
  File "/home/ray/state-spaces/src/models/sequence/ss/kernel.py", line 712, in __init__
    self.kernel = SSKernelNPLR(
  File "/home/ray/state-spaces/src/models/sequence/ss/kernel.py", line 217, in __init__
    self.C = nn.Parameter(_c2r(_resolve_conj(C)))
RuntimeError: view_as_real doesn't work on unresolved conjugated tensors.  To resolve the conjugate tensor so you can view it as real, use self.resolve_conj(); however, be warned that the resulting tensor will NOT alias the original.

Is this something you've seen before? I'd be happy to provide a fuller description of my package version, system architecture, etc. if you can let me know what might help get to the bottom of this bug.

Best,
Matthew

@albertfgu
Copy link
Contributor

Sorry for responding so late! What version of PyTorch are you on? This problem arises due to a change that happened in PT 1.10, and should be handled by this line: https://github.com/HazyResearch/state-spaces/blob/6cbc09aeeebfe72b7bde7897ef157cf63fd12721/src/models/sequence/ss/kernel.py#L67

I'm not sure why you're having issues. I can try to reproduce it if you tell me which version of PyTorch you're using.

@albertfgu
Copy link
Contributor

Closing this issue. Please file a new issue if problems persist with the latest version (V3).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants