Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] cant run the quickstart #206

Closed
sunweice opened this issue Aug 23, 2024 · 2 comments
Closed

[BUG] cant run the quickstart #206

sunweice opened this issue Aug 23, 2024 · 2 comments
Assignees
Labels
bug Something isn't working

Comments

@sunweice
Copy link

Describe the bug

I'm using RL4CO, and when I followed the method on this website (https://rl4co.readthedocs.io/en/latest/examples/1-quickstart/)for testing, I encountered an error. What should I do?

To Reproduce

IN step “trainer.fit(model)”

the error is :

val_file not set. Generating dataset instead
test_file not set. Generating dataset instead

ValueError Traceback (most recent call last)
Cell In[5], line 1
----> 1 trainer.fit(model)

File ~\anaconda3\envs\RL4CO\lib\site-packages\rl4co\utils\trainer.py:146, in RL4COTrainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
141 log.warning(
142 "Overriding gradient_clip_val to None for 'automatic_optimization=False' models"
143 )
144 self.gradient_clip_val = None
--> 146 super().fit(
147 model=model,
148 train_dataloaders=train_dataloaders,
149 val_dataloaders=val_dataloaders,
150 datamodule=datamodule,
151 ckpt_path=ckpt_path,
152 )

File ~\anaconda3\envs\RL4CO\lib\site-packages\lightning\pytorch\trainer\trainer.py:538, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
536 self.state.status = TrainerStatus.RUNNING
537 self.training = True
--> 538 call._call_and_handle_interrupt(
539 self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
540 )

File ~\anaconda3\envs\RL4CO\lib\site-packages\lightning\pytorch\trainer\call.py:47, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
45 if trainer.strategy.launcher is not None:
46 return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 47 return trainer_fn(*args, **kwargs)
49 except _TunerExitException:
50 _call_teardown_hook(trainer)

File ~\anaconda3\envs\RL4CO\lib\site-packages\lightning\pytorch\trainer\trainer.py:574, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
567 assert self.state.fn is not None
568 ckpt_path = self._checkpoint_connector._select_ckpt_path(
569 self.state.fn,
570 ckpt_path,
571 model_provided=True,
572 model_connected=self.lightning_module is not None,
573 )
--> 574 self._run(model, ckpt_path=ckpt_path)
576 assert self.state.stopped
577 self.training = False

File ~\anaconda3\envs\RL4CO\lib\site-packages\lightning\pytorch\trainer\trainer.py:943, in Trainer._run(self, model, ckpt_path)
940 log.debug(f"{self.class.name}: preparing data")
941 self._data_connector.prepare_data()
--> 943 call._call_setup_hook(self) # allow user to set up LightningModule in accelerator environment
944 log.debug(f"{self.class.name}: configuring model")
945 call._call_configure_model(self)

File ~\anaconda3\envs\RL4CO\lib\site-packages\lightning\pytorch\trainer\call.py:104, in _call_setup_hook(trainer)
102 _call_lightning_datamodule_hook(trainer, "setup", stage=fn)
103 _call_callback_hooks(trainer, "setup", stage=fn)
--> 104 _call_lightning_module_hook(trainer, "setup", stage=fn)
106 trainer.strategy.barrier("post_setup")

File ~\anaconda3\envs\RL4CO\lib\site-packages\lightning\pytorch\trainer\call.py:167, in _call_lightning_module_hook(trainer, hook_name, pl_module, *args, **kwargs)
164 pl_module._current_fx_name = hook_name
166 with trainer.profiler.profile(f"[LightningModule]{pl_module.class.name}.{hook_name}"):
--> 167 output = fn(*args, **kwargs)
169 # restore current_fx when nested context
170 pl_module._current_fx_name = prev_fx_name

File ~\anaconda3\envs\RL4CO\lib\site-packages\rl4co\models\rl\common\base.py:155, in RL4COLitModule.setup(self, stage)
153 self.dataloader_names = None
154 self.setup_loggers()
--> 155 self.post_setup_hook()

File ~\anaconda3\envs\RL4CO\lib\site-packages\rl4co\models\rl\reinforce\reinforce.py:110, in REINFORCE.post_setup_hook(self, stage)
108 def post_setup_hook(self, stage="fit"):
109 # Make baseline taking model itself and train_dataloader from model as input
--> 110 self.baseline.setup(
111 self.policy,
112 self.env,
113 batch_size=self.val_batch_size,
114 device=get_lightning_device(self),
115 dataset_size=self.data_cfg["val_data_size"],
116 )

File ~\anaconda3\envs\RL4CO\lib\site-packages\rl4co\models\rl\reinforce\baselines.py:117, in WarmupBaseline.setup(self, *args, **kw)
116 def setup(self, *args, **kw):
--> 117 self.baseline.setup(*args, **kw)

File ~\anaconda3\envs\RL4CO\lib\site-packages\rl4co\models\rl\reinforce\baselines.py:174, in RolloutBaseline.setup(self, *args, **kw)
173 def setup(self, *args, **kw):
--> 174 self._update_policy(*args, **kw)

File ~\anaconda3\envs\RL4CO\lib\site-packages\rl4co\models\rl\reinforce\baselines.py:187, in RolloutBaseline._update_policy(self, policy, env, batch_size, device, dataset_size, dataset)
183 self.dataset = env.dataset(batch_size=[dataset_size])
185 log.info("Evaluating baseline policy on evaluation dataset")
186 self.bl_vals = (
--> 187 self.rollout(self.policy, env, batch_size, device, self.dataset).cpu().numpy()
188 )
189 self.mean = self.bl_vals.mean()

File ~\anaconda3\envs\RL4CO\lib\site-packages\rl4co\models\rl\reinforce\baselines.py:242, in RolloutBaseline.rollout(self, policy, env, batch_size, device, dataset)
238 return policy(batch, env, decode_type="greedy")["reward"]
240 dl = DataLoader(dataset, batch_size=batch_size, collate_fn=dataset.collate_fn)
--> 242 rewards = torch.cat([eval_policy(batch) for batch in dl], 0)
243 return rewards

File ~\anaconda3\envs\RL4CO\lib\site-packages\rl4co\models\rl\reinforce\baselines.py:242, in (.0)
238 return policy(batch, env, decode_type="greedy")["reward"]
240 dl = DataLoader(dataset, batch_size=batch_size, collate_fn=dataset.collate_fn)
--> 242 rewards = torch.cat([eval_policy(batch) for batch in dl], 0)
243 return rewards

File ~\anaconda3\envs\RL4CO\lib\site-packages\torch\utils\data\dataloader.py:630, in _BaseDataLoaderIter.next(self)
627 if self._sampler_iter is None:
628 # TODO(pytorch/pytorch#76750)
629 self._reset() # type: ignore[call-arg]
--> 630 data = self._next_data()
631 self._num_yielded += 1
632 if self._dataset_kind == _DatasetKind.Iterable and
633 self._IterableDataset_len_called is not None and
634 self._num_yielded > self._IterableDataset_len_called:

File ~\anaconda3\envs\RL4CO\lib\site-packages\torch\utils\data\dataloader.py:673, in _SingleProcessDataLoaderIter._next_data(self)
671 def _next_data(self):
672 index = self._next_index() # may raise StopIteration
--> 673 data = self._dataset_fetcher.fetch(index) # may raise StopIteration
674 if self._pin_memory:
675 data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)

File ~\anaconda3\envs\RL4CO\lib\site-packages\torch\utils\data_utils\fetch.py:55, in _MapDatasetFetcher.fetch(self, possibly_batched_index)
53 else:
54 data = self.dataset[possibly_batched_index]
---> 55 return self.collate_fn(data)

File ~\anaconda3\envs\RL4CO\lib\site-packages\rl4co\data\dataset.py:37, in TensorDictDataset.collate_fn(batch)
34 @staticmethod
35 def collate_fn(batch: Union[dict, TensorDict]):
36 """Collate function compatible with TensorDicts that reassembles a list of dicts."""
---> 37 return TensorDict(
38 {key: torch.stack([b[key] for b in batch]) for key in batch[0].keys()},
39 batch_size=torch.Size([len(batch)]),
40 _run_checks=False,
41 )

File ~\anaconda3\envs\RL4CO\lib\site-packages\tensordict_td.py:240, in TensorDict.init(self, source, batch_size, device, names, non_blocking, lock, **kwargs)
229 def init(
230 self,
231 source: T | dict[str, CompatibleType] = None,
(...)
237 **kwargs,
238 ) -> None:
239 if (source is not None) and kwargs:
--> 240 raise ValueError(
241 "Either a dictionary or a sequence of kwargs must be provided, not both."
242 )
243 source = source if not kwargs else kwargs
244 if names and is_dynamo_compiling():

ValueError: Either a dictionary or a sequence of kwargs must be provided, not both.

@sunweice sunweice added the bug Something isn't working label Aug 23, 2024
@sunweice
Copy link
Author

@staticmethod
def collate_fn(batch: Union[dict, TensorDict]):
    """Collate function compatible with TensorDicts that reassembles a list of dicts."""
    return TensorDict(
        {key: torch.stack([b[key] for b in batch]) for key in batch[0].keys()},
        batch_size=torch.Size([len(batch)]),
        _run_checks=False,
    )

It seems I've found the reason. In (pytorch/tensordict#175 (comment)), it's mentioned that TensorDict does not support the _run_checks parameter. and after I delete the _run_checks=False,the code can run

@fedebotu
Copy link
Member

@sunweice yep, that's correct! We recently updated the development version of RL4CO to make it compatible with the latest TensorDict (only that change was needed).

Note we should officially release the latest RL4CO version in the coming month, with lots of updates!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants