-
Notifications
You must be signed in to change notification settings - Fork 21
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2 from allenai/2021-challenge
2021 AI2-THOR Rearrangement Challenge updates.
- Loading branch information
Showing
45 changed files
with
6,753 additions
and
1,745,529 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
from abc import ABC | ||
from typing import Optional, Dict, Sequence | ||
|
||
from allenact.base_abstractions.sensor import SensorSuite, Sensor, DepthSensor | ||
|
||
from baseline_configs.rearrange_base import RearrangeBaseExperimentConfig | ||
from rearrange.sensors import ( | ||
RGBRearrangeSensor, | ||
UnshuffledRGBRearrangeSensor, | ||
) | ||
from rearrange.tasks import RearrangeTaskSampler | ||
|
||
|
||
class OnePhaseRGBBaseExperimentConfig(RearrangeBaseExperimentConfig, ABC): | ||
SENSORS = [ | ||
RGBRearrangeSensor( | ||
height=RearrangeBaseExperimentConfig.SCREEN_SIZE, | ||
width=RearrangeBaseExperimentConfig.SCREEN_SIZE, | ||
use_resnet_normalization=True, | ||
uuid=RearrangeBaseExperimentConfig.EGOCENTRIC_RGB_UUID, | ||
), | ||
UnshuffledRGBRearrangeSensor( | ||
height=RearrangeBaseExperimentConfig.SCREEN_SIZE, | ||
width=RearrangeBaseExperimentConfig.SCREEN_SIZE, | ||
use_resnet_normalization=True, | ||
uuid=RearrangeBaseExperimentConfig.UNSHUFFLED_RGB_UUID, | ||
), | ||
] | ||
|
||
@classmethod | ||
def make_sampler_fn( | ||
cls, | ||
stage: str, | ||
force_cache_reset: bool, | ||
allowed_scenes: Optional[Sequence[str]], | ||
seed: int, | ||
scene_to_allowed_rearrange_inds: Optional[Dict[str, Sequence[int]]] = None, | ||
x_display: Optional[str] = None, | ||
sensors: Optional[Sequence[Sensor]] = None, | ||
**kwargs, | ||
) -> RearrangeTaskSampler: | ||
"""Return a RearrangeTaskSampler.""" | ||
if "mp_ctx" in kwargs: | ||
del kwargs["mp_ctx"] | ||
return RearrangeTaskSampler.from_fixed_dataset( | ||
run_walkthrough_phase=False, | ||
run_unshuffle_phase=True, | ||
stage=stage, | ||
allowed_scenes=allowed_scenes, | ||
scene_to_allowed_rearrange_inds=scene_to_allowed_rearrange_inds, | ||
rearrange_env_kwargs=dict( | ||
force_cache_reset=force_cache_reset, | ||
**cls.REARRANGE_ENV_KWARGS, | ||
controller_kwargs={ | ||
"x_display": x_display, | ||
**cls.THOR_CONTROLLER_KWARGS, | ||
"renderDepthImage": any( | ||
isinstance(s, DepthSensor) for s in cls.SENSORS | ||
), | ||
}, | ||
), | ||
seed=seed, | ||
sensors=SensorSuite(cls.SENSORS) | ||
if sensors is None | ||
else SensorSuite(sensors), | ||
max_steps=cls.MAX_STEPS, | ||
discrete_actions=cls.actions(), | ||
require_done_action=cls.REQUIRE_DONE_ACTION, | ||
force_axis_aligned_start=cls.FORCE_AXIS_ALIGNED_START, | ||
**kwargs, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from baseline_configs.one_phase.one_phase_rgb_il_base import ( | ||
OnePhaseRGBILBaseExperimentConfig, | ||
) | ||
|
||
|
||
class OnePhaseRGBDaggerExperimentConfig(OnePhaseRGBILBaseExperimentConfig): | ||
USE_RESNET_CNN = False | ||
IL_PIPELINE_TYPE = "40proc" | ||
|
||
@classmethod | ||
def tag(cls) -> str: | ||
return f"OnePhaseRGBDagger_{cls.IL_PIPELINE_TYPE}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
from typing import Tuple, Sequence, Optional, Dict, Any | ||
|
||
import torch | ||
from allenact.algorithms.onpolicy_sync.losses.imitation import Imitation | ||
from allenact.base_abstractions.sensor import ExpertActionSensor | ||
from allenact.utils.experiment_utils import PipelineStage | ||
from allenact.utils.misc_utils import all_unique | ||
|
||
from baseline_configs.one_phase.one_phase_rgb_base import ( | ||
OnePhaseRGBBaseExperimentConfig, | ||
) | ||
from baseline_configs.rearrange_base import RearrangeBaseExperimentConfig | ||
|
||
|
||
class StepwiseLinearDecay: | ||
def __init__(self, cumm_steps_and_values: Sequence[Tuple[int, float]]): | ||
assert len(cumm_steps_and_values) >= 1 | ||
|
||
self.steps_and_values = list(sorted(cumm_steps_and_values)) | ||
self.steps = [steps for steps, _ in cumm_steps_and_values] | ||
self.values = [value for _, value in cumm_steps_and_values] | ||
|
||
assert all_unique(self.steps) | ||
assert all(0 <= v <= 1 for v in self.values) | ||
|
||
def __call__(self, epoch: int) -> float: | ||
"""Get the value for the input number of steps.""" | ||
if epoch <= self.steps[0]: | ||
return self.values[0] | ||
elif epoch >= self.steps[-1]: | ||
return self.values[-1] | ||
else: | ||
# TODO: Binary search would be more efficient but seems overkill | ||
for i, (s0, s1) in enumerate(zip(self.steps[:-1], self.steps[1:])): | ||
if epoch < s1: | ||
p = (epoch - s0) / (s1 - s0) | ||
v0 = self.values[i] | ||
v1 = self.values[i + 1] | ||
return p * v1 + (1 - p) * v0 | ||
|
||
|
||
def il_training_params(label: str, training_steps: int): | ||
use_lr_decay = False | ||
|
||
if label == "80proc": | ||
lr = 3e-4 | ||
num_train_processes = 80 | ||
num_steps = 64 | ||
dagger_steps = min(int(1e6), training_steps // 10) | ||
bc_tf1_steps = min(int(1e5), training_steps // 10) | ||
update_repeats = 3 | ||
num_mini_batch = 2 if torch.cuda.is_available() else 1 | ||
|
||
elif label == "40proc": | ||
lr = 3e-4 | ||
num_train_processes = 40 | ||
num_steps = 64 | ||
dagger_steps = min(int(1e6), training_steps // 10) | ||
bc_tf1_steps = min(int(1e5), training_steps // 10) | ||
update_repeats = 3 | ||
num_mini_batch = 1 | ||
|
||
elif label == "40proc-longtf": | ||
lr = 3e-4 | ||
num_train_processes = 40 | ||
num_steps = 64 | ||
dagger_steps = min(int(5e6), training_steps // 10) | ||
bc_tf1_steps = min(int(5e5), training_steps // 10) | ||
update_repeats = 3 | ||
num_mini_batch = 1 | ||
|
||
else: | ||
raise NotImplementedError | ||
|
||
return dict( | ||
lr=lr, | ||
num_steps=num_steps, | ||
num_mini_batch=num_mini_batch, | ||
update_repeats=update_repeats, | ||
use_lr_decay=use_lr_decay, | ||
num_train_processes=num_train_processes, | ||
dagger_steps=dagger_steps, | ||
bc_tf1_steps=bc_tf1_steps, | ||
) | ||
|
||
|
||
class OnePhaseRGBILBaseExperimentConfig(OnePhaseRGBBaseExperimentConfig): | ||
SENSORS = [ | ||
*OnePhaseRGBBaseExperimentConfig.SENSORS, | ||
ExpertActionSensor(len(RearrangeBaseExperimentConfig.actions())), | ||
] | ||
|
||
IL_PIPELINE_TYPE: Optional[str] = None | ||
|
||
@classmethod | ||
def _training_pipeline_info(cls, **kwargs) -> Dict[str, Any]: | ||
"""Define how the model trains.""" | ||
|
||
training_steps = cls.TRAINING_STEPS | ||
params = cls._use_label_to_get_training_params() | ||
bc_tf1_steps = params["bc_tf1_steps"] | ||
dagger_steps = params["dagger_steps"] | ||
|
||
return dict( | ||
named_losses=dict(imitation_loss=Imitation()), | ||
pipeline_stages=[ | ||
PipelineStage( | ||
loss_names=["imitation_loss"], | ||
max_stage_steps=training_steps, | ||
teacher_forcing=StepwiseLinearDecay( | ||
cumm_steps_and_values=[ | ||
(bc_tf1_steps, 1.0), | ||
(bc_tf1_steps + dagger_steps, 0.0), | ||
] | ||
), | ||
) | ||
], | ||
**params | ||
) | ||
|
||
@classmethod | ||
def num_train_processes(cls) -> int: | ||
return cls._use_label_to_get_training_params()["num_train_processes"] | ||
|
||
@classmethod | ||
def _use_label_to_get_training_params(cls): | ||
return il_training_params( | ||
label=cls.IL_PIPELINE_TYPE.lower(), training_steps=cls.TRAINING_STEPS | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
from typing import Dict, Any | ||
|
||
from allenact.algorithms.onpolicy_sync.losses import PPO | ||
from allenact.algorithms.onpolicy_sync.losses.ppo import PPOConfig | ||
from allenact.utils.experiment_utils import LinearDecay, PipelineStage | ||
|
||
from baseline_configs.one_phase.one_phase_rgb_base import ( | ||
OnePhaseRGBBaseExperimentConfig, | ||
) | ||
|
||
|
||
class OnePhasePPORGBExperimentConfig(OnePhaseRGBBaseExperimentConfig): | ||
USE_RESNET_CNN = False | ||
|
||
@classmethod | ||
def tag(cls) -> str: | ||
return "OnePhaseRGBPPO" | ||
|
||
@classmethod | ||
def num_train_processes(cls) -> int: | ||
return 40 | ||
|
||
@classmethod | ||
def _training_pipeline_info(cls, **kwargs) -> Dict[str, Any]: | ||
"""Define how the model trains.""" | ||
|
||
training_steps = cls.TRAINING_STEPS | ||
return dict( | ||
named_losses=dict( | ||
ppo_loss=PPO(clip_decay=LinearDecay(training_steps), **PPOConfig) | ||
), | ||
pipeline_stages=[ | ||
PipelineStage(loss_names=["ppo_loss"], max_stage_steps=training_steps,) | ||
], | ||
num_steps=64, | ||
num_mini_batch=1, | ||
update_repeats=3, | ||
use_lr_decay=True, | ||
lr=3e-4, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from baseline_configs.one_phase.one_phase_rgb_il_base import ( | ||
OnePhaseRGBILBaseExperimentConfig, | ||
) | ||
|
||
|
||
class OnePhaseRGBCompassResNetDaggerExperimentConfig(OnePhaseRGBILBaseExperimentConfig): | ||
USE_RESNET_CNN = True | ||
IL_PIPELINE_TYPE = "40proc" | ||
|
||
@classmethod | ||
def tag(cls) -> str: | ||
return f"OnePhaseRGBResNetDagger_{cls.IL_PIPELINE_TYPE}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from baseline_configs.one_phase.one_phase_rgb_ppo import OnePhasePPORGBExperimentConfig | ||
|
||
|
||
class OnePhasePPORGBResNetExperimentConfig(OnePhasePPORGBExperimentConfig): | ||
USE_RESNET_CNN = True | ||
|
||
@classmethod | ||
def tag(cls) -> str: | ||
return "OnePhaseRGBResNetPPO" |
Oops, something went wrong.