diff --git a/allenact/algorithms/onpolicy_sync/engine.py b/allenact/algorithms/onpolicy_sync/engine.py index d2c6b2317..f478d9025 100644 --- a/allenact/algorithms/onpolicy_sync/engine.py +++ b/allenact/algorithms/onpolicy_sync/engine.py @@ -8,16 +8,9 @@ import traceback from functools import partial from multiprocessing.context import BaseContext -from typing import ( - Optional, - Any, - Dict, - Union, - List, - cast, - Sequence, -) +from typing import Any, Dict, List, Optional, Sequence, Union, cast +import filelock import torch import torch.distributed as dist # type: ignore import torch.distributions # type: ignore @@ -27,7 +20,9 @@ # noinspection PyProtectedMember from torch._C._distributed_c10d import ReduceOp -from allenact.algorithms.onpolicy_sync.misc import TrackingInfoType, TrackingInfo +from allenact.algorithms.onpolicy_sync.misc import TrackingInfo, TrackingInfoType +from allenact.base_abstractions.sensor import Sensor +from allenact.utils.misc_utils import str2bool from allenact.utils.model_utils import md5_hash_of_state_dict try: @@ -43,37 +38,35 @@ from allenact.algorithms.onpolicy_sync.storage import ( ExperienceStorage, MiniBatchStorageMixin, - StreamingStorageMixin, RolloutStorage, + StreamingStorageMixin, ) from allenact.algorithms.onpolicy_sync.vector_sampled_tasks import ( - VectorSampledTasks, + COMPLETE_TASK_CALLBACK_KEY, COMPLETE_TASK_METRICS_KEY, SingleProcessVectorSampledTasks, + VectorSampledTasks, ) +from allenact.base_abstractions.distributions import TeacherForcingDistr from allenact.base_abstractions.experiment_config import ExperimentConfig, MachineParams from allenact.base_abstractions.misc import ( - RLStepResult, - Memory, ActorCriticOutput, GenericAbstractLoss, + Memory, + RLStepResult, ) -from allenact.base_abstractions.distributions import TeacherForcingDistr from allenact.utils import spaces_utils as su from allenact.utils.experiment_utils import ( - set_seed, - TrainingPipeline, LoggingPackage, PipelineStage, - set_deterministic_cudnn, ScalarMeanTracker, StageComponent, + TrainingPipeline, + set_deterministic_cudnn, + set_seed, ) from allenact.utils.system import get_logger -from allenact.utils.tensor_utils import ( - batch_observations, - detach_recursively, -) +from allenact.utils.tensor_utils import batch_observations, detach_recursively from allenact.utils.viz_utils import VizSuite try: @@ -82,9 +75,9 @@ # noinspection PyPackageRequirements import pydevd - DEBUGGING = True + DEBUGGING = str2bool(os.getenv("ALLENACT_DEBUG", "true")) except ImportError: - DEBUGGING = False + DEBUGGING = str2bool(os.getenv("ALLENACT_DEBUG", "false")) DEBUG_VST_TIMEOUT: Optional[int] = (lambda x: int(x) if x is not None else x)( os.getenv("ALLENACT_DEBUG_VST_TIMEOUT", None) @@ -115,6 +108,7 @@ def __init__( ], # to write/read (trainer/evaluator) ready checkpoints checkpoints_dir: str, mode: str = "train", + callback_sensors: Optional[Sequence[Sensor]] = None, seed: Optional[int] = None, deterministic_cudnn: bool = False, mp_ctx: Optional[BaseContext] = None, @@ -126,7 +120,7 @@ def __init__( deterministic_agents: bool = False, max_sampler_processes_per_worker: Optional[int] = None, initial_model_state_dict: Optional[Union[Dict[str, Any], int]] = None, - try_restart_after_task_timeout: bool = False, + try_restart_after_task_error: bool = False, **kwargs, ): """Initializer. @@ -154,7 +148,7 @@ def __init__( self.device = torch.device("cpu") if device == -1 else torch.device(device) # type: ignore self.distributed_ip = distributed_ip self.distributed_port = distributed_port - self.try_restart_after_task_timeout = try_restart_after_task_timeout + self.try_restart_after_task_error = try_restart_after_task_error self.mode = mode.lower().strip() assert self.mode in [ @@ -163,6 +157,7 @@ def __init__( TEST_MODE_STR, ], 'Only "train", "valid", "test" modes supported' + self.callback_sensors = callback_sensors self.deterministic_cudnn = deterministic_cudnn if self.deterministic_cudnn: set_deterministic_cudnn() @@ -245,6 +240,9 @@ def __init__( port=self.distributed_port, world_size=self.num_workers, is_master=self.worker_id == 0, + timeout=datetime.timedelta( + seconds=3 * (DEBUG_VST_TIMEOUT if DEBUGGING else 1 * 60) + 300 + ), ) cpu_device = self.device == torch.device("cpu") # type:ignore @@ -277,6 +275,7 @@ def __init__( # Keeping track of metrics during training/inference self.single_process_metrics: List = [] + self.single_process_task_callback_data: List = [] @property def vector_tasks( @@ -310,12 +309,13 @@ def vector_tasks( self._vector_tasks = VectorSampledTasks( make_sampler_fn=self.config.make_sampler_fn, sampler_fn_args=self.get_sampler_fn_args(seeds), + callback_sensors=self.callback_sensors, multiprocessing_start_method="forkserver" if self.mp_ctx is None else None, mp_ctx=self.mp_ctx, max_processes=self.max_sampler_processes_per_worker, - read_timeout=DEBUG_VST_TIMEOUT if DEBUGGING else 5 * 60, + read_timeout=DEBUG_VST_TIMEOUT if DEBUGGING else 1 * 60, ) return self._vector_tasks @@ -588,14 +588,17 @@ def collect_step_across_all_task_samplers( # Save after task completion metrics for step_result in outputs: - if ( - step_result.info is not None - and COMPLETE_TASK_METRICS_KEY in step_result.info - ): - self.single_process_metrics.append( - step_result.info[COMPLETE_TASK_METRICS_KEY] - ) - del step_result.info[COMPLETE_TASK_METRICS_KEY] + if step_result.info is not None: + if COMPLETE_TASK_METRICS_KEY in step_result.info: + self.single_process_metrics.append( + step_result.info[COMPLETE_TASK_METRICS_KEY] + ) + del step_result.info[COMPLETE_TASK_METRICS_KEY] + if COMPLETE_TASK_CALLBACK_KEY in step_result.info: + self.single_process_task_callback_data.append( + step_result.info[COMPLETE_TASK_CALLBACK_KEY] + ) + del step_result.info[COMPLETE_TASK_CALLBACK_KEY] rewards: Union[List, torch.Tensor] observations, rewards, dones, infos = [list(x) for x in zip(*outputs)] @@ -856,6 +859,36 @@ def deterministic_seeds(self) -> None: ) # use the latest seed for workers and update rng state self.vector_tasks.set_seeds(seeds) + def save_error_data(self, batch: Dict[str, Any]) -> str: + model_path = os.path.join( + self.checkpoints_dir, + "error_for_exp_{}__stage_{:02d}__steps_{:012d}.pt".format( + self.experiment_name, + self.training_pipeline.current_stage_index, + self.training_pipeline.total_steps, + ), + ) + with filelock.FileLock( + os.path.join(self.checkpoints_dir, "error.lock"), timeout=60 + ): + if not os.path.exists(model_path): + save_dict = { + "model_state_dict": self.actor_critic.state_dict(), # type:ignore + "total_steps": self.training_pipeline.total_steps, # Total steps including current stage + "optimizer_state_dict": self.optimizer.state_dict(), # type: ignore + "training_pipeline_state_dict": self.training_pipeline.state_dict(), + "trainer_seed": self.seed, + "batch": batch, + } + + if self.lr_scheduler is not None: + save_dict["scheduler_state"] = cast( + _LRScheduler, self.lr_scheduler + ).state_dict() + + torch.save(save_dict, model_path) + return model_path + def checkpoint_save(self, pipeline_stage_index: Optional[int] = None) -> str: model_path = os.path.join( self.checkpoints_dir, @@ -1124,12 +1157,21 @@ def single_batch_generator(streaming_storage: StreamingStorageMixin): bsize = batch["bsize"] if actor_critic_output_for_batch is None: - actor_critic_output_for_batch, _ = self.actor_critic( - observations=batch["observations"], - memory=batch["memory"], - prev_actions=batch["prev_actions"], - masks=batch["masks"], - ) + + try: + actor_critic_output_for_batch, _ = self.actor_critic( + observations=batch["observations"], + memory=batch["memory"], + prev_actions=batch["prev_actions"], + masks=batch["masks"], + ) + except ValueError: + save_path = self.save_error_data(batch=batch) + get_logger().error( + f"Encountered a value error! Likely because of nans in the output/input." + f" Saving all error information to {save_path}." + ) + raise loss_return = loss.loss( step_count=self.step_count, @@ -1287,6 +1329,10 @@ def aggregate_and_send_logging_package( self.aggregate_task_metrics(logging_pkg=logging_pkg) + for callback_dict in self.single_process_task_callback_data: + logging_pkg.task_callback_data.append(callback_dict) + self.single_process_task_callback_data = [] + if self.mode == TRAIN_MODE_STR: # Technically self.mode should always be "train" here (as this is the training engine), # this conditional is defensive @@ -1450,9 +1496,9 @@ def run_pipeline(self, valid_on_initial_weights: bool = False): rollout_storage_uuid=self.training_pipeline.rollout_storage_uuid, uuid_to_storage=uuid_to_storage, ) - except TimeoutError: + except (TimeoutError, EOFError) as e: if ( - not self.try_restart_after_task_timeout + not self.try_restart_after_task_error ) or self.mode != TRAIN_MODE_STR: # Apparently you can just call `raise` here and doing so will just raise the exception as though # it was not caught (so the stacktrace isn't messed up) @@ -1465,9 +1511,10 @@ def run_pipeline(self, valid_on_initial_weights: bool = False): else: get_logger().warning( f"[{self.mode} worker {self.worker_id}] `vector_tasks` appears to have crashed during" - f" training as it has timed out. You have set `try_restart_after_task_timeout` to `True` so" - f" we will attempt to restart these tasks from the beginning. USE THIS FEATURE AT YOUR OWN" - f" RISK. Timeout exception:\n{traceback.format_exc()}." + f" training due to an {type(e).__name__} error. You have set" + f" `try_restart_after_task_error` to `True` so we will attempt to restart these tasks from" + f" the beginning. USE THIS FEATURE AT YOUR OWN" + f" RISK. Exception:\n{traceback.format_exc()}." ) self.vector_tasks.close() self._vector_tasks = None @@ -1852,6 +1899,10 @@ def run_eval( self.aggregate_task_metrics(logging_pkg=logging_pkg) + for callback_dict in self.single_process_task_callback_data: + logging_pkg.task_callback_data.append(callback_dict) + self.single_process_task_callback_data = [] + logging_pkg.viz_data = ( visualizer.read_and_reset() if visualizer is not None else None ) diff --git a/allenact/algorithms/onpolicy_sync/runner.py b/allenact/algorithms/onpolicy_sync/runner.py index 51811b111..a503c608d 100644 --- a/allenact/algorithms/onpolicy_sync/runner.py +++ b/allenact/algorithms/onpolicy_sync/runner.py @@ -2,6 +2,8 @@ import copy import enum import glob +import importlib.util +import inspect import itertools import json import math @@ -17,36 +19,39 @@ from collections import defaultdict from multiprocessing.context import BaseContext from multiprocessing.process import BaseProcess -from typing import Optional, Dict, Union, Tuple, Sequence, List, Any +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, Set import filelock import numpy as np import torch import torch.multiprocessing as mp from setproctitle import setproctitle as ptitle +from torch.distributions.utils import lazy_property from allenact.algorithms.onpolicy_sync.engine import ( - OnPolicyTrainer, - OnPolicyInference, + TEST_MODE_STR, TRAIN_MODE_STR, VALID_MODE_STR, - TEST_MODE_STR, + OnPolicyInference, OnPolicyRLEngine, + OnPolicyTrainer, ) +from allenact.base_abstractions.callbacks import Callback from allenact.base_abstractions.experiment_config import ExperimentConfig, MachineParams +from allenact.base_abstractions.sensor import Sensor from allenact.utils.experiment_utils import ( + LoggingPackage, ScalarMeanTracker, set_deterministic_cudnn, set_seed, - LoggingPackage, ) from allenact.utils.misc_utils import ( + NumpyJSONEncoder, all_equal, get_git_diff_of_project, - NumpyJSONEncoder, ) from allenact.utils.model_utils import md5_hash_of_state_dict -from allenact.utils.system import get_logger, find_free_port +from allenact.utils.system import find_free_port, get_logger from allenact.utils.tensor_utils import SummaryWriter from allenact.utils.viz_utils import VizSuite @@ -86,14 +91,17 @@ def __init__( disable_tensorboard: bool = False, disable_config_saving: bool = False, distributed_ip_and_port: str = "127.0.0.1:0", + distributed_preemption_threshold: float = 0.7, machine_id: int = 0, save_dir_fmt: SaveDirFormat = SaveDirFormat.FLAT, + callbacks_paths: Optional[str] = None, ): self.config = config self.output_dir = output_dir self.loaded_config_src_files = loaded_config_src_files self.seed = seed if seed is not None else random.randint(0, 2 ** 31 - 1) self.deterministic_cudnn = deterministic_cudnn + self.distributed_preemption_threshold = distributed_preemption_threshold if multiprocessing_start_method == "default": if torch.cuda.is_available(): multiprocessing_start_method = "forkserver" @@ -137,6 +145,8 @@ def __init__( self.save_dir_fmt = save_dir_fmt + self.callbacks = self.setup_callback_classes(callbacks_paths) + @property def local_start_time_str(self) -> str: if self._local_start_time_str is None: @@ -178,6 +188,31 @@ def init_context( return mp_ctx + def setup_callback_classes(self, callbacks: Optional[str]) -> Set[Callback]: + """Get a list of Callback classes from a comma-separated list of + filenames.""" + if callbacks == "" or callbacks is None: + return set() + + setup_dict = dict(name=self.experiment_name, config=self.config, mode=self.mode) + callback_classes = set() + files = callbacks.split(",") + for filename in files: + module_path = filename.replace("/", ".") + if module_path.endswith(".py"): + module_path = module_path[:-3] + module = importlib.import_module(module_path) + classes = inspect.getmembers(module, inspect.isclass) + + for mod_class in classes: + if issubclass(mod_class[1], Callback) and mod_class[1] != Callback: + # NOTE: initialize the callback class + inst_class = mod_class[1]() + inst_class.setup(**setup_dict) + callback_classes.add(inst_class) + + return callback_classes + def _acquire_unique_local_start_time_string(self) -> str: """Creates a (unique) local start time string for this experiment. @@ -311,6 +346,15 @@ def init_worker(engine_class, args, kwargs): finally: return worker + @lazy_property + def _get_callback_sensors(self) -> List[Sensor]: + callback_sensors: List[Sensor] = [] + for c in self.callbacks: + sensors = c.callback_sensors() + if sensors is not None: + callback_sensors.extend(sensors) + return callback_sensors + @staticmethod def train_loop( id: int = 0, @@ -411,6 +455,7 @@ def start_train( save_ckpt_after_every_pipeline_stage: bool = True, collect_valid_results: bool = False, valid_on_initial_weights: bool = False, + try_restart_after_task_error: bool = False, ): self._initialize_start_train_or_start_test() @@ -452,6 +497,7 @@ def start_train( restart_pipeline=restart_pipeline, experiment_name=self.experiment_name, config=self.config, + callback_sensors=self._get_callback_sensors, results_queue=self.queues["results"], checkpoints_queue=self.queues["checkpoints"] if self.running_validation @@ -470,7 +516,9 @@ def start_train( if model_hash is None else model_hash, first_local_worker_id=worker_ids[0], + distributed_preemption_threshold=self.distributed_preemption_threshold, valid_on_initial_weights=valid_on_initial_weights, + try_restart_after_task_error=try_restart_after_task_error, ) train: BaseProcess = self.mp_ctx.Process( target=self.train_loop, kwargs=training_kwargs, @@ -507,6 +555,7 @@ def start_train( args=(0,), kwargs=dict( config=self.config, + callback_sensors=self._get_callback_sensors, results_queue=self.queues["results"], checkpoints_queue=self.queues["checkpoints"], seed=12345, # TODO allow same order for randomly sampled tasks? Is this any useful anyway? @@ -569,6 +618,9 @@ def start_test( assert ( self.machine_id == 0 ), f"Received `machine_id={self.machine_id} for test. Only one machine supported." + assert isinstance( + checkpoint_path_dir_or_pattern, str + ), "Must provide a --checkpoint path or pattern to test on." self.extra_tag += ( "__" * (len(self.extra_tag) > 0) + "enforced_test_expert" @@ -590,6 +642,7 @@ def start_test( args=(tester_it,), kwargs=dict( config=self.config, + callback_sensors=self._get_callback_sensors, results_queue=self.queues["results"], checkpoints_queue=self.queues["checkpoints"], seed=12345, # TODO allow same order for randomly sampled tasks? Is this any useful anyway? @@ -810,8 +863,10 @@ def save_project_state(self): break get_logger().info(f"Config files saved to {base_dir}") + for callback in self.callbacks: + callback.after_save_project_state(base_dir=base_dir) - def process_eval_package( + def process_valid_package( self, log_writer: Optional[SummaryWriter], pkg: LoggingPackage, @@ -824,30 +879,42 @@ def process_eval_package( num_tasks = pkg.num_non_empty_metrics_dicts_added metric_means = pkg.metrics_tracker.means() + callback_metric_means = dict() + tasks_callback_data = pkg.task_callback_data mode = pkg.mode + assert mode == "valid" + num_tasks_key = f"{mode}-misc/num_tasks_evaled" if log_writer is not None: - log_writer.add_scalar( - f"{mode}-misc/num_tasks_evaled", num_tasks, training_steps - ) + log_writer.add_scalar(num_tasks_key, num_tasks, training_steps) + callback_metric_means[num_tasks_key] = num_tasks message = [f"{mode} {training_steps} steps:"] for k in sorted(metric_means.keys()): + metrics_key = f"{mode}-metrics/{k}" if log_writer is not None: - log_writer.add_scalar( - f"{mode}-metrics/{k}", metric_means[k], training_steps - ) + log_writer.add_scalar(metrics_key, metric_means[k], training_steps) + callback_metric_means[metrics_key] = metric_means[k] message.append(f"{k} {metric_means[k]}") + results = copy.deepcopy(metric_means) + results.update({"training_steps": training_steps, "tasks": task_outputs}) if all_results is not None: - results = copy.deepcopy(metric_means) - results.update({"training_steps": training_steps, "tasks": task_outputs}) all_results.append(results) message.append(f"tasks {num_tasks} checkpoint {checkpoint_file_name}") get_logger().info(" ".join(message)) + for callback in self.callbacks: + callback.on_valid_log( + metric_means=callback_metric_means, + metrics=results, + step=training_steps, + checkpoint_file_name=checkpoint_file_name, + tasks_data=tasks_callback_data, + ) + if self.visualizer is not None: self.visualizer.log( log_writer=log_writer, @@ -865,6 +932,7 @@ def process_train_packages( last_time: float, ): assert self.mode == TRAIN_MODE_STR + callback_metric_means = dict() current_time = time.time() @@ -877,13 +945,17 @@ def process_train_packages( scalar_value=pkgs[0].pipeline_stage, global_step=training_steps, ) + callback_metric_means[f"train-misc/pipeline_stage"] = pkgs[0].pipeline_stage - for storage_uuid, val in storage_uuid_to_total_experiences.items(): + for storage_uuid, val in storage_uuid_to_total_experiences.items(): + total_experiences_key = f"train-misc/{storage_uuid}_total_experiences" + if log_writer is not None: log_writer.add_scalar( - tag=f"train-misc/{storage_uuid}_total_experiences", + tag=total_experiences_key, scalar_value=val, global_step=training_steps, ) + callback_metric_means[total_experiences_key] = val def add_prefix( d: Union[Dict[str, Any], str], @@ -907,11 +979,13 @@ def _convert(key: str): metrics_and_train_info_tracker = ScalarMeanTracker() scalar_name_to_total_storage_experience = {} storage_uuid_to_stage_component_uuids = defaultdict(lambda: set()) + tasks_callback_data = [] for pkg in pkgs: metrics_and_train_info_tracker.add_scalars( scalars=add_prefix(pkg.metrics_tracker.means(), "metrics", None), n=add_prefix(pkg.metrics_tracker.counts(), "metrics", None), ) + tasks_callback_data.extend(pkg.task_callback_data) for ( (stage_component_uuid, storage_uuid), @@ -963,6 +1037,7 @@ def _convert(key: str): f"TRAIN: {training_steps} rollout steps ({pkgs[0].storage_uuid_to_total_experiences})" ] means = metrics_and_train_info_tracker.means() + callback_metric_means.update(means) for k in sorted( means.keys(), key=lambda mean_key: (mean_key.count("/"), mean_key) @@ -984,10 +1059,10 @@ def _convert(key: str): if last_steps > 0: fps = (training_steps - last_steps) / (current_time - last_time) message += [f"approx_fps {fps:.3g}"] + approx_fps_key = add_prefix("approx_fps", "misc", None) if log_writer is not None: - log_writer.add_scalar( - add_prefix("approx_fps", "misc", None), fps, training_steps - ) + log_writer.add_scalar(approx_fps_key, fps, training_steps) + callback_metric_means[approx_fps_key] = fps for ( storage_uuid, @@ -997,22 +1072,34 @@ def _convert(key: str): cur_total_exp = storage_uuid_to_total_experiences[storage_uuid] eps = (cur_total_exp - last_total_exp) / (current_time - last_time) message += [f"{storage_uuid}/approx_eps {eps:.3g}"] - if log_writer is not None: - for stage_component_uuid in storage_uuid_to_stage_component_uuids[ - storage_uuid - ]: + for stage_component_uuid in storage_uuid_to_stage_component_uuids[ + storage_uuid + ]: + approx_eps_key = add_prefix( + f"approx_eps", + "misc", + stage_component_uuid=stage_component_uuid, + ) + callback_metric_means[approx_eps_key] = eps + if log_writer is not None: log_writer.add_scalar( - add_prefix( - f"approx_eps", - "misc", - stage_component_uuid=stage_component_uuid, - ), - eps, - cur_total_exp, + approx_eps_key, eps, cur_total_exp, ) get_logger().info(" ".join(message)) + metrics = [] + for pkg in pkgs: + metrics.extend(pkg.metric_dicts) + + for callback in self.callbacks: + callback.on_train_log( + metrics=metrics, + metric_means=callback_metric_means, + step=training_steps, + tasks_data=tasks_callback_data, + ) + return training_steps, storage_uuid_to_total_experiences, current_time def process_test_packages( @@ -1028,10 +1115,12 @@ def process_test_packages( all_metrics_tracker = ScalarMeanTracker() metric_dicts_list, render, checkpoint_file_name = [], {}, [] + tasks_callback_data = [] for pkg in pkgs: all_metrics_tracker.add_scalars( scalars=pkg.metrics_tracker.means(), n=pkg.metrics_tracker.counts() ) + tasks_callback_data.extend(pkg.task_callback_data) metric_dicts_list.extend(pkg.metric_dicts) if pkg.viz_data is not None: render.update(pkg.viz_data) @@ -1042,11 +1131,12 @@ def process_test_packages( message = [f"{mode} {training_steps} steps:"] metric_means = all_metrics_tracker.means() + callback_metric_means = dict() for k in sorted(metric_means.keys()): + metrics_key = f"{mode}-metrics/{k}" if log_writer is not None: - log_writer.add_scalar( - f"{mode}-metrics/{k}", metric_means[k], training_steps - ) + log_writer.add_scalar(metrics_key, metric_means[k], training_steps) + callback_metric_means[metrics_key] = metric_means[k] message.append(k + f" {metric_means[k]:.3g}") if all_results is not None: @@ -1057,14 +1147,24 @@ def process_test_packages( all_results.append(results) num_tasks = sum([pkg.num_non_empty_metrics_dicts_added for pkg in pkgs]) + + num_tasks_evaled_key = f"{mode}-misc/num_tasks_evaled" if log_writer is not None: - log_writer.add_scalar( - f"{mode}-misc/num_tasks_evaled", num_tasks, training_steps - ) + log_writer.add_scalar(num_tasks_evaled_key, num_tasks, training_steps) + callback_metric_means[num_tasks_evaled_key] = num_tasks message.append(f"tasks {num_tasks} checkpoint {checkpoint_file_name[0]}") get_logger().info(" ".join(message)) + for callback in self.callbacks: + callback.on_test_log( + metric_means=callback_metric_means, + metrics=all_results[-1], + step=training_steps, + checkpoint_file_name=checkpoint_file_name[0], + tasks_data=tasks_callback_data, + ) + if self.visualizer is not None: self.visualizer.log( log_writer=log_writer, @@ -1153,7 +1253,7 @@ def log_and_close( if ( package.training_steps is not None ): # no validation samplers - self.process_eval_package( + self.process_valid_package( log_writer=log_writer, pkg=package, all_results=eval_results diff --git a/allenact/algorithms/onpolicy_sync/vector_sampled_tasks.py b/allenact/algorithms/onpolicy_sync/vector_sampled_tasks.py index 981a0fff0..864ee3ecb 100644 --- a/allenact/algorithms/onpolicy_sync/vector_sampled_tasks.py +++ b/allenact/algorithms/onpolicy_sync/vector_sampled_tasks.py @@ -13,15 +13,15 @@ from typing import ( Any, Callable, + Dict, + Generator, + Iterator, List, Optional, Sequence, Set, Tuple, Union, - Dict, - Generator, - Iterator, cast, ) @@ -30,6 +30,7 @@ from setproctitle import setproctitle as ptitle from allenact.base_abstractions.misc import RLStepResult +from allenact.base_abstractions.sensor import SensorSuite, Sensor from allenact.base_abstractions.task import TaskSampler from allenact.utils.misc_utils import partition_sequence from allenact.utils.system import get_logger @@ -46,6 +47,7 @@ DEFAULT_MP_CONTEXT_TYPE = "forkserver" COMPLETE_TASK_METRICS_KEY = "__AFTER_TASK_METRICS__" +COMPLETE_TASK_CALLBACK_KEY = "__AFTER_TASK_CALLBACK__" STEP_COMMAND = "step" NEXT_TASK_COMMAND = "next_task" @@ -150,6 +152,7 @@ def __init__( self, make_sampler_fn: Callable[..., TaskSampler], sampler_fn_args: Sequence[Dict[str, Any]] = None, + callback_sensors: Optional[Sequence[Sensor]] = None, auto_resample_when_done: bool = True, multiprocessing_start_method: Optional[str] = "forkserver", mp_ctx: Optional[BaseContext] = None, @@ -196,7 +199,7 @@ def __init__( ] = None self._reset_sampler_index_to_process_ind_and_subprocess_ind() - self._workers: Optional[List] = None + self._workers: Optional[List[Union[mp.Process, Thread, BaseProcess]]] = None for args in sampler_fn_args: args["mp_ctx"] = self._mp_ctx ( @@ -208,6 +211,11 @@ def __init__( sampler_fn_args_list=[ args_list for args_list in self._partition_to_processes(sampler_fn_args) ], + callback_sensor_suite=( + SensorSuite(callback_sensors) + if isinstance(callback_sensors, Sequence) + else callback_sensors + ), ) self._connection_read_fns = [ @@ -222,8 +230,11 @@ def __init__( for write_fn in self._connection_write_fns: write_fn((OBSERVATION_SPACE_COMMAND, None)) + # Note that we increase the read timeout below as initialization can take some time observation_spaces = [ - space for read_fn in self._connection_read_fns for space in read_fn() + space + for read_fn in self._connection_read_fns + for space in read_fn(timeout_to_use=5 * self.read_timeout if self.read_timeout is not None else None) # type: ignore ] if any(os is None for os in observation_spaces): @@ -259,7 +270,7 @@ def read_with_timeout(timeout_to_use: Optional[float] = timeout): # noinspection PyArgumentList if not poll_fn(timeout=timeout_to_use): raise TimeoutError( - f"Did not recieve output from `VectorSampledTask` worker for {timeout_to_use} seconds." + f"Did not receive output from `VectorSampledTask` worker for {timeout_to_use} seconds." ) return read_fn() @@ -321,6 +332,7 @@ def _task_sampling_loop_worker( connection_write_fn: Callable, make_sampler_fn: Callable[..., TaskSampler], sampler_fn_args_list: List[Dict[str, Any]], + callback_sensor_suite: Optional[SensorSuite], auto_resample_when_done: bool, should_log: bool, child_pipe: Optional[Connection] = None, @@ -334,6 +346,7 @@ def _task_sampling_loop_worker( sp_vector_sampled_tasks = SingleProcessVectorSampledTasks( make_sampler_fn=make_sampler_fn, sampler_fn_args_list=sampler_fn_args_list, + callback_sensor_suite=callback_sensor_suite, auto_resample_when_done=auto_resample_when_done, should_log=should_log, ) @@ -344,57 +357,53 @@ def _task_sampling_loop_worker( while True: read_input = connection_read_fn() - with DelaySignalHandling(): - # Delaying signal handling here is necessary to ensure that we don't - # (when processing a SIGTERM/SIGINT signal) attempt to send data to - # a generator while it is already processing other data. - if len(read_input) == 3: - sampler_index, command, data = read_input - - assert ( - command != CLOSE_COMMAND - ), "Must close all processes at once." - assert ( - command != RESUME_COMMAND - ), "Must resume all task samplers at once." - - if command == PAUSE_COMMAND: - sp_vector_sampled_tasks.pause_at( - sampler_index=sampler_index - ) - connection_write_fn("done") - else: - connection_write_fn( - sp_vector_sampled_tasks.command_at( - sampler_index=sampler_index, - command=command, - data=data, - ) + # TODO: Was the below necessary? + # with DelaySignalHandling(): + # # Delaying signal handling here is necessary to ensure that we don't + # # (when processing a SIGTERM/SIGINT signal) attempt to send data to + # # a generator while it is already processing other data. + if len(read_input) == 3: + sampler_index, command, data = read_input + + assert command != CLOSE_COMMAND, "Must close all processes at once." + assert ( + command != RESUME_COMMAND + ), "Must resume all task samplers at once." + + if command == PAUSE_COMMAND: + sp_vector_sampled_tasks.pause_at(sampler_index=sampler_index) + connection_write_fn("done") + else: + connection_write_fn( + sp_vector_sampled_tasks.command_at( + sampler_index=sampler_index, command=command, data=data, ) + ) + else: + commands, data_list = read_input + + assert ( + commands != PAUSE_COMMAND + ), "Cannot pause all task samplers at once." + + if commands == CLOSE_COMMAND: + # Will close the `sp_vector_sampled_tasks` in the `finally` clause below + break + + elif commands == RESUME_COMMAND: + sp_vector_sampled_tasks.resume_all() + connection_write_fn("done") else: - commands, data_list = read_input - - assert ( - commands != PAUSE_COMMAND - ), "Cannot pause all task samplers at once." - - if commands == CLOSE_COMMAND: - sp_vector_sampled_tasks.close() - break - elif commands == RESUME_COMMAND: - sp_vector_sampled_tasks.resume_all() - connection_write_fn("done") - else: - if isinstance(commands, str): - commands = [ - commands - ] * sp_vector_sampled_tasks.num_unpaused_tasks - - connection_write_fn( - sp_vector_sampled_tasks.command( - commands=commands, data_list=data_list - ) + if isinstance(commands, str): + commands = [ + commands + ] * sp_vector_sampled_tasks.num_unpaused_tasks + + connection_write_fn( + sp_vector_sampled_tasks.command( + commands=commands, data_list=data_list ) + ) except KeyboardInterrupt: if should_log: @@ -405,6 +414,11 @@ def _task_sampling_loop_worker( ) raise e finally: + try: + sp_vector_sampled_tasks.close() + except Exception: + pass + if child_pipe is not None: child_pipe.close() if should_log: @@ -414,6 +428,7 @@ def _spawn_workers( self, make_sampler_fn: Callable[..., TaskSampler], sampler_fn_args_list: Sequence[Sequence[Dict[str, Any]]], + callback_sensor_suite: Optional[SensorSuite], ) -> Tuple[ List[Callable[[], bool]], List[Callable[[], Any]], List[Callable[[Any], None]] ]: @@ -443,6 +458,7 @@ def _spawn_workers( connection_write_fn=worker_conn.send, make_sampler_fn=make_sampler_fn, sampler_fn_args_list=current_sampler_fn_args_list, + callback_sensor_suite=callback_sensor_suite, auto_resample_when_done=self._auto_resample_when_done, should_log=self.should_log, child_pipe=worker_conn, @@ -642,6 +658,10 @@ def close(self) -> None: except Exception: pass + for process in self._workers: + if process.is_alive(): + process.kill() + self._is_closed = True def pause_at(self, sampler_index: int) -> None: @@ -858,6 +878,7 @@ def __init__( self, make_sampler_fn: Callable[..., TaskSampler], sampler_fn_args_list: Sequence[Dict[str, Any]] = None, + callback_sensor_suite: Optional[SensorSuite] = None, auto_resample_when_done: bool = True, should_log: bool = True, ) -> None: @@ -876,6 +897,7 @@ def __init__( self._vector_task_generators: List[Generator] = self._create_generators( make_sampler_fn=make_sampler_fn, sampler_fn_args=[{"mp_ctx": None, **args} for args in sampler_fn_args_list], + callback_sensor_suite=callback_sensor_suite, ) self._is_closed = False @@ -930,6 +952,7 @@ def _task_sampling_loop_generator_fn( worker_id: int, make_sampler_fn: Callable[..., TaskSampler], sampler_fn_args: Dict[str, Any], + callback_sensor_suite: Optional[SensorSuite], auto_resample_when_done: bool, should_log: bool, ) -> Generator: @@ -959,6 +982,16 @@ def _task_sampling_loop_generator_fn( step_result = step_result.clone({"info": {}}) step_result.info[COMPLETE_TASK_METRICS_KEY] = metrics + if callback_sensor_suite is not None: + task_callback_data = callback_sensor_suite.get_observations( + env=current_task.env, task=current_task + ) + if step_result.info is None: + step_result = step_result.clone({"info": {}}) + step_result.info[ + COMPLETE_TASK_CALLBACK_KEY + ] = task_callback_data + if auto_resample_when_done: current_task = task_sampler.next_task() if current_task is None: @@ -1060,21 +1093,21 @@ def _create_generators( self, make_sampler_fn: Callable[..., TaskSampler], sampler_fn_args: Sequence[Dict[str, Any]], + callback_sensor_suite: Optional[SensorSuite], ) -> List[Generator]: generators = [] for id, current_sampler_fn_args in enumerate(sampler_fn_args): if self.should_log: get_logger().info( - "Starting {}-th SingleProcessVectorSampledTasks generator with args {}".format( - id, current_sampler_fn_args - ) + f"Starting {id}-th SingleProcessVectorSampledTasks generator with args {current_sampler_fn_args}." ) generators.append( self._task_sampling_loop_generator_fn( worker_id=id, make_sampler_fn=make_sampler_fn, sampler_fn_args=current_sampler_fn_args, + callback_sensor_suite=callback_sensor_suite, auto_resample_when_done=self._auto_resample_when_done, should_log=self.should_log, ) diff --git a/allenact/base_abstractions/callbacks.py b/allenact/base_abstractions/callbacks.py new file mode 100644 index 000000000..5f1c0f476 --- /dev/null +++ b/allenact/base_abstractions/callbacks.py @@ -0,0 +1,59 @@ +from typing import List, Dict, Any, Sequence, Optional + +from allenact.base_abstractions.experiment_config import ExperimentConfig +from allenact.base_abstractions.sensor import Sensor + +try: + from typing import Literal +except ImportError: + from typing_extensions import Literal + + +class Callback: + def setup( + self, + name: str, + config: ExperimentConfig, + mode: Literal["train", "valid", "test"], + **kwargs, + ) -> None: + """Called once before training begins.""" + + def on_train_log( + self, + metrics: List[Dict[str, Any]], + metric_means: Dict[str, float], + step: int, + tasks_data: List[Any], + **kwargs, + ) -> None: + """Called once train is supposed to log.""" + + def on_valid_log( + self, + metrics: Dict[str, Any], + metric_means: Dict[str, float], + checkpoint_file_name: str, + tasks_data: List[Any], + step: int, + **kwargs, + ) -> None: + """Called after validation ends.""" + + def on_test_log( + self, + checkpoint_file_name: str, + metrics: Dict[str, Any], + metric_means: Dict[str, float], + tasks_data: List[Any], + step: int, + **kwargs, + ) -> None: + """Called after test ends.""" + + def after_save_project_state(self, base_dir: str) -> None: + """Called after saving the project state in base_dir.""" + + def callback_sensors(self) -> Optional[Sequence[Sensor]]: + """Determines the data returned to the `tasks_data` parameter in the + above *_log functions.""" diff --git a/allenact/base_abstractions/task.py b/allenact/base_abstractions/task.py index ec7c3117e..69937094a 100644 --- a/allenact/base_abstractions/task.py +++ b/allenact/base_abstractions/task.py @@ -7,7 +7,7 @@ environment.""" import abc -from typing import Dict, Any, Tuple, Generic, Union, Optional, TypeVar, Sequence, List +from typing import Any, Dict, Generic, List, Optional, Sequence, Tuple, TypeVar, Union import gym import numpy as np diff --git a/allenact/embodiedai/aux_losses/losses.py b/allenact/embodiedai/aux_losses/losses.py index 5ddb51bdb..37cc591cb 100644 --- a/allenact/embodiedai/aux_losses/losses.py +++ b/allenact/embodiedai/aux_losses/losses.py @@ -15,6 +15,7 @@ import numpy as np import torch import torch.nn as nn +import torch.nn.functional as F from allenact.algorithms.onpolicy_sync.losses.abstract_loss import ( AbstractActorCriticLoss, @@ -47,7 +48,7 @@ def loss( # type: ignore batch: ObservationType, actor_critic_output: ActorCriticOutput[CategoricalDistr], *args, - **kwargs + **kwargs, ) -> Tuple[torch.FloatTensor, Dict[str, float]]: task_weights = actor_critic_output.extras[self.UUID] task_weights = task_weights.view(-1, self.num_tasks) @@ -86,7 +87,7 @@ def loss( # type: ignore batch: ObservationType, actor_critic_output: ActorCriticOutput[CategoricalDistr], *args, - **kwargs + **kwargs, ) -> Tuple[torch.Tensor, Dict[str, float]]: # auxiliary loss @@ -107,7 +108,7 @@ def get_aux_loss( beliefs: torch.Tensor, masks: torch.Tensor, *args, - **kwargs + **kwargs, ): raise NotImplementedError() @@ -167,7 +168,7 @@ def get_aux_loss( beliefs: torch.FloatTensor, masks: torch.FloatTensor, *args, - **kwargs + **kwargs, ): ## we discard the last action in the batch num_steps, num_sampler = actions.shape # T, B @@ -260,7 +261,7 @@ def get_aux_loss( beliefs: torch.FloatTensor, masks: torch.FloatTensor, *args, - **kwargs + **kwargs, ): ## we discard the last action in the batch num_steps, num_sampler = actions.shape # T, B @@ -313,7 +314,7 @@ def get_aux_loss( np.repeat(locs[:, [1]], 2 * self.num_pairs, axis=-1), # (M, 2*k) np.repeat(locs[:, [2]] + 1, 2 * self.num_pairs, axis=-1), # (M, 2*k) ).reshape( - (-1, self.num_pairs, 2) + -1, self.num_pairs, 2 ) # (M, k, 2) sampled_pairs_batch = torch.from_numpy(sampled_pairs).to( locs_batch @@ -347,7 +348,7 @@ def get_aux_loss( ).float() # (M, k) pred_error = (pred_temp_dist - true_temp_dist) * normalizer.unsqueeze(1) - loss = 0.5 * pred_error.pow(2) + loss = 0.5 * (pred_error).pow(2) avg_loss = loss.mean() return ( @@ -381,7 +382,7 @@ def get_aux_loss( beliefs: torch.Tensor, masks: torch.Tensor, *args, - **kwargs + **kwargs, ): # prepare for autoregressive inputs: c_{t+1:t+k} = GRU(b_t, a_{t:t+k-1}) <-> z_{t+k} ## where b_t = RNN(b_{t-1}, z_t, a_{t-1}), prev action is optional @@ -412,18 +413,22 @@ def get_aux_loss( action_padded = torch.cat( (action_embedding, action_padding), dim=0 ) # (T+k-1, N, -1) + ## unfold function will create consecutive action sequences action_seq = ( action_padded.unfold(dimension=0, size=self.planning_steps, step=1) .permute(3, 0, 1, 2) .view(self.planning_steps, num_steps * num_sampler, action_embed_size) ) # (k, T*N, -1) + + ## beliefs GRU output beliefs = beliefs.view(num_steps * num_sampler, -1).unsqueeze(0) # (1, T*N, -1) # get future contexts c_{t+1:t+k} = GRU(b_t, a_{t:t+k-1}) future_contexts_all, _ = aux_model.context_model( action_seq, beliefs ) # (k, T*N, -1) + ## NOTE: future_contexts_all starting from next step t+1 to t+k, not t to t+k-1 future_contexts_all = future_contexts_all.view( self.planning_steps, num_steps, num_sampler, -1 @@ -546,6 +551,192 @@ def get_aux_loss( ) +class CPCASoftMaxLoss(AuxiliaryLoss): + """Auxiliary task of CPC|A with multi class softmax.""" + + UUID = "cpcA_SOFTMAX" + + def __init__( + self, + planning_steps: int = 8, + subsample_rate: float = 1, + allow_skipping: bool = True, + *args, + **kwargs, + ): + super().__init__(auxiliary_uuid=self.UUID, *args, **kwargs) + self.planning_steps = planning_steps + self.subsample_rate = subsample_rate + self.cross_entropy_loss = nn.CrossEntropyLoss( + reduction="none" + ) # nn.BCEWithLogitsLoss(reduction="none") + self.allow_skipping = allow_skipping + + def get_aux_loss( + self, + aux_model: nn.Module, + observations: ObservationType, + obs_embeds: torch.Tensor, + actions: torch.Tensor, + beliefs: torch.Tensor, + masks: torch.Tensor, + *args, + **kwargs, + ): + # prepare for autoregressive inputs: c_{t+1:t+k} = GRU(b_t, a_{t:t+k-1}) <-> z_{t+k} + ## where b_t = RNN(b_{t-1}, z_t, a_{t-1}), prev action is optional + num_steps, num_samplers, obs_embed_size = obs_embeds.shape # T, N, H_O + ##visual observation of all num_steps + + if not (0 < self.planning_steps <= num_steps): + if self.allow_skipping: + return 0, {} + else: + raise RuntimeError( + f"Insufficient planning steps: self.planning_steps {self.planning_steps} must" + f" be greater than zero and less than or equal to num_steps {num_steps}." + ) + + ## prepare action sequences and initial beliefs + action_embedding = aux_model.action_embedder(actions) # (T, N, -1) + action_embed_size = action_embedding.size(-1) + action_padding = torch.zeros( + self.planning_steps - 1, + num_samplers, + action_embed_size, + device=action_embedding.device, + ) # (k-1, N, -1) + action_padded = torch.cat( + (action_embedding, action_padding), dim=0 + ) # (T+k-1, N, -1) + + ## unfold function will create consecutive action sequences + action_seq = ( + action_padded.unfold(dimension=0, size=self.planning_steps, step=1) + .permute(3, 0, 1, 2) + .view(self.planning_steps, num_steps * num_samplers, action_embed_size) + ) # (k, T*N, -1) + + ## beliefs GRU output + obs_embeds = aux_model.visual_mlp(obs_embeds) # (T, N, 128) + + beliefs = beliefs.view(1, num_steps * num_samplers, -1) # (1, T*N, -1) + + # get future contexts c_{t+1:t+k} = GRU(b_t, a_{t:t+k-1}) + future_contexts_all, _ = aux_model.context_model( + action_seq, beliefs + ) # (k, T*N, -1) + + future_contexts_all = aux_model.belief_mlp(future_contexts_all) # (k, T*N, 128) + future_contexts_all = future_contexts_all.view(-1, 128) # (k*T*N, 128) + + obs_embeds = obs_embeds.view( + num_steps * num_samplers, obs_embeds.shape[-1] + ).permute( + 1, 0 + ) # (-1, T*N) + + visual_logits = torch.matmul(future_contexts_all, obs_embeds) + visual_log_probs = F.log_softmax(visual_logits, dim=1) ## (k*T*N, T*N) + + target = torch.zeros( + (self.planning_steps, num_steps, num_samplers), + dtype=torch.long, + device=beliefs.device, + ) # (k, T, N) + loss_mask = torch.zeros( + (self.planning_steps, num_steps, num_samplers), device=beliefs.device + ) # (k, T, N) + + num_valid_before = 0 + for j in range(num_samplers): + for i in range(num_steps): + index = i * num_samplers + j + + if i == 0 or masks[i, j].item() == 0: + num_valid_before = 0 + continue + + num_valid_before += 1 + for back in range(min(num_valid_before, self.planning_steps)): + target[back, i - (back + 1), j] = index + loss_mask[back, i - (back + 1), j] = 1.0 + + target = target.view(-1) # (k*T*N,) + + loss_value = self.cross_entropy_loss(visual_log_probs, target) + loss_value = loss_value.view( + self.planning_steps, num_steps, num_samplers, 1 + ) # (k, T, N, 1) + + loss_mask = loss_mask.unsqueeze(-1) # (k, T, N, 1) + loss_valid_masks = loss_mask * _bernoulli_subsample_mask_like( + loss_mask, self.subsample_rate + ) # (k, T, N, 1) + + num_valid_losses = torch.count_nonzero(loss_valid_masks) + + avg_multi_class_loss = (loss_value * loss_valid_masks).sum() / torch.clamp( + num_valid_losses, min=1.0 + ) + + return ( + avg_multi_class_loss, + {"total": cast(torch.Tensor, avg_multi_class_loss).item(),}, + ) + + +######## CPCA Softmax variants ###### + + +class CPCA1SoftMaxLoss(CPCASoftMaxLoss): + UUID = "cpcA_SOFTMAX_1" + + def __init__(self, subsample_rate: float = 1, *args, **kwargs): + super().__init__( + planning_steps=1, subsample_rate=subsample_rate, *args, **kwargs + ) + + +class CPCA2SoftMaxLoss(CPCASoftMaxLoss): + UUID = "cpcA_SOFTMAX_2" + + def __init__(self, subsample_rate: float = 1, *args, **kwargs): + super().__init__( + planning_steps=2, subsample_rate=subsample_rate, *args, **kwargs + ) + + +class CPCA4SoftMaxLoss(CPCASoftMaxLoss): + UUID = "cpcA_SOFTMAX_4" + + def __init__(self, subsample_rate: float = 1, *args, **kwargs): + super().__init__( + planning_steps=4, subsample_rate=subsample_rate, *args, **kwargs + ) + + +class CPCA8SoftMaxLoss(CPCASoftMaxLoss): + UUID = "cpcA_SOFTMAX_8" + + def __init__(self, subsample_rate: float = 1, *args, **kwargs): + super().__init__( + planning_steps=8, subsample_rate=subsample_rate, *args, **kwargs + ) + + +class CPCA16SoftMaxLoss(CPCASoftMaxLoss): + UUID = "cpcA_SOFTMAX_16" + + def __init__(self, subsample_rate: float = 1, *args, **kwargs): + super().__init__( + planning_steps=16, subsample_rate=subsample_rate, *args, **kwargs + ) + + +########### + + class CPCA1Loss(CPCALoss): UUID = "CPCA_1" diff --git a/allenact/embodiedai/models/aux_models.py b/allenact/embodiedai/models/aux_models.py index b234401b4..567906ad9 100644 --- a/allenact/embodiedai/models/aux_models.py +++ b/allenact/embodiedai/models/aux_models.py @@ -6,16 +6,16 @@ found in https://github.com/joel99/habitat-pointnav- aux/blob/master/habitat_baselines/""" - import torch import torch.nn as nn -from allenact.utils.model_utils import FeatureEmbedding from allenact.embodiedai.aux_losses.losses import ( InverseDynamicsLoss, TemporalDistanceLoss, CPCALoss, + CPCASoftMaxLoss, ) +from allenact.utils.model_utils import FeatureEmbedding class AuxiliaryModel(nn.Module): @@ -30,6 +30,7 @@ def __init__( belief_dim: int, action_embed_size: int = 4, cpca_classifier_hidden_dim: int = 32, + cpca_softmax_dim: int = 128, ): super().__init__() self.aux_uuid = aux_uuid @@ -60,6 +61,29 @@ def __init__( nn.Linear(cpca_classifier_hidden_dim, 1), ) + elif CPCASoftMaxLoss.UUID in self.aux_uuid: + ### + # same as CPCA with extra MLP for contrastive losses. + ### + self.action_embedder = FeatureEmbedding( + self.action_dim + 1, action_embed_size + ) + # NOTE: add extra 1 in embedding dict cuz we will pad zero actions? + self.context_model = nn.GRU(action_embed_size, self.belief_dim) + + ## Classifier to estimate mutual information + self.visual_mlp = nn.Sequential( + nn.Linear(obs_embed_dim, cpca_classifier_hidden_dim), + nn.ReLU(), + nn.Linear(cpca_classifier_hidden_dim, cpca_softmax_dim), + ) + + self.belief_mlp = nn.Sequential( + nn.Linear(self.belief_dim, cpca_classifier_hidden_dim), + nn.ReLU(), + nn.Linear(cpca_classifier_hidden_dim, cpca_softmax_dim), + ) + else: raise ValueError("Unknown Auxiliary Loss UUID") diff --git a/allenact/embodiedai/sensors/vision_sensors.py b/allenact/embodiedai/sensors/vision_sensors.py index 3ca2a3cb5..7983382c3 100644 --- a/allenact/embodiedai/sensors/vision_sensors.py +++ b/allenact/embodiedai/sensors/vision_sensors.py @@ -19,8 +19,8 @@ class VisionSensor(Sensor[EnvType, SubTaskType]): def __init__( self, - mean: Optional[np.ndarray] = None, - stdev: Optional[np.ndarray] = None, + mean: Union[Sequence[float], np.ndarray, None] = None, + stdev: Union[Sequence[float], np.ndarray, None] = None, height: Optional[int] = None, width: Optional[int] = None, uuid: str = "vision", @@ -50,8 +50,8 @@ def __init__( kwargs : Extra kwargs. Currently unused. """ - self._norm_means = mean - self._norm_sds = stdev + self._norm_means = np.array(mean) + self._norm_sds = np.array(stdev) assert (self._norm_means is None) == (self._norm_sds is None), ( "In VisionSensor's config, " "either both mean/stdev must be None or neither." diff --git a/allenact/main.py b/allenact/main.py index 2598c8c72..d1ad6d0b1 100755 --- a/allenact/main.py +++ b/allenact/main.py @@ -1,24 +1,29 @@ """Entry point to training/validating/testing for a user given experiment name.""" +import os + +if "CUDA_DEVICE_ORDER" not in os.environ: + # Necessary to order GPUs correctly in some cases + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + import argparse import ast import importlib import inspect import json -import os -from typing import Dict, Tuple, List, Optional, Type +from typing import Dict, List, Optional, Tuple, Type from setproctitle import setproctitle as ptitle from allenact import __version__ from allenact.algorithms.onpolicy_sync.runner import ( - OnPolicyRunner, CONFIG_KWARGS_STR, + OnPolicyRunner, SaveDirFormat, ) from allenact.base_abstractions.experiment_config import ExperimentConfig -from allenact.utils.system import get_logger, init_logging, HUMAN_LOG_LEVELS +from allenact.utils.system import HUMAN_LOG_LEVELS, get_logger, init_logging def get_argument_parser(): @@ -263,6 +268,24 @@ def get_argument_parser(): " tutorial https://allenact.org/tutorials/distributed-objectnav-tutorial/", ) + parser.add_argument( + "--callbacks", + dest="callbacks", + required=False, + type=str, + default="", + help="Comma-separated list of files with Callback classes to use.", + ) + + parser.add_argument( + "--enable_crash_recovery", + dest="enable_crash_recovery", + default=False, + action="store_true", + required=False, + help="Whether or not to try recovering when a task crashes (use at your own risk).", + ) + ### DEPRECATED FLAGS parser.add_argument( "-t", @@ -447,12 +470,14 @@ def main(): disable_config_saving=args.disable_config_saving, distributed_ip_and_port=args.distributed_ip_and_port, machine_id=args.machine_id, + callbacks_paths=args.callbacks, ).start_train( checkpoint=args.checkpoint, restart_pipeline=args.restart_pipeline, max_sampler_processes_per_worker=args.max_sampler_processes_per_worker, collect_valid_results=args.collect_valid_results, valid_on_initial_weights=args.valid_on_initial_weights, + try_restart_after_task_error=args.enable_crash_recovery, ) else: OnPolicyRunner( @@ -469,6 +494,7 @@ def main(): disable_config_saving=args.disable_config_saving, distributed_ip_and_port=args.distributed_ip_and_port, machine_id=args.machine_id, + callbacks_paths=args.callbacks, ).start_test( checkpoint_path_dir_or_pattern=args.checkpoint, infer_output_dir=args.infer_output_dir, diff --git a/allenact/utils/experiment_utils.py b/allenact/utils/experiment_utils.py index ff0bc2651..22345db76 100644 --- a/allenact/utils/experiment_utils.py +++ b/allenact/utils/experiment_utils.py @@ -259,6 +259,7 @@ def __init__( self.metric_dicts: List[Any] = [] self.viz_data: Optional[Dict[str, List[Dict[str, Any]]]] = None self.checkpoint_file_name: Optional[str] = None + self.task_callback_data: List[Any] = [] self.num_empty_metrics_dicts_added: int = 0 diff --git a/allenact/utils/inference.py b/allenact/utils/inference.py index 744c45d9e..874583a18 100644 --- a/allenact/utils/inference.py +++ b/allenact/utils/inference.py @@ -1,4 +1,4 @@ -from typing import Optional, cast, Tuple +from typing import Optional, cast, Tuple, Any, Dict import attr import torch @@ -46,8 +46,12 @@ def from_experiment_config( exp_config: ExperimentConfig, device: torch.device, checkpoint_path: Optional[str] = None, + model_state_dict: Optional[Dict[str, Any]] = None, mode: str = "test", ): + assert ( + checkpoint_path is None or model_state_dict is None + ), "Cannot have `checkpoint_path` and `model_state_dict` both non-None." rollout_storage = exp_config.training_pipeline().rollout_storage machine_params = exp_config.machine_params(mode) @@ -67,6 +71,12 @@ def from_experiment_config( actor_critic.load_state_dict( torch.load(checkpoint_path, map_location="cpu")["model_state_dict"] ) + elif model_state_dict is not None: + actor_critic.load_state_dict( + model_state_dict + if "model_state_dict" not in model_state_dict + else model_state_dict["model_state_dict"] + ) return cls( actor_critic=actor_critic, diff --git a/allenact/utils/misc_utils.py b/allenact/utils/misc_utils.py index 611fbf531..e68f2ba1c 100644 --- a/allenact/utils/misc_utils.py +++ b/allenact/utils/misc_utils.py @@ -321,3 +321,13 @@ def partition_limits(num_items: int, num_parts: int): .astype(np.int32) .tolist() ) + + +def str2bool(v: str): + v = v.lower().strip() + if v in ("yes", "true", "t", "y", "1"): + return True + elif v in ("no", "false", "f", "n", "0"): + return False + else: + raise ValueError(f"{v} cannot be converted to a bool") diff --git a/allenact_plugins/clip_plugin/__init__.py b/allenact_plugins/clip_plugin/__init__.py index e134f0c18..547f91b00 100644 --- a/allenact_plugins/clip_plugin/__init__.py +++ b/allenact_plugins/clip_plugin/__init__.py @@ -2,7 +2,7 @@ with ImportChecker( "Cannot `import clip`. Please install clip from the openai/CLIP git repository:" - "\n`pip install git+https://github.com/openai/CLIP.git@3b473b0e682c091a9e53623eebc1ca1657385717`" + "\n`pip install git+https://github.com/openai/CLIP.git@b46f5ac7587d2e1862f8b7b1573179d80dcdd620`" ): # noinspection PyUnresolvedReferences import clip diff --git a/allenact_plugins/clip_plugin/clip_preprocessors.py b/allenact_plugins/clip_plugin/clip_preprocessors.py index 9c414f6b5..63c011ede 100644 --- a/allenact_plugins/clip_plugin/clip_preprocessors.py +++ b/allenact_plugins/clip_plugin/clip_preprocessors.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Any, cast, Dict +from typing import List, Optional, Any, cast, Dict, Tuple import clip import gym @@ -50,14 +50,18 @@ def __init__( pool: bool, device: Optional[torch.device] = None, device_ids: Optional[List[torch.device]] = None, + input_img_height_width: Tuple[int, int] = (224, 224), **kwargs: Any, ): assert clip_model_type in clip.available_models() + assert pool == False or input_img_height_width == (224, 224) + assert all(iis % 32 == 0 for iis in input_img_height_width) + output_height_width = tuple(iis // 32 for iis in input_img_height_width) if clip_model_type == "RN50": - output_shape = (2048, 7, 7) + output_shape = (2048,) + output_height_width elif clip_model_type == "RN50x16": - output_shape = (3072, 7, 7) + output_shape = (3072,) + output_height_width else: raise NotImplementedError( f"Currently `clip_model_type` must be one of 'RN50' or 'RN50x16'" @@ -113,3 +117,125 @@ def process(self, obs: Dict[str, Any], *args: Any, **kwargs: Any) -> Any: x = x.repeat(1, 3, 1, 1) x = self.resnet(x).float() return x + + +class ClipViTEmbedder(nn.Module): + def __init__(self, model: CLIP, class_emb_only: bool = False): + super().__init__() + self.model = model + self.model.visual.transformer.resblocks = nn.Sequential( + *list(self.model.visual.transformer.resblocks)[:-1] + ) + self.class_emb_only = class_emb_only + + self.eval() + + def forward(self, x): + m = self.model.visual + with torch.no_grad(): + x = m.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat( + [ + m.class_embedding.to(x.dtype) + + torch.zeros( + x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device + ), + x, + ], + dim=1, + ) # shape = [*, grid ** 2 + 1, width] + x = x + m.positional_embedding.to(x.dtype) + x = m.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = m.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + if self.class_emb_only: + return x[:, 0, :] + else: + return x + + +class ClipViTPreprocessor(Preprocessor): + """Preprocess RGB or depth image using a ResNet model with CLIP model + weights.""" + + CLIP_RGB_MEANS = (0.48145466, 0.4578275, 0.40821073) + CLIP_RGB_STDS = (0.26862954, 0.26130258, 0.27577711) + + def __init__( + self, + rgb_input_uuid: str, + clip_model_type: str, + class_emb_only: bool, + device: Optional[torch.device] = None, + device_ids: Optional[List[torch.device]] = None, + **kwargs: Any, + ): + assert clip_model_type in clip.available_models() + + if clip_model_type == "ViT-B/32": + output_shape = (7 * 7 + 1, 768) + elif clip_model_type == "ViT-B/16": + output_shape = (14 * 14 + 1, 768) + elif clip_model_type == "ViT-L/14": + output_shape = (16 * 16 + 1, 1024) + else: + raise NotImplementedError( + f"Currently `clip_model_type` must be one of 'ViT-B/32', 'ViT-B/16', or 'ViT-B/14'" + ) + + if class_emb_only: + output_shape = output_shape[1:] + + self.clip_model_type = clip_model_type + + self.class_emb_only = class_emb_only + + self.device = torch.device("cpu") if device is None else device + self.device_ids = device_ids or cast( + List[torch.device], list(range(torch.cuda.device_count())) + ) + self._vit: Optional[ClipViTEmbedder] = None + + low = -np.inf + high = np.inf + shape = output_shape + + input_uuids = [rgb_input_uuid] + assert ( + len(input_uuids) == 1 + ), "resnet preprocessor can only consume one observation type" + + observation_space = gym.spaces.Box(low=low, high=high, shape=shape) + + super().__init__(**prepare_locals_for_super(locals())) + + @property + def vit(self) -> ClipViTEmbedder: + if self._vit is None: + self._vit = ClipViTEmbedder( + model=clip.load(self.clip_model_type, device=self.device)[0], + class_emb_only=self.class_emb_only, + ).to(self.device) + for module in self._vit.modules(): + if "BatchNorm" in type(module).__name__: + module.momentum = 0.0 + self._vit.eval() + return self._vit + + def to(self, device: torch.device) -> "ClipViTPreprocessor": + self._vit = self.vit.to(device) + self.device = device + return self + + def process(self, obs: Dict[str, Any], *args: Any, **kwargs: Any) -> Any: + x = obs[self.input_uuids[0]].to(self.device).permute(0, 3, 1, 2) # bhwc -> bchw + # If the input is depth, repeat it across all 3 channels + if x.shape[1] == 1: + x = x.repeat(1, 3, 1, 1) + x = self.vit(x).float() + return x diff --git a/allenact_plugins/habitat_plugin/habitat_task_samplers.py b/allenact_plugins/habitat_plugin/habitat_task_samplers.py index 509a1eecb..60ee6889e 100644 --- a/allenact_plugins/habitat_plugin/habitat_task_samplers.py +++ b/allenact_plugins/habitat_plugin/habitat_task_samplers.py @@ -1,14 +1,14 @@ from typing import List, Optional, Union, Callable, Any, Dict, Type import gym -import habitat -from allenact.utils.experiment_utils import Builder -from habitat.config import Config +import habitat from allenact.base_abstractions.sensor import Sensor from allenact.base_abstractions.task import TaskSampler +from allenact.utils.experiment_utils import Builder from allenact_plugins.habitat_plugin.habitat_environment import HabitatEnvironment from allenact_plugins.habitat_plugin.habitat_tasks import PointNavTask, ObjectNavTask # type: ignore +from habitat.config import Config class PointNavTaskSampler(TaskSampler): @@ -208,6 +208,9 @@ def next_task(self, force_advance_scene=False) -> Optional[ObjectNavTask]: return None if self.env is not None: + if force_advance_scene: + self.env.env._episode_iterator._forced_scene_switch() + self.env.env._episode_iterator._set_shuffle_intervals() self.env.reset() else: self.env = self._create_environment() diff --git a/allenact_plugins/habitat_plugin/habitat_tasks.py b/allenact_plugins/habitat_plugin/habitat_tasks.py index 547fb9f23..2cd75b9d3 100644 --- a/allenact_plugins/habitat_plugin/habitat_tasks.py +++ b/allenact_plugins/habitat_plugin/habitat_tasks.py @@ -3,10 +3,6 @@ import gym import numpy as np - -from allenact_plugins.habitat_plugin.habitat_sensors import ( - AgentCoordinatesSensorHabitat, -) from habitat.sims.habitat_simulator.actions import HabitatSimActions from habitat.sims.habitat_simulator.habitat_simulator import HabitatSim from habitat.tasks.nav.shortest_path_follower import ShortestPathFollower @@ -24,6 +20,9 @@ LOOK_DOWN, ) from allenact_plugins.habitat_plugin.habitat_environment import HabitatEnvironment +from allenact_plugins.habitat_plugin.habitat_sensors import ( + AgentCoordinatesSensorHabitat, +) class HabitatTask(Task[HabitatEnvironment], ABC): @@ -33,7 +32,7 @@ def __init__( sensors: List[Sensor], task_info: Dict[str, Any], max_steps: int, - **kwargs + **kwargs, ) -> None: super().__init__( env=env, sensors=sensors, task_info=task_info, max_steps=max_steps, **kwargs @@ -91,7 +90,7 @@ def __init__( task_info: Dict[str, Any], max_steps: int, failed_end_reward: float = 0.0, - **kwargs + **kwargs, ) -> None: super().__init__( env=env, sensors=sensors, task_info=task_info, max_steps=max_steps, **kwargs @@ -241,11 +240,15 @@ def __init__( sensors: List[Sensor], task_info: Dict[str, Any], max_steps: int, - **kwargs + look_constraints: Optional[Tuple[int, int]] = None, + **kwargs, ) -> None: super().__init__( env=env, sensors=sensors, task_info=task_info, max_steps=max_steps, **kwargs ) + self.look_constraints = look_constraints + self._look_state = 0 + self._took_end_action: bool = False self._success: Optional[bool] = False self._subsampled_locations_from_which_obj_visible = None @@ -308,7 +311,26 @@ def _step(self, action: Union[int, Sequence[int]]) -> RLStepResult: action_str = self.action_names()[action] self._actions_taken.append(action_str) - self.env.step({"action": action_str}) + skip_action = False + if self.look_constraints is not None: + max_look_up, max_look_down = self.look_constraints + + if action_str == LOOK_UP: + num_look_ups = self._look_state + # assert num_look_ups <= max_look_up + skip_action = num_look_ups >= max_look_up + self._look_state += 1 + + if action_str == LOOK_DOWN: + num_look_downs = -self._look_state + # assert num_look_downs <= max_look_down + skip_action = num_look_downs >= max_look_down + self._look_state -= 1 + + self._look_state = min(max(self._look_state, -max_look_down), max_look_up) + + if not skip_action: + self.env.step({"action": action_str}) if action_str == END: self._took_end_action = True diff --git a/allenact_plugins/ithor_plugin/ithor_environment.py b/allenact_plugins/ithor_plugin/ithor_environment.py index 9c7bbc777..beda87058 100644 --- a/allenact_plugins/ithor_plugin/ithor_environment.py +++ b/allenact_plugins/ithor_plugin/ithor_environment.py @@ -657,7 +657,7 @@ def step( last_frame = self.current_frame if self.simplify_physics: - action_dict["simplifyOPhysics"] = True + action_dict["simplifyPhysics"] = True if "Move" in action and "Hand" not in action: # type: ignore action_dict = { diff --git a/allenact_plugins/ithor_plugin/ithor_sensors.py b/allenact_plugins/ithor_plugin/ithor_sensors.py index c11fa4ef2..d0d0cc741 100644 --- a/allenact_plugins/ithor_plugin/ithor_sensors.py +++ b/allenact_plugins/ithor_plugin/ithor_sensors.py @@ -22,12 +22,18 @@ from allenact_plugins.robothor_plugin.robothor_environment import RoboThorEnvironment from allenact_plugins.robothor_plugin.robothor_tasks import PointNavTask, ObjectNavTask +THOR_ENV_TYPE = Union[ + ai2thor.controller.Controller, IThorEnvironment, RoboThorEnvironment +] +THOR_TASK_TYPE = Union[ + Task[ai2thor.controller.Controller], + Task[IThorEnvironment], + Task[RoboThorEnvironment], +] + class RGBSensorThor( - RGBSensor[ - Union[IThorEnvironment, RoboThorEnvironment], - Union[Task[IThorEnvironment], Task[RoboThorEnvironment]], - ] + RGBSensor[THOR_ENV_TYPE, THOR_TASK_TYPE] ): """Sensor for RGB images in THOR. @@ -36,9 +42,12 @@ class RGBSensorThor( """ def frame_from_env( - self, env: IThorEnvironment, task: Task[IThorEnvironment] + self, env: THOR_ENV_TYPE, task: Optional[THOR_TASK_TYPE], ) -> np.ndarray: # type:ignore - return env.current_frame.copy() + if isinstance(env, ai2thor.controller.Controller): + return env.last_event.frame.copy() + else: + return env.current_frame.copy() class GoalObjectTypeThorSensor(Sensor): diff --git a/allenact_plugins/manipulathor_plugin/arm_calculation_utils.py b/allenact_plugins/manipulathor_plugin/arm_calculation_utils.py index fe6d8ad8b..827a2d844 100644 --- a/allenact_plugins/manipulathor_plugin/arm_calculation_utils.py +++ b/allenact_plugins/manipulathor_plugin/arm_calculation_utils.py @@ -4,9 +4,10 @@ import numpy as np import torch -from allenact.utils.system import get_logger from scipy.spatial.transform import Rotation as R +from allenact.utils.system import get_logger + def state_dict_to_tensor(state: Dict): result = [] @@ -103,15 +104,17 @@ def matrix_to_position_rotation(matrix): return result -def find_closest_inverse(deg): - for k in _saved_inverse_rotation_mats.keys(): - if abs(k - deg) < 5: - return _saved_inverse_rotation_mats[k] +def find_closest_inverse(deg, use_cache): + if use_cache: + for k in _saved_inverse_rotation_mats.keys(): + if abs(k - deg) < 5: + return _saved_inverse_rotation_mats[k] # if it reaches here it means it had not calculated the degree before rotation = R.from_euler("xyz", [0, deg, 0], degrees=True) result = rotation.as_matrix() inverse = inverse_rot_trans_matrix(result) - get_logger().warning(f"Had to calculate the matrix for {deg}") + if use_cache: + get_logger().warning(f"Had to calculate the matrix for {deg}") return inverse @@ -126,12 +129,12 @@ def calc_inverse(deg): _saved_inverse_rotation_mats[360] = _saved_inverse_rotation_mats[0] -def world_coords_to_agent_coords(world_obj, agent_state): +def world_coords_to_agent_coords(world_obj, agent_state, use_cache=True): position = agent_state["position"] rotation = agent_state["rotation"] agent_translation = [position["x"], position["y"], position["z"]] assert abs(rotation["x"]) < 0.01 and abs(rotation["z"]) < 0.01 - inverse_agent_rotation = find_closest_inverse(rotation["y"]) + inverse_agent_rotation = find_closest_inverse(rotation["y"], use_cache=use_cache) obj_matrix = position_rotation_to_matrix( world_obj["position"], world_obj["rotation"] ) diff --git a/allenact_plugins/manipulathor_plugin/armpointnav_constants.py b/allenact_plugins/manipulathor_plugin/armpointnav_constants.py index 897606f91..07383ab4d 100644 --- a/allenact_plugins/manipulathor_plugin/armpointnav_constants.py +++ b/allenact_plugins/manipulathor_plugin/armpointnav_constants.py @@ -1,5 +1,6 @@ import json import os +from typing import Dict, Optional, Any from constants import ABS_PATH_OF_TOP_LEVEL_DIR @@ -14,8 +15,17 @@ dataset_json_file = os.path.join( ABS_PATH_OF_TOP_LEVEL_DIR, "datasets", "apnd-dataset", "starting_pose.json" ) -try: - with open(dataset_json_file) as f: - ARM_START_POSITIONS = json.load(f) -except Exception: - raise Exception("Dataset not found in {}".format(dataset_json_file)) + +_ARM_START_POSITIONS: Optional[Dict[str, Any]] = None + + +def get_agent_start_positions(): + global _ARM_START_POSITIONS + if _ARM_START_POSITIONS is not None: + try: + with open(dataset_json_file) as f: + _ARM_START_POSITIONS = json.load(f) + except Exception: + raise Exception(f"Dataset not found in {dataset_json_file}") + + return _ARM_START_POSITIONS diff --git a/allenact_plugins/manipulathor_plugin/manipulathor_constants.py b/allenact_plugins/manipulathor_plugin/manipulathor_constants.py index e74f752f1..5832ae7a0 100644 --- a/allenact_plugins/manipulathor_plugin/manipulathor_constants.py +++ b/allenact_plugins/manipulathor_plugin/manipulathor_constants.py @@ -13,6 +13,7 @@ } MOVE_AHEAD = "MoveAheadContinuous" +MOVE_BACK = "MoveBackContinuous" ROTATE_LEFT = "RotateLeftContinuous" ROTATE_RIGHT = "RotateRightContinuous" MOVE_ARM_HEIGHT_P = "MoveArmHeightP" @@ -27,11 +28,14 @@ ROTATE_WRIST_PITCH_M = "RotateArmWristPitchM" ROTATE_WRIST_YAW_P = "RotateArmWristYawP" ROTATE_WRIST_YAW_M = "RotateArmWristYawM" +ROTATE_WRIST_ROLL_P = "RotateArmWristRollP" +ROTATE_WRIST_ROLL_M = "RotateArmWristRollM" ROTATE_ELBOW_P = "RotateArmElbowP" ROTATE_ELBOW_M = "RotateArmElbowM" LOOK_UP = "LookUp" LOOK_DOWN = "LookDown" PICKUP = "PickUpMidLevel" +DROP = "DropMidLevel" DONE = "DoneMidLevel" diff --git a/allenact_plugins/manipulathor_plugin/manipulathor_environment.py b/allenact_plugins/manipulathor_plugin/manipulathor_environment.py index 5895de383..00988c81b 100644 --- a/allenact_plugins/manipulathor_plugin/manipulathor_environment.py +++ b/allenact_plugins/manipulathor_plugin/manipulathor_environment.py @@ -222,7 +222,8 @@ def object_in_hand(self): else: raise AttributeError("Must be <= 1 inventory objects.") - def correct_nan_inf(self, flawed_dict, extra_tag=""): + @classmethod + def correct_nan_inf(cls, flawed_dict, extra_tag=""): corrected_dict = copy.deepcopy(flawed_dict) for (k, v) in corrected_dict.items(): if math.isnan(v) or math.isinf(v): @@ -374,7 +375,7 @@ def step( last_frame = self.current_frame if self.simplify_physics: - action_dict["simplifyOPhysics"] = True + action_dict["simplifyPhysics"] = True if action in [PICKUP, DONE]: if action == PICKUP: object_id = action_dict["object_id"] diff --git a/allenact_plugins/manipulathor_plugin/manipulathor_utils.py b/allenact_plugins/manipulathor_plugin/manipulathor_utils.py index b40f45812..54067322e 100644 --- a/allenact_plugins/manipulathor_plugin/manipulathor_utils.py +++ b/allenact_plugins/manipulathor_plugin/manipulathor_utils.py @@ -2,7 +2,7 @@ from allenact_plugins.ithor_plugin.ithor_environment import IThorEnvironment from allenact_plugins.manipulathor_plugin.armpointnav_constants import ( - ARM_START_POSITIONS, + get_agent_start_positions, ) from allenact_plugins.manipulathor_plugin.manipulathor_constants import ( ADDITIONAL_ARM_ARGS, @@ -51,7 +51,7 @@ def transport_wrapper(controller, target_object, target_location): def initialize_arm(controller): # for start arm from high up, scene = controller.last_event.metadata["sceneName"] - initial_pose = ARM_START_POSITIONS[scene] + initial_pose = get_agent_start_positions()[scene] event1 = controller.step( dict( action="TeleportFull", diff --git a/allenact_plugins/robothor_plugin/robothor_sensors.py b/allenact_plugins/robothor_plugin/robothor_sensors.py index 57c66c381..5891318e8 100644 --- a/allenact_plugins/robothor_plugin/robothor_sensors.py +++ b/allenact_plugins/robothor_plugin/robothor_sensors.py @@ -1,16 +1,20 @@ -from typing import Any, Tuple, Optional, Union +from typing import Any, Tuple, Optional +import ai2thor.controller import gym import numpy as np import quaternion # noqa # pylint: disable=unused-import from allenact.base_abstractions.sensor import Sensor -from allenact.embodiedai.sensors.vision_sensors import RGBSensor, DepthSensor from allenact.base_abstractions.task import Task +from allenact.embodiedai.sensors.vision_sensors import RGBSensor, DepthSensor from allenact.utils.misc_utils import prepare_locals_for_super from allenact.utils.system import get_logger -from allenact_plugins.ithor_plugin.ithor_environment import IThorEnvironment -from allenact_plugins.ithor_plugin.ithor_sensors import RGBSensorThor +from allenact_plugins.ithor_plugin.ithor_sensors import ( + RGBSensorThor, + THOR_ENV_TYPE, + THOR_TASK_TYPE, +) from allenact_plugins.robothor_plugin.robothor_environment import RoboThorEnvironment from allenact_plugins.robothor_plugin.robothor_tasks import PointNavTask @@ -147,12 +151,7 @@ def get_observation( ) -class DepthSensorThor( - DepthSensor[ - Union[IThorEnvironment, RoboThorEnvironment], - Union[Task[IThorEnvironment], Task[RoboThorEnvironment]], - ], -): +class DepthSensorThor(DepthSensor[THOR_ENV_TYPE, THOR_TASK_TYPE,],): def __init__( self, use_resnet_normalization: Optional[bool] = None, @@ -178,9 +177,12 @@ def __init__( super().__init__(**prepare_locals_for_super(locals())) def frame_from_env( - self, env: RoboThorEnvironment, task: Optional[Task[RoboThorEnvironment]] + self, env: THOR_ENV_TYPE, task: Optional[THOR_TASK_TYPE] ) -> np.ndarray: - return env.controller.last_event.depth_frame + if not isinstance(env, ai2thor.controller.Controller): + env = env.controller.last_event.depth_frame + + return env.last_event.depth_frame class DepthSensorRoboThor(DepthSensorThor): diff --git a/projects/objectnav_baselines/experiments/clip/mixins.py b/projects/objectnav_baselines/experiments/clip/mixins.py index 09e7f3b84..80ce94e80 100644 --- a/projects/objectnav_baselines/experiments/clip/mixins.py +++ b/projects/objectnav_baselines/experiments/clip/mixins.py @@ -1,20 +1,62 @@ -from typing import Sequence, Union, Type +from typing import Sequence, Union, Type, Tuple, Optional, Dict, Any import attr import gym import numpy as np +import torch import torch.nn as nn +from allenact.base_abstractions.distributions import CategoricalDistr +from allenact.base_abstractions.misc import ( + ObservationType, + Memory, + ActorCriticOutput, + DistributionType, +) from allenact.base_abstractions.preprocessor import Preprocessor from allenact.base_abstractions.sensor import Sensor from allenact.embodiedai.sensors.vision_sensors import RGBSensor, DepthSensor from allenact.utils.experiment_utils import Builder +from allenact.utils.misc_utils import prepare_locals_for_super from allenact_plugins.clip_plugin.clip_preprocessors import ClipResNetPreprocessor from allenact_plugins.navigation_plugin.objectnav.models import ( ResnetTensorNavActorCritic, ) +class LookDownFirstResnetTensorNavActorCritic(ResnetTensorNavActorCritic): + def __init__(self, look_down_action_index: int, **kwargs): + super().__init__(**kwargs) + + self.look_down_action_index = look_down_action_index + self.register_buffer( + "look_down_delta", torch.zeros(1, 1, self.action_space.n), persistent=False + ) + self.look_down_delta[0, 0, self.look_down_action_index] = 99999 + + def forward( # type:ignore + self, + observations: ObservationType, + memory: Memory, + prev_actions: torch.Tensor, + masks: torch.FloatTensor, + ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]: + ac_out, memory = super(LookDownFirstResnetTensorNavActorCritic, self).forward( + **prepare_locals_for_super(locals()) + ) + + logits = ac_out.distributions.logits * masks + self.look_down_delta * ( + 1 - masks + ) + ac_out = ActorCriticOutput( + distributions=CategoricalDistr(logits=logits), + values=ac_out.values, + extras=ac_out.extras, + ) + + return ac_out, memory + + @attr.s(kw_only=True) class ClipResNetPreprocessGRUActorCriticMixin: sensors: Sequence[Sensor] = attr.ib() @@ -49,6 +91,7 @@ def preprocessors(self) -> Sequence[Union[Preprocessor, Builder[Preprocessor]]]: clip_model_type=self.clip_model_type, pool=self.pool, output_uuid="rgb_clip_resnet", + input_img_height_width=(rgb_sensor.height, rgb_sensor.width), ) ) @@ -62,12 +105,23 @@ def preprocessors(self) -> Sequence[Union[Preprocessor, Builder[Preprocessor]]]: clip_model_type=self.clip_model_type, pool=self.pool, output_uuid="depth_clip_resnet", + input_img_height_width=(depth_sensor.height, depth_sensor.width), ) ) return preprocessors - def create_model(self, num_actions: int, add_prev_actions: bool, **kwargs) -> nn.Module: + def create_model( + self, + num_actions: int, + add_prev_actions: bool, + look_down_first: bool = False, + look_down_action_index: Optional[int] = None, + hidden_size: int = 512, + rnn_type="GRU", + model_kwargs: Optional[Dict[str, Any]] = None, + **kwargs + ) -> nn.Module: has_rgb = any(isinstance(s, RGBSensor) for s in self.sensors) has_depth = any(isinstance(s, DepthSensor) for s in self.sensors) @@ -76,13 +130,25 @@ def create_model(self, num_actions: int, add_prev_actions: bool, **kwargs) -> nn None, ) - return ResnetTensorNavActorCritic( + if model_kwargs is None: + model_kwargs = {} + + model_kwargs = dict( action_space=gym.spaces.Discrete(num_actions), observation_space=kwargs["sensor_preprocessor_graph"].observation_spaces, goal_sensor_uuid=goal_sensor_uuid, rgb_resnet_preprocessor_uuid="rgb_clip_resnet" if has_rgb else None, depth_resnet_preprocessor_uuid="depth_clip_resnet" if has_depth else None, - hidden_size=512, + hidden_size=hidden_size, goal_dims=32, - add_prev_actions=add_prev_actions + add_prev_actions=add_prev_actions, + rnn_type=rnn_type, + **model_kwargs ) + + if not look_down_first: + return ResnetTensorNavActorCritic(**model_kwargs) + else: + return LookDownFirstResnetTensorNavActorCritic( + look_down_action_index=look_down_action_index, **model_kwargs + ) diff --git a/projects/objectnav_baselines/experiments/habitat/clip/objectnav_habitat_rgb_clipresnet50gru_ddppo.py b/projects/objectnav_baselines/experiments/habitat/clip/objectnav_habitat_rgb_clipresnet50gru_ddppo.py index 7d255d8fc..55e6d84ae 100644 --- a/projects/objectnav_baselines/experiments/habitat/clip/objectnav_habitat_rgb_clipresnet50gru_ddppo.py +++ b/projects/objectnav_baselines/experiments/habitat/clip/objectnav_habitat_rgb_clipresnet50gru_ddppo.py @@ -54,7 +54,7 @@ def SENSORS(self): def training_pipeline(self, **kwargs) -> TrainingPipeline: return ObjectNavPPOMixin.training_pipeline( lr=self.lr, - auxiliary_uuids=[], + auxiliary_uuids=self.auxiliary_uuids, multiple_beliefs=False, advance_scene_rollout_period=self.ADVANCE_SCENE_ROLLOUT_PERIOD, ) @@ -64,7 +64,10 @@ def preprocessors(self) -> Sequence[Union[Preprocessor, Builder[Preprocessor]]]: def create_model(self, **kwargs) -> nn.Module: return self.preprocessing_and_model.create_model( - num_actions=self.ACTION_SPACE.n, add_prev_actions=self.add_prev_actions, **kwargs + num_actions=self.ACTION_SPACE.n, + add_prev_actions=self.add_prev_actions, + auxiliary_uuids=self.auxiliary_uuids, + **kwargs, ) def tag(self): diff --git a/projects/objectnav_baselines/experiments/habitat/clip/objectnav_habitat_rgb_clipresnet50gru_ddppo_increasingrollouts.py b/projects/objectnav_baselines/experiments/habitat/clip/objectnav_habitat_rgb_clipresnet50gru_ddppo_increasingrollouts.py index 4792f1a55..af57c5079 100644 --- a/projects/objectnav_baselines/experiments/habitat/clip/objectnav_habitat_rgb_clipresnet50gru_ddppo_increasingrollouts.py +++ b/projects/objectnav_baselines/experiments/habitat/clip/objectnav_habitat_rgb_clipresnet50gru_ddppo_increasingrollouts.py @@ -23,7 +23,7 @@ def __init__(self, lr=1e-4, **kwargs): self.lr = lr def training_pipeline(self, **kwargs) -> TrainingPipeline: - auxiliary_uuids = [] + auxiliary_uuids = self.auxiliary_uuids multiple_beliefs = False normalize_advantage = False advance_scene_rollout_period = self.ADVANCE_SCENE_ROLLOUT_PERIOD @@ -72,21 +72,21 @@ def training_pipeline(self, **kwargs) -> TrainingPipeline: advance_scene_rollout_period=advance_scene_rollout_period, pipeline_stages=[ PipelineStage( - loss_names=["ppo_loss"], + loss_names=list(named_losses.keys()), max_stage_steps=batch_steps_0, training_settings=TrainingSettings( num_steps=32, metric_accumulate_interval=log_interval_small ), ), PipelineStage( - loss_names=["ppo_loss"], + loss_names=list(named_losses.keys()), max_stage_steps=batch_steps_1, training_settings=TrainingSettings( num_steps=64, metric_accumulate_interval=log_interval_med, ), ), PipelineStage( - loss_names=["ppo_loss"], + loss_names=list(named_losses.keys()), max_stage_steps=batch_steps_2, training_settings=TrainingSettings( num_steps=128, metric_accumulate_interval=log_interval_large, diff --git a/projects/objectnav_baselines/experiments/habitat/objectnav_habitat_base.py b/projects/objectnav_baselines/experiments/habitat/objectnav_habitat_base.py index 7e23353d4..2aadde397 100644 --- a/projects/objectnav_baselines/experiments/habitat/objectnav_habitat_base.py +++ b/projects/objectnav_baselines/experiments/habitat/objectnav_habitat_base.py @@ -3,15 +3,15 @@ import os import warnings from abc import ABC -from typing import Dict, Any, List, Optional, Sequence, Union +from typing import Dict, Any, List, Optional, Sequence, Union, Tuple import gym +# noinspection PyUnresolvedReferences +import habitat import numpy as np import torch from torch.distributions.utils import lazy_property -# noinspection PyUnresolvedReferences -import habitat from allenact.base_abstractions.experiment_config import MachineParams from allenact.base_abstractions.preprocessor import ( SensorPreprocessorGraph, @@ -107,7 +107,7 @@ class ObjectNavHabitatBaseConfig(ObjectNavBaseConfig, ABC): # selected auxiliary uuids ## if comment all the keys, then it's vanilla DD-PPO - AUXILIARY_UUIDS = [ + _AUXILIARY_UUIDS = [ # InverseDynamicsLoss.UUID, # TemporalDistanceLoss.UUID, # CPCA1Loss.UUID, @@ -148,6 +148,7 @@ def __init__( val_gpu_ids: Optional[Sequence[int]] = None, test_gpu_ids: Optional[Sequence[int]] = None, add_prev_actions: bool = False, + look_constraints: Optional[Tuple[int, int]] = None, **kwargs, ): super().__init__(**kwargs) @@ -155,6 +156,14 @@ def __init__( self.scene_dataset = scene_dataset self.debug = debug + assert look_constraints is None or all( + lc in [0, 1, 2, 3] for lc in look_constraints + ), "Look constraints limit the number of times agents can look up/down when starting from the horizon line." + assert ( + look_constraints is None or look_constraints[1] > 0 + ), "The agent must be allowed to look down from the horizon at least once." + self.look_constraints = look_constraints + def v_or_default(v, default): return v if v is not None else default @@ -173,6 +182,8 @@ def v_or_default(v, default): self.test_gpu_ids = v_or_default(test_gpu_ids, self.DEFAULT_TEST_GPU_IDS) self.add_prev_actions = add_prev_actions + self.auxiliary_uuids = self._AUXILIARY_UUIDS + def _create_config( self, mode: str, @@ -356,9 +367,13 @@ def test_scenes_path(self): def tag(self): t = f"ObjectNav-Habitat-{self.scene_dataset.upper()}" - if not self.add_prev_actions: - return t - return f"{t}-PrevActions" + if self.add_prev_actions: + t = f"{t}-PrevActions" + + if self.look_constraints is not None: + t = f"{t}-Look{','.join(map(str, self.look_constraints))}" + + return t def preprocessors(self) -> Sequence[Union[Preprocessor, Builder[Preprocessor]]]: return tuple() @@ -402,10 +417,10 @@ def machine_params(self, mode="train", **kwargs): sensor_preprocessor_graph=sensor_preprocessor_graph, ) - @classmethod - def make_sampler_fn(cls, **kwargs) -> TaskSampler: + def make_sampler_fn(self, **kwargs) -> TaskSampler: return ObjectNavTaskSampler( - **{"failed_end_reward": cls.FAILED_END_REWARD, **kwargs} # type: ignore + task_kwargs={"look_constraints": self.look_constraints,}, + **{"failed_end_reward": self.FAILED_END_REWARD, **kwargs}, # type: ignore ) def train_task_sampler_args( diff --git a/projects/objectnav_baselines/experiments/objectnav_thor_base.py b/projects/objectnav_baselines/experiments/objectnav_thor_base.py index ae649a87f..13e925271 100644 --- a/projects/objectnav_baselines/experiments/objectnav_thor_base.py +++ b/projects/objectnav_baselines/experiments/objectnav_thor_base.py @@ -29,9 +29,11 @@ from allenact_plugins.robothor_plugin.robothor_tasks import ObjectNavTask from projects.objectnav_baselines.experiments.objectnav_base import ObjectNavBaseConfig -if ai2thor.__version__ not in ["0.0.1", None] and version.parse( - ai2thor.__version__ -) < version.parse("3.2.0"): +if ( + ai2thor.__version__ not in ["0.0.1", None] + and not ai2thor.__version__.startswith("0+") + and version.parse(ai2thor.__version__) < version.parse("3.2.0") +): raise ImportError( "To run the AI2-THOR ObjectNav baseline experiments you must use" " ai2thor version 3.2.0 or higher." diff --git a/projects/objectnav_baselines/experiments/robothor/clip/objectnav_robothor_rgb_clipresnet50gru_ddppo.py b/projects/objectnav_baselines/experiments/robothor/clip/objectnav_robothor_rgb_clipresnet50gru_ddppo.py index b027e27b9..d00dc70de 100644 --- a/projects/objectnav_baselines/experiments/robothor/clip/objectnav_robothor_rgb_clipresnet50gru_ddppo.py +++ b/projects/objectnav_baselines/experiments/robothor/clip/objectnav_robothor_rgb_clipresnet50gru_ddppo.py @@ -61,7 +61,9 @@ def preprocessors(self) -> Sequence[Union[Preprocessor, Builder[Preprocessor]]]: def create_model(self, **kwargs) -> nn.Module: return self.preprocessing_and_model.create_model( - num_actions=self.ACTION_SPACE.n, add_prev_actions=self.add_prev_actions, **kwargs + num_actions=self.ACTION_SPACE.n, + add_prev_actions=self.add_prev_actions, + **kwargs ) @classmethod diff --git a/projects/objectnav_baselines/experiments/robothor/clip/objectnav_robothor_rgb_clipresnet50x16gru_ddppo.py b/projects/objectnav_baselines/experiments/robothor/clip/objectnav_robothor_rgb_clipresnet50x16gru_ddppo.py index e772ae9f5..2cebef6e3 100644 --- a/projects/objectnav_baselines/experiments/robothor/clip/objectnav_robothor_rgb_clipresnet50x16gru_ddppo.py +++ b/projects/objectnav_baselines/experiments/robothor/clip/objectnav_robothor_rgb_clipresnet50x16gru_ddppo.py @@ -61,7 +61,9 @@ def preprocessors(self) -> Sequence[Union[Preprocessor, Builder[Preprocessor]]]: def create_model(self, **kwargs) -> nn.Module: return self.preprocessing_and_model.create_model( - num_actions=self.ACTION_SPACE.n, add_prev_actions=self.add_prev_actions, **kwargs + num_actions=self.ACTION_SPACE.n, + add_prev_actions=self.add_prev_actions, + **kwargs ) @classmethod diff --git a/projects/objectnav_baselines/mixins.py b/projects/objectnav_baselines/mixins.py index 70ba0a54c..b07b0a3b7 100644 --- a/projects/objectnav_baselines/mixins.py +++ b/projects/objectnav_baselines/mixins.py @@ -25,6 +25,11 @@ CPCA8Loss, CPCA16Loss, MultiAuxTaskNegEntropyLoss, + CPCA1SoftMaxLoss, + CPCA2SoftMaxLoss, + CPCA4SoftMaxLoss, + CPCA8SoftMaxLoss, + CPCA16SoftMaxLoss, ) from allenact.embodiedai.preprocessors.resnet import ResNetPreprocessor from allenact.embodiedai.sensors.vision_sensors import RGBSensor, DepthSensor @@ -261,6 +266,26 @@ def update_with_auxiliary_losses( CPCA16Loss(subsample_rate=0.2,), # TODO: test its effects 0.05 * aux_loss_total_weight, # should times 2 ), + CPCA1SoftMaxLoss.UUID: ( + CPCA1SoftMaxLoss(subsample_rate=1.0,), + 0.05 * aux_loss_total_weight, # should times 2 + ), + CPCA2SoftMaxLoss.UUID: ( + CPCA2SoftMaxLoss(subsample_rate=1.0,), + 0.05 * aux_loss_total_weight, # should times 2 + ), + CPCA4SoftMaxLoss.UUID: ( + CPCA4SoftMaxLoss(subsample_rate=1.0,), + 0.05 * aux_loss_total_weight, # should times 2 + ), + CPCA8SoftMaxLoss.UUID: ( + CPCA8SoftMaxLoss(subsample_rate=1.0,), + 0.05 * aux_loss_total_weight, # should times 2 + ), + CPCA16SoftMaxLoss.UUID: ( + CPCA16SoftMaxLoss(subsample_rate=1.0,), + 0.05 * aux_loss_total_weight, # should times 2 + ), } named_losses.update({uuid: total_aux_losses[uuid] for uuid in auxiliary_uuids}) @@ -290,11 +315,17 @@ def training_pipeline( use_gae=True, gae_lambda=0.95, max_grad_norm=0.5, + anneal_lr: bool = True, + extra_losses: Optional[Dict[str, Tuple[AbstractActorCriticLoss, float]]] = None, ) -> TrainingPipeline: ppo_steps = int(300000000) named_losses = { - "ppo_loss": (PPO(**PPOConfig, normalize_advantage=normalize_advantage), 1.0) + "ppo_loss": ( + PPO(**PPOConfig, normalize_advantage=normalize_advantage), + 1.0, + ), + **({} if extra_losses is None else extra_losses), } named_losses = update_with_auxiliary_losses( named_losses=named_losses, @@ -324,5 +355,7 @@ def training_pipeline( ], lr_scheduler_builder=Builder( LambdaLR, {"lr_lambda": LinearDecay(steps=ppo_steps)} - ), + ) + if anneal_lr + else None, )